# Pytorch/lightning implementation of SASRec model
Implementation of **Self-attentive sequential recommendation** paper:
```
@inproceedings{kang2018self,
  title={Self-attentive sequential recommendation},
  author={Kang, Wang-Cheng and McAuley, Julian},
  booktitle={2018 IEEE International Conference on Data Mining (ICDM)},
  pages={197--206},
  year={2018},
  organization={IEEE}
}
```
Originally taken [this code](https://github.com/pmixer/SASRec.pytorchhttps://github.com/pmixer/SASRec.pytorch) and rewritten model class plus used lightning.  
This notebook serves the purpose of interactive code execution/debugging.  
Main code for multiple GPU training is [here](./SASRecMain.py) 

Author: Sergei Bazhin  
Date: 2021-DEC - JAN-2022

In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import numpy as np
import torch
import pytorch_lightning as pl
import argparse
from importlib import reload
# module with datasets definition = train, validation and test
import DataHelper as DH
import SASRecModel as SASRec
SASRec = reload(SASRec)
DH = reload(DH)
import torch.optim as optim
import torch.nn.functional as F
from pytorch_lightning.callbacks import StochasticWeightAveraging

In [10]:
# setup command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='ml-1m', 
                    required=True, 
                    help="dataset to use : Beauty, ml-1m(default), Steam or Video")

parser.add_argument('--maxlen', default=50, type=int, 
                    help="truncate input sequence to last maxlen items, default 50")
parser.add_argument('--hidden_units', default=50, type=int, help="synonym for d_model") # synonym for d_model
parser.add_argument('--d_model', default=50, type=int, 
                    help="Transformer internal dimention") # same as hidden_units   
parser.add_argument('--num_blocks', default=2, type=int, help="Number of blocks in Transformer")
parser.add_argument('--num_heads', default=1, type=int, help="Number of heads in self-attention")
parser.add_argument('--dropout_rate', default=0.5, type=float, help="Dropout rate for Transformer")


parser.add_argument('--ndcg_samples', default=100, type=int, 
                    help="How many random items to pick up in hit-rate and ndcg calculation, default 100")
parser.add_argument('--top_k', default=10, type=int, 
                    help="How many items with high scores to pick for hit-rate and ndcg calculation, default 10")
parser.add_argument('--opt', default='Adam', type=str, help="Oplimizer to use: Adam(default), AdmaW, FusedAdam(requires apex library)")
parser.add_argument('--lr', default=0.001, type=float, 
                    help="learning rate, default 0.001")
parser.add_argument('--weight_decay', default=0.001, type=float, help="Weight decay for AdmaW")
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--warmup_proportion', default=0.2, type=float, help="Fraction of total optimization steps to increase learning rate from zero to max value")
# for different optimizers - regular Adam uses num_epochs and LAMB uses max_iters
parser.add_argument('--max_iters', default=10000, type=int, help="Optimization budget in update iterations")
parser.add_argument('--num_epochs', default=201, type=int, help="Number of epochs to train")
# swa parameters
parser.add_argument('--use_swa', default=False, type=bool, help="Use Stochastic Weights Ageraging algorythm")
parser.add_argument('--swa_epoch_start', default=0.8, type=float, help="Start SWA after that part of total epochs")
parser.add_argument('--swa_annealing_epochs', default=10, type=int, help="Number of epochs in the annealing phase of SWA")

# xavier init
parser.add_argument('--xavier_init', default=True, type=bool, help="Use xavier normal to init the model")

parser.add_argument('--inference_only', default=False, type=bool)
parser.add_argument('--checkpoint_path', default=None, type=str, help="Path to lightning checkpoint file")

