In [1]:
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

import pandas as pd

file_path = '/content/drive/My Drive/sasrec_format.csv'
df = pd.read_csv(file_path)

def str_to_list(s):
    return [int(x) for x in s.strip('[]').split(',')]

df['sequence_item_ids'] = df['sequence_item_ids'].apply(str_to_list)

In [4]:
max_item_id = max(set(item for sublist in df['sequence_item_ids'] for item in sublist))

In [54]:
_w = torch.nn.Parameter(
            torch.empty(2 * 200 - 1).normal_(mean=0, std=0.02),
        )


In [55]:
_w.shape

torch.Size([399])

In [56]:
n=199
t = F.pad(_w[: 2 * n - 1], [0, n]).repeat(n)
t = t[..., :-n].reshape(1, n, 3 * n - 2)
r = (2 * n - 1) // 2
t[..., r:-r].shape

torch.Size([1, 199, 199])

In [57]:
t

tensor([[[ 0.0173,  0.0256,  0.0017,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0173,  0.0256,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0173,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0024,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0041, -0.0024,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0127, -0.0041, -0.0024]]],
       grad_fn=<ViewBackward0>)

In [80]:
class RelativePositionalBias(nn.Module):

    def __init__(self, max_seq_len: int) -> None:
        super().__init__()

        self._max_seq_len: int = max_seq_len
        self._w = torch.nn.Parameter(
            torch.empty(2 * max_seq_len - 1).normal_(mean=0, std=0.02),
        )

    def forward(
        self,
    ) -> torch.Tensor:
        n: int = self._max_seq_len
        t = F.pad(self._w[: 2 * n - 1], [0, n]).repeat(n)
        t = t[..., :-n].reshape(1, n, 3 * n - 2)
        r = (2 * n - 1) // 2
        return t[..., r:-r]

In [95]:
from torch.utils.data import Dataset, DataLoader

class SequenceDataset(Dataset):
    def __init__(self, dataframe, max_seq_length=48, stride=24):
        self.dataframe = dataframe
        self.max_seq_length = max_seq_length
        self.stride = stride
        self.sequences = []
        for i in range(len(dataframe)):
            sequence = dataframe.iloc[i]['sequence_item_ids']

            if len(sequence) > max_seq_length:
                # Create all possible strides
                for start_idx in range(0, len(sequence) - max_seq_length + 1, stride):
                    chunk = sequence[start_idx:start_idx + max_seq_length]
                    self.sequences.append(chunk)
            else:
                # Pad the sequence if it's shorter than max_seq_length
                sequence = [0] * (max_seq_length - len(sequence)) + sequence
                self.sequences.append(sequence)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        sequence_tensor = torch.tensor(self.sequences[idx], dtype=torch.long)
        return sequence_tensor

In [96]:
batch_size = 128
max_seq_length = 48
stride = 1

