In [None]:
! pip install pytorch-lightning
! pip install tensorboard

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, LearningRateMonitor 
from pytorch_lightning.loggers import TensorBoardLogger
import pandas as pd
from torchmetrics.retrieval import RetrievalHitRate, RetrievalNormalizedDCG, RetrievalAUROC

# Prepare data

In [None]:
class BPRData(Dataset):
    def __init__(self, df, num_neg=1):
        super(BPRData, self).__init__()
        self.df = df
        if num_neg is not None:
            self.num_neg= num_neg
        else:
            self.num_neg=0
        
        # load ratings as a dok matrix
        self.features = self.df.values

        self.user_num = self.features[:,0].max() + 1 # 1 for unknown
        self.item_num = self.features[:,1:].max() + 1

        self.pos_item = {}
        if self.num_neg > 0:
            for user in range(self.user_num):
                self.pos_item[user] = self.df[self.df['user']==user]['item'].tolist()

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        user = self.features[idx][0]
        item = self.features[idx][1:] # positive
        
        user = max(0, min(user,self.user_num))
        item = [max(0, min(i,self.item_num)) for i in item]
        
        if len(item) == 1 and self.num_neg >0:
            pos = self.pos_item[user]
            pos_i = torch.randint(0, len(pos)+1, (self.num_neg,))
            pos.append(0)
            pos.append(self.item_num)
            pos, _  = torch.sort(torch.tensor(pos))
            for j in pos_i:
                low = pos[j]
                high = pos[j+1]
                if low < high:
                    item_j = int(torch.randint(low, high,(1,))[0])
                else:
                    item_j = 0       # <UNK>
                item.append(item_j)
        return user, torch.tensor(item)
    
def collate_data(batch):
    user, item  = zip(*batch)
    user = torch.tensor(user)
    item = torch.nn.utils.rnn.pad_sequence(item, batch_first=True)
    item= torch.nan_to_num(item)

    return user, item

class BPRDataModule(LightningDataModule):
    def __init__(self, data_path, tr_neg=1, val_neg=99, tr_bs=128, val_bs=32, num_workers=4):
        super().__init__()
        self.data_dir = data_path
        self.tr_neg = tr_neg
        self.val_neg = val_neg
        self.tr_bs = tr_bs
        self.val_bs = val_bs
        self.num_workers = num_workers
        self.get_data()
        
    def get_data(self):    
        self.df_train = pd.read_csv(self.data_dir+'.train.rating', usecols=[0,1], sep='\t', names = ['user', 'item'])
        self.df_val = pd.read_csv(self.data_dir+'.test.rating', usecols=[0,1], sep='\t', names = ['user', 'item'])
        self.df_test = pd.read_csv(self.data_dir+'.test.negative', sep='\t', names = ['user', 'item']+[f'item_neg_{i+1}' for i in range(99)])
        self.df_test['user'] = [t[1] for t in self.df_test['user']]
        
        self.number_user = self.df_train['user'].max()
        self.number_item = self.df_train['item'].max()
        self.get_df_info()
        
    def get_df_info(self):    
        print(self.df_train.describe())
        print(self.df_train.head())
        print(self.df_val.head())
        print(self.df_test.head())
        
        number_interaction = len(self.df_train)
        sparsity = 100 - 100.0*number_interaction/ (self.number_user*self.number_user)
        print(f'Number of users {self.number_user}')
        print(f'Number of items {self.number_item}')
        print(f'Number of interactions {number_interaction}')
        print(f'Sparsity {sparsity:6f}')

    def setup(self, stage: str):
        if stage == "fit":
            self.train_ds = BPRData(self.df_train, num_neg=self.tr_neg)
            self.val_ds = BPRData(self.df_val, num_neg=self.val_neg)
        if stage == "test" or stage == "predict":
            self.test_ds = BPRData(self.df_val, num_neg=99)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.tr_bs, 
                          num_workers=self.num_workers ,shuffle=True, collate_fn = collate_data)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.val_bs, 
                          num_workers=self.num_workers ,shuffle=False, collate_fn = collate_data)

    def predict_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.val_bs, 
                          num_workers=self.num_workers ,shuffle=False, collate_fn = collate_data)