# Torch Lightning settings
# https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html
# Data Parallel (strategy='dp') (multiple-gpus, 1 machine)
# DistributedDataParallel (strategy='ddp') (multiple-gpus across many machines (python script based)).
# DistributedDataParallel (strategy='ddp_spawn') (multiple-gpus across many machines (spawn based)).
# DistributedDataParallel 2 (strategy='ddp2') (DP in a machine, DDP across machines).
# Horovod (strategy='horovod') (multi-machine, multi-gpu, configured at runtime)
# TPUs (tpu_cores=8|x) (tpu or TPU pod)
parser.add_argument('--strategy', default='ddp_spawn', type=str, help="Lightning parallel training strategy dp, ddp, ddp_spawn(default), ddp2, etc ")
parser.add_argument('--precision', default=16, type=int, help="Lightning precision for model data during trining 16(default) or 32")
parser.add_argument('--accelerator', default="auto", type=str, help="Lightning accelerator auto(defaut), cpu, gpu, tpu")
parser.add_argument('--devices', default="auto", type=str, 
                    help="Lightning devices to use - see https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#devices")

args = parser.parse_args( ['--dataset=ml-1m', '--maxlen=200', '--dropout_rate=0.2'])
args = vars(args)
print(*[(k,v) for (k,v) in args.items()], sep="\n")

('dataset', 'ml-1m')
('maxlen', 200)
('hidden_units', 50)
('d_model', 50)
('num_blocks', 2)
('num_heads', 1)
('dropout_rate', 0.2)
('ndcg_samples', 100)
('top_k', 10)
('opt', 'Adam')
('lr', 0.001)
('weight_decay', 0.001)
('batch_size', 128)
('warmup_proportion', 0.2)
('max_iters', 10000)
('num_epochs', 201)
('use_swa', False)
('swa_epoch_start', 0.8)
('swa_annealing_epochs', 10)
('xavier_init', True)
('inference_only', False)
('checkpoint_path', None)
('strategy', 'ddp_spawn')
('precision', 16)
('accelerator', 'auto')
('devices', 'auto')


## Perepare the data  
We have 3 datasets:  
 - for training
 - for validation
 - for testing  
 They all contain all users and the last two items in sequence are split between validation (penultimate item) and test (last item)  
 Training has all user items but without last two that falls into validation and test

**user_train** - dict with key = *userid* and value = list of all items selected in respected time order  
**user_valid** - dict with the same structure as above but with penulitimate item (just one item)  
**user_test** - same as above but with ultimate item selected  
i.e. you have user 5 with items 1, 29, 34, 15, 8 in his sequence of items there will be the below data in vars:  
```
user_train[5] = [1,29,34]  
user_valid[5] = [15]  
user_test[5] = [8]
```

In [11]:
# read dataset
dataset = DH.data_partition('ml-1m')

In [12]:
[user_train, user_valid, user_test, usernum, itemnum] = dataset

In [13]:
# batches got sliced by users, i.e. batch accumulate BATCH_SIZE user sequences of items selected/bought
BATCH_SIZE = args['batch_size']
num_batch = len(user_train) // BATCH_SIZE  # number of batches

user_train_lens = list(map(len,[v for k,v in user_train.items()]))
print(f'average sequence length: {sum(user_train_lens)/len(user_train):.1f}')

average sequence length: 163.5


In [14]:
# dataset for validation
valid_data = DH.SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [16]:
# dataset for test
test_data = DH.SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen'], args['ndcg_samples'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [17]:
#dataset for training
train_data = DH.SequenceData(user_train, usernum, itemnum)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [18]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4,
                          shuffle=False, collate_fn = lambda x: DH.tokenize_batch(x, max_len=args['maxlen']))

### Unit-test training loader 

In [16]:
u, seq, pos, neg = next(iter(train_loader))

In [47]:
neg.shape

torch.Size([4, 200])

In [17]:
assert len(seq[0]) == args['maxlen']

In [18]:
u

[1, 2, 3, 4]

In [39]:
len(u)

4

In [41]:
# random user from batch
_u = np.random.randint(0,len(u))
_u

3

In [42]:
# train sequnce
print(seq[_u].numpy()[-10:])

[255 256 179 167 172 157 257  39 199 258]


In [43]:
# train shifted one item ahead
print(pos[_u].numpy()[-10:])

[256 179 167 172 157 257  39 199 258  29]


In [44]:
# negative sequnce
print(neg[_u].numpy()[-10:])

[2377  992 1485 2751  470 3337  547 1605 1347 1024]


### Unit-test validation and test data 

In [24]:
ii = np.random.randint(1,usernum+1)

