In [9]:
import gzip
import pandas as pd
import torch
import datetime
from tqdm import tqdm
import numpy as np

import srdatasets

from torch import nn
from torch.nn import functional as F


device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device =torch.device('cpu')

```
srdatasets process --dataset=Amazon-VideoGames --split-by=user --task=long-short --target-len=1 --session-interval=120 --min-freq-item=0 --min-freq-user=0
```

In [10]:
from srdatasets.dataloader_pytorch import DataLoader

BATCH_SIZE = 50
DATASET_CODE = 'c1683435794587'


trainloader = DataLoader("Amazon-VideoGames", DATASET_CODE, batch_size=BATCH_SIZE, train=True, negatives_per_target=5, include_timestamp=True)#, num_workers=8, pin_memory=True)
valloader = DataLoader("Amazon-VideoGames", DATASET_CODE, batch_size=BATCH_SIZE, train=False, development = True, include_timestamp=True)#, num_workers=8, pin_memory=True)
testloader = DataLoader("Amazon-VideoGames", DATASET_CODE, batch_size=BATCH_SIZE, train=False, include_timestamp=True)#, num_workers=8, pin_memory=True)

In [11]:
n_users, n_items = trainloader.num_users, trainloader.num_items
n_users, n_items

(65051, 27148)

In [12]:
class SHAN(nn.Module):
    def __init__(self, embedding_dims, n_users = n_users, n_items = n_items):
        super().__init__()

        #user
        self.user_embed = nn.Embedding(n_users, embedding_dims)

        #item
        self.item_embed = nn.Embedding(n_items, embedding_dims, 0)

        #long-term layer
        self.item_trans1 = nn.Linear(embedding_dims, embedding_dims)
        self.act_1 = nn.ReLU()

        #after taking the embeddings of the item (v), feedforward them through the item_tran1 network to get h
        #thereafter, compute the attention weights of each item by taking the softmax activation of the dotted user.h
        #compute u_long by taking attention_weights * v for each

        #short-term layer
        self.item_trans2 = nn.Linear(embedding_dims, embedding_dims)
        self.act_2 = nn.ReLU()

        #weighting for net user representation
        self.beta_0 = torch.randn(1, requires_grad = True).to(device)
    
    def forward(self, users, pre_sessions_items, cur_session_items):
        user_rep = self.user_embed(users)[..., None]                            # batch * emb * 1
        
        
        # LONG TERM 
        long_term_item_rep = self.item_embed(pre_sessions_items)
        activated_long_term = self.act_1(self.item_trans1(long_term_item_rep))  # batch * num * emb
        #compute attention weights
        attention_mat_1 = torch.bmm(activated_long_term, user_rep)              # batch * num * 1     
        attention_weights_1 = F.softmax(attention_mat_1, dim = 1)               # batch * num * 1
        #long term representation
        u_long = attention_weights_1*long_term_item_rep                         # batch * num * emb
        u_long = torch.sum(u_long, dim = 1)                                     # batch * emb

        #SHORT TERM
        short_term_item_rep = self.item_embed(cur_session_items)
        activated_short_term = self.act_2(self.item_trans2(short_term_item_rep))
        #compute attention weights 
        attention_mat_2 = torch.bmm(activated_short_term, user_rep)
        attention_weights_2 = F.softmax(attention_mat_2, dim = 1)
        #short term representation
        u_short = attention_weights_2*short_term_item_rep
        u_short = torch.sum(u_short, dim = 1)

        #HYBRID
        u_hybrid = self.beta_0*u_long + u_short

        preference_scores = u_hybrid @ self.item_embed.weight.T # batch * emb @ (n_items x emb).T
        
        return preference_scores



In [13]:
def loss_fn(preds, target):
    bs, nitms = preds.size()
    idx = torch.randint(0, nitms, (bs, 1)).to(device)
    others = preds.gather(1, idx)
    actual = preds.gather(1, target)

    loss = -F.logsigmoid(actual - others)
    loss = torch.mean(loss)

    return loss

In [14]:
n_users = 0
n_items = 0
for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
    n_users = max(n_users, max(users))
    n_items = max(n_items, pre_sessions_items.max())
    n_items = max(n_items, cur_session_items.max())
    n_items = max(n_items, target_items.max())
n_users, n_items

100%|██████████| 1315/1315 [00:11<00:00, 113.08it/s]


(tensor(65050), tensor(27162))

In [18]:
class EarlyStopping:

    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [27]:
DIMS = 20

model = SHAN(embedding_dims= DIMS, n_users = n_users +1, n_items = n_items+1).to(device)
early_stopping = EarlyStopping(patience=5, verbose=True, path = 'shan_ed_{}.pth'.format(DIMS))
optim = torch.optim.Adam(model.parameters(), lr = 1e-3)
NUM_EPOCHS = 25

