In [133]:
import gzip
import pandas as pd
import torch
import datetime
from tqdm import tqdm
import numpy as np

import srdatasets

from torch import nn
from torch.nn import functional as F
from torchmetrics import functional as tm_f


device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device =torch.device('cpu')

```
srdatasets process --dataset=Amazon-VideoGames --split-by=user --task=long-short --target-len=1 --session-interval=120 --min-freq-item=0 --min-freq-user=0

srdatasets process --dataset=Amazon-Books --split-by=time --task=long-short --input-len=5 --target-len=10 --pick-target=last --session-interval=14400 --min-freq-item=0 --min-freq-user=0 --min-session-len=11 --pre-sessions=6

srdatasets process --dataset=Amazon-VideoGames --split-by=user --test-split=0.2 --dev-split=0.1 --task=long-short --input-len=5 --target-len=10 --pre-sessions=10 --pick-targets=last --session-interval=14400 --min-session-len=11 --max-session-len=30
: c1683467389709

srdatasets process --dataset=Gowalla --split-by=user --test-split=0.2 --dev-split=0.1 --task=long-short --input-len=5 --target-len=1 --pre-sessions=10 --pick-targets=last --session-interval=1440 --min-session-len=2 --max-session-len=30 --min-freq-item=20 --min-freq-user=20
: c1683467693343
```

In [222]:
from srdatasets.dataloader_pytorch import DataLoader

BATCH_SIZE = 50
DATASET_CODE_VIDEOGAME = 'c1683466871546'
DATASET_CODE_BOOKS = 'c1683461837235'
DATASET_CODE_GOWALLA = 'c1683468658591'


DATASET_CODE = DATASET_CODE_GOWALLA
DATASET = "Gowalla"

trainloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=True, negatives_per_target=5, include_timestamp=True)#, num_workers=8, pin_memory=True)
valloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=False, development = True, include_timestamp=True)#, num_workers=8, pin_memory=True)
testloader = DataLoader(DATASET, DATASET_CODE, batch_size=BATCH_SIZE, train=False, development = False, include_timestamp=True)#, num_workers=8, pin_memory=True)

len(trainloader), len(valloader), len(testloader)

(217, 7, 31)

In [223]:
n_users, n_items = trainloader.num_users, trainloader.num_items
n_users, n_items

(9908, 41012)

In [224]:
len_train = len(trainloader.dataset)
len_val = len(valloader.dataset)
len_test = len(testloader.dataset)

len_train, len_val, len_test

(10808, 349, 1525)

In [225]:
class SHAN(nn.Module):
    def __init__(self, embedding_dims, n_users = n_users, n_items = n_items):
        super().__init__()

        #user
        self.user_embed = nn.Embedding(n_users, embedding_dims)

        #item
        self.item_embed = nn.Embedding(n_items, embedding_dims, 0)

        #long-term layer
        self.item_trans1 = nn.Linear(embedding_dims, embedding_dims)
        self.act_1 = nn.ReLU()

        #after taking the embeddings of the item (v), feedforward them through the item_tran1 network to get h
        #thereafter, compute the attention weights of each item by taking the softmax activation of the dotted user.h
        #compute u_long by taking attention_weights * v for each

        #short-term layer
        self.item_trans2 = nn.Linear(embedding_dims, embedding_dims)
        self.act_2 = nn.ReLU()

        #weighting for net user representation
        self.beta_0 = torch.randn(1, requires_grad = True).to(device)
    
    def forward(self, users, pre_sessions_items, cur_session_items):
        user_rep = self.user_embed(users)[..., None]                            # batch * emb * 1
        
        
        # LONG TERM 
        long_term_item_rep = self.item_embed(pre_sessions_items)
        activated_long_term = self.act_1(self.item_trans1(long_term_item_rep))  # batch * num * emb
        #compute attention weights
        attention_mat_1 = torch.bmm(activated_long_term, user_rep)              # batch * num * 1     
        attention_weights_1 = F.softmax(attention_mat_1, dim = 1)               # batch * num * 1
        #long term representation
        u_long = attention_weights_1*long_term_item_rep                         # batch * num * emb
        u_long = torch.sum(u_long, dim = 1)                                     # batch * emb

        #SHORT TERM
        short_term_item_rep = self.item_embed(cur_session_items)
        activated_short_term = self.act_2(self.item_trans2(short_term_item_rep))
        #compute attention weights 
        attention_mat_2 = torch.bmm(activated_short_term, user_rep)
        attention_weights_2 = F.softmax(attention_mat_2, dim = 1)
        #short term representation
        u_short = attention_weights_2*short_term_item_rep
        u_short = torch.sum(u_short, dim = 1)

        # HYBRID
        u_hybrid = self.beta_0*u_long + u_short

        preference_scores = u_hybrid @ self.item_embed.weight.T # batch * emb @ (n_items x emb).T
        return preference_scores
        

