In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Dataset

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

class BertTrainDataset(Dataset):
    def __init__(self, u2seq, max_len, mask_prob, mask_token, num_items, rng):
        self.u2seq = u2seq
        self.users = sorted(self.u2seq.keys())
        self.max_len = max_len
        self.mask_prob = mask_prob
        self.mask_token = mask_token
        self.num_items = num_items
        self.rng = rng

    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self._getseq(user)

        tokens = []
        labels = []
        for s in seq:
            prob = self.rng.random()
            if prob < self.mask_prob:
                prob /= self.mask_prob

                if prob < 0.8:
                    tokens.append(self.mask_token)
                elif prob < 0.9:
                    tokens.append(self.rng.randint(1, self.num_items))
                else:
                    tokens.append(s)

                labels.append(s)
            else:
                tokens.append(s)
                labels.append(0)

        tokens = tokens[-self.max_len:]
        labels = labels[-self.max_len:]

        mask_len = self.max_len - len(tokens)

        tokens = [0] * mask_len + tokens
        labels = [0] * mask_len + labels

        return torch.LongTensor(tokens), torch.LongTensor(labels)

    def _getseq(self, user):
        return self.u2seq[user]



class BertEvalDataset(Dataset):
    def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples):
        self.u2seq = u2seq
        self.users = sorted(self.u2seq.keys())
        self.u2answer = u2answer
        self.max_len = max_len
        self.mask_token = mask_token
        self.negative_samples = negative_samples

    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self.u2seq[user]
        answer = self.u2answer[user]
        negs = self.negative_samples[user]

        candidates = answer + negs
        labels = [1] * len(answer) + [0] * len(negs)

        seq = seq + [self.mask_token]
        seq = seq[-self.max_len:]
        padding_len = self.max_len - len(seq)
        seq = [0] * padding_len + seq

        return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)

class PredictDataset(Dataset):
    def __init__(self, u2seq, max_len, mask_token):
        self.u2seq = u2seq
        self.users = sorted(self.u2seq.keys())
        self.max_len = max_len
        self.mask_token = mask_token
        
    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self.u2seq[user]
        history_items = list(set(seq))

        seq = seq + [self.mask_token]
        seq = seq[-self.max_len:]
        padding_len = self.max_len - len(seq)
        seq = [0] * padding_len + seq

        return {
            'seq': torch.LongTensor(seq),
            'history': torch.LongTensor(history_items)  
        }

# Data loader

In [3]:
import random
from tqdm import trange
from pathlib import Path
import pickle
import numpy as np

class RandomNegativeSampler():
    def __init__(self, train, val, test, user_count, item_count, sample_size, seed, save_dir):
        self.train = train
        self.val = val
        self.test = test
        self.user_count = user_count
        self.item_count = item_count
        self.sample_size = sample_size
        self.seed = seed
        self.save_dir = save_dir
        self.save_path = os.path.join(self.save_dir, 'negative_sample-sample_size{}-seed{}.pkl'.format(self.sample_size, self.seed))
        
    def generate_negative_samples(self):
        assert self.seed is not None, 'Specify seed for random sampling'
        np.random.seed(self.seed)
        negative_samples = {}
        print('Sampling negative items')
        for user in trange(self.user_count):
            if isinstance(self.train[user][1], tuple):
                seen = set(x[0] for x in self.train[user])
                seen.update(x[0] for x in self.val[user])
                seen.update(x[0] for x in self.test[user])
            else:
                seen = set(self.train[user])
                seen.update(self.val[user])
                seen.update(self.test[user])

            samples = []
            for _ in range(self.sample_size):
                item = np.random.choice(self.item_count) + 1
                while item in seen or item in samples:
                    item = np.random.choice(self.item_count) + 1
                samples.append(item)

            negative_samples[user] = samples

        return negative_samples

    def get_negative_samples(self):
        """Lấy negative samples từ file nếu đã có, nếu không thì tạo mới và lưu vào file."""
        # Kiểm tra xem file đã tồn tại chưa
        if os.path.exists(self.save_path):
            print("Negatives samples exist. Loading.")
            with open(self.save_path, "rb") as f:
                negative_samples = pickle.load(f)
        else:
            print("Negative samples don't exist. Generating.")
            negative_samples = self.generate_negative_samples()
            os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
            
            with open(self.save_path, "wb") as f:
                pickle.dump(negative_samples, f)
            print(f"Saved negative samples to {self.save_path}")

        return negative_samples

