In [1]:
import pandas as pd
import sys
sys.path.append('../src/')
from utils.metrics import RecallK, DiversityK, LongTailK
from datasets.datasets import SASRecDataset, TwhinDataset
import torch
from torch.utils.data import DataLoader
from models.sasrec import SASRec
from models.graph_encoders import TwhinGraphEncoder
from utils.losses import TwhinLoss
from torch.utils.tensorboard import SummaryWriter
import numpy as np

import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

In [2]:
args = {
    'device': 'cpu', 
    'hidden_units': 50, 
    'dropout_rate': 0.5,
    'num_blocks': 2,
    'num_heads': 1,
    'maxlen': 128,
    'batch_size': 2048,
    'num_epochs_sasrec': 100, 
    'num_epochs_twhin': 50
}

# 1. Читаем данные

In [3]:
train = pd.read_pickle('../code/data/ml-1m/train.csv')
val = pd.read_pickle('../code/data/ml-1m/val.csv')
train.head()

Unnamed: 0,UserID,MovieID,Rating,history,candidate,timestamp,Genre
0,3403,6,4,"[[413, Airheads, [Comedy], M, 35, 5, 48342, 19...","[6, Heat, [Action, Crime, Thriller], M, 35, 5,...",967429703,"[Action, Crime, Thriller]"
1,4630,1883,3,"[[593, Silence of the Lambs, The, [Drama, Thri...","[1883, Bulworth, [Comedy], F, 25, 4, 94610, 19...",964040034,[Comedy]
2,2882,27,3,"[[2683, Austin Powers: The Spy Who Shagged Me,...","[27, Now and Then, [Drama], M, 18, 20, 78759, ...",972243969,[Drama]
3,3513,593,5,"[[908, North by Northwest, [Drama, Thriller], ...","[593, Silence of the Lambs, The, [Drama, Thril...",966976389,"[Drama, Thriller]"
4,2010,2759,2,"[[2021, Dune, [Fantasy, Sci-Fi], M, 18, 4, 815...","[2759, Dick, [Comedy], M, 18, 4, 81520, 1999, ...",974680399,[Comedy]


# 2. Обучаем SASRec

In [4]:
train_dataset_sasrec = SASRecDataset(train)
val_dataset_sasrec = SASRecDataset(val)
train_dataloader_sasrec = DataLoader(train_dataset_sasrec, batch_size=args['batch_size'], collate_fn=SASRecDataset.collate_fn, drop_last=True)
val_dataloader_sasrec = DataLoader(val_dataset_sasrec, batch_size=args['batch_size'], collate_fn=SASRecDataset.collate_fn)
print(len(train_dataloader_sasrec), len(val_dataloader_sasrec))

model = SASRec(3952, args).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
writer = SummaryWriter(log_dir='./sasrec')


def train_epoch_sasrec(model, optimizer, epoch, train_dataloader, writer):
    model.train()
    last_loss = -1
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()

        historys = batch['users']
        positives = batch['candidates']
        perm = torch.randperm(positives.shape[0])
        negatives = positives[perm]
        logits_pos, logits_neg = model(historys, positives, negatives)
        logits = torch.cat([logits_pos, logits_neg], dim=0)
        labels = torch.cat([torch.ones(logits_pos.shape[0]).to(args['device']), torch.zeros(logits_neg.shape[0]).to(args['device'])])
        loss = loss_fn(logits, labels)

        writer.add_scalar('Train/BCE', loss.item(), global_step=idx + epoch * len(train_dataloader))
        last_loss = loss.item() 

        loss.backward()
        optimizer.step()
    
    return last_loss

