In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import time

from model import Discriminator
from config import irgan_config
from data_utils import RecDataset, DataProvider
from evaluation.rec_evaluator import RecEvaluator

In [2]:
epochs = irgan_config.epochs
batch_size = irgan_config.batch_size
epochs_d = irgan_config.epochs_d
emb_dim = irgan_config.emb_dim
eta_D = irgan_config.eta_D
device = irgan_config.device
weight_decay_d = irgan_config.weight_decay_d
patience = irgan_config.patience

In [3]:
# Hyper-parameters and datset-specific parameters
rec_dataset = RecDataset(irgan_config.dir_path)
all_users = rec_dataset.get_users()
all_items = rec_dataset.get_items()
num_users = rec_dataset.get_num_users()
num_items = rec_dataset.get_num_items()
bought_mask = rec_dataset.get_bought_mask().to(device)
bought_dict = rec_dataset.get_interaction_records()
eval_dict = rec_dataset.get_interaction_records("test")
train_ui = rec_dataset.get_user_item_pairs()

In [4]:
dp = DataProvider(device)
evaluator = RecEvaluator(eval_dict, None, device)

In [5]:
writer = SummaryWriter("runs/Pretrained-discriminator-Static-Negative-Sampling")

In [6]:
D = Discriminator(num_users, num_items, emb_dim,bought_mask)
D = D.to(device)
loss_D = nn.BCELoss()
optimizer_D = torch.optim.Adam(D.parameters(), lr = eta_D, weight_decay = weight_decay_d)

In [7]:
train_set = dp.prepare_bpr_triplets_dns(D, bought_mask)

In [8]:
bad_counter = 0
best_epoch = 0
best_p = 0

for epoch in range(epochs):
    time_start = time.time()
    loss_epoch = 0
    train_set = dp.prepare_bpr_triplets(all_items,bought_dict) # Static Negative Sampling
    for users, pos_items, neg_items in train_set:
        x_ui = D(users,pos_items)
        x_uj = D(users,neg_items)
        
        x_uij = x_ui - x_uj
        loss = -torch.log(torch.sigmoid(x_uij)).mean()
        loss_epoch += loss.item()
        optimizer_D.zero_grad()
        loss.backward()
        optimizer_D.step()
            
    time_end = time.time()
    loss_epoch /= len(train_set)
    print(
        "\t[Discriminator][Epochs %d/%d] [D epoch loss: %6.5f] [Time:%6.5f] "
        % (epoch+1, epochs, loss_epoch, time_end - time_start)
    )
    writer.add_scalar('Loss', loss_epoch, epoch)
    
    with torch.no_grad():
        res = evaluator.top_k_evaluation(D, [3,5,10])    
        ndcg3, precision3, hit3, map3, mrr3 = res[0]
        ndcg5, precision5, hit5, map5, mrr5 = res[1]
        ndcg10, precision10, hit10, map10, mrr10 = res[2]
        writer.add_scalar('Metrics/NDCG@3', ndcg3, epoch)
        writer.add_scalar('Metrics/NDCG@5', ndcg5, epoch)
        writer.add_scalar('Metrics/NDCG@10', ndcg10, epoch)
        writer.add_scalar('Metrics/Precision@3', precision3, epoch)
        writer.add_scalar('Metrics/Precision@5', precision5, epoch)
        writer.add_scalar('Metrics/Precision@10', precision10, epoch)
        writer.add_scalar('Metrics/Hit@3', hit3, epoch)
        writer.add_scalar('Metrics/Hit@5', hit5, epoch)
        writer.add_scalar('Metrics/Hit@10', hit10, epoch)
        writer.add_scalar('Metrics/MAP', map10, epoch)
        writer.add_scalar('Metrics/MRR', mrr10, epoch)
        
        if(precision10 > best_p):
            best_p = precision10
            best_epoch = epoch
            bad_counter = 0
            torch.save(D.state_dict(),"./pretrained_models/pretrained_model_sns.pkl")
        else:
            bad_counter += 1   
            
        if bad_counter == patience:
            break
    time_epoch_end = time.time()
writer.close()    

	[Discriminator][Epochs 1/200] [D epoch loss: 0.68244] [Time:2.12873] 
	[Discriminator][Epochs 2/200] [D epoch loss: 0.66134] [Time:2.43261] 
	[Discriminator][Epochs 3/200] [D epoch loss: 0.64027] [Time:2.27455] 
	[Discriminator][Epochs 4/200] [D epoch loss: 0.61851] [Time:2.35460] 
	[Discriminator][Epochs 5/200] [D epoch loss: 0.59687] [Time:2.29360] 
	[Discriminator][Epochs 6/200] [D epoch loss: 0.57810] [Time:2.14305] 
	[Discriminator][Epochs 7/200] [D epoch loss: 0.56323] [Time:2.16312] 
	[Discriminator][Epochs 8/200] [D epoch loss: 0.55216] [Time:2.40571] 
	[Discriminator][Epochs 9/200] [D epoch loss: 0.54220] [Time:2.37356] 
	[Discriminator][Epochs 10/200] [D epoch loss: 0.53450] [Time:2.18632] 
	[Discriminator][Epochs 11/200] [D epoch loss: 0.52865] [Time:2.23099] 
	[Discriminator][Epochs 12/200] [D epoch loss: 0.52408] [Time:2.05657] 
	[Discriminator][Epochs 13/200] [D epoch loss: 0.51952] [Time:2.13023] 
	[Discriminator][Epochs 14/200] [D epoch loss: 0.51717] [Time:2.15012] 
	

	[Discriminator][Epochs 115/200] [D epoch loss: 0.46931] [Time:2.44334] 
	[Discriminator][Epochs 116/200] [D epoch loss: 0.46906] [Time:2.54584] 
	[Discriminator][Epochs 117/200] [D epoch loss: 0.46708] [Time:2.54067] 
	[Discriminator][Epochs 118/200] [D epoch loss: 0.46826] [Time:2.52430] 
	[Discriminator][Epochs 119/200] [D epoch loss: 0.46742] [Time:2.55100] 
	[Discriminator][Epochs 120/200] [D epoch loss: 0.46701] [Time:2.53249] 
	[Discriminator][Epochs 121/200] [D epoch loss: 0.46784] [Time:2.53956] 
	[Discriminator][Epochs 122/200] [D epoch loss: 0.46701] [Time:2.18805] 
	[Discriminator][Epochs 123/200] [D epoch loss: 0.46905] [Time:2.14618] 
	[Discriminator][Epochs 124/200] [D epoch loss: 0.46634] [Time:2.55756] 
	[Discriminator][Epochs 125/200] [D epoch loss: 0.46651] [Time:2.13988] 
	[Discriminator][Epochs 126/200] [D epoch loss: 0.46726] [Time:2.54032] 
	[Discriminator][Epochs 127/200] [D epoch loss: 0.46636] [Time:2.45880] 
	[Discriminator][Epochs 128/200] [D epoch loss: 0.4

In [9]:
print(best_epoch)
D.load_state_dict(torch.load("./pretrained_models/pretrained_model_sns.pkl"))

199


<All keys matched successfully>