## Load

In [1]:
import os
import sys

from collections import defaultdict

import numpy as np 
import pandas as pd
import random

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

sys.path.append('../')
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all
from utils.loss import ListNet, ListNet2, ListNet_origin, MultiLabelMarginLossCustom, MultiLabelMarginLossCustomV, MSE
from utils.data_processing import get_process_data

seed = 33

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chrisliu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
embedding_type = ''
dataset = '20news'
documentembedding_normalize = True

embedding_dim = 128
data = get_process_data(dataset='20news', agg='IDF', embedding_type=embedding_type, 
                     word2embedding_path='../data/glove.6B.100d.txt', word2embedding_normalize=False,
                     documentembedding_normalize=documentembedding_normalize,
                     embedding_dim=embedding_dim, max_seq_length=128,
                     load_embedding=True)

document_TFIDF = np.array(data["document_word_weight"])
document_vectors = np.array(data["document_embedding"])

Loading word2embedding from ../data/glove.6B.100d.txt


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Number of words:400000


HBox(children=(IntProgress(value=0, description='Start buiding vocabulary...', max=18846, style=ProgressStyle(…


doc num 18846
eliminate freq words
Load from saving
delete items 150


In [3]:
config = {}
config["topk"] = [10, 30, 50]

## MLP Decoder

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'


In [5]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [6]:
def prepare_dataloader(batch_size=100, train_size_ratio=0.8, topk=50, TFIDF_normalize=False):
    train_size = int(len(document_vectors) * train_size_ratio)
    
    print('train size', train_size)
    print('valid size', len(document_vectors) - train_size)

    if TFIDF_normalize:
        # normalize TFIDF summation of each document to 1 
        norm = document_TFIDF.sum(axis=1).reshape(-1, 1)
        document_TFIDF_ = (document_TFIDF / norm)
        # normalize TFIDF L2 norm of each document to 1
        # norm = np.linalg.norm(document_TFIDF, axis=1).reshape(-1, 1)
        # document_TFIDF_ = (document_TFIDF / norm)
    else:
        document_TFIDF_ = document_TFIDF
    
    # shuffle
    randomize = np.arange(len(document_vectors))
    np.random.shuffle(randomize)
    document_vectors_ = document_vectors[randomize]
    document_TFIDF_ = document_TFIDF_[randomize]
    
    # dataloader
    train_dataset = MLPDecoderDataset(document_vectors_[:train_size], document_TFIDF_[:train_size])
    train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    valid_dataset = MLPDecoderDataset(document_vectors_[train_size:], document_TFIDF_[train_size:])
    valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    return train_loader, valid_loader


In [7]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [8]:
def evaluate_MLPDecoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [9]:
# train_loader, valid_loader = prepare_dataloader(batch_size=100,\
#                                                 train_size_ratio=0.8, topk=30,
#                                                 TFIDF_normalize=True)
# for data in train_loader:
#     doc_embs, target, target_rank, target_topk = data
#     print(doc_embs.shape)
#     print(target.shape)
#     print(target_rank.shape)
#     print(target_topk.shape)
#     break

In [16]:
def train_decoder(config):
    model = MLPDecoder(
        doc_emb_dim=document_vectors.shape[1], num_words=document_TFIDF.shape[1], h_dim=config["h_dim"]).to(device)
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"],
                          weight_decay=config["weight_decay"])
    # prepare loss
    if config["criterion"] == "MultiLabelMarginLoss":
        criterion = nn.MultiLabelMarginLoss(reduction='mean')
    elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        def criterion(a, b, c): return MultiLabelMarginLossCustomV(
            a, b, c, float(config["criterion"].split(':')[-1]))
    elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
        def criterion(a, b, c): return MultiLabelMarginLossCustom(
            a, b, c, float(config["criterion"].split(':')[-1]))
    elif config["criterion"] == "ListNet":
        criterion = ListNet_origin
    elif config["criterion"] == "ListNet2":
        criterion = ListNet2
    elif config["criterion"] == "MSE":
        criterion = MSE
    else:
        print("loss not found")
        return

    results = []
    n_epoch = config["n_epoch"]
    valid_epoch = config["valid_epoch"]
    verbose = config["verbose"]

    for epoch in tqdm(range(n_epoch)):
        train_loss_his = []
        valid_loss_his = []

        model.train()

        for data in train_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            if config["criterion"] == "MultiLabelMarginLoss":
                target_rank[:, config["topk"]] = -1
                loss = criterion(pred, target_rank)
            elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
                loss = criterion(pred, target_rank, target_topk)
            elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
                loss = criterion(pred, target_rank, config["topk"])
            elif config["criterion"] == "ListNet":
                loss = criterion(pred, target)
            elif config["criterion"] == "ListNet2":
                loss = criterion(pred, target)
            elif config["criterion"] == "MSE":
                loss = criterion(pred, target)

            train_loss_his.append(loss.item())

            # Model backwarding
            model.zero_grad()
            loss.backward()
            opt.step()

        model.eval()
        for data in valid_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            if config["criterion"] == "MultiLabelMarginLoss":
                target_rank[:, config["topk"]] = -1
                loss = criterion(pred, target_rank)
            elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
                loss = criterion(pred, target_rank, target_topk)
            elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
                loss = criterion(pred, target_rank, config["topk"])
            elif config["criterion"] == "ListNet":
                loss = criterion(pred, target)
            elif config["criterion"] == "ListNet2":
                loss = criterion(pred, target)
            elif config["criterion"] == "MSE":
                loss = criterion(pred, target)

            valid_loss_his.append(loss.item())

        print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))

        # show decoder result
        if epoch % valid_epoch == 0:
            res = {}
            res['epoch'] = epoch

            train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
            valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)

            res.update(valid_res_ndcg)
            results.append(res)

            if verbose:
                print()
                print('train', train_res_ndcg)
                print('valid', valid_res_ndcg)