def val_epoch_sasrec(model, epoch, val_dataloader, train_dataloader, writer, loss_fn):
    model.eval()
    sum_loss = 0
    cnt = 0
    with torch.no_grad():
        for batch in val_dataloader:
            cnt += 1
            historys = batch['users']
            positives = batch['candidates']
            perm = torch.randperm(positives.shape[0])
            negatives = positives[perm]
            logits_pos, logits_neg = model(historys, positives, negatives)
            logits = torch.cat([logits_pos, logits_neg], dim=0)
            labels = torch.cat([torch.ones(logits_pos.shape[0]).to(args['device']), torch.zeros(logits_neg.shape[0]).to(args['device'])], dim=0)
            loss = loss_fn(logits, labels)
            sum_loss += loss.item()
        
        writer.add_scalar('Val/BCE', sum_loss / cnt, global_step=(epoch + 1) * len(train_dataloader))
    
    return sum_loss / cnt

916429it [00:23, 38975.26it/s]
77740it [00:01, 50735.45it/s]

447 38





In [5]:
# for epoch in range(args['num_epochs_sasrec']):
#     train_loss = train_epoch_sasrec(model, optimizer, epoch, train_dataloader_sasrec, writer, torch.nn.BCEWithLogitsLoss())
#     print(f'Train bce loss on epoch {epoch + 1}: {train_loss}')
#     val_loss = val_epoch_sasrec(model, epoch, val_dataloader_twhin, train_dataloader_sasrec, writer, torch.nn.BCEWithLogitsLoss())
#     print(f'Val bce loss on epoch {epoch + 1}: {val_loss}')

# 3. Обучаем twhin-like графовые вектора

In [6]:
train_dataset_twhin = TwhinDataset(train)
val_dataset_twhin = TwhinDataset(val)

train_dataloader_twhin = DataLoader(train_dataset_twhin, batch_size=args['batch_size'], collate_fn=TwhinDataset.collate_fn, drop_last=True)
val_dataloader_twhin = DataLoader(val_dataset_twhin, batch_size=args['batch_size'], collate_fn=TwhinDataset.collate_fn)
print(len(train_dataloader_twhin), len(val_dataloader_twhin))

model = TwhinGraphEncoder(6040, 3952, 5, args).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
loss_fn = TwhinLoss(reg_weight=1)
writer = SummaryWriter(log_dir='./twhin')


def train_epoch_twhin(model, optimizer, epoch, train_dataloader, writer, loss_fn):
    model.train()
    last_loss = -1
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()

        users = batch['users']
        types = batch['types']
        items = batch['items']
        
        users_output, items_output = model(users, items, types)
        twhin_loss, l2_reg = loss_fn(users_output, items_output)
        loss = twhin_loss + l2_reg

        writer.add_scalar('Train/link_prediction', twhin_loss.item(), global_step=idx + epoch * len(train_dataloader))
        writer.add_scalar('Train/total_loss', loss.item(), global_step=idx + epoch * len(train_dataloader))

        last_loss = twhin_loss.item() 

        loss.backward()
        optimizer.step()
    
    return last_loss

def val_epoch_twhin(model, epoch, val_dataloader, train_dataloader, writer, loss_fn):
    model.eval()
    sum_loss = 0
    cnt = 0
    with torch.no_grad():
        for batch in val_dataloader:
            cnt += 1
            users = batch['users']
            types = batch['types']
            items = batch['items']
            
            users_output, items_output = model(users, items, types)
            twhin_loss, l2_reg = loss_fn(users_output, items_output)

            sum_loss += twhin_loss.item()
        
        writer.add_scalar('Val/link_prediction', sum_loss / cnt, global_step=(epoch + 1) * len(train_dataloader))
    
    return sum_loss / cnt

916429it [00:00, 2424178.21it/s]
77740it [00:00, 1965929.85it/s]

447 38





In [7]:
for epoch in range(args['num_epochs_twhin']):
    train_loss = train_epoch_twhin(model, optimizer, epoch, train_dataloader_twhin, writer, loss_fn)
    print(f'Train link-prediction loss on epoch {epoch + 1}: {train_loss}')
    val_loss = val_epoch_twhin(model, epoch, val_dataloader_twhin, train_dataloader_twhin, writer, loss_fn)
    print(f'Val link-prediction loss on epoch {epoch + 1}: {val_loss}')

: 

# 4. Обучаем SASRec, дополненный первой версией графовых векторов