In [226]:
def loss_fn(preds, target, bootstraps = 100):
    bs, nitms = preds.size()

    total_loss = 0
    for i in range(bootstraps):
        idx = torch.randint(0, nitms, (bs, 1)).to(device)
        others = preds.gather(1, idx)
        actual = preds.gather(1, target)

        loss = -F.logsigmoid(actual - others)
        loss = torch.mean(loss)
        total_loss+= loss

    return total_loss/bootstraps

In [227]:
n_users = 0
n_items = 0
for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
    n_users = max(n_users, max(users))
    n_items = max(n_items, pre_sessions_items.max())
    n_items = max(n_items, cur_session_items.max())
    n_items = max(n_items, target_items.max())
n_users, n_items

100%|██████████| 217/217 [00:02<00:00, 82.23it/s]


(tensor(9907), tensor(41052))

In [228]:
class EarlyStopping:

    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [229]:
def get_function(func, preds, targets, k = 10):
    b = torch.zeros(preds.shape).bool()
    b[targets] = True
    b = b.to(device)
    
    return func(preds, b, k = k)

def get_batch_func(func, preds, targets, k=10, averaging = None):
    val = []
    for i in range(len(preds)):
        val.append(get_function(func, preds[i], targets[i], k).item())

    if(averaging is None):
        return np.sum(val)
    return np.mean(val)

In [310]:
DIMS = 50

model = SHAN(embedding_dims= DIMS, n_users = n_users +1, n_items = n_items+1).to(device)
early_stopping = EarlyStopping(patience=3, verbose=True, path = 'shan_ed_{}.pth'.format(DIMS))
optim = torch.optim.Adam(model.parameters(), lr = 1e-3)
NUM_EPOCHS = 20
BOOTSTRAPS = 20

train_loss = []
val_loss = []

# train_prec_5 = []
# train_rec_5 = []
# train_prec_1 = []
# train_rec_1 = []
# train_prec_10 = []
# train_rec_10 = []

for epoch in range(NUM_EPOCHS):


    # Train
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _, _ in tqdm(iter(trainloader)):
        # Shape
        #   users:                          (batch_size,)
        #   pre_sessions_items:             (batch_size, pre_sessions * max_session_len)
        #   cur_session_items:              (batch_size, max_session_len - target_len)
        #   target_items:                   (batch_size, target_len)
        #   negative_samples:               (batch_size, target_len, negatives_per_target)
        # DataType
        #   numpy.ndarray or torch.LongTensor\
        optim.zero_grad()
        users = users.to(device)
        pre_sessions_items = pre_sessions_items.to(device)
        cur_session_items = cur_session_items.to(device)
        target_items = target_items.to(device)

        preferences = model(users, pre_sessions_items, cur_session_items)
        loss = loss_fn(preferences, target_items, BOOTSTRAPS)

        loss.backward()
        optim.step()

        net_loss+=loss.item()



    net_loss_val = 0
    # Val
    for users, pre_sessions_items, cur_session_items, target_items, _, _, _ in tqdm(iter(valloader)):
        with torch.no_grad():
            users = users.to(device)
            pre_sessions_items = pre_sessions_items.to(device)
            cur_session_items = cur_session_items.to(device)
            target_items = target_items.to(device)

            preferences = model(users, pre_sessions_items, cur_session_items)
            loss = loss_fn(preferences, target_items, BOOTSTRAPS)

            net_loss_val+=loss.item()
    net_loss = net_loss/len(trainloader)
    net_loss_val = net_loss_val/len(valloader)
    early_stopping(net_loss_val, model)

    if early_stopping.early_stop:
        print("Early stopping")
        print('-'*60)
        break

    print("Epoch {}: Training loss: {:.4f}, Validation loss: {:.4f}".format(epoch+1, net_loss, net_loss_val))



