In [1]:
import pandas as pd
import sys
sys.path.append('../src/')
from varka.ml_varka import MovieLenseVarka
from utils.metric import RecallK, DiversityK, LongTailK
from batch_generator.datasets import MlDataset
import torch
from torch.utils.data import DataLoader
from model.sasrec import SASRec
from torch.utils.tensorboard import SummaryWriter
import numpy as np

import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

# 1. Читаем данные

In [2]:
train = pd.read_pickle('../code/data/ml-1m/train.csv')
val = pd.read_pickle('../code/data/ml-1m/val.csv')
train.head()

Unnamed: 0,UserID,MovieID,Rating,history,candidate,timestamp,Genre
0,3403,6,4,"[[413, Airheads, [Comedy], M, 35, 5, 48342, 19...","[6, Heat, [Action, Crime, Thriller], M, 35, 5,...",967429703,"[Action, Crime, Thriller]"
1,4630,1883,3,"[[593, Silence of the Lambs, The, [Drama, Thri...","[1883, Bulworth, [Comedy], F, 25, 4, 94610, 19...",964040034,[Comedy]
2,2882,27,3,"[[2683, Austin Powers: The Spy Who Shagged Me,...","[27, Now and Then, [Drama], M, 18, 20, 78759, ...",972243969,[Drama]
3,3513,593,5,"[[908, North by Northwest, [Drama, Thriller], ...","[593, Silence of the Lambs, The, [Drama, Thril...",966976389,"[Drama, Thriller]"
4,2010,2759,2,"[[2021, Dune, [Fantasy, Sci-Fi], M, 18, 4, 815...","[2759, Dick, [Comedy], M, 18, 4, 81520, 1999, ...",974680399,[Comedy]


In [3]:
train_dataset = MlDataset(train)
val_dataset = MlDataset(val)

916429it [00:23, 39038.98it/s]
77740it [00:01, 45109.37it/s]


In [4]:
def collate_fn(batch):
    """
    convert list of tuple to dict of tensors (add padding)
    """

    users = []
    candidates = []
    for it in batch:
        users.append(it[0])
        candidates.append(it[1])
    return {
        'users': torch.nn.utils.rnn.pad_sequence(users, batch_first=True),
        'candidates': torch.LongTensor(candidates)
    }

# 2. Обучаем SASRec

In [5]:
args = {
    'device': 'cuda', 
    'hidden_units': 50, 
    'dropout_rate': 0.5,
    'num_blocks': 2,
    'num_heads': 1,
    'maxlen': 128,
    'batch_size': 2048,
    'num_epochs': 200
}

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=args['batch_size'], collate_fn=collate_fn, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=args['batch_size'], collate_fn=collate_fn)
len(train_dataloader), len(val_dataloader)

(916429, 77740)

In [7]:
model = SASRec(3952, args).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
loss_fn = torch.nn.BCEWithLogitsLoss()
writer = SummaryWriter(log_dir='./logs')

In [8]:
def train_epoch(model, optimizer, epoch, train_dataloader, writer):
    model.train()
    last_loss = -1
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()

        historys = batch['users']
        positives = batch['candidates']
        perm = torch.randperm(positives.shape[0])
        negatives = positives[perm]
        logits_pos, logits_neg = model(historys, positives, negatives)
        logits = torch.cat([logits_pos, logits_neg], dim=0)
        labels = torch.cat([torch.ones(logits_pos.shape[0]).to(args['device']), torch.zeros(logits_neg.shape[0]).to(args['device'])])
        loss = loss_fn(logits, labels)

        writer.add_scalar('Train/BCE', loss.item(), global_step=idx + epoch * len(train_dataloader))
        last_loss = loss.item() 

        loss.backward()
        optimizer.step()
    
    return last_loss

def val_epoch(model, epoch, val_dataloader, train_dataloader, writer):
    model.eval()
    sum_loss = 0
    cnt = 0
    with torch.no_grad():
        for batch in val_dataloader:
            cnt += 1
            historys = batch['users']
            positives = batch['candidates']
            perm = torch.randperm(positives.shape[0])
            negatives = positives[perm]
            logits_pos, logits_neg = model(historys, positives, negatives)
            logits = torch.cat([logits_pos, logits_neg], dim=0)
            labels = torch.cat([torch.ones(logits_pos.shape[0]).to(args['device']), torch.zeros(logits_neg.shape[0]).to(args['device'])], dim=0)
            loss = loss_fn(logits, labels)
            sum_loss += loss.item()
        
        writer.add_scalar('Val/BCE', sum_loss / cnt, global_step=(epoch + 1) * len(train_dataloader))
    
    return sum_loss / cnt

In [9]:
for epoch in range(args['num_epochs']):
        train_loss = train_epoch(model, optimizer, epoch, train_dataloader, writer)
        print(f'Train bce loss on epoch {epoch + 1}: {train_loss}')
        val_loss = val_epoch(model, epoch, val_dataloader, train_dataloader, writer)
        print(f'Val bce loss on epoch {epoch + 1}: {val_loss}')

?!
torch.Size([129, 64, 50])
torch.Size([129, 64, 50])
?!
torch.Size([129, 64, 50])
torch.Size([129, 64, 50])