In [25]:
print("{0:}{1:>40}".format("\n","Validation data \n"))
print("{0:<30}".format("Main sequence "),":",*valid_data.seq[ii,-10:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *valid_data.valid[ii,:10].numpy())


                       Validation data 

Main sequence                  : 1100 45 985 1862 370 65 642 66 48 501
Validation sequene             : 639 2288 2136 2025 874 3413 2600 2114 2417 260


In [26]:
print("{0:}{1:>40}".format("\n","Test data \n"))
print("{0:<30}".format("Main sequence "),":",*test_data.seq[ii, -11:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *test_data.valid[ii,:10].numpy())


                             Test data 

Main sequence                  : 1100 45 985 1862 370 65 642 66 48 501 639
Validation sequene             : 710 3359 420 35 2368 3280 205 2602 1460 2534


In [70]:
print(*user_train[ii+1][-10:], *user_valid[ii+1], *user_test[ii+1])

2529 645 812 592 963 1035 1038 837 816 1044 1047 3018


In [27]:
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=4, shuffle=False)

### Unit-test metrics calculation

In [107]:
from copy import deepcopy
[train, valid, test, usernum, itemnum] = deepcopy(dataset)

NDCG = 0.0
HT = 0.0

In [108]:
# list of users in batch
u = [122, 144]

# get validation items
valid_seq = torch.as_tensor([valid[_u] for _u in u], dtype=torch.int) # (batch x 1)

In [110]:
# make a matrix from train sequence (batch, args['maxlen'] - 1)
final_seq = torch.zeros((len(u),args['maxlen'] - 1), dtype=torch.int)
for ii,_u in enumerate(u):
    idx = min(args['maxlen'] - 1, len(train[_u]))
    final_seq[ii, -idx:] = torch.as_tensor(train[_u][-idx:])

# final seq (batch, args['maxlen'])
final_seq = torch.cat((final_seq, valid_seq), dim=1)

In [111]:
# making a test sequence with one element from test and the rest random
test_seq = torch.zeros((len(u), 101), dtype=torch.int)
# all elements that are in train plus padding zero
for ii, (_u, seq) in enumerate(zip(u, final_seq)):
    items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(seq.numpy()))) # random stuff not in seq
    test_seq[ii,0] = test[_u][0] # get true next element from test set
    test_seq[ii,1:] = torch.from_numpy(items_not_in_seq[np.random.randint(0, len(items_not_in_seq), 100)]) # fill the rest with random stuff

In [112]:
test_seq

