# Refactor SASRec code for a cleaner pytorch
Author: bazman  
Date: 2021-DEC

In [1]:
%config Completer.use_jedi = False

In [212]:
import os
import numpy as np
import torch
import pytorch_lightning as pl
import argparse
from importlib import reload
from utils import data_partition
import torch.optim as optim
import torch.nn.functional as F

In [11]:
# setup command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--train_dir', required=True)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--hidden_units', default=50, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--inference_only', default=False, type=bool)
parser.add_argument('--state_dict_path', default=None, type=str)
args = parser.parse_args( ['--dataset=ml-1m', '--train_dir=default', '--maxlen=200', '--dropout_rate=0.2', '--device=cuda'])
args = vars(args)
print(*[(k,v) for (k,v) in args.items()], sep="\n")

('dataset', 'ml-1m')
('train_dir', 'default')
('batch_size', 128)
('lr', 0.001)
('maxlen', 200)
('hidden_units', 50)
('num_blocks', 2)
('num_epochs', 201)
('num_heads', 1)
('dropout_rate', 0.2)
('l2_emb', 0.0)
('device', 'cuda')
('inference_only', False)
('state_dict_path', None)


## Perepare the data  
We have 3 datasets:  
 - for training
 - for validation
 - for testing  
 They all contain all users and the last two items in sequence are split between validation (penultimate item) and test (last item)  
 Training has all user items but without last two that falls into validation and test

**user_train** - dict with key = *userid* and value = list of all items selected in respected time order  
**user_valid** - dict with the same structure as above but with penulitimate item (just one item)  
**user_test** - same as above but with ultimate item selected  
i.e. you have user 5 with items 1, 29, 34, 15, 8 in his sequence of items there will be the below data in vars:  
```
user_train[5] = [1,29,34]  
user_valid = [15]  
user_test=[8]
```

In [72]:
# read dataset
dataset = data_partition('ml-1m')

In [14]:
[user_train, user_valid, user_test, usernum, itemnum] = dataset

In [15]:
# batches got sliced by users, i.e. batch accumulate BATCH_SIZE user sequences of items selected/bought
BATCH_SIZE = 128
num_batch = len(user_train) // BATCH_SIZE  # number of batches

user_train_lens = list(map(len,[v for k,v in user_train.items()]))
print(f'average sequence length: {sum(user_train_lens)/len(user_train):.1f}')

average sequence length: 163.5


### Dataset for validation

In [19]:
class SequenceDataValidation(torch.utils.data.Dataset):
    '''
    train -> **valid** -> test
    dataset to produce validation data
    Input:
    - user_train : known sequence of items for the user (train data)
    - user_valid : one item that makes up a next selection after user_train sequence
    Returns:
    - user_train is the same
    - user_valid is appended with 100 random items that are not in user_trian after that 101 items are scored with model and logit 
        for the 0-th element(user_valid) should be somewhere in top 10 scores
    '''
    def __init__(self, user_train, user_valid, usernum, itemnum, maxlen):
        '''
        Input:
        - user_train: dict of user training sequence
        - user_valid: dict with one item for validation sequence
        - usernum - number of users in dataset
        - itemnum - number of items in dataset
        - maxlen - max len of sequence for truncation
        Output:
        self.seq - maxlen sequnce for train
        self.valid - 101 len for validation
        '''
        from tqdm.notebook import tqdm
        super(SequenceDataValidation, self).__init__()
        
        # make a list of users to validate on
        # limit users max to 10000 or to whatever we have in case less than 10000
        if usernum > 10_000:
            users = random.sample(range(1, usernum + 1), 10_000)
        else:
            users = range(1, usernum + 1)
            
        # making a validation sequence with one element from valid and the rest random
        # all elements that are in train plus padding zero
        valid_seq = torch.zeros((len(users), 101), dtype=torch.int)
        
        # make a matrix from train sequence (batch, maxlen)
        final_seq = torch.zeros((len(users), maxlen), dtype=torch.int)
        
        with tqdm(total=len(users)) as pbar:
            for ii,_u in enumerate(users):
                # truncate seq  to maxlen
                idx = min(maxlen, len(user_train[_u]))
                final_seq[ii, -idx:] = torch.as_tensor(user_train[_u][-idx:])

                items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(final_seq[ii].numpy().flatten()))) # random stuff not in final_seq
                valid_seq[ii,0] = user_valid[_u][0] # get true next element from validation set
                valid_seq[ii,1:] = torch.from_numpy(items_not_in_seq[np.random.randint(0, len(items_not_in_seq), 100)]) # fill the rest with random stuff
                pbar.update(1)
        
        self.seq = final_seq # store training seq
        self.valid = valid_seq # store validation seq
        self.users = users # store validation users
            
    def __getitem__(self, index):
        return self.seq[index], self.valid[index]

    def __len__(self):
        return len(self.seq)