In [None]:
movielens_1m = '/kaggle/input/movielens-1m/ml-1m'
yelp = '/kaggle/input/yelp-dataset/yelp'
pinterest ='/kaggle/input/pinterest-20/pinterest-20'
data_config = {
    'tr_neg':1, 
    'val_neg':99, 
    'tr_bs': 256, 
    'val_bs': 128, 
    'num_workers':4
}
dm = BPRDataModule(yelp, **data_config)

# BPR Model

In [None]:
def bpr_loss(prediction):
    prediction_i = prediction[:,0].reshape(-1,1)
    prediction_j = prediction[:,1:]
    loss = - torch.nn.functional.logsigmoid(prediction_i - prediction_j).mean()
    return loss

def batch_NDCG(prediction, top_k=10):
    """
    prediction: num_user x num_item
    each user has one postive item, indicated by 'positive'
    NDCG = log(2)/log(rank+1) if rank <=topk
    """
    positive_score = prediction[:,0].reshape(-1,1)
    rank = (prediction - positive_score) > 0
    rank = rank.to(torch.int).sum(dim=-1)            # rank >=0

    mask = (rank < top_k).to(torch.int)         # NDCG <0 if rank>top_k  
    NDCG = torch.log(torch.tensor(2.0)) / (torch.log(rank + 2.0))
    NDCG *= mask

    return NDCG.mean()