100%|██████████| 217/217 [00:05<00:00, 37.16it/s]
100%|██████████| 7/7 [00:00<00:00, 98.27it/s]


Validation loss decreased (inf --> 1.371490).  Saving model ...
Epoch 1: Training loss: 1.7037, Validation loss: 1.3715


100%|██████████| 217/217 [00:05<00:00, 38.88it/s]
100%|██████████| 7/7 [00:00<00:00, 92.72it/s]


Validation loss decreased (1.371490 --> 1.069517).  Saving model ...
Epoch 2: Training loss: 1.1322, Validation loss: 1.0695


100%|██████████| 217/217 [00:05<00:00, 38.55it/s]
100%|██████████| 7/7 [00:00<00:00, 101.41it/s]


Validation loss decreased (1.069517 --> 0.874238).  Saving model ...
Epoch 3: Training loss: 0.8459, Validation loss: 0.8742


100%|██████████| 217/217 [00:05<00:00, 39.17it/s]
100%|██████████| 7/7 [00:00<00:00, 83.39it/s]


Validation loss decreased (0.874238 --> 0.800087).  Saving model ...
Epoch 4: Training loss: 0.7234, Validation loss: 0.8001


100%|██████████| 217/217 [00:05<00:00, 38.66it/s]
100%|██████████| 7/7 [00:00<00:00, 87.84it/s]


Validation loss decreased (0.800087 --> 0.781565).  Saving model ...
Epoch 5: Training loss: 0.6670, Validation loss: 0.7816


100%|██████████| 217/217 [00:05<00:00, 38.41it/s]
100%|██████████| 7/7 [00:00<00:00, 89.74it/s]


Validation loss decreased (0.781565 --> 0.770248).  Saving model ...
Epoch 6: Training loss: 0.6318, Validation loss: 0.7702


100%|██████████| 217/217 [00:05<00:00, 38.34it/s]
100%|██████████| 7/7 [00:00<00:00, 96.34it/s]


Validation loss decreased (0.770248 --> 0.755334).  Saving model ...
Epoch 7: Training loss: 0.6027, Validation loss: 0.7553


100%|██████████| 217/217 [00:05<00:00, 38.01it/s]
100%|██████████| 7/7 [00:00<00:00, 99.15it/s]


Validation loss decreased (0.755334 --> 0.745823).  Saving model ...
Epoch 8: Training loss: 0.5749, Validation loss: 0.7458


100%|██████████| 217/217 [00:05<00:00, 38.40it/s]
100%|██████████| 7/7 [00:00<00:00, 87.37it/s]


EarlyStopping counter: 1 out of 3
Epoch 9: Training loss: 0.5519, Validation loss: 0.7685


100%|██████████| 217/217 [00:05<00:00, 39.25it/s]
100%|██████████| 7/7 [00:00<00:00, 102.17it/s]


EarlyStopping counter: 2 out of 3
Epoch 10: Training loss: 0.5312, Validation loss: 0.7617


100%|██████████| 217/217 [00:05<00:00, 39.55it/s]
100%|██████████| 7/7 [00:00<00:00, 94.19it/s]

EarlyStopping counter: 3 out of 3
Early stopping
------------------------------------------------------------





In [311]:
with open('shan_ed_{}.pth'.format(DIMS), 'rb') as f:
    model.load_state_dict(torch.load(f))