class BertDataloader():
    def __init__(self, args, dataset):
        self.args = args
        seed = args.dataloader_random_seed
        self.rng = random.Random(seed)
        self.train = dataset['train']
        self.val = dataset['val']
        self.test = dataset['test']
        self.user_count = dataset['user_count']
        self.item_count = dataset['item_count']
        self.max_len = args.bert_max_len
        self.mask_prob = args.bert_mask_prob
        self.CLOZE_MASK_TOKEN = self.item_count + 1

        train_negative_sampler = RandomNegativeSampler(self.train, self.val, self.test,
                                                       self.user_count, self.item_count,
                                                       args.train_negative_sample_size,
                                                       args.train_negative_sampling_seed,
                                                       args.data_dir
                                                      )
        test_negative_sampler = RandomNegativeSampler(self.train, self.val, self.test,
                                                      self.user_count, self.item_count,
                                                      args.test_negative_sample_size,
                                                      args.test_negative_sampling_seed,
                                                      args.data_dir
                                                      )

        self.train_negative_samples = train_negative_sampler.get_negative_samples()
        self.test_negative_samples = test_negative_sampler.get_negative_samples()

        self.predict = {}
        for key in dataset['train'].keys():
            self.predict[key] = (
                dataset['train'].get(key, []) + 
                dataset['val'].get(key, []) +
                dataset['test'].get(key, [])
            )

    def get_pytorch_dataloaders(self):
        train_loader = self.get_train_loader()
        val_loader = self.get_val_loader()
        test_loader = self.get_test_loader()
        if self.args.is_distributed:
            return train_loader, val_loader, test_loader

        return train_loader, val_loader, test_loader

    def get_train_loader(self):
        dataset = self._get_train_dataset()
        if self.args.is_distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            train_sampler = None
        dataloader = DataLoader(dataset, batch_size=self.args.train_batch_size, sampler=train_sampler,
                                           shuffle=(train_sampler is None), pin_memory=True)
        return dataloader
    def _get_train_dataset(self):
        dataset = BertTrainDataset(self.train, self.max_len, self.mask_prob, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)
        return dataset

    def get_val_loader(self):
        return self._get_eval_loader(mode='val')

    def get_test_loader(self):
        return self._get_eval_loader(mode='test')

    def _get_eval_loader(self, mode):
        batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
        dataset = self._get_eval_dataset(mode)
        dataloader = DataLoader(dataset, batch_size=batch_size,
                                           shuffle=False, pin_memory=True)
        return dataloader

    def _get_eval_dataset(self, mode):
        answers = self.val if mode == 'val' else self.test
        dataset = BertEvalDataset(self.train, answers, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples)
        return dataset

    def get_predict_loader(self):
        dataset = PredictDataset(self.predict,  self.max_len, self.CLOZE_MASK_TOKEN)
        dataloader = DataLoader(dataset, batch_size=self.args.test_batch_size,
                                           shuffle=False, pin_memory=True, collate_fn=lambda x: {
                                               'seq': torch.stack([item['seq'] for item in x]),
                                               'history': [item['history'] for item in x]
                                           })
        return dataloader

# Model

In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()

        # Compute the positional encodings once in log space.
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)


class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

class BERTEmbedding(nn.Module):
    def __init__(self,
                 vocab_size,
                 embed_size,
                 max_len,
                 course_embeddings=None,
                 dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size,
                                    embed_size=embed_size)
        self.position = PositionalEmbedding(max_len=max_len,
                                            d_model=embed_size)
        if course_embeddings is not None:
            self.metapath_embed = nn.Embedding.from_pretrained(
                torch.FloatTensor(course_embeddings), freeze=False)
            self.fusion_proj = nn.Linear(embed_size * 2, embed_size)
        else:
            self.metapath_embed = None

        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    def forward(self, sequence):
        x = self.token(sequence) +  self.position(sequence)
        if self.metapath_embed is not None:
            meta_embed = self.metapath_embed(sequence)
            x = torch.cat([x, meta_embed], dim=-1)
            x = self.fusion_proj(x)
        return self.dropout(x)

