In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import argparse
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import LayerNorm, Dropout, Conv1d, Embedding, BCEWithLogitsLoss

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging

import MultimodalSR_Model as SASRec
from MultimodalSR_Model import PointWiseFF, SASRecEncoderLayer, PositinalEncoder, SASRecEncoder
import MultimodalSR_DataHelper as DH

2022-06-28 16:17:17.435401: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-06-28 16:17:17.435429: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [4]:
# setup command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='Luxury_Beauty', 
                    required=True, 
                    help="dataset to use : Beauty, ml-1m(default), Steam or Video")

parser.add_argument('--maxlen', default=50, type=int, 
                    help="truncate input sequence to last maxlen items, default 50")
parser.add_argument('--hidden_units', default=50, type=int, help="synonym for d_model") # synonym for d_model
parser.add_argument('--d_model', default=50, type=int, 
                    help="Transformer internal dimention") # same as hidden_units   
parser.add_argument('--num_blocks', default=2, type=int, help="Number of blocks in Transformer")
parser.add_argument('--num_heads', default=1, type=int, help="Number of heads in self-attention")
parser.add_argument('--dropout_rate', default=0.5, type=float, help="Dropout rate for Transformer")
parser.add_argument('--l2_pe_reg', default=0.1, type=float, help="Regularization for positional embedding")

parser.add_argument('--ndcg_samples', default=100, type=int, 
                    help="How many random items to pick up in hit-rate and ndcg calculation, default 100")
parser.add_argument('--top_k', default=10, type=int, 
                    help="How many items with high scores to pick for hit-rate and ndcg calculation, default 10")
parser.add_argument('--opt', default='Adam', type=str, help="Oplimizer to use: Adam(default), AdmaW, FusedAdam(requires apex library)")
parser.add_argument('--lr', default=0.001, type=float, 
                    help="learning rate, default 0.001")
parser.add_argument('--weight_decay', default=0.001, type=float, help="Weight decay for AdmaW")
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--warmup_proportion', default=0.2, type=float, help="Fraction of total optimization steps to increase learning rate from zero to max value")
# for different optimizers - regular Adam uses num_epochs and LAMB uses max_iters
parser.add_argument('--max_iters', default=10000, type=int, help="Optimization budget in update iterations")
parser.add_argument('--num_epochs', default=201, type=int, help="Number of epochs to train")
# swa parameters
parser.add_argument('--use_swa', default=False, type=bool, help="Use Stochastic Weights Ageraging algorythm")
parser.add_argument('--swa_epoch_start', default=0.8, type=float, help="Start SWA after that part of total epochs")
parser.add_argument('--swa_annealing_epochs', default=10, type=int, help="Number of epochs in the annealing phase of SWA")

# xavier init
parser.add_argument('--xavier_init', default=True, type=bool, help="Use xavier normal to init the model")

parser.add_argument('--inference_only', default=False, type=bool)
parser.add_argument('--checkpoint_path', default=None, type=str, help="Path to lightning checkpoint file")
parser.add_argument('--strategy', default='ddp_spawn', type=str, help="Lightning parallel training strategy dp, ddp, ddp_spawn(default), ddp2, etc ")
parser.add_argument('--precision', default=16, type=int, help="Lightning precision for model data during trining 16(default) or 32")
parser.add_argument('--accelerator', default="auto", type=str, help="Lightning accelerator auto(defaut), cpu, gpu, tpu")
parser.add_argument('--devices', default="auto", type=str, 
                    help="Lightning devices to use - see https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#devices")

args = parser.parse_args( ['--dataset=Sports_and_Outdoors', '--maxlen=50', '--dropout_rate=0.2'])
args = vars(args)

In [5]:
# add fusion
args['fusion'] = 'gate'

In [6]:
# read dataset
dataset = DH.data_partition(args['dataset'])
[user_train, user_valid, user_test, usernum, itemnum] = dataset

In [7]:
usernum, itemnum

(153940, 55697)

In [8]:
model = SASRecEncoder(itemnum, args['fusion'], **args)

In [9]:
BATCH_SIZE = args['batch_size']
num_batch = len(user_train) // BATCH_SIZE  # number of batches

user_train_lens = list(map(len,[v for k,v in user_train.items()]))
print(f'average sequence length: {sum(user_train_lens)/len(user_train):.1f}')

average sequence length: 5.7


In [10]:
#dataset for training
train_data = DH.SequenceData(user_train, usernum, itemnum, 'pre_image_Sports_and_Outdoors', 'pre_description_Sports_and_Outdoors')

100%|██████████| 153940/153940 [17:26<00:00, 147.10it/s]


In [11]:
# dataset for validation
valid_data = DH.SequenceDataValidation(user_train, user_valid, usernum, itemnum, args['maxlen'], 
                                      'pre_image_Sports_and_Outdoors', 'pre_description_Sports_and_Outdoors')

100%|██████████| 10000/10000 [01:20<00:00, 123.85it/s]


In [13]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=args['batch_size'],
                                           shuffle=True,
                                           collate_fn=DH.tokenize_batch)

In [14]:
val_loader = torch.utils.data.DataLoader(dataset=valid_data, 
                                         batch_size=args['batch_size'], shuffle=True, 
                                         drop_last=True)

In [16]:
# save checkpoints
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor="hr_val", mode='max')

In [17]:
trainer = pl.Trainer(gpus=[0], 
                     auto_select_gpus=False, 
                     max_epochs=300,
                     reload_dataloaders_every_n_epochs=1,
                     val_check_interval=1.0,
                     callbacks=[checkpoint_callback],
                     log_every_n_steps= int(len(train_data)/args['batch_size']/3), # log 4 times per epoch
                     num_sanity_val_steps=10, 
                     precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [18]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | ie           | Embedding          | 2.8 M 
1 | pe           | PositinalEncoder   | 2.5 K 
2 | emb_dropout  | Dropout            | 0     
3 | enc_stack    | SASRecEncoderLayer | 15.5 K
4 | final_norm   | LayerNorm          | 100   
5 | loss         | BCEWithLogitsLoss  | 0     
6 | image_layer  | Linear             | 204 K 
7 | text_layer   | Linear             | 38.5 K
8 | fusion_layer | VanillaAttention   | 2.6 K 
----------------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
6.098     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

  f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [19]:
torch.save(model.state_dict(), f"weight/Sports_and_Outdoors_{trainer.logger.version}.pt")

In [20]:
print('로거 버전 :', trainer.logger.version)

로거 버전 : 108


In [21]:
test_data = DH.SequenceDataTest(user_train, user_valid, user_test, usernum, itemnum, args['maxlen'], 
                    'pre_image_Sports_and_Outdoors', 'pre_description_Sports_and_Outdoors',args['ndcg_samples']
                   )

100%|██████████| 10000/10000 [00:50<00:00, 198.94it/s]


In [22]:
test_loader = torch.utils.data.DataLoader(dataset=test_data, 
                                          batch_size=args['batch_size'], shuffle=True, 
                                          drop_last=True)

In [24]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         hr_test            0.2585136294364929
        ndcg_test           0.14981384575366974
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'ndcg_test': 0.14981384575366974, 'hr_test': 0.2585136294364929}]