In [None]:
class BPR(LightningModule):
    def __init__(self, user_num, item_num, embed_size, optimizer=None, scheduler=None,  
                 top_k=10, eps=0.1, reg = 0.0, reg_adv=0.0):
        super().__init__()
        """
        user_num: number of users;
        item_num: number of items;
        embed_size: number of predictive factors.
        """        
        self.user_num = user_num
        self.item_num = item_num
        self.embed_size = embed_size
        self.top_k=top_k
        self.eps=eps
        self.reg=reg
        self.reg_adv=reg_adv

        self.embed_user = nn.Embedding(self.user_num, self.embed_size)
        self.embed_item = nn.Embedding(self.item_num, self.embed_size)
        
        self.automatic_optimization = False
        
        self.optimizer_config = optimizer
        self.schedulerr_config = scheduler
        self.save_hyperparameters()
     
    def forward(self, user, item):
        user_embed = self.embed_user(user)       # N x in
        item_embed = self.embed_item(item)       # N x (1+neg) x in
        # N x 1 x in  * N x in x (1+neg) => N x (1+neg)
        prediction = torch.bmm(user_embed.unsqueeze(1), item_embed.transpose(1,2)).squeeze(1)
            
        return prediction, user_embed, item_embed
    
    @torch.no_grad()
    def _get_adversarial(self, user, item, user_embed, item_embed):
        grad_user = torch.index_select(self.embed_user.weight.grad.clone(), 0, user) 
        grad_item = torch.index_select(self.embed_item.weight.grad.clone(), 0, item.flatten()).reshape(item.size(0), item.size(1),-1)

        # normalization: new_grad = (grad / |grad|) * eps
        delta_user = self.eps * nn.functional.normalize(grad_user, p=2, dim=-1)
        delta_item = self.eps * nn.functional.normalize(grad_item, p=2, dim=-1)

        delta_user = torch.max(user_embed.abs(), dim=0, keepdim=True)[0] * delta_user
        delta_item = torch.max(item_embed.abs(), dim=0, keepdim=True)[0] * delta_item

        return delta_user, delta_item

    def training_step(self, batch):
        user, item = batch
        opt = self.optimizers()
        opt.zero_grad()
        
        self.embed_user.weight.retain_grad()
        self.embed_item.weight.retain_grad()
        
        prediction, user_embed, item_embed = self(user,item)
        
        loss = bpr_loss(prediction)
        self.log("train_loss", loss, prog_bar=True)
        reg = torch.linalg.norm(user_embed, dim =-1).mean() + torch.linalg.norm(item_embed, dim =-1).mean()
        self.log("train_reg_loss", reg, prog_bar=True)
        loss += self.reg * reg
        
        if self.reg_adv > 0:
            self.manual_backward(loss,retain_graph=True)    
            delta_user, delta_item = self._get_adversarial(user, item, user_embed, item_embed)
            
            user_adv = user_embed + delta_user
            item_adv = item_embed + delta_item
            
            pred_adv = torch.bmm(user_adv.unsqueeze(1), item_adv.transpose(1,2)).squeeze(1)
            
            loss_adv = bpr_loss(pred_adv)
            self.log("train_adv_loss", loss_adv, prog_bar=True)
            reg_adv = torch.linalg.norm(user_adv, dim =-1).mean() + torch.linalg.norm(item_adv, dim =-1).mean()
            self.log("train_reg_adv_loss", reg_adv, prog_bar=True)
            
            loss_adv = self.reg_adv*(loss_adv + self.reg * reg_adv)
            self.manual_backward(loss_adv)  
        else:
            self.manual_backward(loss)  
            
        opt.step()
        return loss
    
    def validation_step(self, batch, batch_idx):
        user, item = batch
        
        prediction, user_embed, item_embed = self(user,item)
        
        loss = bpr_loss(prediction)
        self.log("bpr_loss", loss, prog_bar=True)
        reg = torch.linalg.norm(user_embed, dim =-1).mean() + torch.linalg.norm(item_embed, dim =-1).mean()
        self.log("reg_loss", reg, prog_bar=True)
        loss += self.reg * reg
        
        ndcg = batch_NDCG(prediction, self.top_k)
        self.log("ndcg", ndcg, prog_bar=True)
        return loss 
      
    def predict_step(self, batch):
        user, item = batch
        prediction, user_embed, item_embed = self(user,item)
        return prediction

    def configure_optimizers(self):
        if self.optimizer_config is None:
            optimizer = torch.optim.AdamW(self.parameters())
        else:
            optimizer = torch.optim.AdamW(self.parameters(), **self.optimizer_config)
        if self.schedulerr_config is None:
            return optimizer
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, **self.schedulerr_config)
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'interval': 'step'
                }
            }

# Training

In [None]:
base_param ={
    'user_num' : dm.number_user + 1,  # 1 for unknown
    'item_num' : dm.number_item + 1, 
    'embed_size' : 64  ,
    "top_k" : 10,
}
BPR_param = {
    "reg" : 0.01,
    "eps" : 0,
    "reg_adv" : 0,
    'optimizer': { 
        'lr' : 0.01, 
        'betas' : [0.9,0.999], 
        'weight_decay' : 1e-2, 
        'eps' : 1e-9,        
    },
    'scheduler': {
        'T_0' : 2000, 
        'T_mult' : 2, 
        'eta_min' : 0.001,
    }
}

APR_param = {
    "reg": 0.01,
    "eps": 0.5,
    "reg_adv" : 1.0,
    'optimizer': { 
        'lr' : 0.003, 
        'betas' : [0.9,0.999], 
        'weight_decay' : 1e-2, 
        'eps' : 1e-9,        
    },
}

## BPR warm up

In [None]:
accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'

checkpoint_callback = ModelCheckpoint(
        filename='{epoch}-{ndcg:.4f}',
        save_top_k = 5,
        monitor ='ndcg',
        mode='max',
        )

lr_monitor = LearningRateMonitor(logging_interval='step')
progress_bar = RichProgressBar()

