# MLP baseline

In [1]:
import os
import sys
from collections import defaultdict
import numpy as np 
import pandas as pd
import json

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

sys.path.append('../')
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all, retrieval_precision_all_v2, semantic_precision_all, semantic_precision_all_v2, precision_recall_f1_all
from utils.loss import *
from utils.data_loader import load_document
from utils.toolbox import preprocess_document, get_preprocess_document, get_preprocess_document_embs,\
                          get_preprocess_document_labels, get_preprocess_document_labels_v2, get_word_embs,\
                          get_free_gpu, merge_targets


## Load Data, Label
label -> bow, tf-idf, keybert, classification

In [2]:
dataset ='20news'
# cross domain
dataset2 = None # None
model_name = 'mpnet'
label_type = 'tf-idf-gensim'
# 用binary(f1) evaluation或rank evaluation
eval_f1 = False
criterion = 'ListNet_sigmoid_L1'#'ListNet_sigmoid_L1'
# 選preprocess config
preprocess_config_dir = 'parameters_baseline2'
n_gram = 1

lr = 1e-3
n_epoch = 300
valid_epoch = 10
h_dim = 3000
target_normalization = False

# 訓練幾次
n_time = 1
seed = 133
if dataset2:
    experiment_dir = f'cross_{dataset}_{dataset2}_{model_name}_{label_type}_{criterion}'
else:
    experiment_dir = f'{dataset}_{model_name}_{label_type}_{criterion}'
    
save_dir = 'default4'

config = {}
config['experiment_dir'] = experiment_dir
config['preprocess_config_dir'] = preprocess_config_dir
config['save_dir'] = save_dir
config['dataset'] = dataset
config['dataset2'] = dataset2
config['model_name'] = model_name
config['label_type'] = label_type
config['eval_f1'] = eval_f1
config['n_gram'] = n_gram
config['criterion'] = criterion
config['n_time'] = n_time
config['seed'] = seed

config['lr'] = lr
config['n_epoch'] = n_epoch
config['valid_epoch'] = valid_epoch
config['h_dim'] = h_dim
config['target_normalization'] = target_normalization
        
save_dir = os.path.join('experiment', experiment_dir, config['save_dir'])
os.makedirs(save_dir, exist_ok=False)

In [3]:
def load_training_data(config, dataset):
    preprocess_config_dir = config['preprocess_config_dir']
    with open(os.path.join(f'../chris/{preprocess_config_dir}', f'preprocess_config_{dataset}.json'), 'r') as f:
        preprocess_config = json.load(f)
        
    # load preprocess dataset
    unpreprocessed_docs, preprocessed_docs = get_preprocess_document(**preprocess_config)
    print('doc num', len(preprocessed_docs))

    # get document embeddings
    doc_embs, doc_model, device = get_preprocess_document_embs(preprocessed_docs, model_name)
    print('doc_embs', doc_embs.shape)
    
    # load labels
    labels, vocabularys = get_preprocess_document_labels_v2(preprocessed_docs, preprocess_config, preprocess_config_dir, config['n_gram'])    
    # check nonzero numbers
    for k in labels:
        print(k, np.sum(labels[k]!=0), labels[k].shape)
    print(len(vocabularys))
    # select label type
    targets = labels[config['label_type']].toarray()
    vocabularys = vocabularys
    
    return unpreprocessed_docs ,preprocessed_docs, doc_embs, targets, vocabularys, device

In [4]:
unpreprocessed_docs, preprocessed_docs, doc_embs, targets, vocabularys, device = load_training_data(config, config['dataset'])

Getting preprocess documents: 20news
min_df: 62 max_df: 1.0 vocabulary_size: None min_doc_word: 15
doc num 18589
Getting preprocess documents embeddings
Using cuda 0 for training...


Batches:   0%|          | 0/1162 [00:00<?, ?it/s]

doc_embs (18589, 768)
Getting preprocess documents labels
Finding precompute_keyword by preprocess_config {'dataset': '20news', 'min_df': 62, 'max_df': 1.0, 'vocabulary_size': None, 'min_doc_word': 15}
Getting gensim tf-idf labels
gensim missing word num 0
tf-idf 1092802 (18589, 4823)
bow 1092802 (18589, 4823)
keybert 1028492 (18589, 4823)
yake 892783 (18589, 4823)
tf-idf-gensim 1083990 (18589, 4823)
4823


In [5]:
if config['dataset2'] is not None:
    unpreprocessed_docs2, preprocessed_docs2, doc_embs2, targets2, vocabularys2, device = load_training_data(config, config['dataset2'])
    targets, targets2, vocabularys = merge_targets(targets, targets2, vocabularys, vocabularys2)
    

In [6]:
word_embs = get_word_embs(vocabularys)
print('word_embs', word_embs.shape)
word_embs_tensor = torch.FloatTensor(word_embs)

