In [1]:
import gzip
import pandas as pd
import torch
import datetime
from tqdm import tqdm
import numpy as np

import srdatasets

from torch import nn
from torch.nn import functional as F
from torchmetrics import functional as tm_f


device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device =torch.device('cpu')

```
srdatasets process --dataset=Gowalla --split-by=user --test-split=0.2 --dev-split=0.1 --task=long-short --input-len=5 --target-len=10 --pre-sessions=10 --pick-targets=last --session-interval=1440 --min-session-len=11 --max-session-len=30 --min-freq-item=20 --min-freq-user=20 
Code : c1683468658591

srdatasets process --dataset=Amazon-Books --split-by=user --test-split=0.2 --dev-split=0.1 --task=long-short --input-len=5 --target-len=10 --pre-sessions=10 --pick-targets=last --session-interval=1440 --min-session-len=11 --max-session-len=30 --min-freq-item=20 --min-freq-user=20
Code : c1683470275348

srdatasets process --dataset=MovieLens20M --split-by=user --test-split=0.2 --dev-split=0.1 --task=long-short --input-len=5 --target-len=10 --pre-sessions=10 --pick-targets=last --session-interval=1440 --min-session-len=11 --max-session-len=30 --min-freq-item=20 --min-freq-user=20
Code: c1683470665721
```

In [25]:
from srdatasets.dataloader_pytorch import DataLoader

BATCH_SIZE = 50
DATASET_CODE_VIDEOGAME = 'c1683466871546'
DATASET_CODE_BOOKS = 'c1683470275348'
DATASET_CODE_GOWALLA = 'c1683468658591'
DATASET_CODE_MOVIELENS20M = 'c1683470665721'


DATASET_CODE = DATASET_CODE_MOVIELENS20M
DATASET = "MovieLens20M"

trainloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=True, negatives_per_target=5, include_timestamp=True)#, num_workers=8, pin_memory=True)
valloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=False, development = True, include_timestamp=True)#, num_workers=8, pin_memory=True)
testloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=False, development = False, include_timestamp=True)#, num_workers=8, pin_memory=True)

len(trainloader), len(valloader), len(testloader)

(2036, 8, 83)

In [26]:
n_users, n_items = trainloader.num_users, trainloader.num_items
n_users, n_items

(101197, 9630)

In [27]:
len_train = len(trainloader.dataset)
len_val = len(valloader.dataset)
len_test = len(testloader.dataset)

len_train, len_val, len_test

(101780, 351, 4143)

In [28]:
class SHAN(nn.Module):
    def __init__(self, embedding_dims, n_users = n_users, n_items = n_items):
        super().__init__()

        #user
        self.user_embed = nn.Embedding(n_users, embedding_dims)

        #item
        self.item_embed = nn.Embedding(n_items, embedding_dims, 0)

        #long-term layer
        self.item_trans1 = nn.Linear(embedding_dims, embedding_dims)
        self.act_1 = nn.ReLU()

        #after taking the embeddings of the item (v), feedforward them through the item_tran1 network to get h
        #thereafter, compute the attention weights of each item by taking the softmax activation of the dotted user.h
        #compute u_long by taking attention_weights * v for each

        #short-term layer
        self.item_trans2 = nn.Linear(embedding_dims, embedding_dims)
        self.act_2 = nn.ReLU()

        #weighting for net user representation
        self.beta_0 = torch.randn(1, requires_grad = True).to(device)
    
    def forward(self, users, pre_sessions_items, cur_session_items):
        user_rep = self.user_embed(users)[..., None]                            # batch * emb * 1
        
        
        # LONG TERM 
        long_term_item_rep = self.item_embed(pre_sessions_items)
        activated_long_term = self.act_1(self.item_trans1(long_term_item_rep))  # batch * num * emb
        #compute attention weights
        attention_mat_1 = torch.bmm(activated_long_term, user_rep)              # batch * num * 1     
        attention_weights_1 = F.softmax(attention_mat_1, dim = 1)               # batch * num * 1
        #long term representation
        u_long = attention_weights_1*long_term_item_rep                         # batch * num * emb
        u_long = torch.sum(u_long, dim = 1)                                     # batch * emb

        #SHORT TERM
        short_term_item_rep = self.item_embed(cur_session_items)
        activated_short_term = self.act_2(self.item_trans2(short_term_item_rep))
        #compute attention weights 
        attention_mat_2 = torch.bmm(activated_short_term, user_rep)
        attention_weights_2 = F.softmax(attention_mat_2, dim = 1)
        #short term representation
        u_short = attention_weights_2*short_term_item_rep
        u_short = torch.sum(u_short, dim = 1)

        # HYBRID
        u_hybrid = self.beta_0*u_long + u_short

        preference_scores = u_hybrid @ self.item_embed.weight.T # batch * emb @ (n_items x emb).T
        return preference_scores
        