### Dataset for test 

In [20]:
class SequenceDataTest(SequenceDataValidation):
    '''
    train -> valid -> **test**
    dataset to produce test data set
    same as SequenceDataValidation class but uses one element from test_seq to make a test_seq
    alse adds up validation item to train sequence
    '''
    def __init__(self, user_train, user_valid, user_test, usernum, itemnum, maxlen):
        super().__init__(user_train, user_test, usernum, itemnum, maxlen)
        # now we need to shift self.seq one item back
        self.seq[:,:-1] = self.seq[:,1:]
        # this is an extra item that will be the last in training seq
        extra_valid_item = torch.as_tensor([user_valid[_u][0] for _u in self.users])
        self.seq[:,-1] = extra_valid_item

In [21]:
valid_data = SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [22]:
test_data = SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




### Dataset for training 

In [77]:
class SequenceData(torch.utils.data.Dataset):
    '''
    dataset for training the network
    '''
    def __init__(self, user_seq, usernum, itemnum):
        '''
        user_seq is a dict with keys = userid - sequential from 1 to number of users(usernum)
        itemnum - number of items in vocabulary of selected movies
        Sets up the following props in the object:
        seq - all elements of user seq without last element
        pos - all elements without first element (shift one time item ahead)
        neg - the same length but with all different elements
        Resulting data looks like this:
        seq = [250,  13, 251,  70, 252,  81, 237, 150, 253,  27, 143, 254, 236,
        196, 229, 255, 256, 179, 167, 172, 157, 257,  39, 199, 258]
        
        pos = [ 13, 251,  70, 252,  81, 237, 150, 253,  27, 143, 254, 236, 196,
        229, 255, 256, 179, 167, 172, 157, 257,  39, 199, 258,  29]
        
        neg =  [928, 3404,  821, 2505, 1931, 2588, 1365,  527, 3140, 1615, 1649,
        1981,  450, 1175, 1576, 1787, 1425, 2698, 1916,  729, 3390, 2503,
        2751, 1481, 2422]
        '''
        from tqdm.notebook import tqdm
        super(SequenceData, self).__init__()
        self.usernum = usernum
        self.userids = np.array(list(user_seq.keys())) # store userids in a property
        self.seq, self.pos, self.neg = dict(), dict(), dict()
        with tqdm(total=len(user_seq)) as pbar:
            for userid, _user_seq in user_seq.items():
                self.seq[userid] = np.array(_user_seq[:-1]) # all but last element
                self.pos[userid] = np.array(_user_seq[1:]) # shifted one time slot ahead
                # negative sequence
                items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(_user_seq))) # all items from vocab that are out of user_seq
                self.neg[userid] = items_not_in_seq[np.random.randint(0, len(items_not_in_seq), len(self.seq[userid]))] # select random items from above array
                pbar.update(1)

    def __getitem__(self, index):
        userid = self.userids[index]
        return userid, self.seq[userid], self.pos[userid], self.neg[userid]

    def __len__(self):
        return len(self.seq)