tensor([[ 919, 2871, 2688, 3360,  480, 2031, 2828,  524,  598, 1611, 2785, 3042,
         1532, 1063, 1467, 2040, 2763, 2916, 1823,  668, 1124, 1976, 1673,  525,
         1518, 2615, 2776, 1437, 1079, 1978,  406, 2954,  897, 2514, 3390,  417,
         1519, 2713, 2332, 1517,  321, 2499, 1283, 2517, 2688,  185, 2233,  453,
         1756, 2679,  456,  669, 2025, 1473, 2226,  812, 2107, 2198,  643, 2980,
         2694,    8, 1778, 2027,  322, 1815, 1284, 2881, 1716,  243, 2384, 2340,
          617,  906, 2236, 2315, 1637, 3184, 2395,  610, 2118, 2334, 3355, 1912,
         1521,  604, 2003, 1173, 1519, 2207, 2051, 2369, 1029,  109, 2086,  228,
          106,  865,  520,  725, 2287],
        [ 672,  770, 3298,  668,   94, 3290,   20, 2532,  766, 3186, 1625, 2100,
         2251,  709,  391, 3336, 1080,  137, 2104, 1768,  476,  381, 1283, 1948,
         2409,  414, 2309, 1415, 1231, 2427, 1998, 2464,  899, 3323,  971,  134,
         1400, 2559,  527, 2520, 3211,  926,  428,  734, 3321, 1288, 

In [257]:
with torch.no_grad():
    log_feats = model.log2feats(final_seq) # shape (batch, seq_len, hidden_dim) = (1x200x50)

In [260]:
final_feat = log_feats[:, -1, :] # last hidden state/embedding
final_feat, final_feat.shape

(tensor([[ 0.6619,  0.6596, -0.1745, -0.2133, -0.7123,  0.4204,  0.4625, -0.6209,
          -0.4138,  0.0817, -0.1496,  0.6792,  0.5925,  0.7006, -0.6502, -0.5306,
           0.2231, -0.6313,  0.7336,  0.5250,  0.6301,  0.3757,  0.5017, -0.7404,
           0.4616,  0.5239, -0.6687, -0.6023, -0.6875, -0.5923,  0.6337, -0.1855,
           0.6248,  0.4989, -0.4843, -0.3989,  0.7503, -0.4861,  0.3672, -0.7682,
          -0.7489,  0.0185,  0.2974,  0.7162,  0.3279, -0.6246, -0.8353,  0.4458,
          -0.9165,  0.6383],
         [ 0.6621,  0.6496, -0.1759, -0.2299, -0.7259,  0.4308,  0.4675, -0.5906,
          -0.4444,  0.0792, -0.1448,  0.6820,  0.5929,  0.7078, -0.6550, -0.5333,
           0.2334, -0.6315,  0.7423,  0.5466,  0.6331,  0.3593,  0.5242, -0.7519,
           0.4644,  0.5121, -0.6704, -0.5962, -0.6760, -0.6022,  0.6266, -0.1775,
           0.6117,  0.4997, -0.4859, -0.4087,  0.7321, -0.4957,  0.3658, -0.7753,
          -0.7522,  0.0150,  0.3030,  0.7147,  0.3252, -0.6060, -0.83

In [261]:
with torch.no_grad():
    item_embs = model.item_emb(test_seq) # shape torch.Size([1, 101, 50]) 

In [262]:
item_embs.shape, final_feat.unsqueeze(-1).shape

(torch.Size([2, 101, 50]), torch.Size([2, 50, 1]))

In [264]:
logits = torch.bmm(item_embs, final_feat.unsqueeze(-1))
logits.shape

torch.Size([2, 101, 1])

In [266]:
predictions = -logits.squeeze()
predictions.shape

torch.Size([2, 101])

In [277]:
_, indices = torch.topk(predictions,15,dim=1, largest=False)
indices

tensor([[43, 98, 22, 40, 38, 92, 21, 75, 48, 68,  0,  8, 57, 89, 35],
        [30, 36, 78,  7, 51, 97, 31, 72, 64, 24,  0, 28, 99, 22, 86]])

In [278]:
_, indices = torch.where(indices == 0)

In [293]:
indices

tensor([10, 10])

In [294]:
hits = torch.as_tensor(indices < 11, dtype=torch.int)
hits

tensor([1, 1], dtype=torch.int32)

In [295]:
hits/torch.log2(indices+2)

tensor([0.2789, 0.2789])

## Assemble a model  

**From `Attention is all you need` paper:**  
We apply dropout to the **output of each sub-layer**, before it is added to the sub-layer input and normalized. 
In addition, we apply dropout to the **sums of the embeddings and the positional encodings** in both the encoder and decoder stacks. 
For the base model, we use a rate of P drop = 0.1.  

The encoder is composed of a stack of N = 6 identical layers. Each layer has two `sub-layers`.  
The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.   
We employ a residual connection around each of the two sub-layers, followed by layer normalization.  
That is, the output of each sub-layer is `LayerNorm(x + Sublayer(x))`, where Sublayer(x) is the function implemented by the sub-layer itself.  
To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512.
---

However, when the network goes deeper, several problems become exacerbated: 
1) the increased model capacity leads to overfitting; 
2) the training process becomes unstable (due to vanishing gradients etc.); and 
3) models with more parameters often require more training time.  

We perform the following operations to alleviate these problems: `g(x) = x + Dropout(g(LayerNorm(x)))`,  
where `g(x)` represents the self attention layer or the feed-forward network.  
That is to say, for layer `g` in each block, we apply **layer normalization**on the input `x` before feeding into `g`,   
apply **dropout** on `g’s` output, and add the input `x` to the final output.   
The **dropout** rate of turning off neurons is 0.2 for **MovieLens-1m** and 0.5 for the other three datasets due to their sparsity.  
We also apply a **dropout** layer on the embedding `E`.

The optimizer is the Adam optimizer, the `learning rate` is set to `0.001`, and the `batch size` is `128`.

**Shared Item Embedding**  
To reduce the model size and alleviate overfitting, we consider a single item embedding M:  
`ri,t=Ft Mi`. 

<img src="./sasrec-loss-function.png" alt="Loss function" style="width: 600px;"/>

Note that we ignore the terms where `ot = <pad>`.

### Ablation study
Remove PE (Positional Embedding): Without the positional embedding P, the attention weight on each item depends only on item embeddings. That is to say, the model makes recommendations based on users’ past actions, but their order doesn’t matter. This variant might be suitable for sparse datasets, where user sequences are typically short. This variant performs better then the default model on the sparsest dataset (Beauty), but worse on other denser datasets.

## Model assembly

In [92]:
SASRec = reload(SASRec)
from torch.nn import MultiheadAttention, LayerNorm, Dropout, Conv1d, Embedding, BCEWithLogitsLoss
from SASRecModel import PointWiseFF, SASRecEncoderLayer, PositinalEncoder, SASRecEncoder

### Building blocks for Encoder:
- **encoder** layer 
- **positional encoder**
- **point-wise** feed-forward

### Unit-test building blocks

In [20]:
input_x = torch.randn(BATCH_SIZE, args['maxlen'], args['d_model'])

In [21]:
emb_test = Embedding(args['maxlen'], embedding_dim=args['hidden_units'])

In [22]:
pe_test = Embedding(args['maxlen'], embedding_dim=args['d_model'])

In [23]:
pe_for_one_sequence = pe_test(torch.arange(0,input_x.shape[1], dtype=torch.int)) # get a single positional embedding

In [24]:
torch.tile(pe_for_one_sequence, (BATCH_SIZE,1,1)).shape

torch.Size([128, 200, 50])

In [25]:
pe_for_one_sequence.shape, input_x.shape

(torch.Size([200, 50]), torch.Size([128, 200, 50]))

In [26]:
pe_test = PositinalEncoder(args['maxlen'], args['d_model'])
pe_test(input_x).shape

torch.Size([128, 200, 50])

In [27]:
encoder_test = SASRecEncoderLayer(itemnum, **args)

In [28]:
encoder_test(input_x).shape

torch.Size([128, 200, 50])

### Unit-test final model

In [29]:
sas_rec_test = SASRecEncoder(itemnum, **args)

In [30]:
minib = next(iter(train_loader))

In [31]:
# need to comment out logging loss to run this cell
sas_rec_test.training_step(minib, 1)

  "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."


{'loss': tensor(6.1220, grad_fn=<AddBackward0>)}

In [412]:
u, seq, pos, neg = next(iter(train_loader))
item_emb = sas_rec_test.forward(seq) # get embeddings from transformer

In [33]:
sas_rec_test(seq).shape, seq.shape

(torch.Size([4, 200, 50]), torch.Size([4, 200]))

In [319]:
# pos_scores = sas_rec_test.compute_relevance_scores(item_emb, pos) # scores for positive sequence
# neg_scores = sas_rec_test.compute_relevance_scores(item_emb, neg) # scores for negative sequence

# pos_labels = torch.ones(pos_scores.shape)
# neg_labels = torch.zeros(neg_scores.shape)

# indices = torch.where(pos!=0) # exclude padding from loss computation

# loss = sas_rec_test.loss(pos_scores[indices], pos_labels[indices]) # loss for positive sequence
# loss += self.loss(neg_scores[indices], neg_labels[indices]) # loss for negative sequence

## Declare data loaders

In [34]:
print(f"\nBatch size is - {args['batch_size']}\n")


Batch size is - 128



In [35]:
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen']), 
                                         batch_size=args['batch_size'], shuffle=True, 
                                         drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [36]:
test_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen'], args['ndcg_samples']), 
                                          batch_size=args['batch_size'], shuffle=True, 
                                          drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [37]:
train_loader = torch.utils.data.DataLoader(dataset=DH.SequenceData(user_train, usernum, itemnum), 
                                           batch_size=args['batch_size'],
                                           shuffle=True, 
                                           collate_fn=DH.tokenize_batch)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [147]:
# seq, val = next(iter(val_loader))
# seq[0], val[0]

(tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
         58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
         76, 77], dtype=torch.int32),
 tensor([  78,  282, 2760, 1584,  642, 3178, 1761, 2186,  593, 2509, 16

In [38]:
args

{'dataset': 'ml-1m',
 'maxlen': 200,
 'hidden_units': 50,
 'd_model': 50,
 'num_blocks': 2,
 'num_heads': 1,
 'dropout_rate': 0.2,
 'ndcg_samples': 100,
 'top_k': 10,
 'opt': 'Adam',
 'lr': 0.001,
 'weight_decay': 0.001,
 'batch_size': 128,
 'warmup_proportion': 0.2,
 'max_iters': 10000,
 'num_epochs': 201,
 'use_swa': False,
 'swa_epoch_start': 0.8,
 'swa_annealing_epochs': 10,
 'xavier_init': True,
 'inference_only': False,
 'checkpoint_path': None,
 'strategy': 'ddp_spawn',
 'precision': 16,
 'accelerator': 'auto',
 'devices': 'auto'}

In [41]:
model = SASRecEncoder(itemnum, **args)
# model.load_state_dict(torch.load("bazman_sasrec.pt"))

In [42]:
# weight initialization
for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
        print(f"{name:<40} sucess")
    except:
        print(f"{name:<40} failure")

ie.weight                                sucess
pe.pe.weight                             sucess
enc_stack.0.norm_1.weight                failure
enc_stack.0.norm_1.bias                  failure
enc_stack.0.norm_2.weight                failure
enc_stack.0.norm_2.bias                  failure
enc_stack.0.attn.in_proj_weight          sucess
enc_stack.0.attn.in_proj_bias            failure
enc_stack.0.attn.out_proj.weight         sucess
enc_stack.0.attn.out_proj.bias           failure
enc_stack.0.ff.conv1.weight              sucess
enc_stack.0.ff.conv1.bias                failure
enc_stack.0.ff.conv2.weight              sucess
enc_stack.0.ff.conv2.bias                failure
enc_stack.1.norm_1.weight                failure
enc_stack.1.norm_1.bias                  failure
enc_stack.1.norm_2.weight                failure
enc_stack.1.norm_2.bias                  failure
enc_stack.1.attn.in_proj_weight          sucess
enc_stack.1.attn.in_proj_bias            failure
enc_stack.1.attn.out_proj.w

In [43]:
# save checkpoints
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor="hr_val", mode='max')

In [44]:
# run tensorboard before the script launch
# tensorboard --logdir ./lightning_logs/ --host 0.0.0.0

# https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
# strategy, accelerator, devices
trainer = pl.Trainer(gpus=[0], 
                     auto_select_gpus=False, 
                     max_epochs=300,
                     reload_dataloaders_every_n_epochs=1,
                     val_check_interval=1.0,
                     callbacks=[checkpoint_callback],
                     log_every_n_steps= int(len(train_data)/args['batch_size']/3), # log 4 times per epoch
                     # limit_val_batches=0, How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch.
                     num_sanity_val_steps=10, 
                     precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [45]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params
--------------------------------------------------
0 | ie          | Embedding         | 170 K 
1 | pe          | PositinalEncoder  | 10.0 K
2 | emb_dropout | Dropout           | 0     
3 | enc_stack   | Sequential        | 31.0 K
4 | final_norm  | LayerNorm         | 100   
5 | loss        | BCEWithLogitsLoss | 0     
--------------------------------------------------
211 K     Trainable params
0         Non-trainable params
211 K     Total params
0.424     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [203]:
torch.save(model.state_dict(), f"bazman_sasrec_{trainer.logger.version}.pt")

In [46]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hr_test': 0.6969747543334961, 'ndcg_test': 0.43651801347732544}
--------------------------------------------------------------------------------



[{'ndcg_test': 0.43651801347732544, 'hr_test': 0.6969747543334961}]

## Playground zone

In [83]:
emb_test = torch.zeros((1,1), dtype=torch.int)
emb_test

tensor([[0]], dtype=torch.int32)

In [85]:
model.pe(emb_test)

tensor([[[-0.1040,  0.7982, -0.4536,  0.3354,  0.7432,  0.4915,  1.7926,
          -0.3678, -1.4826, -0.1286,  0.2442, -0.8088, -0.1810, -2.6418,
           0.4637, -0.8777,  2.1836, -1.3642, -1.2097,  2.2583, -1.5101,
          -0.7968, -0.3392, -0.3674, -1.6548, -0.1962, -0.8084, -1.4440,
           1.4914,  1.4214, -0.9880, -1.0189,  1.4612, -1.5223,  1.2919,
          -0.0792,  0.4893,  1.8704,  0.7675, -0.6249,  0.4509, -0.3465,
          -0.5978, -1.1212,  0.7557, -1.6298,  1.0226, -0.5319, -1.1509,
          -2.3767]]], grad_fn=<EmbeddingBackward>)

In [92]:
dcg = lambda x: 1/np.log2(x+1)

In [95]:
dcg(10), dcg(100), dcg(200), dcg(300)

(0.2890648263178879,
 0.15019048322368797,
 0.13070098600339125,
 0.12145326590959868)

In [93]:
best_model = SASRecEncoder.load_from_checkpoint("sasrec.ckpt")

In [94]:
trainer.validate(best_model, dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'hr_val': 0.827293872833252, 'ndcg_val': 0.5855227708816528}
--------------------------------------------------------------------------------


[{'ndcg_val': 0.5855227708816528, 'hr_val': 0.827293872833252}]

### Export to ONNX 

In [415]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1,
                          shuffle=False, collate_fn = lambda x: DH.tokenize_batch(x, max_len=args['maxlen']))
u, seq, pos, neg = next(iter(train_loader))

In [416]:
best_model.forward(seq)

tensor([[[-0.1139, -1.0945, -0.3163,  ...,  0.3335, -0.7704,  0.3022],
         [-0.7676, -1.0056, -0.2909,  ..., -0.0067, -1.3996,  0.6735],
         [-0.7053, -1.0908, -0.8257,  ...,  0.0653, -0.6168,  0.6413],
         ...,
         [-0.8415,  2.4567,  1.3969,  ..., -0.2030, -1.7757, -0.4156],
         [-1.2199,  1.9336,  0.7832,  ...,  0.3152, -1.1556, -0.4826],
         [-1.2295,  2.9491,  1.5308,  ..., -0.2035, -0.4542,  0.2170]]],
       grad_fn=<NativeLayerNormBackward>)

In [417]:
torch.onnx.export(best_model, seq, "sasrec.onnx", verbose=True)

  "Passing an tensor of different rank in execution will be incorrect.")


graph(%input.1 : Int(1, 200, strides=[200, 1], requires_grad=0, device=cpu),
      %ie.weight : Float(3417, 50, strides=[50, 1], requires_grad=1, device=cpu),
      %pe.pe.weight : Float(200, 50, strides=[50, 1], requires_grad=1, device=cpu),
      %enc_stack.0.norm_1.weight : Float(50, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.norm_1.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.norm_2.weight : Float(50, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.norm_2.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.attn.in_proj_weight : Float(150, 50, strides=[50, 1], requires_grad=1, device=cpu),
      %enc_stack.0.attn.in_proj_bias : Float(150, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.attn.out_proj.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %enc_stack.0.ff.conv1.weight : Float(50, 50, 1, strides=[50, 1, 1], requires_grad=1, device=cpu),
      %enc_stack

In [68]:
input_x = torch.randn(BATCH_SIZE, args['maxlen'], args['d_model'])

In [69]:
pe_test = Embedding(args['maxlen'], embedding_dim=args['d_model'])

In [70]:
pe_for_one_sequence = pe_test(torch.arange(0,input_x.shape[1], dtype=torch.int)) # get a single positional embedding

In [113]:
torch.linalg.matrix_norm(((next(best_model.pe.parameters())).data))

tensor(22.0447)

In [104]:
v.data[0,0] = 0.7

In [122]:
(v.data<-1).sum()

tensor(0)

In [108]:
torch.linalg.matrix_norm(v.data)

tensor(22.0447)

In [178]:
_u = 45
all_items_set = set(range(itemnum))
user_items_set = set(user_train[_u])  - set(user_valid[_u]) # union of what user has selected up to valid sequence

### Reading raw data from movie lens 1M
https://grouplens.org/datasets/movielens/

In [375]:
import pandas as pd
import time
from datetime import datetime
display_settings = {
    'display.max_columns': 100,
    'display.max_rows':100,
    'display.expand_frame_repr': True,  # Wrap to multiple pages
    'display.precision': 2,
    'display.show_dimensions': True,
    'display.float_format': '{:,.2f}'.format,
    'io.hdf.default_format':'table' # appendable hp5 table
}

for op, value in display_settings.items():
    pd.set_option("{}".format(op), value)

In [401]:
# choose wich dataset to read - 20M or 1M
ratings = pd.read_csv('data/ratings-20m.csv', dtype={'userId':np.int32, 'movieId':np.int32, 'rating':np.float32})
# ratings = pd.read_csv('data/ratings-1m.csv', delimiter="::", engine='python',
#                       names=['userId','movieId','rating','timestamp'], 
#                       dtype={'userId':np.int32, 'movieId':np.int32, 'rating':np.float32})

In [402]:
ratings['timestamp'] = ratings['timestamp'].apply(lambda x: datetime.fromtimestamp(x))
ratings = ratings.sort_values(by=['userId','timestamp'])

In [403]:
ratings.head()

Unnamed: 0,userId,movieId,rating,timestamp
20,1,924,3.5,2004-09-10 03:06:38
19,1,919,3.5,2004-09-10 03:07:01
86,1,2683,3.5,2004-09-10 03:07:30
61,1,1584,3.5,2004-09-10 03:07:36
23,1,1079,4.0,2004-09-10 03:07:45


In [404]:
# filter out movies with less than 5 ragings
movie_cnt = ratings.movieId.value_counts()
ratings = ratings[ratings.movieId.isin(movie_cnt[movie_cnt>=5].index)]

In [405]:
len(set(ratings.movieId.values))

18345

In [406]:
# make a mapping so movie ids are sequential
movieId2seq = dict() # original movieId -> sequential id
seq2movieId = np.array(list(set(ratings.movieId.values))) # sequential id to original movie mapping
for i,movie_id in enumerate(seq2movieId):
    movieId2seq[movie_id] = i+1 
# change ids to sequential
ratings['movieId'] = ratings['movieId'].map(movieId2seq)

In [407]:
ratings[['userId','movieId']].to_csv("data/ml-20m.txt",sep=" ", header=False, index=False)

In [409]:
!ls -lh data/ml*

-rw-r--r-- 1 testuser testuser 9,1M янв 28 10:02 data/ml-1m_manual.txt
-rw-r--r-- 1 testuser testuser 8,7M дек 15 17:12 data/ml-1m.txt
-rw-r--r-- 1 testuser testuser 209M янв 28 10:10 data/ml-20m.txt


In [408]:
!head data/ml-20m.txt

1 924
1 919
1 2683
1 1584
1 1079
1 653
1 2959
1 337
1 1304
1 3996


In [269]:
ratings[ratings.userId.isin(ratings[ratings.userId.isin(ratings[ratings.userId.isin(ratings[ratings.userId.isin(ratings[ratings.userId.isin(ratings[ratings.userId.\
                                    isin(ratings[ratings['movieId']==1].userId.values)].\
                            query('movieId==2').userId.values)].query('movieId==3').userId.values)].\
                            query('movieId==4').userId.values)].query('movieId==5').userId.values)].query('movieId==6').userId.values)].query('movieId==7')

Unnamed: 0,userId,movieId,rating,timestamp
516235,3189,7,3.0,"(2000, 9, 25, 8, 15, 16, 0, 269, 0)"
549613,3391,7,3.0,"(2003, 2, 25, 16, 51, 22, 1, 56, 0)"
691512,4140,7,3.0,"(2000, 8, 4, 1, 6, 19, 4, 217, 0)"