class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model)
                                            for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            layer(x).view(batch_size, -1,
                          self.h, self.d_k).transpose(1, 2)
            for layer, x in zip(self.linear_layers, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value,
                                 mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1,
                                                self.h * self.d_k)

        return self.output_linear(x)


class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
                                         (x + 0.044715 * torch.pow(x, 3))))


class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()

    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))


class LayerNorm(nn.Module):

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class SublayerConnection(nn.Module):

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

class SublayerConnection(nn.Module):

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

class TransformerBlock(nn.Module):

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads,
                                              d_model=hidden,
                                              dropout=dropout)

        self.feed_forward = PositionwiseFeedForward(d_model=hidden,
                                                    d_ff=feed_forward_hidden,
                                                    dropout=dropout)

        self.input_sublayer = SublayerConnection(size=hidden,
                                                 dropout=dropout)

        self.output_sublayer = SublayerConnection(size=hidden,
                                                  dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x,
                                lambda _x: self.attention.forward(_x, _x, _x,
                                                                  mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

class BERT4Rec(nn.Module):
    def __init__(self, args, course_embeddings=None):
        super().__init__()

        max_len = args.bert_max_len
        num_items = args.num_items
        n_layers = args.bert_num_blocks
        heads = args.bert_num_heads
        self.vocab_size = num_items + 2
        hidden = args.bert_hidden_units
        self.hidden = hidden
        dropout = args.bert_dropout

        # embedding for BERT, sum of positional, token, course pretrained embeddings
        self.embedding = BERTEmbedding(vocab_size=self.vocab_size,
                                       embed_size=self.hidden,
                                       max_len=max_len,
                                       course_embeddings=course_embeddings,
                                       dropout=dropout
                                       )

        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(hidden, heads, hidden * 4, dropout)
            for _ in range(n_layers)
        ])

        self.out = nn.Linear(self.hidden, args.num_items + 1)

    def forward(self, x):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x)

        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        x = self.out(x)
        
        return x

    def init_weights(self):
        pass

# Utils

In [5]:
import os
import logging
import csv
from collections import OrderedDict
import numpy as np
import pandas as pd
import ast
from collections import defaultdict
import ast

def create_log_id(dir_path):
    log_count = 0
    file_path = os.path.join(dir_path, 'log{:d}.log'.format(log_count))
    while os.path.exists(file_path):
        log_count += 1
        file_path = os.path.join(dir_path, 'log{:d}.log'.format(log_count))
    return log_count


def logging_config(folder=None, name=None,
                   level=logging.DEBUG,
                   console_level=logging.DEBUG,
                   no_console=True):

    if not os.path.exists(folder):
        os.makedirs(folder)
    for handler in logging.root.handlers:
        logging.root.removeHandler(handler)
    logging.root.handlers = []
    logpath = os.path.join(folder, name + ".log")
    print("All logs will be saved to %s" %logpath)

    logging.root.setLevel(level)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logfile = logging.FileHandler(logpath)
    logfile.setLevel(level)
    logfile.setFormatter(formatter)
    logging.root.addHandler(logfile)

    if not no_console:
        logconsole = logging.StreamHandler()
        logconsole.setLevel(console_level)
        logconsole.setFormatter(formatter)
        logging.root.addHandler(logconsole)
    return folder

def early_stopping(recall_list, stopping_steps):
    best_recall = max(recall_list)
    best_step = recall_list.index(best_recall)
    if len(recall_list) - best_step - 1 >= stopping_steps:
        should_stop = True
    else:
        should_stop = False
    return best_recall, should_stop


def save_checkpoint(model_dir, model, optimizer, current_epoch, best_recall, best_epoch, metrics_list, epoch_list):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    checkpoint_file = os.path.join(model_dir, 'checkpoint_epoch{}.pth'.format(current_epoch))
    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': current_epoch,
                'best_recall': best_recall,
                'best_epoch': best_epoch,
                'metrics_list': metrics_list,
                'epoch_list': epoch_list
               }, checkpoint_file)