In [78]:
train_data = SequenceData(user_train, usernum, itemnum)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [81]:
def tokenize_batch(batch, max_len=args['maxlen']):
    '''
    use tokenizer to cast dict type to tensors and shrink the data to maxlen - nothing else
    could have made it in dataset directly but anyway 
    '''
    u        = []
    seq_list = []
    pos_list = []
    neg_list = []
    
    # torch.zeros(max_len, dtype=torch.int)
    # torch.zeros_like(seq_batch)
    # torch.zeros_like(seq_batch)
        
    for _u, seq, pos, neg in batch:
        # fixed size tensor of max_len
        seq_holder = torch.zeros(max_len, dtype=torch.int)
        pos_holder  = torch.zeros_like(seq_holder)
        neg_holder = torch.zeros_like(seq_holder)
        
        idx = min(max_len, len(seq))
        seq_holder[-idx:] = torch.from_numpy(seq[-idx:])
        pos_holder[-idx:] = torch.from_numpy(pos[-idx:])
        neg_holder[-idx:] = torch.from_numpy(neg[-idx:])
        
        seq_list.append(seq_holder.unsqueeze(dim=0))
        pos_list.append(pos_holder.unsqueeze(dim=0))
        neg_list.append(neg_holder.unsqueeze(dim=0))
        u.append(_u)
    return u, torch.cat(seq_list, dim=0), torch.cat(pos_list, dim=0), torch.cat(neg_list, dim=0)

In [85]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4,
                          shuffle=False, collate_fn=tokenize_batch)

### Unit-test training loader 

In [86]:
u, seq, pos, neg = next(iter(train_loader))

In [100]:
assert len(seq[0]) == args['maxlen']

In [87]:
u

[1, 2, 3, 4]

In [101]:
# random user from batch
_u = np.random.randint(1,len(u)+1)
_u

3

In [102]:
# train sequnce
print(seq[_u].numpy()[-10:])

[255 256 179 167 172 157 257  39 199 258]


In [103]:
# train shifted one item ahead
print(pos[_u].numpy()[-10:])

[256 179 167 172 157 257  39 199 258  29]


In [104]:
# negative sequnce
print(neg[_u].numpy()[-10:])

[ 533  591 3324 1193 2504  847 2562  803 1933 2728]


### Unit-test validation and test data 

In [65]:
ii = np.random.randint(1,usernum+1)

In [66]:
print("{0:}{1:>40}".format("\n","Validation data \n"))
print("{0:<30}".format("Main sequence "),":",*valid_data.seq[ii,-10:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *valid_data.valid[ii,:10].numpy())


                       Validation data 

Main sequence                  : 2529 645 812 592 963 1035 1038 837 816 1044
Validation sequene             : 1047 3069 1576 2161 2186 1694 580 3013 176 1737


In [69]:
print("{0:}{1:>40}".format("\n","Test data \n"))
print("{0:<30}".format("Main sequence "),":",*test_data.seq[ii, -11:].numpy()) 
print("{0:<30}".format("Validation sequene "),":", *test_data.valid[ii,:10].numpy())


                             Test data 

Main sequence                  : 2529 645 812 592 963 1035 1038 837 816 1044 1047
Validation sequene             : 3018 143 3145 3066 206 1866 1125 3003 2327 2175


In [70]:
print(*user_train[ii+1][-10:], *user_valid[ii+1], *user_test[ii+1])

2529 645 812 592 963 1035 1038 837 816 1044 1047 3018


In [113]:
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=4, shuffle=False)

### Unit-test metrics calculation

In [107]:
from copy import deepcopy
[train, valid, test, usernum, itemnum] = deepcopy(dataset)

NDCG = 0.0
HT = 0.0

In [108]:
# list of users in batch
u = [122, 144]

# get validation items
valid_seq = torch.as_tensor([valid[_u] for _u in u], dtype=torch.int) # (batch x 1)

In [110]:
# make a matrix from train sequence (batch, args['maxlen'] - 1)
final_seq = torch.zeros((len(u),args['maxlen'] - 1), dtype=torch.int)
for ii,_u in enumerate(u):
    idx = min(args['maxlen'] - 1, len(train[_u]))
    final_seq[ii, -idx:] = torch.as_tensor(train[_u][-idx:])

# final seq (batch, args['maxlen'])
final_seq = torch.cat((final_seq, valid_seq), dim=1)

In [111]:
# making a test sequence with one element from test and the rest random
test_seq = torch.zeros((len(u), 101), dtype=torch.int)
# all elements that are in train plus padding zero
for ii, (_u, seq) in enumerate(zip(u, final_seq)):
    items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(seq.numpy()))) # random stuff not in seq
    test_seq[ii,0] = test[_u][0] # get true next element from test set
    test_seq[ii,1:] = torch.from_numpy(items_not_in_seq[np.random.randint(0, len(items_not_in_seq), 100)]) # fill the rest with random stuff

In [112]:
test_seq