In [12]:
# prepare dataloader
train_loader, valid_loader = prepare_dataloader(batch_size=100,\
                                                train_size_ratio=0.8,
                                                TFIDF_normalize=False)

train size 14956
valid size 3740


In [14]:
train_config = {
    "lr": 0.05,
    "momentum": 0.0,
    "weight_decay": 0.0,
    
    "n_epoch": 200,
    "verbose": True,
    "valid_epoch": 10,
    
    "topk": 50,
    
    "h_dim": 3000,
    "criterion": "MultiLabelMarginLoss", # "ListNet",
    "TFIDF_normalize": False
}

train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 12.926986516316731 10.367545755285965

train defaultdict(<class 'list'>, {'precision@10': 0.01298571437364444, 'precision@30': 0.022366032687326273, 'precision@50': 0.03565847411751747, 'ndcg@10': 0.007084246649465058, 'ndcg@30': 0.014139637804279724, 'ndcg@50': 0.025549024467666943, 'ndcg@all': 0.30848127603530884})
valid defaultdict(<class 'list'>, {'precision@10': 0.013750000139943472, 'precision@30': 0.022337719749071096, 'precision@50': 0.03500789276471263, 'ndcg@10': 0.007297709279411815, 'ndcg@30': 0.013929717696124786, 'ndcg@50': 0.024690042151824423, 'ndcg@all': 0.30443468454637024})
Epoch 1 8.729334386189779 8.618516645933452
Epoch 2 6.989227790832519 7.60809178101389
Epoch 3 5.855995915730794 7.009011130583914
Epoch 4 5.150923207600911 6.635706223939595
Epoch 5 4.5350236924489336 6.336806234560515
Epoch 6 4.109379731814067 6.112365509334364
Epoch 7 3.706019247372945 5.929456986879048
Epoch 8 3.3766524442036947 5.755431225425319
Epoch 9 3.0935816542307535 5.6257014651

Epoch 81 0.5511282507578532 4.777822419216759
Epoch 82 0.549428324898084 4.773677612605848
Epoch 83 0.540177527864774 4.776716200928939
Epoch 84 0.5391905397176743 4.80451930824079
Epoch 85 0.5353025472164155 4.795261081896331
Epoch 86 0.5354098961750666 4.794589895951121
Epoch 87 0.532573949098587 4.81730748477735
Epoch 88 0.5242013980944952 4.773377838887666
Epoch 89 0.5177740911642711 4.787182255795128
Epoch 90 0.5155039509137471 4.8031771810431225

train defaultdict(<class 'list'>, {'precision@10': 0.3521447577079137, 'precision@30': 0.31940032680829367, 'precision@50': 0.3270316008726756, 'ndcg@10': 0.3110571692387263, 'ndcg@30': 0.34199759940306346, 'ndcg@50': 0.4067611628770828, 'ndcg@all': 0.6038625693321228})
valid defaultdict(<class 'list'>, {'precision@10': 0.28424999666841405, 'precision@30': 0.20931579760814967, 'precision@50': 0.1879263053599157, 'ndcg@10': 0.2683543378585263, 'ndcg@30': 0.262510457713353, 'ndcg@50': 0.2830406176416497, 'ndcg@all': 0.5091669418309864})
Ep