In [113]:
dataset = SequenceDataset(df.iloc[:5500], max_seq_length=max_seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

In [114]:
class HSTU(nn.Module):
    def __init__(self, vocab_size, emb_dim, attn_dim, seq_length):
        super(HSTU, self).__init__()
        self.attn_dim = attn_dim
        self.seq_length = seq_length

        self.register_buffer("bias", torch.tril(torch.ones(seq_length, seq_length))
                                     .view(1, seq_length, seq_length))

        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.f1 = nn.Linear(emb_dim, attn_dim*4)
        self.f2 = nn.Linear(attn_dim, emb_dim)
        self.rab = RelativePositionalBias(seq_length)


    def forward(self, input):
      embedded =  self.emb(input)
      out = F.silu(self.f1(embedded))
      u,v,q,k = torch.split(out, self.attn_dim, dim=-1)
      a = (q @ k.transpose(-1,-2))
      a = a + self.rab()
      a = F.silu(a.masked_fill(self.bias[:,:self.seq_length,:self.seq_length] == 0, 0)) / self.seq_length
      a = F.layer_norm(a @ v, normalized_shape=v.shape[1:])
      y = self.f2(a*u)

      return y


In [115]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [116]:
hstu = HSTU(vocab_size=3952+1, emb_dim=50, attn_dim=50, seq_length=max_seq_length)
optimizer = optim.AdamW(hstu.parameters(), lr=1e-3, betas=(0.9, 0.98),
        weight_decay=0)  # You can adjust the learning rate as needed

hstu.to(DEVICE)

HSTU(
  (emb): Embedding(3953, 50)
  (f1): Linear(in_features=50, out_features=200, bias=True)
  (f2): Linear(in_features=50, out_features=50, bias=True)
  (rab): RelativePositionalBias()
)

In [117]:
eval_seq = []
for i in range(len(df.iloc[5500:])):
      sequence = df.iloc[i]['sequence_item_ids']
      if len(sequence) > max_seq_length:
        chunk = sequence[-(max_seq_length+1):]
        eval_seq.append(chunk)
      else:
        # Pad the sequence if it's shorter than max_seq_length
        sequence = [0] * ((max_seq_length+1) - len(sequence)) + sequence
        eval_seq.append(sequence)
eval_seq = torch.tensor(eval_seq, dtype=torch.long).to(DEVICE)

In [118]:
def hrr():
  with torch.no_grad():
    all_items = torch.arange(0, max_item_id+1).to(DEVICE)
    all_embs = hstu.emb(all_items)
    y = hstu(eval_seq[:,:-1])[:,-1, :]
    hits = 0
    for i, e in enumerate(y):
      actual = eval_seq[i,-1]
      hits += int(actual in  torch.topk(all_embs @  e, k=10 ).indices)
    print(f"hit rate {100*hits/y.shape[0]}, hits: {hits}")

In [119]:
for i in range(100):
  avg_loss_for_epoch = 0
  counter = 0
  for batch in dataloader:
      batch = batch.to(DEVICE)
      optimizer.zero_grad()

      y = hstu(batch)
      outputs = y[:,:-1,:]
      targets = hstu.emb(batch[:,1:])
      samples = torch.randint(high=max_item_id, size=(164,)).to(DEVICE)


      label_scores = (outputs * targets).sum(-1)
      negative_logits = torch.matmul(outputs.reshape(outputs.shape[0]*outputs.shape[1], 50), hstu.emb(samples).T) # Shape: (B*S) x K
      scores = torch.cat([label_scores.view(-1, 1), negative_logits], dim=-1)
      #loss = F.cross_entropy(scores, torch.zeros(scores.shape[0], dtype=torch.long).to(DEVICE))
      loss = -F.log_softmax(scores, dim=-1)[:,0].sum()
      loss.backward()
      optimizer.step()
      avg_loss_for_epoch += loss.item()
      counter += 1

  hrr()
  print(avg_loss_for_epoch/counter)

hit rate 0.18518518518518517, hits: 1
32828.07203311012
hit rate 0.0, hits: 0
30700.176339285714
hit rate 0.5555555555555556, hits: 3
30226.431129092263
hit rate 1.1111111111111112, hits: 6
29970.851143973214
hit rate 2.037037037037037, hits: 11
29740.694986979168
hit rate 2.4074074074074074, hits: 13
29607.641927083332
hit rate 2.962962962962963, hits: 16
29261.681082589286
hit rate 2.5925925925925926, hits: 14
28950.374162946428
hit rate 2.962962962962963, hits: 16
28672.09398251488
hit rate 2.037037037037037, hits: 11
28299.463774181546
hit rate 2.037037037037037, hits: 11
27918.38741629464
hit rate 1.6666666666666667, hits: 9
27759.722191220237
hit rate 1.6666666666666667, hits: 9
27410.434384300595
hit rate 2.2222222222222223, hits: 12
27231.544131324405
hit rate 3.1481481481481484, hits: 17
26833.018043154763
hit rate 3.3333333333333335, hits: 18
26781.50744047619
hit rate 3.3333333333333335, hits: 18
26495.089425223214
hit rate 2.4074074074074074, hits: 13
26279.719680059523
hit