tensor([[ 919, 2871, 2688, 3360,  480, 2031, 2828,  524,  598, 1611, 2785, 3042,
         1532, 1063, 1467, 2040, 2763, 2916, 1823,  668, 1124, 1976, 1673,  525,
         1518, 2615, 2776, 1437, 1079, 1978,  406, 2954,  897, 2514, 3390,  417,
         1519, 2713, 2332, 1517,  321, 2499, 1283, 2517, 2688,  185, 2233,  453,
         1756, 2679,  456,  669, 2025, 1473, 2226,  812, 2107, 2198,  643, 2980,
         2694,    8, 1778, 2027,  322, 1815, 1284, 2881, 1716,  243, 2384, 2340,
          617,  906, 2236, 2315, 1637, 3184, 2395,  610, 2118, 2334, 3355, 1912,
         1521,  604, 2003, 1173, 1519, 2207, 2051, 2369, 1029,  109, 2086,  228,
          106,  865,  520,  725, 2287],
        [ 672,  770, 3298,  668,   94, 3290,   20, 2532,  766, 3186, 1625, 2100,
         2251,  709,  391, 3336, 1080,  137, 2104, 1768,  476,  381, 1283, 1948,
         2409,  414, 2309, 1415, 1231, 2427, 1998, 2464,  899, 3323,  971,  134,
         1400, 2559,  527, 2520, 3211,  926,  428,  734, 3321, 1288, 

In [257]:
with torch.no_grad():
    log_feats = model.log2feats(final_seq) # shape (batch, seq_len, hidden_dim) = (1x200x50)

In [260]:
final_feat = log_feats[:, -1, :] # last hidden state/embedding
final_feat, final_feat.shape

(tensor([[ 0.6619,  0.6596, -0.1745, -0.2133, -0.7123,  0.4204,  0.4625, -0.6209,
          -0.4138,  0.0817, -0.1496,  0.6792,  0.5925,  0.7006, -0.6502, -0.5306,
           0.2231, -0.6313,  0.7336,  0.5250,  0.6301,  0.3757,  0.5017, -0.7404,
           0.4616,  0.5239, -0.6687, -0.6023, -0.6875, -0.5923,  0.6337, -0.1855,
           0.6248,  0.4989, -0.4843, -0.3989,  0.7503, -0.4861,  0.3672, -0.7682,
          -0.7489,  0.0185,  0.2974,  0.7162,  0.3279, -0.6246, -0.8353,  0.4458,
          -0.9165,  0.6383],
         [ 0.6621,  0.6496, -0.1759, -0.2299, -0.7259,  0.4308,  0.4675, -0.5906,
          -0.4444,  0.0792, -0.1448,  0.6820,  0.5929,  0.7078, -0.6550, -0.5333,
           0.2334, -0.6315,  0.7423,  0.5466,  0.6331,  0.3593,  0.5242, -0.7519,
           0.4644,  0.5121, -0.6704, -0.5962, -0.6760, -0.6022,  0.6266, -0.1775,
           0.6117,  0.4997, -0.4859, -0.4087,  0.7321, -0.4957,  0.3658, -0.7753,
          -0.7522,  0.0150,  0.3030,  0.7147,  0.3252, -0.6060, -0.83

In [261]:
with torch.no_grad():
    item_embs = model.item_emb(test_seq) # shape torch.Size([1, 101, 50]) 

In [262]:
item_embs.shape, final_feat.unsqueeze(-1).shape

(torch.Size([2, 101, 50]), torch.Size([2, 50, 1]))

In [264]:
logits = torch.bmm(item_embs, final_feat.unsqueeze(-1))
logits.shape

torch.Size([2, 101, 1])

In [266]:
predictions = -logits.squeeze()
predictions.shape

torch.Size([2, 101])

In [277]:
_, indices = torch.topk(predictions,15,dim=1, largest=False)
indices

tensor([[43, 98, 22, 40, 38, 92, 21, 75, 48, 68,  0,  8, 57, 89, 35],
        [30, 36, 78,  7, 51, 97, 31, 72, 64, 24,  0, 28, 99, 22, 86]])

In [278]:
_, indices = torch.where(indices == 0)

In [293]:
indices

tensor([10, 10])

In [294]:
hits = torch.as_tensor(indices < 11, dtype=torch.int)
hits

tensor([1, 1], dtype=torch.int32)

In [295]:
hits/torch.log2(indices+2)

tensor([0.2789, 0.2789])

## Assemble a model  

In [122]:
class PolyWarmUpScheduler(optim.lr_scheduler._LRScheduler):
    """
    taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/schedulers.py
    """

    def __init__(self, optimizer, warmup, total_steps, degree=0.5, last_epoch=-1):
        self.warmup = warmup
        self.total_steps = total_steps
        self.degree = degree
        super(PolyWarmUpScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        progress = self.last_epoch / self.total_steps
        if progress < self.warmup:
            return [base_lr * progress / self.warmup for base_lr in self.base_lrs]
        else:
            return [base_lr * ((1.0 - progress) ** self.degree) for base_lr in self.base_lrs]

In [123]:
class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs

In [187]:
class SASrecPL(pl.LightningModule):
    """
    https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """
    def __init__(self, item_num, warmup_proportion=0.2, max_iters=10000, opt='AdamW', **kwargs):
        '''
        counstructor
        '''
        super().__init__()
        self.save_hyperparameters("warmup_proportion", 
                                  "max_iters", 
                                  "opt", 
                                  "hidden_units", 
                                  "num_blocks", 
                                  "num_heads", 
                                  "dropout_rate", 
                                  "l2_emb", 
                                  "lr", 
                                  "maxlen", 
                                  "item_num")

        # TODO: loss += args.l2_emb for regularizing embedding vectors during training
        # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch
        self.item_emb = torch.nn.Embedding(self.hparams.item_num+1, self.hparams.hidden_units, padding_idx=0)
        self.pos_emb = torch.nn.Embedding(self.hparams.maxlen, self.hparams.hidden_units) # TO IMPROVE
        self.emb_dropout = torch.nn.Dropout(p=self.hparams.dropout_rate)

        self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention
        self.attention_layers = torch.nn.ModuleList()
        self.forward_layernorms = torch.nn.ModuleList()
        self.forward_layers = torch.nn.ModuleList()

        self.last_layernorm = torch.nn.LayerNorm(self.hparams.hidden_units, eps=1e-8)

        for _ in range(self.hparams.num_blocks):
            new_attn_layernorm = torch.nn.LayerNorm(self.hparams.hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer =  torch.nn.MultiheadAttention(self.hparams.hidden_units,
                                                            self.hparams.num_heads,
                                                            self.hparams.dropout_rate)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = torch.nn.LayerNorm(self.hparams.hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(self.hparams.hidden_units, self.hparams.dropout_rate)
            self.forward_layers.append(new_fwd_layer)

            # self.pos_sigmoid = torch.nn.Sigmoid()
            # self.neg_sigmoid = torch.nn.Sigmoid()
        
    
        self.loss = torch.nn.BCEWithLogitsLoss()

    def log2feats(self, log_seqs):
        seqs = self.item_emb(log_seqs)
        seqs *= self.item_emb.embedding_dim ** 0.5
        positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1])

        seqs += self.pos_emb(torch.from_numpy(positions).to(self.device))
        seqs = self.emb_dropout(seqs)

        timeline_mask = (log_seqs == 0)
        seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim

        tl = seqs.shape[1] # time dim len for enforce causality
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool)).to(self.device)

        for i in range(len(self.attention_layers)):
            seqs = torch.transpose(seqs, 0, 1)
            Q = self.attention_layernorms[i](seqs)
            mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 
                                            attn_mask=attention_mask)
                                            # key_padding_mask=timeline_mask
                                            # need_weights=False) this arg do not work?
            seqs = Q + mha_outputs
            seqs = torch.transpose(seqs, 0, 1)

            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *=  ~timeline_mask.unsqueeze(-1)

        log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)

        return log_feats
        
    def forward(self, log_seqs, pos_seqs, neg_seqs):   
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        pos_embs = self.item_emb(pos_seqs)
        neg_embs = self.item_emb(neg_seqs)

        pos_logits = (log_feats * pos_embs).sum(dim=-1)
        neg_logits = (log_feats * neg_embs).sum(dim=-1)

        # pos_pred = self.pos_sigmoid(pos_logits)
        # neg_pred = self.neg_sigmoid(neg_logits)

        return pos_logits, neg_logits # pos_pred, neg_pred
    
    def predict_step(self, log_seqs, items2score): # for inference
        '''
        method to score a new data
        Input:
        log_seqs - sequence of items, we recommend one new item that follows up this log_seqs
        items2score - we recommend from those items
        Returns:
        logits for each element from items2score
        '''
        with torch.no_grad:
            
            log_feats = self.log2feats(log_seqs) # shape (batch, seq_len, hidden_dim) = (batchx200x50)

            final_feat = log_feats[:, -1, :] # only use last embedding (batch, 50)

            item_embs = self.item_emb(items2score) # shape (batch, seq_len, hidden_dim) = torch.Size([batch, 101, 50]) 

            logits = torch.bmm(item_embs, final_feat.unsqueeze(-1)).squeeze() # (batch, 101, 50)x(batch, 50,1)=(batch, 101, 1).squeese() = (batch, 101)

            # preds = self.pos_sigmoid(logits) # rank same item list for different users

        return logits # preds # (batch, len(tems2score))
        
        
    def configure_optimizers(self):
        param_optimizer = list(self.named_parameters())
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
             'weight_decay': 0.0}]

        if self.hparams.opt == 'AdamW':
            optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.lr, betas=(0.9, 0.98))
        elif self.hparams.opt == 'Adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.98))

        # self.lr_scheduler = PolyWarmUpScheduler(optimizer,
        #                                         warmup=self.hparams.warmup_proportion,
        #                                         total_steps=self.hparams.max_iters)
        return optimizer
    
    # def optimizer_step(self, *args, **kwargs):
    #     super().optimizer_step(*args, **kwargs)
    #     # self.lr_scheduler.step()  # Step per iteration
    
    def training_step(self, batch, batch_idx):
        """
        return a loss given a batch
        """
        u, seq, pos, neg = batch
        pos_logits, neg_logits = self.forward(seq, pos, neg)
        pos_labels, neg_labels = torch.ones(pos_logits.shape, device=self.device), torch.zeros(neg_logits.shape, device=self.device)
        indices = torch.where(pos!=0)

        loss = self.loss(pos_logits[indices], pos_labels[indices])
        loss += self.loss(neg_logits[indices], neg_labels[indices])
        
        for param in self.item_emb.parameters(): loss += self.hparams.l2_emb * torch.norm(param) # not working

        self.log('loss', loss.item(), prog_bar=True, logger=True)
        # self.log('lr', self.lr_scheduler.get_last_lr()[0],  prog_bar=True, logger=True)

        return {'loss': loss}
    
    def _shared_val_step(self, batch, batch_idx):
        # batch = 2, model dim = 50 as a sample data for dimentions
        
        final_seq, test_seq = batch
        
        with torch.no_grad():
            log_feats = model.log2feats(final_seq) # shape (batch, seq_len, hidden_dim) = (1x200x50)
            final_feat = log_feats[:, -1, :] # last hidden state/embedding, [2, 50]
            # calculate embeddings from test_sequence
            item_embs = model.item_emb(test_seq) # shape [1, 101, 50])

            # item_embs.shape, final_feat.unsqueeze(-1).shape -> [2, 101, 50], [2, 50, 1]

            # get dot product of last hidden state with all embeddings
            logits = torch.bmm(item_embs, final_feat.unsqueeze(-1)) # [2, 101, 1]

            predictions = -logits.squeeze() # [2, 101]
            # in element with index 0 we have a logit for the ground truth item
            GROUND_TRUTH_IDX = 0

            TOP_N = 10 # number of items that we look for a proper recommendation in
            _, indices = torch.topk(predictions,TOP_N,dim=1, largest=False)
            _, rank = torch.where(indices == GROUND_TRUTH_IDX) # now we have ranks of ground truth elements
            HITS = torch.as_tensor(rank <= TOP_N , dtype=torch.int) # 0 for miss and 1 for hit
            NDCG = HITS/torch.log2(rank+2)
        return HITS.sum().item()/len(final_seq), NDCG.sum().item()/len(final_seq)
        
    def validation_step(self, batch, batch_idx):
        """
        calculate Hit Rate and NDCG on validation dataset
        """
        hits, ndcg = self._shared_val_step(batch, batch_idx)
        self.log('NDCG@10/val', ndcg, prog_bar=True, logger=True)
        self.log('HR@10/val', hits, prog_bar=True, logger=True)
    
    def test_step(self, batch, batch_idx):
        """
        calculate Hit Rate and NDCG on test dataset
        """
        hits, ndcg = self._shared_val_step(batch, batch_idx)
        self.log('NDCG@10/test', ndcg, prog_bar=True, logger=True)
        self.log('HR@10/test', hits, prog_bar=True, logger=True)
        


In [188]:
model = SASrecPL(itemnum, warmup_proportion=0.2, max_iters=10000, opt='Adam', **args)

In [189]:
# model.load_state_dict(torch.load("bazman_sasrec.pt"))

<All keys matched successfully>

In [180]:
#torch.save(model.state_dict(), "bazman_sasrec.pt")

In [126]:
u, seq, pos, neg = next(iter(train_loader))

In [127]:
model.forward(seq, pos, neg)

(tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0

## Declare data loaders

In [152]:
print(f"\nBatch size is - {args['batch_size']}\n")


Batch size is - 128



In [206]:
val_loader = torch.utils.data.DataLoader(dataset=SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen']), 
                                         batch_size=args['batch_size'], shuffle=True, 
                                         drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [207]:
test_loader = torch.utils.data.DataLoader(dataset=SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen']), 
                                          batch_size=args['batch_size'], shuffle=True, 
                                          drop_last=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [156]:
train_loader = torch.utils.data.DataLoader(dataset=SequenceData(user_train, usernum, itemnum), 
                                           batch_size=args['batch_size'],
                                           shuffle=True, 
                                           collate_fn=tokenize_batch)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6040.0), HTML(value='')))




In [147]:
# seq, val = next(iter(val_loader))
# seq[0], val[0]

(tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
         58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
         76, 77], dtype=torch.int32),
 tensor([  78,  282, 2760, 1584,  642, 3178, 1761, 2186,  593, 2509, 16

In [209]:
# trainer = pl.Trainer(accelerator='dp', 
#                      gpus=-1, 
#                      max_epochs=MAX_EPOCHS,
#                      log_every_n_steps=1, 
#                      val_check_interval=0.2,
#                      num_sanity_val_steps=1, 
#                      callbacks=[checkpoint_callback], 
#                      accumulate_grad_batches=ACCUM_GRAD_BATCHES, 
#                      precision=16)

# AVAIL_GPUS = min(1, torch.cuda.device_count())
# torch.cuda.empty_cache()

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer = pl.Trainer(gpus=[0], 
                     auto_select_gpus=False, 
                     max_epochs=300,
                     reload_dataloaders_every_n_epochs=1,
                     val_check_interval=1.0,
                     log_every_n_steps= int(len(train_data)/args['batch_size']/3), # log 4 times per epoch
                     # limit_val_batches=0, How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch.
                     num_sanity_val_steps=10, 
                     precision=16)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [210]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name                 | Type              | Params
-----------------------------------------------------------
0 | item_emb             | Embedding         | 170 K 
1 | pos_emb              | Embedding         | 10.0 K
2 | emb_dropout          | Dropout           | 0     
3 | attention_layernorms | ModuleList        | 200   
4 | attention_layers     | ModuleList        | 20.4 K
5 | forward_layernorms   | ModuleList        | 200   
6 | forward_layers       | ModuleList        | 10.2 K
7 | last_layernorm       | LayerNorm         | 100   
8 | loss                 | BCEWithLogitsLoss | 0     
-----------------------------------------------------------
211 K     Trainable params
0         Non-trainable params
211 K     Total params
0.848     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [203]:
torch.save(model.state_dict(), "bazman_sasrec.pt")

In [211]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'HR@10/test': 0.7819148898124695, 'NDCG@10/test': 0.5337005257606506}
--------------------------------------------------------------------------------



[{'NDCG@10/test': 0.5337005257606506, 'HR@10/test': 0.7819148898124695}]

In [201]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'HR@10/test': 0.7928807735443115, 'NDCG@10/test': 0.5423007011413574}
--------------------------------------------------------------------------------



[{'NDCG@10/test': 0.5423007011413574, 'HR@10/test': 0.7928807735443115}]