Epoch 163 0.3637645554542541 4.85907695167943
Epoch 164 0.3641070870558421 4.8763046515615365
Epoch 165 0.36793646673361463 4.861670864255805
Epoch 166 0.36293113211790723 4.849828255803962
Epoch 167 0.3596035075187683 4.867250273102208
Epoch 168 0.3581301752726237 4.881314390584042
Epoch 169 0.3604918046792348 4.851338135568719
Epoch 170 0.3581660771369934 4.87010141422874

train defaultdict(<class 'list'>, {'precision@10': 0.37580571313699085, 'precision@30': 0.38567620078722636, 'precision@50': 0.4044636925061544, 'ndcg@10': 0.3183146751920382, 'ndcg@30': 0.3786684077978134, 'ndcg@50': 0.46218003153800963, 'ndcg@all': 0.6206560742855072})
valid defaultdict(<class 'list'>, {'precision@10': 0.2770131571512473, 'precision@30': 0.21406579723483637, 'precision@50': 0.19389472705753227, 'ndcg@10': 0.2607403958314343, 'ndcg@30': 0.2613470730812926, 'ndcg@50': 0.2835346487791915, 'ndcg@all': 0.5069160869247035})
Epoch 171 0.35632449984550474 4.857083602955467
Epoch 172 0.35544779976209007 4

In [17]:
train_config["lr"] = 0.05
train_config["criterion"] = "ListNet"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))



Epoch 0 8.779600276947022 8.270802824120773

train defaultdict(<class 'list'>, {'precision@10': 0.13154476121068, 'precision@30': 0.07283190704882145, 'precision@50': 0.05430942522982756, 'ndcg@10': 0.1406103073557218, 'ndcg@30': 0.12513888771335283, 'ndcg@50': 0.12488226979970932, 'ndcg@all': 0.34428453107674917})
valid defaultdict(<class 'list'>, {'precision@10': 0.1273947364013446, 'precision@30': 0.0715043890829149, 'precision@50': 0.05351578640310388, 'ndcg@10': 0.12529076792691884, 'ndcg@30': 0.11258802641379206, 'ndcg@50': 0.11253306916669796, 'ndcg@all': 0.33292802776160996})
Epoch 1 7.522814617156983 7.6988303284896045
Epoch 2 6.711652297973632 7.276042687265496
Epoch 3 6.039139928817749 6.984229326248169
Epoch 4 5.485612293879191 6.712406735671194
Epoch 5 5.02360135714213 6.516861664621453
Epoch 6 4.629629158973694 6.332073952022352
Epoch 7 4.302843651771545 6.1994853019714355
Epoch 8 3.983352425893148 6.070155206479524
Epoch 9 3.719247449239095 5.967334910442955
Epoch 10 3.5

Epoch 81 1.8695633323987324 4.948344193006816
Epoch 82 1.869421784877777 4.935971649069535
Epoch 83 1.8689814623196919 4.943802406913356
Epoch 84 1.8655093137423198 4.9573658516532495
Epoch 85 1.8632086277008058 4.925061953695197
Epoch 86 1.8618929441769918 4.940608363402517
Epoch 87 1.8600742880503336 4.925412441554823
Epoch 88 1.859886171023051 4.938482297094245
Epoch 89 1.8592898281415304 4.9342898946059375
Epoch 90 1.8559330455462137 4.9233393920095345