0it [00:00, ?it/s]

Number of words:400001
Getting [ndarray] word embeddings
word_embs (4823, 300)


## MLP Decoder

In [7]:
class DNNDecoderDataset(Dataset):
    def __init__(self, doc_embs, targets):
        
        assert len(doc_embs) == len(targets)

        self.doc_embs = torch.FloatTensor(doc_embs)
        self.targets = torch.FloatTensor(targets)        
        self.targets_rank = torch.argsort(self.targets, dim=1, descending=True)
        self.topk = torch.sum(self.targets > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_embs[idx], self.targets[idx], self.targets_rank[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_embs)

In [8]:
def prepare_dataloader(doc_embs, targets, batch_size=100, train_valid_test_ratio=[0.7, 0.1, 0.2],\
                       target_normalize=False, seed=123):
    train_size = int(len(doc_embs) * train_valid_test_ratio[0])
    valid_size = int(len(doc_embs) * (train_valid_test_ratio[0] + train_valid_test_ratio[1])) - train_size
    test_size = len(doc_embs) - train_size - valid_size
    
    print('Preparing dataloader')
    print('train size', train_size)
    print('valid size', valid_size)
    print('test size', test_size)

    if target_normalize:
        # normalize target summation of each document to 1 
        norm = targets.sum(axis=1).reshape(-1, 1)
        targets = (targets / norm)
        # normalize target L2 norm of each document to 1
        # norm = np.linalg.norm(targets, axis=1).reshape(-1, 1)
        # targets = (targets / norm)

    # shuffle
    randomize = np.arange(len(doc_embs))
    np.random.seed(seed)
    np.random.shuffle(randomize)
    doc_embs = doc_embs[randomize]
    targets = targets[randomize]
    
    # dataloader
    train_dataset = DNNDecoderDataset(doc_embs[:train_size], targets[:train_size])
    train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    valid_dataset = DNNDecoderDataset(doc_embs[train_size:train_size+valid_size], targets[train_size:train_size+valid_size])
    valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    test_dataset = DNNDecoderDataset(doc_embs[train_size+valid_size:], targets[train_size+valid_size:])
    test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    
    return train_loader, valid_loader, test_loader


In [9]:
# prepare dataloader
train_loader, valid_loader, test_loader = prepare_dataloader(doc_embs, targets, batch_size=64,\
                                                             train_valid_test_ratio=[0.6, 0.1, 0.2],\
                                                             target_normalize=config['target_normalization'],\
                                                             seed=seed)
if config['dataset2'] is not None:
    _, _, test_loader = prepare_dataloader(doc_embs2, targets2, batch_size=64,\
                                           train_valid_test_ratio=[0.6, 0.1, 0.2],\
                                           target_normalize=config['target_normalization'],\
                                           seed=seed)

Preparing dataloader
train size 11153
valid size 1859
test size 5577


In [10]:
# class DNNDecoder(nn.Module):
#     def __init__(self, doc_emb_dim, num_words, h_dim=300):
#         super().__init__()
#         self.decoder = nn.Sequential(
#             nn.Linear(doc_emb_dim, h_dim),
# #             nn.Dropout(p=0.2),
#             nn.Tanh(),
#             nn.Linear(h_dim, h_dim),
#             nn.Dropout(p=0.3),
#             nn.Tanh(),
#             nn.Linear(h_dim, num_words),
#             # nn.Dropout(p=0.5),
#             # nn.Sigmoid(),
#         )
#     def forward(self, x):
#         return self.decoder(x)

In [11]:
class DNNDecoder(nn.Module):

    ### casimir
    # (1) Add parameter vocab_size
    def __init__(self, doc_emb_dim, num_words=0, h_dim=300):
        super(DNNDecoder, self).__init__()
        vocab_size = num_words
        bert_size = doc_emb_dim
        
        self.vocab_size = vocab_size
        self.network = nn.Sequential(
#             nn.Dropout(p=0.1),
            nn.Linear(bert_size, bert_size*4),
            nn.BatchNorm1d(bert_size*4),
            nn.Sigmoid(),
#             nn.Tanh(),
#             nn.Mish(),
            nn.Dropout(p=0.2),
            nn.Linear(bert_size*4, vocab_size),
            nn.BatchNorm1d(vocab_size),
#             nn.Sigmoid(),
        )

    def forward(self, x_bert):
        recon_dist = self.network(x_bert)

        return recon_dist

In [12]:
def evaluate_DNNDecoder(model, data_loader, config, pred_semantic=False):
    results = defaultdict(list)
    model.eval()
    
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
        if config['eval_f1']:
            # Precision / Recall / F1
            p, r, f = precision_recall_f1_all(pred, target)
            results['precision'].append(p)
            results['recall'].append(r)
            results['f1_score'].append(f)
        else:
            # Precision
            precision_scores = retrieval_precision_all(pred, target, k=config["valid_topk"])
            for k, v in precision_scores.items():
                results['precision@{}'.format(k)].append(v)

            # Precision
            precision_scores = retrieval_precision_all_v2(pred, target, k=config["valid_topk"])
            for k, v in precision_scores.items():
                results['precisionv2@{}'.format(k)].append(v)

            # NDCG
            ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["valid_topk"])
            for k, v in ndcg_scores.items():
                results['ndcg@{}'.format(k)].append(v)
            
            # Semantic Precision
            if pred_semantic:
                semantic_precision_scores, word_result = semantic_precision_all(pred, target, word_embs_tensor, vocabularys,\
                                                                                k=config["valid_topk"], th=0.5, display_word_result=False)
                for k, v in semantic_precision_scores.items():
                    results['semantic_precision@{}'.format(k)].append(v)
                    
                semantic_precision_scores, word_result = semantic_precision_all_v2(pred, target, word_embs_tensor, vocabularys,\
                                                                                k=config["valid_topk"], th=0.5, display_word_result=False)
                for k, v in semantic_precision_scores.items():
                    results['semantic_precision_v2@{}'.format(k)].append(v)

    for k in results:
        results[k] = np.mean(results[k])

    return results

