# Pytorch implementation of SASRec model  
Implementation of 
```
@inproceedings{kang2018self,
  title={Self-attentive sequential recommendation},
  author={Kang, Wang-Cheng and McAuley, Julian},
  booktitle={2018 IEEE International Conference on Data Mining (ICDM)},
  pages={197--206},
  year={2018},
  organization={IEEE}
}
```
Originally taken [this code](https://github.com/pmixer/SASRec.pytorchhttps://github.com/pmixer/SASRec.pytorch) and rewritten

Author: bazman  
Date: 2021-DEC - JAN-2022

In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import numpy as np
import torch
import pytorch_lightning as pl
import argparse
from importlib import reload
# module with datasets definition = train, validation and test
import DataHelper as DH
import SASRecModel as SASRec
SASRec = reload(SASRec)
DH = reload(DH)
import torch.optim as optim
import torch.nn.functional as F

In [3]:
# setup command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--train_dir', required=True)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)

parser.add_argument('--hidden_units', default=50, type=int) # synonym for d_model
parser.add_argument('--d_model', default=50, type=int) # same as hidden_units   

parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--inference_only', default=False, type=bool)
parser.add_argument('--state_dict_path', default=None, type=str)
args = parser.parse_args( ['--dataset=ml-1m', '--train_dir=default', '--maxlen=200', '--dropout_rate=0.2', '--device=cuda'])
args = vars(args)
print(*[(k,v) for (k,v) in args.items()], sep="\n")

('dataset', 'ml-1m')
('train_dir', 'default')
('batch_size', 128)
('lr', 0.001)
('maxlen', 200)
('hidden_units', 50)
('d_model', 50)
('num_blocks', 2)
('num_epochs', 201)
('num_heads', 1)
('dropout_rate', 0.2)
('l2_emb', 0.0)
('device', 'cuda')
('inference_only', False)
('state_dict_path', None)


## Perepare the data  
We have 3 datasets:  
 - for training
 - for validation
 - for testing  
 They all contain all users and the last two items in sequence are split between validation (penultimate item) and test (last item)  
 Training has all user items but without last two that falls into validation and test

**user_train** - dict with key = *userid* and value = list of all items selected in respected time order  
**user_valid** - dict with the same structure as above but with penulitimate item (just one item)  
**user_test** - same as above but with ultimate item selected  
i.e. you have user 5 with items 1, 29, 34, 15, 8 in his sequence of items there will be the below data in vars:  
```
user_train[5] = [1,29,34]  
user_valid[5] = [15]  
user_test[5] = [8]
```

In [4]:
# read dataset
dataset = DH.data_partition('ml-1m')

In [5]:
[user_train, user_valid, user_test, usernum, itemnum] = dataset

In [6]:
# batches got sliced by users, i.e. batch accumulate BATCH_SIZE user sequences of items selected/bought
BATCH_SIZE = 128
num_batch = len(user_train) // BATCH_SIZE  # number of batches

user_train_lens = list(map(len,[v for k,v in user_train.items()]))
print(f'average sequence length: {sum(user_train_lens)/len(user_train):.1f}')

average sequence length: 163.5


In [7]:
# dataset for validation
valid_data = DH.SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [8]:
# dataset for test
test_data = DH.SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [9]:
#dataset for training
train_data = DH.SequenceData(user_train, usernum, itemnum)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [10]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4,
                          shuffle=False, collate_fn = lambda x: DH.tokenize_batch(x, max_len=args['maxlen']))

### Unit-test training loader 

In [16]:
u, seq, pos, neg = next(iter(train_loader))

In [47]:
neg.shape

torch.Size([4, 200])

In [17]:
assert len(seq[0]) == args['maxlen']

In [18]:
u

[1, 2, 3, 4]

In [39]:
len(u)

4

In [41]:
# random user from batch
_u = np.random.randint(0,len(u))
_u

3

In [42]:
# train sequnce
print(seq[_u].numpy()[-10:])

[255 256 179 167 172 157 257  39 199 258]


In [43]:
# train shifted one item ahead
print(pos[_u].numpy()[-10:])

[256 179 167 172 157 257  39 199 258  29]


In [44]:
# negative sequnce
print(neg[_u].numpy()[-10:])

[2377  992 1485 2751  470 3337  547 1605 1347 1024]