train defaultdict(<class 'list'>, {'precision@10': 0.4259928532441457, 'precision@30': 0.21371714959541957, 'precision@50': 0.15368799209594727, 'ndcg@10': 0.6284285191694895, 'ndcg@30': 0.5171652140220007, 'ndcg@50': 0.5029019131263097, 'ndcg@all': 0.6728660102685292})
valid defaultdict(<class 'list'>, {'precision@10': 0.3552894662869604, 'precision@30': 0.18916228492009013, 'precision@50': 0.13875525562386765, 'ndcg@10': 0.45553841245801824, 'ndcg@30': 0.3872027475582926, 'ndcg@50': 0.3794813548263751, 'ndcg@all': 0.56635421514511

Epoch 164 1.7989327494303387 4.8244124462730005
Epoch 165 1.7980794254938761 4.817323684692383
Epoch 166 1.8013844037055968 4.817371148812144
Epoch 167 1.8024095280965169 4.820573618537502
Epoch 168 1.7981065273284913 4.807498216629028
Epoch 169 1.7989592695236205 4.813499212265015
Epoch 170 1.8000410787264507 4.815797542270861

train defaultdict(<class 'list'>, {'precision@10': 0.4648009479045868, 'precision@30': 0.2314939754207929, 'precision@50': 0.16556522885958352, 'ndcg@10': 0.6647539071242015, 'ndcg@30': 0.5472718973954519, 'ndcg@50': 0.5324799817800522, 'ndcg@all': 0.6965645356973013})
valid defaultdict(<class 'list'>, {'precision@10': 0.3735394642541283, 'precision@30': 0.19907456871710325, 'precision@50': 0.1465973603098016, 'ndcg@10': 0.4728847012708062, 'ndcg@30': 0.4030431297264601, 'ndcg@50': 0.3959691838214272, 'ndcg@all': 0.5800005526919114})
Epoch 171 1.8011258220672608 4.798938312028584
Epoch 172 1.7991087047259013 4.808155988392077
Epoch 173 1.798958134651184 4.82770

In [23]:
train_config["lr"] = 0.05
train_config["criterion"] = "MultiLabelMarginLossCustomV:1"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 13.785018844604492 11.420345331493177

train defaultdict(<class 'list'>, {'precision@10': 0.13868428488572437, 'precision@30': 0.11677159165342649, 'precision@50': 0.10457894638180733, 'ndcg@10': 0.05765215064088503, 'ndcg@30': 0.0723681225379308, 'ndcg@50': 0.08430370658636094, 'ndcg@all': 0.34371833463509877})
valid defaultdict(<class 'list'>, {'precision@10': 0.13090789474939046, 'precision@30': 0.11157895056040663, 'precision@50': 0.09974473185445133, 'ndcg@10': 0.05500796162768414, 'ndcg@30': 0.06892378157690952, 'ndcg@50': 0.08011008544187796, 'ndcg@all': 0.33820266080530065})
Epoch 1 9.88448076248169 9.716178718366121
Epoch 2 8.294592364629109 8.736898020694131
Epoch 3 7.295467777252197 8.213632006394235
Epoch 4 6.566032371520996 7.711031813370554
Epoch 5 6.028131144841512 7.443974193773772
Epoch 6 5.6148630809783935 7.217059599725824
Epoch 7 5.289643580118815 7.092882106178685
Epoch 8 5.041446418762207 6.93988813852009
Epoch 9 4.757912985483805 6.7996297886497095
Epoch 

Epoch 81 2.605337845484416 5.870275572726601
Epoch 82 2.703078950246175 5.895386055896156
Epoch 83 2.6239482665061953 5.81911593989322
Epoch 84 2.618926386833191 5.869568247544138
Epoch 85 2.6996029949188234 5.9395816577108285
Epoch 86 2.7737104940414428 5.892573431918495
Epoch 87 2.7875453853607177 5.907250667873182
Epoch 88 2.7943146721522014 5.887305435381438
Epoch 89 2.779125566482544 5.9554048211951
Epoch 90 2.7921983861923216 5.904024124145508

train defaultdict(<class 'list'>, {'precision@10': 0.3659447588523229, 'precision@30': 0.2638265175620715, 'precision@50': 0.2219936090707779, 'ndcg@10': 0.2779722765088081, 'ndcg@30': 0.2861177777250608, 'ndcg@50': 0.3059681864579519, 'ndcg@all': 0.5368692926565806})
valid defaultdict(<class 'list'>, {'precision@10': 0.3346184147031684, 'precision@30': 0.23015351436640086, 'precision@50': 0.19005262224297775, 'ndcg@10': 0.2612951578278291, 'ndcg@30': 0.25935419648885727, 'ndcg@50': 0.2723833824458875, 'ndcg@all': 0.4994989233581643})
Epoc

Epoch 164 2.8724117994308473 5.926890875163831
Epoch 165 2.9043995730082193 6.031179428100586
Epoch 166 2.9729225126902263 5.987993202711406
Epoch 167 2.895872712135315 5.976330656754343
Epoch 168 2.8952412287394207 6.034680918643349
Epoch 169 2.9936668078104653 5.996246199858816
Epoch 170 2.97880704720815 6.01840803497716

train defaultdict(<class 'list'>, {'precision@10': 0.3603338094552358, 'precision@30': 0.25530826290448505, 'precision@50': 0.21429931998252869, 'ndcg@10': 0.2734780192375183, 'ndcg@30': 0.2795998146136602, 'ndcg@50': 0.2981928692261378, 'ndcg@all': 0.5300064051151275})
valid defaultdict(<class 'list'>, {'precision@10': 0.3283421012916063, 'precision@30': 0.22341667311756233, 'precision@50': 0.18446577928568186, 'ndcg@10': 0.25525658695321335, 'ndcg@30': 0.2529666302235503, 'ndcg@50': 0.26547503392947347, 'ndcg@all': 0.49358571203131424})
Epoch 171 2.8870508734385174 5.982695529335423
Epoch 172 2.8800257364908854 5.9617159994024975
Epoch 173 2.8879299036661785 5.979

In [21]:
train_config["lr"] = 0.1
train_config["criterion"] = "MultiLabelMarginLossCustom:1"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 14.189671160380046 12.044600235788446

train defaultdict(<class 'list'>, {'precision@10': 0.00668523816857487, 'precision@30': 0.01296619102358818, 'precision@50': 0.02072580819949508, 'ndcg@10': 0.003200937565610123, 'ndcg@30': 0.007910092350405951, 'ndcg@50': 0.015334676609685023, 'ndcg@all': 0.29739442547162376})
valid defaultdict(<class 'list'>, {'precision@10': 0.006526315863562846, 'precision@30': 0.012289474179085932, 'precision@50': 0.020223682844325117, 'ndcg@10': 0.0027818153194905455, 'ndcg@30': 0.007231580197664076, 'ndcg@50': 0.01449536164536288, 'ndcg@all': 0.2938428411358281})
Epoch 1 9.651857067743936 10.137314319610596
Epoch 2 7.840189949671427 8.873016056261564
Epoch 3 6.706648476918539 8.130282966714157
Epoch 4 5.956039934158325 7.771389459308825
Epoch 5 5.392199691136678 7.342848677384226
Epoch 6 4.9237745110193885 7.035325025257311
Epoch 7 4.514802223841349 6.704293363972714
Epoch 8 4.192282379468282 6.538900475752981
Epoch 9 3.950083859761556 6.44116730439

Epoch 81 1.7574559537569683 5.453842552084672
Epoch 82 1.7849850209554037 5.452479839324951
Epoch 83 1.7739443786938984 5.470519442307322
Epoch 84 1.7811320956548056 5.427326340424387
Epoch 85 1.7379778997103372 5.526796177813881
Epoch 86 1.846744983990987 5.511223215805857
Epoch 87 1.8463557974497478 5.51557637515821
Epoch 88 1.8167947141329448 5.534933127855
Epoch 89 1.85169970591863 5.516656863062005
Epoch 90 1.8152916471163432 5.4923607550169296

train defaultdict(<class 'list'>, {'precision@10': 0.19620618869860967, 'precision@30': 0.15682206849257152, 'precision@50': 0.15701399127642313, 'ndcg@10': 0.18435644408067067, 'ndcg@30': 0.19086122145255408, 'ndcg@50': 0.22061039944489796, 'ndcg@all': 0.48840649286905924})
valid defaultdict(<class 'list'>, {'precision@10': 0.17625000112150846, 'precision@30': 0.1326622847271593, 'precision@50': 0.12824736221840508, 'ndcg@10': 0.17182862562568565, 'ndcg@30': 0.17198735514753744, 'ndcg@50': 0.1925171813682506, 'ndcg@all': 0.445753847297869

Epoch 164 1.8556201553344727 5.517834901809692
Epoch 165 1.810522468884786 5.521107460323133
Epoch 166 1.7794555075963339 5.4911905087922745
Epoch 167 1.770076707204183 5.522306442260742
Epoch 168 1.7800941244761148 5.48606568888614
Epoch 169 1.7732719572385152 5.472409662447478
Epoch 170 1.7868926405906678 5.465866816671271

train defaultdict(<class 'list'>, {'precision@10': 0.28446571191151937, 'precision@30': 0.20244794309139252, 'precision@50': 0.18375551462173462, 'ndcg@10': 0.244930848578612, 'ndcg@30': 0.24263023714224496, 'ndcg@50': 0.266967919866244, 'ndcg@all': 0.522479189435641})
valid defaultdict(<class 'list'>, {'precision@10': 0.24886842032796458, 'precision@30': 0.16892544689931369, 'precision@50': 0.14654472901632912, 'ndcg@10': 0.22425252042318644, 'ndcg@30': 0.2158024146368629, 'ndcg@50': 0.22960590257456429, 'ndcg@all': 0.47417959727739034})
Epoch 171 1.7884100739161173 5.521323618135955
Epoch 172 1.7871321884791056 5.539810456727681
Epoch 173 1.7888223385810853 5.56

In [20]:
train_config["lr"] = 1e-4
train_config["criterion"] = "MSE"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 9472.620388997397 5878.876917788857

train defaultdict(<class 'list'>, {'precision@10': 0.04449952368934949, 'precision@30': 0.045631430484354495, 'precision@50': 0.04647914019723733, 'ndcg@10': 0.02258770486339927, 'ndcg@30': 0.029797555406888325, 'ndcg@50': 0.03529148083180189, 'ndcg@all': 0.2798873488108317})
valid defaultdict(<class 'list'>, {'precision@10': 0.043671052118665295, 'precision@30': 0.04518421249170052, 'precision@50': 0.046231575898433984, 'ndcg@10': 0.021337154247847042, 'ndcg@30': 0.028417121128816353, 'ndcg@50': 0.03379312039990174, 'ndcg@all': 0.27857690892721476})
Epoch 1 8927.910372721353 5822.181839792352
Epoch 2 8895.243605143229 5869.760183233964
Epoch 3 8817.14533203125 6464.537555895354
Epoch 4 8759.221715494792 6121.648431075247
Epoch 5 8727.118600260417 5754.325709292763
Epoch 6 8694.94938639323 5739.677991365132
Epoch 7 8689.030289713543 5801.3423879523025
Epoch 8 8638.267360026042 5798.513703998767
Epoch 9 8626.295239257812 5744.1158061780425
Ep

Epoch 81 7947.494143880208 5870.22308670847
Epoch 82 7896.0198583984375 5770.056878340872
Epoch 83 7862.241801757813 5735.927862870066
Epoch 84 7851.945504557291 5726.703420538652
Epoch 85 7844.1040869140625 5705.206285978618
Epoch 86 7830.0401009114585 5766.686565198396
Epoch 87 7829.595776367188 5668.455229106702
Epoch 88 7828.185068359375 5668.8590473375825
Epoch 89 7828.964311523438 5739.909831799959
Epoch 90 7809.977560221354 5672.380583110608

train defaultdict(<class 'list'>, {'precision@10': 0.04511571401109298, 'precision@30': 0.05089174774785837, 'precision@50': 0.05041142580409845, 'ndcg@10': 0.03714833003158371, 'ndcg@30': 0.04717261114468177, 'ndcg@50': 0.053961068664987885, 'ndcg@all': 0.30652166505654654})
valid defaultdict(<class 'list'>, {'precision@10': 0.04121052615932728, 'precision@30': 0.047385966601340396, 'precision@50': 0.04701315552780503, 'ndcg@10': 0.029273265237478835, 'ndcg@30': 0.038242811239079424, 'ndcg@50': 0.04410333475588184, 'ndcg@all': 0.2973239657

Epoch 163 7593.870861816406 5730.920904862253
Epoch 164 7575.586459960938 5647.609660901521
Epoch 165 7593.343510742187 5700.023572419819
Epoch 166 7575.219254557292 5698.373766447368
Epoch 167 7578.9615087890625 5981.636905067845
Epoch 168 7597.148253580729 5699.799464175576
Epoch 169 7577.474934895833 6271.2813720703125
Epoch 170 7571.775971679687 5693.499338250411

train defaultdict(<class 'list'>, {'precision@10': 0.09989571412404379, 'precision@30': 0.07882603406906127, 'precision@50': 0.06763104401528836, 'ndcg@10': 0.05917069253822168, 'ndcg@30': 0.06710924245417119, 'ndcg@50': 0.07361726276576519, 'ndcg@all': 0.3241467614968618})
valid defaultdict(<class 'list'>, {'precision@10': 0.09147368302862895, 'precision@30': 0.07183772119644441, 'precision@50': 0.06268157613904853, 'ndcg@10': 0.04699867865756938, 'ndcg@30': 0.053081859344322434, 'ndcg@50': 0.05927741155028343, 'ndcg@all': 0.31071653805280985})
Epoch 171 7568.2203051757815 5657.2859529194075
Epoch 172 7563.521949869792 5