# callbacks = [checkpoint_callback, progress_bar, lr_monitor]
callbacks = [checkpoint_callback, lr_monitor]

param = {
    'max_epochs': 50,
    # 'overfit_batches': 64,
    'devices':"auto", 'accelerator': accelerator,
    "log_every_n_steps": 20,
    "callbacks": callbacks,
    'reload_dataloaders_every_n_epochs': 1, 
     "logger" : TensorBoardLogger(save_dir="exp", name="BPR")
}

In [None]:
_param = base_param
_param.update(BPR_param)

model = BPR(**_param)
trainer = Trainer(**param)

trainer.fit(model, dm)

In [None]:
! realpath exp/BPR/*/checkpoints/epoch=*.ckpt
! cp `realpath exp/BPR/*/checkpoints/epoch=*.ckpt | tail -n 1` best_BPR.ckpt

## APR training

In [None]:
checkpoint_callback = ModelCheckpoint(
        filename='{epoch}-{ndcg:.4f}',
        save_top_k = 5,
        monitor ='ndcg',
        mode='max',
        )
lr_monitor = LearningRateMonitor(logging_interval='step')
progress_bar = RichProgressBar()

# callbacks = [checkpoint_callback, progress_bar, lr_monitor]
callbacks = [checkpoint_callback, lr_monitor]

param = {
    'max_epochs': 25,
    # 'overfit_batches': 64,
    'devices':"auto", 'accelerator': accelerator,
    "log_every_n_steps": 20,
    "callbacks": callbacks,
    'reload_dataloaders_every_n_epochs': 1, 
     "logger" : TensorBoardLogger(save_dir="exp", name="APR")
}


In [None]:
_param = base_param
_param.update(APR_param)

model = BPR.load_from_checkpoint('best_BPR.ckpt', **_param)
trainer = Trainer(**param)

trainer.fit(model, dm)

In [None]:
! realpath exp/APR/*/checkpoints/epoch=*.ckpt
! cp `realpath exp/APR/*/checkpoints/epoch=*.ckpt | tail -n 1` best_APR.ckpt

# Testing

In [None]:
def test_model(model, dm):
    preds = trainer.predict(model, dm)
    preds= torch.cat(preds,dim=0).flatten()
    index = torch.tensor(dm.test_ds.df['user']).repeat_interleave(dm.test_ds.num_neg + 1)
    target = torch.tensor([True]+ [False]*dm.val_neg).repeat(len(dm.test_ds.df))
    
    for top_k in range(1,10):
        metric_hr = RetrievalHitRate(top_k=top_k)
        metric_ndcg = RetrievalNormalizedDCG(top_k=top_k)
        metric_auroc = RetrievalAUROC(top_k=top_k)
    
        
        hr = metric_hr(preds, target, indexes=index)
        ndcg = metric_ndcg(preds, target, indexes=index)
        auroc = metric_auroc(preds, target, indexes=index)
        
        print(f"top {top_k:d} : hr: {hr:.3f}, ndcg: {ndcg:.3f}, auroc: {auroc:.3f}")
        
    for top_k in range(10,100+1,10):
        metric_hr = RetrievalHitRate(top_k=top_k)
        metric_ndcg = RetrievalNormalizedDCG(top_k=top_k)
        metric_auroc = RetrievalAUROC(top_k=top_k)
    
        
        hr = metric_hr(preds, target, indexes=index)
        ndcg = metric_ndcg(preds, target, indexes=index)
        auroc = metric_auroc(preds, target, indexes=index)
        
        print(f"top {top_k:d} : hr: {hr:.3f}, ndcg: {ndcg:.3f}, auroc: {auroc:.3f}")

In [None]:
BPR_model = BPR.load_from_checkpoint('best_BPR.ckpt', **_param)
APR_model = BPR.load_from_checkpoint('best_APR.ckpt', **_param)

# 

In [None]:
test_model(BPR_model, dm)

In [None]:
test_model(APR_model, dm)