def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def recalls_and_ndcgs_for_ks(scores, labels, ks):
    metrics = {}

    scores = scores
    labels = labels
    answer_count = labels.sum(1)

    labels_float = labels.float()
    rank = (-scores).argsort(dim=1)
    cut = rank
    for k in sorted(ks, reverse=True):
       cut = cut[:, :k]
       hits = labels_float.gather(1, cut)
       metrics['Recall@%d' % k] = \
           (hits.sum(1) / torch.min(torch.Tensor([k]).to(labels.device), labels.sum(1).float())).mean().cpu().item()

       position = torch.arange(2, 2+k)
       weights = 1 / torch.log2(position.float())
       dcg = (hits * weights.to(hits.device)).sum(1)
       idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in answer_count]).to(dcg.device)
       ndcg = (dcg / idcg).mean()
       metrics['NDCG@%d' % k] = ndcg.cpu().item()

    return metrics

def create_dataset(args):
    num_users = 0
    num_courses = 0
    train = defaultdict(list)
    val = defaultdict(list)
    test = defaultdict(list)

    def load_train(path, storage):
        nonlocal num_users, num_courses
        df = pd.read_csv(path)
        for _, row in df.iterrows():
            user = int(row['user'])
            courses = ast.literal_eval(row['feature'])
            courses = [course + 1 for course in courses]
            storage[user].extend(courses)
            num_users = max(num_users, user)
            if courses:
                num_courses = max(num_courses, max(courses))

    def load_single_label_file(path, label_column, storage):
        nonlocal num_users, num_courses
        df = pd.read_csv(path)
        for _, row in df.iterrows():
            user = int(row['user'])
            course = int(row[label_column])
            storage[user].append(course + 1)
            num_users = max(num_users, user)
            num_courses = max(num_courses, course)

    data_dir = args.data_dir
    load_train(os.path.join(data_dir, 'train_df.csv'), train)
    load_single_label_file(os.path.join(data_dir, 'val_df.csv'), 'val_label', val)
    load_single_label_file(os.path.join(data_dir, 'test_df.csv'), 'test_label', test)
    args.num_items = num_courses + 1

    dataset = {'train': train,'val': val,'test': test,'user_count': num_users + 1,'item_count': num_courses + 1}
    return dataset

# Trainer

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import json
import os
import pprint as pp
import random
from datetime import date, timedelta
from pathlib import Path
import sys
import numpy as np