In [312]:
get_function(tm_f.retrieval_recall, preferences[0], target_items[0])

tensor(0.2000, device='cuda:0')

In [313]:
net_loss_test = 0


net_loss = 0
prec_5_epoch = 0
rec_5_epoch = 0
prec_1_epoch = 0
rec_1_epoch = 0
prec_10_epoch = 0
rec_10_epoch = 0
# Test
for users, pre_sessions_items, cur_session_items, target_items, _, _, _ in tqdm(iter(testloader)):
    with torch.no_grad():

        
        users = users.to(device)
        pre_sessions_items = pre_sessions_items.to(device)
        cur_session_items = cur_session_items.to(device)
        target_items = target_items.to(device)

        preferences = model(users, pre_sessions_items, cur_session_items)
        loss_fn(preferences, target_items, BOOTSTRAPS)

        net_loss_test+=loss.item()
        
        
        prec_5_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 5, averaging = None)
        rec_5_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 5, averaging = None)
        prec_1_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 1, averaging = None)
        rec_1_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 1, averaging = None)
        prec_10_epoch+= get_batch_func(tm_f.retrieval_precision, preferences, target_items, 10, averaging = None)
        rec_10_epoch+= get_batch_func(tm_f.retrieval_recall, preferences, target_items, 10, averaging = None)
        

prec_10_epoch/=len_test
prec_5_epoch/=len_test
prec_1_epoch/=len_test
rec_10_epoch/=len_test
rec_5_epoch/=len_test
rec_1_epoch/=len_test

        # train_prec_5.append(prec_5_epoch)
        # train_prec_1.append(prec_1_epoch)
        # train_prec_10.append(prec_10_epoch)

        # train_rec_5.append(rec_5_epoch)
        # train_rec_1.append(rec_1_epoch)
        # train_rec_10.append(rec_10_epoch)

print("Precision@1 = {}, @5 = {}, @10 = {}".format(prec_1_epoch, prec_5_epoch, prec_10_epoch))
print("Recall@1 = {}, @5 = {}, @10 = {}".format(rec_1_epoch, rec_5_epoch, rec_10_epoch))
net_loss_test = net_loss_test/len(testloader)

print(net_loss_test)

100%|██████████| 31/31 [00:06<00:00,  4.69it/s]

Precision@1 = 0.42950819672131146, @5 = 0.11921311659891097, @10 = 0.06504918139000408
Recall@1 = 0.08518605391510198, @5 = 0.11749518792648785, @10 = 0.12675123802951124
0.7085349559783936





In [314]:
preferences.shape

torch.Size([25, 41053])

In [315]:
target_items