In [29]:
def loss_fn(preds, target, bootstraps = 100):
    bs, nitms = preds.size()

    total_loss = 0
    for i in range(bootstraps):
        idx = torch.randint(0, nitms, (bs, 1)).to(device)
        others = preds.gather(1, idx)
        actual = preds.gather(1, target)

        loss = -F.logsigmoid(actual - others)
        loss = torch.mean(loss)
        total_loss+= loss

    return total_loss/bootstraps

In [30]:
n_users = 0
n_items = 0
for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
    n_users = max(n_users, max(users))
    n_items = max(n_items, pre_sessions_items.max())
    n_items = max(n_items, cur_session_items.max())
    n_items = max(n_items, target_items.max())
n_users, n_items

100%|██████████| 2036/2036 [00:16<00:00, 120.94it/s]


(tensor(101196), tensor(9630))

In [31]:
class EarlyStopping:

    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [32]:
def get_function(func, preds, targets, k = 10):
    b = torch.zeros(preds.shape).bool()
    b[targets] = True
    b = b.to(device)
    
    return func(preds, b, k = k)

def get_batch_func(func, preds, targets, k=10, averaging = None):
    val = []
    for i in range(len(preds)):
        val.append(get_function(func, preds[i], targets[i], k).item())

    if(averaging is None):
        return np.sum(val)
    return np.mean(val)

In [33]:
DIMS = 5

model = SHAN(embedding_dims= DIMS, n_users = n_users +1, n_items = n_items+1).to(device)
early_stopping = EarlyStopping(patience=3, verbose=True, path = 'shan_ed_{}_{}.pth'.format(DATASET, DIMS))
optim = torch.optim.Adam(model.parameters(), lr = 1e-3)
NUM_EPOCHS = 20
BOOTSTRAPS = 20

train_loss = []
val_loss = []

# train_prec_5 = []
# train_rec_5 = []
# train_prec_1 = []
# train_rec_1 = []
# train_prec_10 = []
# train_rec_10 = []

for epoch in range(NUM_EPOCHS):

    net_loss = 0
    # Train
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
        # Shape
        #   users:                          (batch_size,)
        #   pre_sessions_items:             (batch_size, pre_sessions * max_session_len)
        #   cur_session_items:              (batch_size, max_session_len - target_len)
        #   target_items:                   (batch_size, target_len)
        #   negative_samples:               (batch_size, target_len, negatives_per_target)
        # DataType
        #   numpy.ndarray or torch.LongTensor\
        optim.zero_grad()
        users = users.to(device)
        pre_sessions_items = pre_sessions_items.to(device)
        cur_session_items = cur_session_items.to(device)
        target_items = target_items.to(device)

        preferences = model(users, pre_sessions_items, cur_session_items)
        loss = loss_fn(preferences, target_items, BOOTSTRAPS)

        loss.backward()
        optim.step()

        net_loss+=loss.item()



    net_loss_val = 0
    # Val
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _ in tqdm(iter(valloader)):
        with torch.no_grad():
            users = users.to(device)
            pre_sessions_items = pre_sessions_items.to(device)
            cur_session_items = cur_session_items.to(device)
            target_items = target_items.to(device)

            preferences = model(users, pre_sessions_items, cur_session_items)
            loss = loss_fn(preferences, target_items, BOOTSTRAPS)

            net_loss_val+=loss.item()
    net_loss = net_loss/len(trainloader)
    net_loss_val = net_loss_val/len(valloader)
    early_stopping(net_loss_val, model)

    if early_stopping.early_stop:
        print("Early stopping")
        print('-'*60)
        break

    print("Epoch {}: Training loss: {:.4f}, Validation loss: {:.4f}".format(epoch+1, net_loss, net_loss_val))