import json
class Trainer():
    def __init__(self, args, model, dataset):
        self.args = args
        self.device = args.device
        self.model = model.to(self.device)
        self.optimizer = self._create_optimizer()
        if args.enable_lr_schedule:
            self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=args.decay_step, gamma=args.gamma)

        self.num_epochs = args.num_epochs
        self.metric_ks = args.metric_ks
        self.best_metric = args.best_metric
        self.ce = nn.CrossEntropyLoss(ignore_index=0)
        
        self.dataloader = BertDataloader(args, dataset)
        self.log_save_id = create_log_id(self.args.save_dir)
        
        
    def train(self):
        if self.args.is_distributed:
            dist.init_process_group("gloo", timeout=timedelta(seconds=3000))
            rank = dist.get_rank()
            is_main_process = (rank == 0)
            self.model = DDP(self.model)        
        else:
            is_main_process = True

        self.train_loader = self.dataloader.get_train_loader()
        

        # Logging chỉ ở process chính
        if is_main_process:
            logging_config(folder=self.args.save_dir, name=f'log{self.log_save_id}', no_console=False)
            logging.info(self.args)
            logging.info(self.model)

        self.model.train()
        best_recall = 0
        best_epoch = 0
        recall_list = []
        min_valid_batch_size = 16
        should_stop = False

        for epoch in range(1, self.num_epochs + 1):
            if self.args.is_distributed:
                self.train_loader.sampler.set_epoch(epoch)

            if self.args.enable_lr_schedule:
                self.lr_scheduler.step()

            tqdm_dataloader = tqdm(self.train_loader) if is_main_process else self.train_loader

            total_loss = 0.0
            num_batches = 0

            for batch_idx, batch in enumerate(tqdm_dataloader):
                try:
                    batch = [x.to(self.device) for x in batch]
                except AttributeError:
                    batch = {k: v.to(self.device) for k, v in batch.items()}

                self.optimizer.zero_grad()
                loss = self.calculate_loss(batch)

                if torch.isnan(loss).any():
                    if batch[0].size(0) >= min_valid_batch_size:
                        if is_main_process:
                            logging.error(f'ERROR: Epoch {epoch}, batch {batch_idx + 1} Loss is nan.')
                        if self.args.is_distributed:
                            dist.destroy_process_group()
                        sys.exit()
                    else:
                        continue

                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                num_batches += 1


            # Tính average_loss: nếu phân tán thì cần reduce từ các GPU
            average_loss_tensor = torch.tensor([total_loss, num_batches], dtype=torch.float32, device=self.device)

            if self.args.is_distributed:
                dist.all_reduce(average_loss_tensor, op=dist.ReduceOp.SUM)

            total_loss = average_loss_tensor[0].item()
            num_batches = average_loss_tensor[1].item()
            average_loss = total_loss / max(num_batches, 1)

            if is_main_process:
                logging.info(f'Epoch {epoch:04d} | Average Loss: {average_loss:.4f}')

            dist.barrier()
                
            if (epoch % self.args.evaluate_every == 0 or epoch == self.args.num_epochs) and is_main_process:
                if self.args.is_distributed:
                    self.model.module.eval()
                else:
                    self.model.eval()

                all_metrics = {
                    "Recall@1": [],
                    "Recall@5": [],
                    "Recall@10": [],
                    "NDCG@1": [],
                    "NDCG@5": [],
                    "NDCG@10": []
                }

                with torch.no_grad():
                    self.val_loader = self.dataloader.get_val_loader()
                    tqdm_dataloader = tqdm(self.val_loader)
                    for batch_idx, batch in enumerate(tqdm_dataloader):
                        batch_size = batch[0].size(0)
                        batch = [x.to(self.device) for x in batch]

                        metrics = self.calculate_metrics(batch)
                        if batch_size >= min_valid_batch_size:
                            for key in all_metrics:
                                all_metrics[key].append(metrics[key])

                avg_metrics = {key: sum(values) / len(values) for key, values in all_metrics.items()}
                recall_list.append(avg_metrics[self.args.best_metric])

                logging.info('Val: Epoch {:04d} | Recall [{:.4f}, {:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f} {:.4f}]'.format(
                            epoch, avg_metrics['Recall@1'],  avg_metrics['Recall@5'], avg_metrics['Recall@10'],  avg_metrics['NDCG@1'], avg_metrics['NDCG@5'], avg_metrics['NDCG@10']))

                best_recall, should_stop = early_stopping(recall_list, args.stopping_steps)
                

                if recall_list[-1] == best_recall:
                    if self.args.is_distributed:
                        model_state_dict = self.model.module.state_dict()
                    else:
                        model_state_dict = self.model.state_dict()

                    model_save_path = os.path.join(self.args.save_dir, 'model', 'best_model.pth')
                    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
                    torch.save({'model_state_dict': model_state_dict, 'epoch': epoch},
                               model_save_path)
                    logging.info(f'Save model at epoch {epoch:04d}!')
                    best_epoch = epoch
            if should_stop:
                    break    
        dist.destroy_process_group()

    def test(self):
        logging_config(folder=self.args.save_dir, name=f'log{self.log_save_id}', no_console=False)
        print('Test best model with test set!')
        best_model = torch.load(os.path.join(self.args.save_dir, 'model', 'best_model.pth')).get('model_state_dict')
        self.model.load_state_dict(best_model)

        self.model.eval()
        all_metrics = {
            "Recall@1": [],
            "Recall@5": [],
            "Recall@10": [],
            "NDCG@1": [],
            "NDCG@5": [],
            "NDCG@10": []
        }
        
        min_test_batch_size = 16
        
        self.test_loader = self.dataloader.get_test_loader()
        with torch.no_grad():
            tqdm_dataloader = tqdm(self.test_loader)
            for batch_idx, batch in enumerate(tqdm_dataloader):
                batch_size = batch[0].size(0)
                batch = [x.to(self.device) for x in batch]
                metrics = self.calculate_metrics(batch)
                if batch_size >= min_test_batch_size:
                    for key in all_metrics:
                        all_metrics[key].append(metrics[key])

        avg_metrics = {key: sum(values) / len(values) for key, values in all_metrics.items()}
        logging.info('Test: Recall [{:.4f}, {:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f} {:.4f}]'.format(avg_metrics['Recall@1'],  avg_metrics['Recall@5'], avg_metrics['Recall@10'],  avg_metrics['NDCG@1'], avg_metrics['NDCG@5'], avg_metrics['NDCG@10']))

    def predict(self, top_n=10):
        print('Predict with best model.')
        best_model = torch.load(os.path.join(self.args.save_dir, 'model', 'best_model.pth')).get('model_state_dict')
        self.model.load_state_dict(best_model)
        
        self.model.eval()
        
        self.predict_loader = self.dataloader.get_predict_loader()
        all_preds = []
        all_scores = []
        with torch.no_grad():
            for batch in self.predict_loader:
                seq = batch['seq'].to(self.device)          # (batch_size, seq_len)
                history = batch['history']
                logits = self.model(seq)     
                last_logits = logits[:, -1, :] 
                
                for i in range(last_logits.size(0)):
                    last_logits[i, history[i]] = float('-inf')
                    
                topk_scores, topk_items = torch.topk(last_logits, k=top_n, dim=-1)
                all_preds.extend(topk_items.cpu().tolist())
                all_scores.extend(topk_scores.cpu().tolist())

        return all_preds, all_scores
        
    def _create_optimizer(self):
        args = self.args
        if args.optimizer.lower() == 'adam':
            return optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == 'sgd':
            return optim.SGD(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
        else:
            raise ValueError

    def calculate_loss(self, batch):
        seqs, labels = batch
        logits = self.model(seqs)

        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)
        loss = self.ce(logits, labels)
        return loss

    def calculate_metrics(self, batch):
        seqs, candidates, labels = batch
        scores = self.model(seqs)
        scores = scores[:, -1, :]
        scores = scores.gather(1, candidates)

        metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
        return metrics