### Unit-test validation and test data 

In [24]:
ii = np.random.randint(1,usernum+1)

In [25]:
print("{0:}{1:>40}".format("\n","Validation data \n"))
print("{0:<30}".format("Main sequence "),":",*valid_data.seq[ii,-10:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *valid_data.valid[ii,:10].numpy())


                       Validation data 

Main sequence                  : 1100 45 985 1862 370 65 642 66 48 501
Validation sequene             : 639 2288 2136 2025 874 3413 2600 2114 2417 260


In [26]:
print("{0:}{1:>40}".format("\n","Test data \n"))
print("{0:<30}".format("Main sequence "),":",*test_data.seq[ii, -11:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *test_data.valid[ii,:10].numpy())


                             Test data 

Main sequence                  : 1100 45 985 1862 370 65 642 66 48 501 639
Validation sequene             : 710 3359 420 35 2368 3280 205 2602 1460 2534


In [70]:
print(*user_train[ii+1][-10:], *user_valid[ii+1], *user_test[ii+1])

2529 645 812 592 963 1035 1038 837 816 1044 1047 3018


In [27]:
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=4, shuffle=False)

### Unit-test metrics calculation

In [107]:
from copy import deepcopy
[train, valid, test, usernum, itemnum] = deepcopy(dataset)

NDCG = 0.0
HT = 0.0

In [108]:
# list of users in batch
u = [122, 144]

# get validation items
valid_seq = torch.as_tensor([valid[_u] for _u in u], dtype=torch.int) # (batch x 1)

In [110]:
# make a matrix from train sequence (batch, args['maxlen'] - 1)
final_seq = torch.zeros((len(u),args['maxlen'] - 1), dtype=torch.int)
for ii,_u in enumerate(u):
    idx = min(args['maxlen'] - 1, len(train[_u]))
    final_seq[ii, -idx:] = torch.as_tensor(train[_u][-idx:])

# final seq (batch, args['maxlen'])
final_seq = torch.cat((final_seq, valid_seq), dim=1)

In [111]:
# making a test sequence with one element from test and the rest random
test_seq = torch.zeros((len(u), 101), dtype=torch.int)
# all elements that are in train plus padding zero
for ii, (_u, seq) in enumerate(zip(u, final_seq)):
    items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(seq.numpy()))) # random stuff not in seq
    test_seq[ii,0] = test[_u][0] # get true next element from test set
    test_seq[ii,1:] = torch.from_numpy(items_not_in_seq[np.random.randint(0, len(items_not_in_seq), 100)]) # fill the rest with random stuff

In [112]:
test_seq