tensor([[40525, 34039, 34039, 40524, 40525, 40525, 40524, 34039, 40526, 40525],
        [40528, 32767, 40527, 28788, 40528, 37824, 40527, 37824, 28788, 40528],
        [ 9066, 15051, 40549, 12786, 40549,  3148, 14566,  3897,  9074, 14571],
        [ 6019, 26255, 26531, 23989, 34764, 34767, 27766, 36897, 36897, 27766],
        [40593, 35318, 40593, 40593, 40592, 40592, 35325, 40592, 35325, 40593],
        [40592, 35318, 40592, 35318, 40592, 35318, 40592, 40593, 35318, 35320],
        [40600, 40601, 40602, 40601, 40597, 40599, 40602, 37844, 40600, 37844],
        [28546, 10633, 20985, 28546, 10633, 13451, 40675, 13451, 28546, 23043],
        [32980, 28873, 35983,  6055, 25244,  3287, 40677, 13335, 35983, 32980],
        [40678, 40679, 40682, 23786, 29511, 40678, 25263, 35983, 14630, 40682],
        [36094, 35582, 35581, 40693, 35580, 40693, 35578, 36094, 35580, 35581],
        [18143, 16449, 16446,  7137, 35645, 15435, 15434,  7137,  7138, 28037],
        [15850,  3217, 40733, 10451, 158

In [316]:
preferences

tensor([[-1.0675e-02, -2.5502e+00, -7.5968e-02,  ..., -2.9945e+00,
         -1.5199e-01, -5.5077e-01],
        [ 4.0324e-02, -4.4944e-01,  1.3139e-01,  ...,  1.2590e+00,
          5.7969e-01,  3.9998e-01],
        [ 1.0541e-02, -1.0599e+00, -7.9814e-01,  ...,  1.8736e-01,
          2.3158e+00,  6.9104e-01],
        ...,
        [-3.0522e-02,  9.9574e-01,  1.5569e-01,  ...,  3.5698e-03,
         -1.4902e+00, -1.7064e+00],
        [ 2.6391e-02,  2.3572e+00, -7.8037e-01,  ...,  4.2007e+00,
          6.0871e+00,  1.9057e+00],
        [-9.6658e-03, -3.4209e-01,  5.9871e+00,  ..., -1.9429e+00,
         -1.0512e+00,  1.1403e+00]], device='cuda:0')

In [317]:
torch.topk(preferences, k = 10, dim = -1).indices

tensor([[40525, 40524, 11283,  6200, 11399, 29969, 15215,  4039,  3120, 17432],
        [40527, 38350, 40595,  5572, 27551, 21674, 25692,  1789, 31363, 26545],
        [40549, 26084, 40944, 26146,  5266, 34855,   681, 36688, 14571,  7189],
        [21189, 22284, 34746, 37933, 25531, 11962, 17992, 35475, 11675, 29253],
        [40593, 35318, 10191,  4058,  5131, 23654, 22580, 40592, 14345, 39813],
        [35318, 23654, 11671, 40593, 38518, 23975, 27562, 22580,  8977, 36135],
        [24176, 37844, 10940,  1020, 33201, 40597, 29431, 32216, 29009, 33319],
        [13451, 35617,  4838, 33653, 15221, 29558, 37328, 17555,  1553, 32149],
        [32980, 31378,  7060,  6836,  5801, 28460, 23787, 34825, 10474, 26081],
        [40679, 28879, 33764,  5371, 40850, 28006, 35104, 13698, 22249, 16810],
        [35579, 22673, 26937, 12693, 24683, 13958, 12391,  1925, 33315,  4836],
        [16450, 40702, 34389, 35772,  2348, 37688, 27146, 17365, 34439,  2093],
        [15850,  8154,  2255, 29204, 238

In [318]:
target_items

tensor([[40525, 34039, 34039, 40524, 40525, 40525, 40524, 34039, 40526, 40525],
        [40528, 32767, 40527, 28788, 40528, 37824, 40527, 37824, 28788, 40528],
        [ 9066, 15051, 40549, 12786, 40549,  3148, 14566,  3897,  9074, 14571],
        [ 6019, 26255, 26531, 23989, 34764, 34767, 27766, 36897, 36897, 27766],
        [40593, 35318, 40593, 40593, 40592, 40592, 35325, 40592, 35325, 40593],
        [40592, 35318, 40592, 35318, 40592, 35318, 40592, 40593, 35318, 35320],
        [40600, 40601, 40602, 40601, 40597, 40599, 40602, 37844, 40600, 37844],
        [28546, 10633, 20985, 28546, 10633, 13451, 40675, 13451, 28546, 23043],
        [32980, 28873, 35983,  6055, 25244,  3287, 40677, 13335, 35983, 32980],
        [40678, 40679, 40682, 23786, 29511, 40678, 25263, 35983, 14630, 40682],
        [36094, 35582, 35581, 40693, 35580, 40693, 35578, 36094, 35580, 35581],
        [18143, 16449, 16446,  7137, 35645, 15435, 15434,  7137,  7138, 28037],
        [15850,  3217, 40733, 10451, 158

In [319]:
# for users, pre_sessions_items, cur_session_items, target_items, pre_sessions_item_timestamps, cur_session_item_timestamps, target_item_timestamps in valloader:
#     pass
#     break