In [7]:
from pyspark.ml.torch.distributor import TorchDistributor
from pyspark.sql import SparkSession
spark = SparkSession.builder \
    .appName("DistributedTorchTrain") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/19 06:37:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [8]:
distributor = TorchDistributor(num_processes=2, local_mode=True, use_gpu=False)

#### bert_max_len=20, bert_hidden_units=128, bert_mask_prob=0.3

In [10]:
from types import SimpleNamespace
args = SimpleNamespace(
    train_batch_size=128,
    val_batch_size=128,
    test_batch_size=128,
    train_negative_sample_size=0,
    train_negative_sampling_seed=0,
    test_negative_sample_size=100,
    test_negative_sampling_seed=98765,
    is_distributed=True,
    device='cpu',
    optimizer='Adam',
    lr=0.001,
    enable_lr_schedule=True,
    weight_decay=0.00,
    decay_step=25,
    gamma=1.0,
    num_epochs=50,
    evaluate_every=5,
    metric_ks=[1, 5, 10],
    best_metric='Recall@10',
    stopping_steps=5,
    bert_dropout=0.1,
    bert_hidden_units=128,
    bert_mask_prob=0.3,
    bert_max_len=20,
    bert_num_blocks=2,
    bert_num_heads=4,
    data_dir="/kaggle/input/mooc-bert4rec",
    save_dir="/kaggle/working/",
    dataloader_random_seed=2025
)

In [11]:
dataset = create_dataset(args)
model = BERT4Rec(args)
trainer = Trainer(args, model, dataset)
distributor.run(trainer.train)