In [13]:
def calculate_loss(train_train_config, criterion, pred, target, target_rank, target_topk):
    if train_config["criterion"] == "MultiLabelMarginLoss":
        assert target_rank.shape[0] == len(target_topk)
        for i in range(len(target_topk)):
            target_rank[i, target_topk[i]] = -1
        loss = criterion(pred, target_rank)
    elif train_config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        loss = criterion(pred, target_rank, target_topk)
    elif train_config["criterion"].startswith("MultiLabelMarginLossCustom"):
        loss = criterion(pred, target_rank, train_config["loss_topk"])
    else:
        loss = criterion(pred, target)
        
    return loss
    
def train_decoder(doc_embs, targets, train_config):
    model = DNNDecoder(doc_emb_dim=doc_embs.shape[1], num_words=targets.shape[1],\
                       h_dim=train_config["h_dim"]).to(device)
    model.train()

    opt = torch.optim.Adam(model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"])
    # prepare loss
    if train_config["criterion"] == "MultiLabelMarginLoss":
        criterion = nn.MultiLabelMarginLoss(reduction='mean')
    elif train_config["criterion"] == "BCE":
        criterion = nn.BCEWithLogitsLoss(reduction='mean')
    elif train_config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        def criterion(a, b, c): return MultiLabelMarginLossCustomV(
            a, b, c, float(train_config["criterion"].split(':')[-1]))
    elif train_config["criterion"].startswith("MultiLabelMarginLossCustom"):
        def criterion(a, b, c): return MultiLabelMarginLossCustom(
            a, b, c, float(train_config["criterion"].split(':')[-1]))
    else:
        criterion = eval(train_config["criterion"])

    results = []
    n_epoch = train_config["n_epoch"]
    valid_epoch = train_config["valid_epoch"]
    valid_verbose = train_config["valid_verbose"]

    for epoch in tqdm(range(n_epoch)):
        train_loss_his = []
        valid_loss_his = []

        model.train()

        for data in train_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)
            # loss
            pred = model(doc_embs)
            loss = calculate_loss(train_config, criterion, pred, target, target_rank, target_topk)
            train_loss_his.append(loss.item())

            # Model backwarding
            model.zero_grad()
            loss.backward()
            opt.step()

        model.eval()
        for data in valid_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            loss = calculate_loss(train_config, criterion, pred, target, target_rank, target_topk)
            valid_loss_his.append(loss.item())

        print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))

        # show decoder result
        if (valid_epoch > 0 and epoch % valid_epoch == 0) or epoch == n_epoch-1:
            res = {}
            res['epoch'] = epoch

            train_res_ndcg = evaluate_DNNDecoder(model, train_loader, train_config, False)
            valid_res_ndcg = evaluate_DNNDecoder(model, valid_loader, train_config, False)
            test_res_ndcg = evaluate_DNNDecoder(model, test_loader, train_config, False)
            
            res['train'] = train_res_ndcg
            res['valid'] = valid_res_ndcg
            res['test'] = test_res_ndcg 
            results.append(res)

            if valid_verbose:
                print()
                print('train', train_res_ndcg)
                print('valid', valid_res_ndcg)
                print('test', test_res_ndcg)
                
    del model
    
    return results

def train_experiment(n_time):
    # train n_time in different seed
    results = []
    for _ in range(n_time):
        result = train_decoder(doc_embs, targets, train_config)
        results.append(result)

    with open(os.path.join(save_dir, 'result.json'), 'w') as f:
        json.dump(results, f)

    return results