for epoch in range(NUM_EPOCHS):

    net_loss = 0
    # Train
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
        # Shape
        #   users:                          (batch_size,)
        #   pre_sessions_items:             (batch_size, pre_sessions * max_session_len)
        #   cur_session_items:              (batch_size, max_session_len - target_len)
        #   target_items:                   (batch_size, target_len)
        #   negative_samples:               (batch_size, target_len, negatives_per_target)
        # DataType
        #   numpy.ndarray or torch.LongTensor\
        optim.zero_grad()
        users = users.to(device)
        pre_sessions_items = pre_sessions_items.to(device)
        cur_session_items = cur_session_items.to(device)
        target_items = target_items.to(device)

        preferences = model(users, pre_sessions_items, cur_session_items)
        loss = loss_fn(preferences, target_items)

        loss.backward()
        optim.step()

        net_loss+=loss.item()

    net_loss_val = 0
    # Val
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _ in tqdm(iter(valloader)):
        with torch.no_grad():
            users = users.to(device)
            pre_sessions_items = pre_sessions_items.to(device)
            cur_session_items = cur_session_items.to(device)
            target_items = target_items.to(device)

            preferences = model(users, pre_sessions_items, cur_session_items)
            loss = loss_fn(preferences, target_items)

            net_loss_val+=loss.item()
    net_loss = net_loss/len(trainloader)
    net_loss_val = net_loss_val/len(valloader)
    early_stopping(net_loss_val, model)

    if early_stopping.early_stop:
        print("Early stopping")
        print('-'*60)
        break

    print("Epoch {}: Training loss: {:.4f}, Validation loss: {:.4f}".format(epoch+1, net_loss, net_loss_val))

100%|██████████| 1315/1315 [00:18<00:00, 70.39it/s]
100%|██████████| 18/18 [00:00<00:00, 172.80it/s]


Validation loss decreased (inf --> 0.726736).  Saving model ...
Epoch 1: Training loss: 0.7590, Validation loss: 0.7267


100%|██████████| 1315/1315 [00:18<00:00, 70.40it/s]
100%|██████████| 18/18 [00:00<00:00, 236.50it/s]


Validation loss decreased (0.726736 --> 0.715203).  Saving model ...
Epoch 2: Training loss: 0.7056, Validation loss: 0.7152


100%|██████████| 1315/1315 [00:18<00:00, 70.58it/s]
100%|██████████| 18/18 [00:00<00:00, 149.90it/s]


EarlyStopping counter: 1 out of 5
Epoch 3: Training loss: 0.6906, Validation loss: 0.7271


100%|██████████| 1315/1315 [00:18<00:00, 71.16it/s]
100%|██████████| 18/18 [00:00<00:00, 230.77it/s]


EarlyStopping counter: 2 out of 5
Epoch 4: Training loss: 0.6723, Validation loss: 0.7346


100%|██████████| 1315/1315 [00:18<00:00, 70.54it/s]
100%|██████████| 18/18 [00:00<00:00, 173.88it/s]


EarlyStopping counter: 3 out of 5
Epoch 5: Training loss: 0.6518, Validation loss: 0.7405


100%|██████████| 1315/1315 [00:18<00:00, 71.22it/s]
100%|██████████| 18/18 [00:00<00:00, 175.58it/s]


EarlyStopping counter: 4 out of 5
Epoch 6: Training loss: 0.6302, Validation loss: 0.7708


100%|██████████| 1315/1315 [00:18<00:00, 71.17it/s]
100%|██████████| 18/18 [00:00<00:00, 182.71it/s]

EarlyStopping counter: 5 out of 5
Early stopping
------------------------------------------------------------





In [None]:
preferences

tensor([[ 0.0000, -2.7698,  1.6947,  ...,  0.5771, -2.5745,  0.0255],
        [ 0.0000, -0.4039,  0.6148,  ...,  1.6784, -0.5766, -1.1723],
        [ 0.0000, -0.0121,  0.0837,  ..., -0.2411, -0.0976,  0.2043],
        ...,
        [ 0.0000,  0.1797, -0.1234,  ...,  0.0288,  0.0837, -0.1138],
        [ 0.0000,  0.8341,  0.9508,  ...,  0.4071, -0.4671, -0.0968],
        [ 0.0000, -0.1023, -0.3824,  ...,  0.6696, -0.4787, -0.4024]],
       grad_fn=<MmBackward0>)

In [None]:
pre_sessions_items

tensor([[   0,    0,    0,  ...,    0, 4658, 1803],
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0],
        ...,
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0]])

In [None]:
# for users, pre_sessions_items, cur_session_items, target_items, pre_sessions_item_timestamps, cur_session_item_timestamps, target_item_timestamps in valloader:
#     pass
#     break