Negatives samples exist. Loading.
Negatives samples exist. Loading.
All logs will be saved to /kaggle/working/log0.log
2025-05-18 07:41:49,765 - root - INFO - namespace(train_batch_size=128, val_batch_size=128, test_batch_size=128, train_negative_sample_size=0, train_negative_sampling_seed=0, test_negative_sample_size=100, test_negative_sampling_seed=98765, is_distributed=True, device='cpu', optimizer='Adam', lr=0.001, enable_lr_schedule=True, weight_decay=0.0, decay_step=25, gamma=1.0, num_epochs=50, evaluate_every=5, metric_ks=[1, 5, 10], best_metric='Recall@10', stopping_steps=5, bert_dropout=0.1, bert_hidden_units=128, bert_mask_prob=0.3, bert_max_len=20, bert_num_blocks=2, bert_num_heads=4, data_dir='/kaggle/input/mooc-bert4rec', save_dir='/kaggle/working/', dataloader_random_seed=2025, num_items=2828)
2025-05-18 07:41:49,765 - root - INFO - DistributedDataParallel(
  (module): BERT4Rec(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(2830, 128, padding_idx=0)
      

In [12]:
trainer.test()

  best_model = torch.load(os.path.join(self.args.save_dir, 'model', 'best_model.pth')).get('model_state_dict')


All logs will be saved to /kaggle/working/log0.log
Test best model with test set!


100%|██████████| 782/782 [01:11<00:00, 10.97it/s]
2025-05-18 10:34:14,962 - root - INFO - Test: Recall [0.2767, 0.6065, 0.7701], NDCG [0.2767, 0.4483 0.5013]


#### bert_max_len=20, bert_hidden_units=256, bert_mask_prob=0.15

In [9]:
from types import SimpleNamespace
args = SimpleNamespace(
    train_batch_size=128,
    val_batch_size=128,
    test_batch_size=128,
    train_negative_sample_size=0,
    train_negative_sampling_seed=0,
    test_negative_sample_size=100,
    test_negative_sampling_seed=98765,
    is_distributed=True,
    device='cpu',
    optimizer='Adam',
    lr=0.001,
    enable_lr_schedule=True,
    weight_decay=0.00,
    decay_step=25,
    gamma=1.0,
    num_epochs=50,
    evaluate_every=5,
    metric_ks=[1, 5, 10],
    best_metric='Recall@10',
    stopping_steps=5,
    bert_dropout=0.1,
    bert_hidden_units=256,
    bert_mask_prob=0.15,
    bert_max_len=20,
    bert_num_blocks=2,
    bert_num_heads=4,
    data_dir="/kaggle/input/mooc-bert4rec",
    save_dir="/kaggle/working/",
    dataloader_random_seed=2025
)

In [10]:
dataset = create_dataset(args)
model = BERT4Rec(args)
trainer = Trainer(args, model, dataset)
distributor.run(trainer.train)

Negatives samples exist. Loading.
Negatives samples exist. Loading.
All logs will be saved to /kaggle/working/log0.log
2025-05-19 06:39:36,121 - root - INFO - namespace(train_batch_size=128, val_batch_size=128, test_batch_size=128, train_negative_sample_size=0, train_negative_sampling_seed=0, test_negative_sample_size=100, test_negative_sampling_seed=98765, is_distributed=True, device='cpu', optimizer='Adam', lr=0.001, enable_lr_schedule=True, weight_decay=0.0, decay_step=25, gamma=1.0, num_epochs=50, evaluate_every=5, metric_ks=[1, 5, 10], best_metric='Recall@10', stopping_steps=5, bert_dropout=0.1, bert_hidden_units=256, bert_mask_prob=0.15, bert_max_len=20, bert_num_blocks=2, bert_num_heads=4, data_dir='/kaggle/input/mooc-bert4rec', save_dir='/kaggle/working/', dataloader_random_seed=2025, num_items=2828)
2025-05-19 06:39:36,121 - root - INFO - DistributedDataParallel(
  (module): BERT4Rec(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(2830, 256, padding_idx=0)
     

In [11]:
trainer.test()

  best_model = torch.load(os.path.join(self.args.save_dir, 'model', 'best_model.pth')).get('model_state_dict')


All logs will be saved to /kaggle/working/log0.log
Test best model with test set!


100%|██████████| 782/782 [02:11<00:00,  5.97it/s]
2025-05-19 12:07:22,713 - root - INFO - Test: Recall [0.2704, 0.6005, 0.7660], NDCG [0.2704, 0.4418 0.4955]