tensor([[ 919, 2871, 2688, 3360,  480, 2031, 2828,  524,  598, 1611, 2785, 3042,
         1532, 1063, 1467, 2040, 2763, 2916, 1823,  668, 1124, 1976, 1673,  525,
         1518, 2615, 2776, 1437, 1079, 1978,  406, 2954,  897, 2514, 3390,  417,
         1519, 2713, 2332, 1517,  321, 2499, 1283, 2517, 2688,  185, 2233,  453,
         1756, 2679,  456,  669, 2025, 1473, 2226,  812, 2107, 2198,  643, 2980,
         2694,    8, 1778, 2027,  322, 1815, 1284, 2881, 1716,  243, 2384, 2340,
          617,  906, 2236, 2315, 1637, 3184, 2395,  610, 2118, 2334, 3355, 1912,
         1521,  604, 2003, 1173, 1519, 2207, 2051, 2369, 1029,  109, 2086,  228,
          106,  865,  520,  725, 2287],
        [ 672,  770, 3298,  668,   94, 3290,   20, 2532,  766, 3186, 1625, 2100,
         2251,  709,  391, 3336, 1080,  137, 2104, 1768,  476,  381, 1283, 1948,
         2409,  414, 2309, 1415, 1231, 2427, 1998, 2464,  899, 3323,  971,  134,
         1400, 2559,  527, 2520, 3211,  926,  428,  734, 3321, 1288, 

In [257]:
with torch.no_grad():
    log_feats = model.log2feats(final_seq) # shape (batch, seq_len, hidden_dim) = (1x200x50)

In [260]:
final_feat = log_feats[:, -1, :] # last hidden state/embedding
final_feat, final_feat.shape

(tensor([[ 0.6619,  0.6596, -0.1745, -0.2133, -0.7123,  0.4204,  0.4625, -0.6209,
          -0.4138,  0.0817, -0.1496,  0.6792,  0.5925,  0.7006, -0.6502, -0.5306,
           0.2231, -0.6313,  0.7336,  0.5250,  0.6301,  0.3757,  0.5017, -0.7404,
           0.4616,  0.5239, -0.6687, -0.6023, -0.6875, -0.5923,  0.6337, -0.1855,
           0.6248,  0.4989, -0.4843, -0.3989,  0.7503, -0.4861,  0.3672, -0.7682,
          -0.7489,  0.0185,  0.2974,  0.7162,  0.3279, -0.6246, -0.8353,  0.4458,
          -0.9165,  0.6383],
         [ 0.6621,  0.6496, -0.1759, -0.2299, -0.7259,  0.4308,  0.4675, -0.5906,
          -0.4444,  0.0792, -0.1448,  0.6820,  0.5929,  0.7078, -0.6550, -0.5333,
           0.2334, -0.6315,  0.7423,  0.5466,  0.6331,  0.3593,  0.5242, -0.7519,
           0.4644,  0.5121, -0.6704, -0.5962, -0.6760, -0.6022,  0.6266, -0.1775,
           0.6117,  0.4997, -0.4859, -0.4087,  0.7321, -0.4957,  0.3658, -0.7753,
          -0.7522,  0.0150,  0.3030,  0.7147,  0.3252, -0.6060, -0.83

In [261]:
with torch.no_grad():
    item_embs = model.item_emb(test_seq) # shape torch.Size([1, 101, 50]) 

In [262]:
item_embs.shape, final_feat.unsqueeze(-1).shape

(torch.Size([2, 101, 50]), torch.Size([2, 50, 1]))

In [264]:
logits = torch.bmm(item_embs, final_feat.unsqueeze(-1))
logits.shape

torch.Size([2, 101, 1])

In [266]:
predictions = -logits.squeeze()
predictions.shape

torch.Size([2, 101])

In [277]:
_, indices = torch.topk(predictions,15,dim=1, largest=False)
indices

tensor([[43, 98, 22, 40, 38, 92, 21, 75, 48, 68,  0,  8, 57, 89, 35],
        [30, 36, 78,  7, 51, 97, 31, 72, 64, 24,  0, 28, 99, 22, 86]])

In [278]:
_, indices = torch.where(indices == 0)

In [293]:
indices

tensor([10, 10])

In [294]:
hits = torch.as_tensor(indices < 11, dtype=torch.int)
hits

tensor([1, 1], dtype=torch.int32)

In [295]:
hits/torch.log2(indices+2)

tensor([0.2789, 0.2789])

## Assemble a model  

**From `Attention is all you need` paper:**  
We apply dropout to the **output of each sub-layer**, before it is added to the sub-layer input and normalized. 
In addition, we apply dropout to the **sums of the embeddings and the positional encodings** in both the encoder and decoder stacks. 
For the base model, we use a rate of P drop = 0.1.  

The encoder is composed of a stack of N = 6 identical layers. Each layer has two `sub-layers`.  
The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.   
We employ a residual connection around each of the two sub-layers, followed by layer normalization.  
That is, the output of each sub-layer is `LayerNorm(x + Sublayer(x))`, where Sublayer(x) is the function implemented by the sub-layer itself.  
To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512.
---

However, when the network goes deeper, several problems become exacerbated: 
1) the increased model capacity leads to overfitting; 
2) the training process becomes unstable (due to vanishing gradients etc.); and 
3) models with more parameters often require more training time.  

We perform the following operations to alleviate these problems: `g(x) = x + Dropout(g(LayerNorm(x)))`,  
where `g(x)` represents the self attention layer or the feed-forward network.  
That is to say, for layer `g` in each block, we apply **layer normalization**on the input `x` before feeding into `g`,   
apply **dropout** on `g’s` output, and add the input `x` to the final output.   
The **dropout** rate of turning off neurons is 0.2 for **MovieLens-1m** and 0.5 for the other three datasets due to their sparsity.  
We also apply a **dropout** layer on the embedding `E`.

The optimizer is the Adam optimizer, the `learning rate` is set to `0.001`, and the `batch size` is `128`.

**Shared Item Embedding**  
To reduce the model size and alleviate overfitting, we consider a single item embedding M:  
`ri,t=Ft Mi`. 

<img src="./sasrec-loss-function.png" alt="Loss function" style="width: 600px;"/>

Note that we ignore the terms where `ot = <pad>`.

### Ablation study
Remove PE (Positional Embedding): Without the positional embedding P, the attention weight on each item depends only on item embeddings. That is to say, the model makes recommendations based on users’ past actions, but their order doesn’t matter. This variant might be suitable for sparse datasets, where user sequences are typically short. This variant performs better then the default model on the sparsest dataset (Beauty), but worse on other denser datasets.

## Model assembly

In [28]:
class PolyWarmUpScheduler(optim.lr_scheduler._LRScheduler):
    """
    taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/schedulers.py
    """

    def __init__(self, optimizer, warmup, total_steps, degree=0.5, last_epoch=-1):
        self.warmup = warmup
        self.total_steps = total_steps
        self.degree = degree
        super(PolyWarmUpScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        progress = self.last_epoch / self.total_steps
        if progress < self.warmup:
            return [base_lr * progress / self.warmup for base_lr in self.base_lrs]
        else:
            return [base_lr * ((1.0 - progress) ** self.degree) for base_lr in self.base_lrs]

In [27]:
SASRec = reload(SASRec)
from torch.nn import MultiheadAttention, LayerNorm, Dropout, Conv1d, Embedding, BCEWithLogitsLoss
from SASRecModel import PointWiseFF, SASRecEncoderLayer, PositinalEncoder, SASRecEncoder

### Building blocks for Encoder:
- **encoder** layer 
- **positional encoder**
- **point-wise** feed-forward

### Unit-test building blocks

In [28]:
input_x = torch.randn(BATCH_SIZE, args['maxlen'], args['d_model'])

In [29]:
emb_test = Embedding(args['maxlen'], embedding_dim=args['hidden_units'])

In [30]:
pe_test = PositinalEncoder(args['maxlen'], args['d_model'])
pe_test(input_x).shape

torch.Size([128, 200, 50])

In [31]:
encoder_test = SASRecEncoderLayer(itemnum, **args)

In [23]:
encoder_test(input_x).shape

torch.Size([128, 200, 50])

### Unit-test final model

In [32]:
sas_rec_test = SASRecEncoder(itemnum, **args)

In [33]:
minib = next(iter(train_loader))

In [34]:
# need to comment out logging loss to run this cell
sas_rec_test.training_step(minib, 1)

AttributeError: 'NoneType' object has no attribute '_results'

In [36]:
u, seq, pos, neg = next(iter(train_loader))
item_emb = sas_rec_test.forward(seq) # get embeddings from transformer

In [37]:
sas_rec_test(seq).shape, seq.shape

(torch.Size([4, 200, 50]), torch.Size([4, 200]))

In [318]:
pos_scores = sas_rec_test.compute_relevance_scores(item_emb, pos) # scores for positive sequence
neg_scores = sas_rec_test.compute_relevance_scores(item_emb, neg) # scores for negative sequence

pos_labels = torch.ones(pos_scores.shape)
neg_labels = torch.zeros(neg_scores.shape)

indices = torch.where(pos!=0) # exclude padding from loss computation

In [319]:
loss = sas_rec_test.loss(pos_scores[indices], pos_labels[indices]) # loss for positive sequence
# loss += self.loss(neg_scores[indices], neg_labels[indices]) # loss for negative sequence

## Declare data loaders

In [29]:
print(f"\nBatch size is - {args['batch_size']}\n")


Batch size is - 128



In [38]:
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen']), 
                                         batch_size=args['batch_size'], shuffle=True, 
                                         drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [39]:
test_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen']), 
                                          batch_size=args['batch_size'], shuffle=True, 
                                          drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [41]:
train_loader = torch.utils.data.DataLoader(dataset=DH.SequenceData(user_train, usernum, itemnum), 
                                           batch_size=args['batch_size'],
                                           shuffle=True, 
                                           collate_fn=DH.tokenize_batch)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [147]:
# seq, val = next(iter(val_loader))
# seq[0], val[0]

(tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
         58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
         76, 77], dtype=torch.int32),
 tensor([  78,  282, 2760, 1584,  642, 3178, 1761, 2186,  593, 2509, 16

In [33]:
args

{'dataset': 'ml-1m',
 'train_dir': 'default',
 'batch_size': 128,
 'lr': 0.001,
 'maxlen': 200,
 'hidden_units': 50,
 'd_model': 50,
 'num_blocks': 2,
 'num_epochs': 201,
 'num_heads': 1,
 'dropout_rate': 0.2,
 'l2_emb': 0.0,
 'device': 'cuda',
 'inference_only': False,
 'state_dict_path': None}

In [42]:
model = SASRecEncoder(itemnum, opt='Adam', **args)
# model.load_state_dict(torch.load("bazman_sasrec.pt"))

In [43]:
# weight initialization
for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
        print(f"{name:<40} sucess")
    except:
        print(f"{name:<40} failure")

ie.weight                                sucess
pe.pe.weight                             sucess
enc_stack.0.norm_1.weight                failure
enc_stack.0.norm_1.bias                  failure
enc_stack.0.norm_2.weight                failure
enc_stack.0.norm_2.bias                  failure
enc_stack.0.attn.in_proj_weight          sucess
enc_stack.0.attn.in_proj_bias            failure
enc_stack.0.attn.out_proj.weight         sucess
enc_stack.0.attn.out_proj.bias           failure
enc_stack.0.ff.conv1.weight              sucess
enc_stack.0.ff.conv1.bias                failure
enc_stack.0.ff.conv2.weight              sucess
enc_stack.0.ff.conv2.bias                failure
enc_stack.1.norm_1.weight                failure
enc_stack.1.norm_1.bias                  failure
enc_stack.1.norm_2.weight                failure
enc_stack.1.norm_2.bias                  failure
enc_stack.1.attn.in_proj_weight          sucess
enc_stack.1.attn.in_proj_bias            failure
enc_stack.1.attn.out_proj.w

In [44]:
# save checkpoints
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor="HR@10/val")

In [46]:
# trainer = pl.Trainer(accelerator='dp', 
#                      gpus=-1, 
#                      max_epochs=MAX_EPOCHS,
#                      log_every_n_steps=1, 
#                      val_check_interval=0.2,
#                      num_sanity_val_steps=1, 
#                      callbacks=[checkpoint_callback], 
#                      accumulate_grad_batches=ACCUM_GRAD_BATCHES, 
#                      precision=16)

# AVAIL_GPUS = min(1, torch.cuda.device_count())
# torch.cuda.empty_cache()


# run tensorboard before the script launch
# tensorboard --logdir ./lightning_logs/ --host 0.0.0.0

# run the script with the command
# PL_TORCH_DISTRIBUTED_BACKEND=nccl python <name of your script>.py

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer = pl.Trainer(gpus=[0], 
                     auto_select_gpus=False, 
                     max_epochs=300,
                     reload_dataloaders_every_n_epochs=1,
                     val_check_interval=1.0,
                     callbacks=[checkpoint_callback],
                     log_every_n_steps= int(len(train_data)/args['batch_size']/3), # log 4 times per epoch
                     # limit_val_batches=0, How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch.
                     num_sanity_val_steps=10, 
                     precision=32)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [80]:
# torch.autograd.set_detect_anomaly(False)

In [47]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params
--------------------------------------------------
0 | ie          | Embedding         | 170 K 
1 | pe          | PositinalEncoder  | 10.0 K
2 | emb_dropout | Dropout           | 0     
3 | enc_stack   | Sequential        | 31.0 K
4 | final_norm  | LayerNorm         | 100   
5 | loss        | BCEWithLogitsLoss | 0     
--------------------------------------------------
211 K     Trainable params
0         Non-trainable params
211 K     Total params
0.848     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

  f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [203]:
torch.save(model.state_dict(), "bazman_sasrec_refactored.pt")

In [None]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'HR@10/test': 0.7716090679168701, 'NDCG@10/test': 0.5173300504684448}
--------------------------------------------------------------------------------



[{'NDCG@10/test': 0.5173300504684448, 'HR@10/test': 0.7716090679168701}]

In [201]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'HR@10/test': 0.7928807735443115, 'NDCG@10/test': 0.5423007011413574}
--------------------------------------------------------------------------------



[{'NDCG@10/test': 0.5423007011413574, 'HR@10/test': 0.7928807735443115}]