In [18]:
train_config = {
    "n_time": config['n_time'],
    "lr": config['lr'],
    "lr": 1e-3,
    "weight_decay": 0.0,
    "loss_topk": 15,
    
    "n_epoch": config['n_epoch'],
    "valid_epoch": config['valid_epoch'],
    "valid_verbose": True,
    "valid_topk": [5, 10, 15],
    
    "h_dim": config['h_dim'],
    "label_type": config['label_type'],
    "eval_f1": config['eval_f1'],
    "criterion": 'ListNet_sigmoid_L1',#config['criterion'] 'ListNet_sigmoid_L1'
}

In [19]:
train_experiment(config['n_time'])

  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0 8.13433209010533 8.035237661997478

train defaultdict(<class 'list'>, {'precision@5': 0.6056544232368469, 'precision@10': 0.4743734298433576, 'precision@15': 0.4001642373629979, 'precisionv2@5': 0.30050945895058767, 'precisionv2@10': 0.2890751109804426, 'precisionv2@15': 0.28036066012723104, 'ndcg@5': 0.5042255708149501, 'ndcg@10': 0.47430176956312997, 'ndcg@15': 0.4561565903254918, 'ndcg@all': 0.6393318506649562})
valid defaultdict(<class 'list'>, {'precision@5': 0.5275000135103861, 'precision@10': 0.4109201471010844, 'precision@15': 0.3459953933954239, 'precisionv2@5': 0.243263894567887, 'precisionv2@10': 0.23402778059244156, 'precisionv2@15': 0.22998843987782797, 'ndcg@5': 0.4188481499751409, 'ndcg@10': 0.39480057259400686, 'ndcg@15': 0.3796604673067729, 'ndcg@all': 0.5751318355401357})
test defaultdict(<class 'list'>, {'precision@5': 0.5198666412721981, 'precision@10': 0.4002584530548616, 'precision@15': 0.3342579650607976, 'precisionv2@5': 0.23551136746325277, 'precisionv2

Epoch 51 4.5213366263253345 6.298384173711141
Epoch 52 4.502209102085659 6.282387653986613
Epoch 53 4.479029565538679 6.258969116210937
Epoch 54 4.457533356802804 6.3340683301289875
Epoch 55 4.437666418892997 6.270529286066691
Epoch 56 4.4187034034729 6.289256683985392
Epoch 57 4.398723850250244 6.275392675399781
Epoch 58 4.388168612888881 6.2763565381368
Epoch 59 4.369151654924665 6.294771480560303
Epoch 60 4.357162677219936 6.276208194096883

train defaultdict(<class 'list'>, {'precision@5': 0.9726554679870606, 'precision@10': 0.9661150206838335, 'precision@15': 0.9592892289161682, 'precisionv2@5': 0.6055178713798522, 'precisionv2@10': 0.6810336218561445, 'precisionv2@15': 0.7337703500475202, 'ndcg@5': 0.8257221024377005, 'ndcg@10': 0.8511012319156102, 'ndcg@15': 0.8643901467323303, 'ndcg@all': 0.9217500478880746})
valid defaultdict(<class 'list'>, {'precision@5': 0.8312152683734894, 'precision@10': 0.7281771024068197, 'precision@15': 0.656342621644338, 'precisionv2@5': 0.34072917203

Epoch 107 4.062440379006522 6.564222526550293
Epoch 108 4.049707085745675 6.504064909617106
Epoch 109 4.041282855442592 6.502161725362142
Epoch 110 4.042172074999128 6.569033447901408

train defaultdict(<class 'list'>, {'precision@5': 0.9870357142175947, 'precision@10': 0.9846785729272025, 'precision@15': 0.9797023854936873, 'precisionv2@5': 0.6274485468864441, 'precisionv2@10': 0.713284148148128, 'precisionv2@15': 0.76832042660032, 'ndcg@5': 0.8505352834292821, 'ndcg@10': 0.8790139474187578, 'ndcg@15': 0.8929164096287319, 'ndcg@all': 0.9385286484445844})
valid defaultdict(<class 'list'>, {'precision@5': 0.8396527806917826, 'precision@10': 0.7357986191908519, 'precision@15': 0.660219939549764, 'precisionv2@5': 0.354097231477499, 'precisionv2@10': 0.3412673642237981, 'precisionv2@15': 0.34023149907588957, 'ndcg@5': 0.6120730171600978, 'ndcg@10': 0.5835861047108968, 'ndcg@15': 0.563827732205391, 'ndcg@all': 0.706995012362798})
test defaultdict(<class 'list'>, {'precision@5': 0.8414141468


train defaultdict(<class 'list'>, {'precision@5': 0.9883214276177542, 'precision@10': 0.9871428608894348, 'precision@15': 0.9838631020273481, 'precisionv2@5': 0.6185514851978847, 'precisionv2@10': 0.7170210187775748, 'precisionv2@15': 0.7772377913338797, 'ndcg@5': 0.8425564759118216, 'ndcg@10': 0.8762578661101205, 'ndcg@15': 0.8926754014832633, 'ndcg@all': 0.938176099232265})
valid defaultdict(<class 'list'>, {'precision@5': 0.8391319413979849, 'precision@10': 0.7359201550483704, 'precision@15': 0.6617824514706929, 'precisionv2@5': 0.36736112236976626, 'precisionv2@10': 0.3500868082046509, 'precisionv2@15': 0.34886576334635416, 'ndcg@5': 0.6221641004085541, 'ndcg@10': 0.5919829328854879, 'ndcg@15': 0.5722658256689708, 'ndcg@all': 0.7095939815044403})
test defaultdict(<class 'list'>, {'precision@5': 0.8388888937505808, 'precision@10': 0.740208738906817, 'precision@15': 0.6641059409488331, 'precisionv2@5': 0.35947759584947064, 'precisionv2@10': 0.3442451906475154, 'precisionv2@15': 0.34

Epoch 211 3.881266472680228 7.062894105911255
Epoch 212 3.8785927500043598 7.0066449801127115
Epoch 213 3.8800866862705776 7.053057066599528
Epoch 214 3.8782898412431988 7.062592077255249
Epoch 215 3.875956610270909 7.03933097521464
Epoch 216 3.8755038819994243 7.015826749801636
Epoch 217 3.8746861171722413 7.053727547327678
Epoch 218 3.8750711618150984 7.086748504638672
Epoch 219 3.8744802720206124 7.055708106358846
Epoch 220 3.872818783351353 7.052558549245199

train defaultdict(<class 'list'>, {'precision@5': 0.999928571156093, 'precision@10': 0.9998214306150164, 'precision@15': 0.997237754549299, 'precisionv2@5': 0.6144348924500601, 'precisionv2@10': 0.7250971719196865, 'precisionv2@15': 0.7941698629515511, 'ndcg@5': 0.8471038709368025, 'ndcg@10': 0.8857331698281424, 'ndcg@15': 0.90488342830113, 'ndcg@all': 0.9448354319163731})
valid defaultdict(<class 'list'>, {'precision@5': 0.8403472284475962, 'precision@10': 0.739861124753952, 'precision@15': 0.6639930884043376, 'precisionv2@5'

Epoch 265 3.851116405214582 7.114934571584066
Epoch 266 3.8516918849945068 7.201845693588257
Epoch 267 3.8501479721069334 7.140807294845581
Epoch 268 3.851611373083932 7.144179662068685
Epoch 269 3.850309968675886 7.176191663742065
Epoch 270 3.8524334648677283 7.212303479512532

train defaultdict(<class 'list'>, {'precision@5': 0.9999821427890233, 'precision@10': 0.9999375002724784, 'precision@15': 0.9975178660665239, 'precisionv2@5': 0.6137016946928842, 'precisionv2@10': 0.7275084158352443, 'precisionv2@15': 0.7981765222549438, 'ndcg@5': 0.8462298362595695, 'ndcg@10': 0.8856199189594813, 'ndcg@15': 0.9052431205340794, 'ndcg@all': 0.9453329420089722})
valid defaultdict(<class 'list'>, {'precision@5': 0.8461111187934875, 'precision@10': 0.7500000178813935, 'precision@15': 0.6752083639303843, 'precisionv2@5': 0.38104167183240256, 'precisionv2@10': 0.36036458909511565, 'precisionv2@15': 0.3587153007586797, 'ndcg@5': 0.6330671568711599, 'ndcg@10': 0.6061063865820567, 'ndcg@15': 0.585495293

[[{'epoch': 0,
   'train': defaultdict(list,
               {'precision@5': 0.6056544232368469,
                'precision@10': 0.4743734298433576,
                'precision@15': 0.4001642373629979,
                'precisionv2@5': 0.30050945895058767,
                'precisionv2@10': 0.2890751109804426,
                'precisionv2@15': 0.28036066012723104,
                'ndcg@5': 0.5042255708149501,
                'ndcg@10': 0.47430176956312997,
                'ndcg@15': 0.4561565903254918,
                'ndcg@all': 0.6393318506649562}),
   'valid': defaultdict(list,
               {'precision@5': 0.5275000135103861,
                'precision@10': 0.4109201471010844,
                'precision@15': 0.3459953933954239,
                'precisionv2@5': 0.243263894567887,
                'precisionv2@10': 0.23402778059244156,
                'precisionv2@15': 0.22998843987782797,
                'ndcg@5': 0.4188481499751409,
                'ndcg@10': 0.39480057259400686,
     

In [20]:
train_config = {
    "n_time": config['n_time'],
    "lr": config['lr'],
    "lr": 1e-3,
    "weight_decay": 0.0,
    "loss_topk": 15,
    
    "n_epoch": config['n_epoch'],
    "valid_epoch": config['valid_epoch'],
    "valid_verbose": True,
    "valid_topk": [5, 10, 15],
    
    "h_dim": config['h_dim'],
    "label_type": config['label_type'],
    "eval_f1": config['eval_f1'],
    "criterion": 'MSE',#config['criterion'] 'ListNet_sigmoid_L1'
}

In [21]:
train_experiment(config['n_time'])

  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0 19034.904112723216 18460.464225260417

train defaultdict(<class 'list'>, {'precision@5': 0.5933214439664568, 'precision@10': 0.4615635551725115, 'precision@15': 0.38542474048478265, 'precisionv2@5': 0.30731513363974433, 'precisionv2@10': 0.2862006351777485, 'precisionv2@15': 0.2703949741806303, 'ndcg@5': 0.50814692735672, 'ndcg@10': 0.4727885510240282, 'ndcg@15': 0.4500257112298693, 'ndcg@all': 0.6190389939716884})
valid defaultdict(<class 'list'>, {'precision@5': 0.534791675209999, 'precision@10': 0.4091145922740301, 'precision@15': 0.3411921481291453, 'precisionv2@5': 0.26211805939674376, 'precisionv2@10': 0.24427083730697632, 'precisionv2@15': 0.22907408823569617, 'ndcg@5': 0.44097380141417186, 'ndcg@10': 0.4081339627504349, 'ndcg@15': 0.38899775842825574, 'ndcg@all': 0.5732472439606985})
test defaultdict(<class 'list'>, {'precision@5': 0.5206991854039106, 'precision@10': 0.39951862835071306, 'precision@15': 0.33143546601588075, 'precisionv2@5': 0.2510416732931679, 'precisio

Epoch 51 15154.379780970981 15420.20205891927
Epoch 52 15127.629342912947 15557.364819335937
Epoch 53 15121.634928850446 15438.042553710937
Epoch 54 15099.034439174107 15534.147599283855
Epoch 55 15094.265426897322 15514.109830729167
Epoch 56 15085.38716936384 15606.600854492188
Epoch 57 15083.443056640624 15500.971036783854
Epoch 58 15135.612745535715 15390.095817057292
Epoch 59 15073.780837053571 15377.70673828125
Epoch 60 15083.388000837054 15333.91758219401

train defaultdict(<class 'list'>, {'precision@5': 0.8995073464938572, 'precision@10': 0.8318461295536587, 'precision@15': 0.7758557850973947, 'precisionv2@5': 0.4075315194470542, 'precisionv2@10': 0.42957983698163715, 'precisionv2@15': 0.4466827951158796, 'ndcg@5': 0.6610330922263009, 'ndcg@10': 0.6669441499028887, 'ndcg@15': 0.6663998075893947, 'ndcg@all': 0.8065805217197963})
valid defaultdict(<class 'list'>, {'precision@5': 0.7301388959089915, 'precision@10': 0.6162326614061991, 'precision@15': 0.5444097568591436, 'precision

Epoch 105 14811.238814174107 15243.623372395834
Epoch 106 14822.936492745535 15043.122371419271
Epoch 107 14814.785499441965 15254.355590820312
Epoch 108 14812.147712053571 15010.44208984375
Epoch 109 14882.328356584821 14813.889176432293
Epoch 110 14814.787572544643 14631.476033528646

train defaultdict(<class 'list'>, {'precision@5': 0.9515745728356497, 'precision@10': 0.9112258590970721, 'precision@15': 0.8729135666574751, 'precisionv2@5': 0.431231096472059, 'precisionv2@10': 0.46786975179399765, 'precisionv2@15': 0.4980024741377149, 'ndcg@5': 0.6984118308339801, 'ndcg@10': 0.7160727804047721, 'ndcg@15': 0.7241918856757028, 'ndcg@all': 0.8477770229748317})
valid defaultdict(<class 'list'>, {'precision@5': 0.7252430578072866, 'precision@10': 0.6159548739592234, 'precision@15': 0.5460532695055008, 'precisionv2@5': 0.2900694524248441, 'precisionv2@10': 0.28371528287728626, 'precisionv2@15': 0.28739585081736246, 'ndcg@5': 0.5064510236183802, 'ndcg@10': 0.48925537467002866, 'ndcg@15': 0.

KeyboardInterrupt: 

In [22]:
train_config = {
    "n_time": config['n_time'],
    "lr": config['lr'],
    "lr": 1e-3,
    "weight_decay": 0.0,
    "loss_topk": 15,
    
    "n_epoch": config['n_epoch'],
    "valid_epoch": config['valid_epoch'],
    "valid_verbose": True,
    "valid_topk": [5, 10, 15],
    
    "h_dim": config['h_dim'],
    "label_type": config['label_type'],
    "eval_f1": config['eval_f1'],
    "criterion": 'MSE3',#config['criterion'] 'ListNet_sigmoid_L1'
}

In [23]:
train_experiment(config['n_time'])

  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0 3.5174719766208105 3.6467280864715574

train defaultdict(<class 'list'>, {'precision@5': 0.5168287886892047, 'precision@10': 0.3931066252504076, 'precision@15': 0.32640022686549597, 'precisionv2@5': 0.22942017197608947, 'precisionv2@10': 0.21439286053180695, 'precisionv2@15': 0.20472935455186025, 'ndcg@5': 0.4110302218369075, 'ndcg@10': 0.37929179549217223, 'ndcg@15': 0.360973379101072, 'ndcg@all': 0.5606223147256034})
valid defaultdict(<class 'list'>, {'precision@5': 0.4778819551070531, 'precision@10': 0.359652782479922, 'precision@15': 0.2981481651465098, 'precisionv2@5': 0.20215278019507726, 'precisionv2@10': 0.18925347725550334, 'precisionv2@15': 0.18288195629914603, 'ndcg@5': 0.3710514545440674, 'ndcg@10': 0.34163162608941394, 'ndcg@15': 0.32486132681369784, 'ndcg@all': 0.5297597547372183})
test defaultdict(<class 'list'>, {'precision@5': 0.47048611837354576, 'precision@10': 0.35626776211641054, 'precision@15': 0.29457467015494, 'precisionv2@5': 0.19641730036925187, 'preci

Epoch 51 3.243140674318586 3.4596896022558212
Epoch 52 3.244120421069009 3.457479480902354
Epoch 53 3.2523085042408533 3.5343117038408915
Epoch 54 3.2608061368124828 3.4556353266040483
Epoch 55 3.241837829521724 3.4781583666801454
Epoch 56 3.2397736038480485 3.4580718954404195
Epoch 57 3.2402044888905115 3.459036545952161
Epoch 58 3.251692270210811 3.463331639766693
Epoch 59 3.243510801792145 3.469017752011617
Epoch 60 3.2388000791413445 3.4868282397588093

train defaultdict(<class 'list'>, {'precision@5': 0.9505399097715105, 'precision@10': 0.9475703978538513, 'precision@15': 0.9420242077963693, 'precisionv2@5': 0.3839327791758946, 'precisionv2@10': 0.48269748534475054, 'precisionv2@15': 0.5568515665190561, 'ndcg@5': 0.5678244757652283, 'ndcg@10': 0.6395109016554696, 'ndcg@15': 0.6806538091387068, 'ndcg@all': 0.8269202542304993})
valid defaultdict(<class 'list'>, {'precision@5': 0.841284720102946, 'precision@10': 0.7495659788449606, 'precision@15': 0.6740509549776713, 'precisionv2@5':

Epoch 105 3.2267881989479066 3.4747397263844806
Epoch 106 3.226768102305276 3.4907462120056154
Epoch 107 3.2268141402517045 3.472147125005722
Epoch 108 3.2305238590921674 3.4631255090236666
Epoch 109 3.225840926340648 3.48251300851504
Epoch 110 3.2280458644458228 3.5997073928515118

train defaultdict(<class 'list'>, {'precision@5': 0.9699947541100639, 'precision@10': 0.9724322669846671, 'precision@15': 0.9705840601239886, 'precisionv2@5': 0.38705147845404486, 'precisionv2@10': 0.5052626124450139, 'precisionv2@15': 0.5887167651312691, 'ndcg@5': 0.5688244346209935, 'ndcg@10': 0.6474710461071559, 'ndcg@15': 0.6925681403705052, 'ndcg@all': 0.8328872122083392})
valid defaultdict(<class 'list'>, {'precision@5': 0.8415625055631002, 'precision@10': 0.7446007033189138, 'precision@15': 0.6727778156598408, 'precisionv2@5': 0.2307986135284106, 'precisionv2@10': 0.2632638951142629, 'precisionv2@15': 0.2879051109155019, 'ndcg@5': 0.39765285750230156, 'ndcg@10': 0.42293761869271596, 'ndcg@15': 0.4328

Epoch 158 3.2254828279359 3.4987058341503143
Epoch 159 3.2229819355692184 3.5139937818050386
Epoch 160 3.2234310940333777 3.504314923286438

train defaultdict(<class 'list'>, {'precision@5': 0.9788518990789141, 'precision@10': 0.9806759609494891, 'precision@15': 0.9788445571490697, 'precisionv2@5': 0.37827626705169676, 'precisionv2@10': 0.5075346672534943, 'precisionv2@15': 0.5921579422269548, 'ndcg@5': 0.5525781600815909, 'ndcg@10': 0.6379822843415397, 'ndcg@15': 0.6854297004427229, 'ndcg@all': 0.8277028635569981})
valid defaultdict(<class 'list'>, {'precision@5': 0.8499305486679077, 'precision@10': 0.7568229337533315, 'precision@15': 0.6780671755472819, 'precisionv2@5': 0.21798611531654993, 'precisionv2@10': 0.2656250054637591, 'precisionv2@15': 0.29134261012077334, 'ndcg@5': 0.3879659175872803, 'ndcg@10': 0.42095714807510376, 'ndcg@15': 0.42920816044012705, 'ndcg@all': 0.6183523495992025})
test defaultdict(<class 'list'>, {'precision@5': 0.8468158116394823, 'precision@10': 0.7535570


train defaultdict(<class 'list'>, {'precision@5': 0.9821250074250357, 'precision@10': 0.9841696510996137, 'precision@15': 0.9825234746932984, 'precisionv2@5': 0.361500004189355, 'precisionv2@10': 0.49240074004445755, 'precisionv2@15': 0.5822262150900704, 'ndcg@5': 0.5438951221534184, 'ndcg@10': 0.6287801844733102, 'ndcg@15': 0.6780559069769723, 'ndcg@all': 0.8248758220672607})
valid defaultdict(<class 'list'>, {'precision@5': 0.8377777814865113, 'precision@10': 0.7384895920753479, 'precision@15': 0.6598611384630203, 'precisionv2@5': 0.21642361481984457, 'precisionv2@10': 0.25706597516934077, 'precisionv2@15': 0.2821064939101537, 'ndcg@5': 0.37954379121462506, 'ndcg@10': 0.4096948156754176, 'ndcg@15': 0.4184723675251007, 'ndcg@all': 0.610534726579984})
test defaultdict(<class 'list'>, {'precision@5': 0.8409485417333517, 'precision@10': 0.7466224824840372, 'precision@15': 0.6672506603327665, 'precisionv2@5': 0.21375868384811011, 'precisionv2@10': 0.24991517209193922, 'precisionv2@15': 0

Epoch 261 3.219148589202336 3.4714733342329662
Epoch 262 3.229757720061711 3.476089155673981
Epoch 263 3.227980441025325 3.486685134967168
Epoch 264 3.2197472766467503 3.6013853629430135
Epoch 265 3.2195511606761387 3.4777337431907656
Epoch 266 3.220165809563228 3.487168691555659
Epoch 267 3.219076968601772 3.482224969069163
Epoch 268 3.219322086742946 3.5075738787651063
Epoch 269 3.2234771783011302 3.4828306714693706
Epoch 270 3.2191412006105695 3.4747190892696382

train defaultdict(<class 'list'>, {'precision@5': 0.9826428641591753, 'precision@10': 0.9858839358602252, 'precision@15': 0.9842797766413007, 'precisionv2@5': 0.34819013237953184, 'precisionv2@10': 0.4816113473687853, 'precisionv2@15': 0.56914393390928, 'ndcg@5': 0.5319439630849021, 'ndcg@10': 0.6168679370198931, 'ndcg@15': 0.6664136402947562, 'ndcg@all': 0.8195131312097822})
valid defaultdict(<class 'list'>, {'precision@5': 0.8273263951142629, 'precision@10': 0.7419965326786041, 'precision@15': 0.666273178656896, 'precisio

[[{'epoch': 0,
   'train': defaultdict(list,
               {'precision@5': 0.5168287886892047,
                'precision@10': 0.3931066252504076,
                'precision@15': 0.32640022686549597,
                'precisionv2@5': 0.22942017197608947,
                'precisionv2@10': 0.21439286053180695,
                'precisionv2@15': 0.20472935455186025,
                'ndcg@5': 0.4110302218369075,
                'ndcg@10': 0.37929179549217223,
                'ndcg@15': 0.360973379101072,
                'ndcg@all': 0.5606223147256034}),
   'valid': defaultdict(list,
               {'precision@5': 0.4778819551070531,
                'precision@10': 0.359652782479922,
                'precision@15': 0.2981481651465098,
                'precisionv2@5': 0.20215278019507726,
                'precisionv2@10': 0.18925347725550334,
                'precisionv2@15': 0.18288195629914603,
                'ndcg@5': 0.3710514545440674,
                'ndcg@10': 0.34163162608941394,
   

In [None]:
# save config, training config
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
    json.dump(config, f)
with open(os.path.join(save_dir, 'train_config.json'), 'w') as f:
    json.dump(train_config, f)