100%|██████████| 2036/2036 [00:45<00:00, 45.05it/s]
100%|██████████| 8/8 [00:00<00:00, 107.19it/s]


Validation loss decreased (inf --> 0.753221).  Saving model ...
Epoch 1: Training loss: 0.7305, Validation loss: 0.7532


100%|██████████| 2036/2036 [00:45<00:00, 45.13it/s]
100%|██████████| 8/8 [00:00<00:00, 95.07it/s]


EarlyStopping counter: 1 out of 3
Epoch 2: Training loss: 0.5616, Validation loss: 1.0325


100%|██████████| 2036/2036 [00:45<00:00, 45.18it/s]
100%|██████████| 8/8 [00:00<00:00, 91.50it/s]


EarlyStopping counter: 2 out of 3
Epoch 3: Training loss: 0.4375, Validation loss: 1.2724


 98%|█████████▊| 1991/2036 [00:44<00:00, 45.75it/s]

In [None]:
with open('shan_ed_{}_{}.pth'.format(DATASET, DIMS), 'rb') as f:
    model.load_state_dict(torch.load(f))

In [None]:
get_function(tm_f.retrieval_recall, preferences[0], target_items[0])

tensor(0., device='cuda:0')

In [None]:
net_loss_test = 0


net_loss = 0
prec_5_epoch = 0
rec_5_epoch = 0
prec_1_epoch = 0
rec_1_epoch = 0
prec_10_epoch = 0
rec_10_epoch = 0
# Test
for users, pre_sessions_items, cur_session_items, target_items, _, _, _ in tqdm(iter(testloader)):
    with torch.no_grad():

        
        users = users.to(device)
        pre_sessions_items = pre_sessions_items.to(device)
        cur_session_items = cur_session_items.to(device)
        target_items = target_items.to(device)

        preferences = model(users, pre_sessions_items, cur_session_items)
        loss_fn(preferences, target_items, BOOTSTRAPS)

        net_loss_test+=loss.item()
        
        
        prec_5_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 5, averaging = None)
        rec_5_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 5, averaging = None)
        prec_1_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 1, averaging = None)
        rec_1_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 1, averaging = None)
        prec_10_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 10, averaging = None)
        rec_10_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 10, averaging = None)
        

prec_10_epoch/=len_test
prec_5_epoch/=len_test
prec_1_epoch/=len_test
rec_10_epoch/=len_test
rec_5_epoch/=len_test
rec_1_epoch/=len_test

        # train_prec_5.append(prec_5_epoch)
        # train_prec_1.append(prec_1_epoch)
        # train_prec_10.append(prec_10_epoch)

        # train_rec_5.append(rec_5_epoch)
        # train_rec_1.append(rec_1_epoch)
        # train_rec_10.append(rec_10_epoch)

print("Precision@1 = {}, @5 = {}, @10 = {}".format(prec_1_epoch, prec_5_epoch, prec_10_epoch))
print("Recall@1 = {}, @5 = {}, @10 = {}".format(rec_1_epoch, rec_5_epoch, rec_10_epoch))
net_loss_test = net_loss_test/len(testloader)

print(net_loss_test)

100%|██████████| 4/4 [00:00<00:00,  5.40it/s]

Precision@1 = 0.0, @5 = 0.0, @10 = 0.0
Recall@1 = 0.0, @5 = 0.0, @10 = 0.0
1.3114395141601562



