## Load

In [1]:
import os
import sys

from collections import defaultdict

import numpy as np 
import pandas as pd
import random

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

sys.path.append('../')
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all
from utils.loss import ListNet, ListNet2, ListNet_origin, MultiLabelMarginLossCustom, MultiLabelMarginLossCustomV, MSE
from utils.data_processing import get_process_data

seed = 33

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chrisliu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
embedding_type = ''
dataset = '20news'
documentembedding_normalize = False

embedding_dim = 128
data = get_process_data(dataset='20news', agg='IDF', embedding_type=embedding_type, 
                     word2embedding_path='../data/glove.6B.100d.txt', word2embedding_normalize=False,
                     documentembedding_normalize=documentembedding_normalize,
                     embedding_dim=embedding_dim, max_seq_length=128,
                     load_embedding=True)

document_TFIDF = np.array(data["document_word_weight"])
document_vectors = np.array(data["document_embedding"])

Loading word2embedding from ../data/glove.6B.100d.txt


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Number of words:400000


HBox(children=(IntProgress(value=0, description='Start buiding vocabulary...', max=18846, style=ProgressStyle(…


doc num 18846
eliminate freq words
Vocabulary size:15029, Word embedding dim:100
Initial word weight


HBox(children=(IntProgress(value=0, description='Calculate document vectors...', max=18846, style=ProgressStyl…


delete items 150


In [3]:
config = {}
config["topk"] = [10, 30, 50]

## MLP Decoder

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'


In [5]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [6]:
def prepare_dataloader(batch_size=100, train_size_ratio=0.8, topk=50, TFIDF_normalize=False):
    train_size = int(len(document_vectors) * train_size_ratio)
    
    print('train size', train_size)
    print('valid size', len(document_vectors) - train_size)

    if TFIDF_normalize:
        # normalize TFIDF summation of each document to 1 
        norm = document_TFIDF.sum(axis=1).reshape(-1, 1)
        document_TFIDF_ = (document_TFIDF / norm)
        # normalize TFIDF L2 norm of each document to 1
        # norm = np.linalg.norm(document_TFIDF, axis=1).reshape(-1, 1)
        # document_TFIDF_ = (document_TFIDF / norm)
    else:
        document_TFIDF_ = document_TFIDF
    
    # shuffle
    randomize = np.arange(len(document_vectors))
    np.random.shuffle(randomize)
    document_vectors_ = document_vectors[randomize]
    document_TFIDF_ = document_TFIDF_[randomize]
    
    # dataloader
    train_dataset = MLPDecoderDataset(document_vectors_[:train_size], document_TFIDF_[:train_size])
    train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    valid_dataset = MLPDecoderDataset(document_vectors_[train_size:], document_TFIDF_[train_size:])
    valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    return train_loader, valid_loader


In [7]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [8]:
def evaluate_MLPDecoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [9]:
# train_loader, valid_loader = prepare_dataloader(batch_size=100,\
#                                                 train_size_ratio=0.8, topk=30,
#                                                 TFIDF_normalize=True)
# for data in train_loader:
#     doc_embs, target, target_rank, target_topk = data
#     print(doc_embs.shape)
#     print(target.shape)
#     print(target_rank.shape)
#     print(target_topk.shape)
#     break

In [10]:
def train_decoder(config):
    model = MLPDecoder(
        doc_emb_dim=document_vectors.shape[1], num_words=document_TFIDF.shape[1], h_dim=config["h_dim"]).to(device)
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"],
                          weight_decay=config["weight_decay"])
    # prepare loss
    if config["criterion"] == "MultiLabelMarginLoss":
        criterion = nn.MultiLabelMarginLoss(reduction='mean')
    elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        def criterion(a, b, c): return MultiLabelMarginLossCustomV(
            a, b, c, float(config["criterion"].split(':')[-1]))
    elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
        def criterion(a, b, c): return MultiLabelMarginLossCustom(
            a, b, c, float(config["criterion"].split(':')[-1]))
    elif config["criterion"] == "ListNet":
        criterion = ListNet_origin
    elif config["criterion"] == "ListNet2":
        criterion = ListNet2
    elif config["criterion"] == "MSE":
        criterion = MSE
    else:
        print("loss not found")
        return

    results = []
    n_epoch = config["n_epoch"]
    valid_epoch = config["valid_epoch"]
    verbose = config["verbose"]

    for epoch in tqdm(range(n_epoch)):
        train_loss_his = []
        valid_loss_his = []

        model.train()

        for data in train_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            if config["criterion"] == "MultiLabelMarginLoss":
                target_rank[:, config["topk"]] = -1
                loss = criterion(pred, target_rank)
            elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
                loss = criterion(pred, target_rank, target_topk)
            elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
                loss = criterion(pred, target_rank, config["topk"])
            elif config["criterion"] == "ListNet":
                loss = criterion(pred, target)
            elif config["criterion"] == "ListNet2":
                loss = criterion(pred, target)
            elif config["criterion"] == "MSE":
                loss = criterion(pred, target)

            train_loss_his.append(loss.item())

            # Model backwarding
            model.zero_grad()
            loss.backward()
            opt.step()

        model.eval()
        for data in valid_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            if config["criterion"] == "MultiLabelMarginLoss":
                target_rank[:, config["topk"]] = -1
                loss = criterion(pred, target_rank)
            elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
                loss = criterion(pred, target_rank, target_topk)
            elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
                loss = criterion(pred, target_rank, config["topk"])
            elif config["criterion"] == "ListNet":
                loss = criterion(pred, target)
            elif config["criterion"] == "ListNet2":
                loss = criterion(pred, target)
            elif config["criterion"] == "MSE":
                loss = criterion(pred, target)

            valid_loss_his.append(loss.item())

        print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))

        # show decoder result
        if epoch % valid_epoch == 0:
            res = {}
            res['epoch'] = epoch

            train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
            valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)

            res.update(valid_res_ndcg)
            results.append(res)

            if verbose:
                print()
                print('train', train_res_ndcg)
                print('valid', valid_res_ndcg)

In [11]:
# prepare dataloader
train_loader, valid_loader = prepare_dataloader(batch_size=100,\
                                                train_size_ratio=0.8,
                                                TFIDF_normalize=False)

train size 14956
valid size 3740


In [12]:
train_config = {
    "lr": 0.05,
    "momentum": 0.0,
    "weight_decay": 0.0,
    
    "n_epoch": 200,
    "verbose": True,
    "valid_epoch": 10,
    
    "topk": 50,
    
    "h_dim": 3000,
    "criterion": "MultiLabelMarginLoss", # "ListNet",
    "TFIDF_normalize": False
}

train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))



Epoch 0 12.950749581654867 10.35442886854473

train defaultdict(<class 'list'>, {'precision@10': 0.016619047755375503, 'precision@30': 0.02619888989875714, 'precision@50': 0.03938780720035235, 'ndcg@10': 0.009174947685872514, 'ndcg@30': 0.01699017146602273, 'ndcg@50': 0.02930683473745982, 'ndcg@all': 0.3102540604273478})
valid defaultdict(<class 'list'>, {'precision@10': 0.016894736609078552, 'precision@30': 0.02599561486491247, 'precision@50': 0.03809999748084106, 'ndcg@10': 0.009016449262976255, 'ndcg@30': 0.016945118648245147, 'ndcg@50': 0.028461527569513572, 'ndcg@all': 0.30662567913532257})
Epoch 1 8.658901691436768 8.61814644462184
Epoch 2 6.948061978022258 7.697461078041478
Epoch 3 5.891134885152181 7.075902913746081
Epoch 4 5.128959770202637 6.661200435538041
Epoch 5 4.556165803273519 6.3375075114400765
Epoch 6 4.088933502833049 6.115446718115556
Epoch 7 3.7375409507751467 5.942995021217747
Epoch 8 3.4232584969202677 5.787793397903442
Epoch 9 3.1495900456110637 5.65623530588651

Epoch 81 0.612888152996699 4.867851545936183
Epoch 82 0.6114874553680419 4.867242185693038
Epoch 83 0.6100374801953634 4.851350445496409
Epoch 84 0.6037867959340414 4.859016104748375
Epoch 85 0.595047295888265 4.833356782009727
Epoch 86 0.5868951173623403 4.86835441463872
Epoch 87 0.5842486461003621 4.856406324788144
Epoch 88 0.5791291368007659 4.873734248311896
Epoch 89 0.5759382935365042 4.849940758002432
Epoch 90 0.5691403917471568 4.862831994106895

train defaultdict(<class 'list'>, {'precision@10': 0.34302285303672153, 'precision@30': 0.3143306453029315, 'precision@50': 0.3172226482629776, 'ndcg@10': 0.2995419732729594, 'ndcg@30': 0.33289787272612253, 'ndcg@50': 0.3926736354827881, 'ndcg@all': 0.5959874316056569})
valid defaultdict(<class 'list'>, {'precision@10': 0.2786052662291025, 'precision@30': 0.21079386731511668, 'precision@50': 0.18796578167300476, 'ndcg@10': 0.2642453775594109, 'ndcg@30': 0.26192623376846313, 'ndcg@50': 0.281323783883923, 'ndcg@all': 0.5068284735867852})


Epoch 163 0.45077356179555256 4.948065418946116
Epoch 164 0.44929868837197623 4.935932059037058
Epoch 165 0.4520904338359833 4.944666900132832
Epoch 166 0.4523187983036041 4.949658368763171
Epoch 167 0.44836316009362537 4.930637961939762
Epoch 168 0.44760389904181164 4.975684918855366
Epoch 169 0.45089269379774727 4.959338941072163
Epoch 170 0.4489334809780121 4.952203060451307

train defaultdict(<class 'list'>, {'precision@10': 0.3421571412682533, 'precision@30': 0.336964137951533, 'precision@50': 0.35366359949111936, 'ndcg@10': 0.2905681973695755, 'ndcg@30': 0.3365547446409861, 'ndcg@50': 0.41040116687615713, 'ndcg@all': 0.5989750425020853})
valid defaultdict(<class 'list'>, {'precision@10': 0.2610526284889171, 'precision@30': 0.20054386595362111, 'precision@50': 0.1821920946240425, 'ndcg@10': 0.24899627307527944, 'ndcg@30': 0.24824262684897372, 'ndcg@50': 0.2694164871385223, 'ndcg@all': 0.49844206712747874})
Epoch 171 0.4515022287766139 4.960223298323782
Epoch 172 0.4518465195099512

In [13]:
train_config["lr"] = 0.05
train_config["criterion"] = "ListNet"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 8.780688600540161 8.269434288928384

train defaultdict(<class 'list'>, {'precision@10': 0.11911285708347956, 'precision@30': 0.07029079630970955, 'precision@50': 0.05420704424381256, 'ndcg@10': 0.12368400568763414, 'ndcg@30': 0.11461505730946858, 'ndcg@50': 0.11618192767103513, 'ndcg@all': 0.33515613893667856})
valid defaultdict(<class 'list'>, {'precision@10': 0.11301315654265254, 'precision@30': 0.06723245939141825, 'precision@50': 0.051647365485367025, 'ndcg@10': 0.11337263823339813, 'ndcg@30': 0.10468813208372969, 'ndcg@50': 0.10558525159170754, 'ndcg@all': 0.3253556897765712})
Epoch 1 7.530011978149414 7.6931632945412085
Epoch 2 6.707909889221192 7.301191643664711
Epoch 3 6.03844833056132 6.961074415006135
Epoch 4 5.491165018081665 6.738619352641859
Epoch 5 5.01815066019694 6.4971505340776945
Epoch 6 4.634463210105896 6.342011263495998
Epoch 7 4.289539901415507 6.170811477460359
Epoch 8 3.9744633563359577 6.083279183036403
Epoch 9 3.72920139948527 5.975049570987099
Epoch 1

Epoch 81 1.8659382780392966 4.932169437408447
Epoch 82 1.8664333454767863 4.940719014719913
Epoch 83 1.8622751998901368 4.953596842916388
Epoch 84 1.8606339486440022 4.939883859534013
Epoch 85 1.862034990787506 4.941081354492589
Epoch 86 1.8584711774190268 4.934576850188406
Epoch 87 1.854975659052531 4.915280128780164
Epoch 88 1.856667729218801 4.928106897755673
Epoch 89 1.8528309814135233 4.922192774320903
Epoch 90 1.8523603288332622 4.907137933530305

train defaultdict(<class 'list'>, {'precision@10': 0.42716047406196594, 'precision@30': 0.21352064261833828, 'precision@50': 0.1531333248813947, 'ndcg@10': 0.6286269887288412, 'ndcg@30': 0.5168015515804291, 'ndcg@50': 0.5023798352479935, 'ndcg@all': 0.6725488742192586})
valid defaultdict(<class 'list'>, {'precision@10': 0.35164473637154225, 'precision@30': 0.18631140574028618, 'precision@50': 0.1364157819434216, 'ndcg@10': 0.4513037345911327, 'ndcg@30': 0.38360537274887685, 'ndcg@50': 0.3762736869485755, 'ndcg@all': 0.563150355690404})


Epoch 164 1.7999474994341533 4.794661509363275
Epoch 165 1.8000664893786114 4.799362596712615
Epoch 166 1.7982785638173422 4.819504762950697
Epoch 167 1.7971543176968892 4.809898828205309
Epoch 168 1.7968101938565573 4.822765889920686
Epoch 169 1.7980436237653097 4.802490880614833
Epoch 170 1.795216867129008 4.814125951967742

train defaultdict(<class 'list'>, {'precision@10': 0.4653109500805537, 'precision@30': 0.23151127735773722, 'precision@50': 0.1655331340432167, 'ndcg@10': 0.6649528674284617, 'ndcg@30': 0.5471249580383301, 'ndcg@50': 0.5321465214093526, 'ndcg@all': 0.6964184367656707})
valid defaultdict(<class 'list'>, {'precision@10': 0.36639473312779475, 'precision@30': 0.1963070246734117, 'precision@50': 0.14408683502360395, 'ndcg@10': 0.467432973416228, 'ndcg@30': 0.4002479150107032, 'ndcg@50': 0.39338297985102, 'ndcg@all': 0.5772950413979983})
Epoch 171 1.7954631702105204 4.7977609320690755
Epoch 172 1.7955167086919148 4.811173451574225
Epoch 173 1.7945185494422913 4.8051491

In [20]:
train_config["lr"] = 2
train_config["criterion"] = "ListNet2"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 4.636804641146834e-05 4.617300232654854e-05

train defaultdict(<class 'list'>, {'precision@10': 0.003610476253864666, 'precision@30': 0.0035139684084181983, 'precision@50': 0.003535999773691098, 'ndcg@10': 0.001479192317056004, 'ndcg@30': 0.002100287616679755, 'ndcg@50': 0.0027085772102388244, 'ndcg@all': 0.2227728400627772})
valid defaultdict(<class 'list'>, {'precision@10': 0.003052631647350561, 'precision@30': 0.0033201755784255894, 'precision@50': 0.003384210322493393, 'ndcg@10': 0.00117888631020354, 'ndcg@30': 0.0018707180018904374, 'ndcg@50': 0.0024694386137477857, 'ndcg@all': 0.2230176160994329})
Epoch 1 4.634984965377953e-05 4.604376303369032e-05
Epoch 2 4.635232047197254e-05 4.612837342153254e-05
Epoch 3 4.63561555564714e-05 4.608394058442699e-05
Epoch 4 4.636289636740306e-05 4.607850544408919e-05
Epoch 5 4.636174754220216e-05 4.6007384968872523e-05
Epoch 6 4.635841723938938e-05 4.609332886742011e-05
Epoch 7 4.6350816774065606e-05 4.6108539065339994e-05
Epoch 8 4.63548

Epoch 71 4.6354213554877785e-05 4.600392333630129e-05
Epoch 72 4.634664976038039e-05 4.6116956295200474e-05
Epoch 73 4.634903363087991e-05 4.606912625604309e-05
Epoch 74 4.637153744018482e-05 4.595867645356951e-05
Epoch 75 4.634526723142092e-05 4.6141053500928376e-05
Epoch 76 4.635831874717648e-05 4.599578364592928e-05
Epoch 77 4.636491496057715e-05 4.611455781406747e-05
Epoch 78 4.6343569410964844e-05 4.610420623258075e-05
Epoch 79 4.635134757942675e-05 4.608005732925927e-05
Epoch 80 4.635580837202724e-05 4.6100977717862024e-05

train defaultdict(<class 'list'>, {'precision@10': 0.0036000000537994006, 'precision@30': 0.0034995239468601845, 'precision@50': 0.0035292378766462206, 'ndcg@10': 0.0014770623538546109, 'ndcg@30': 0.002092560981885375, 'ndcg@50': 0.0027011450259791067, 'ndcg@all': 0.22271815836429595})
valid defaultdict(<class 'list'>, {'precision@10': 0.0030131579689240376, 'precision@30': 0.0032675439770652077, 'precision@50': 0.0033684208480592227, 'ndcg@10': 0.001181780550

Epoch 149 4.6349722154748936e-05 4.604083771578429e-05
Epoch 150 4.6352005665539767e-05 4.601152824378867e-05

train defaultdict(<class 'list'>, {'precision@10': 0.0036104762624017896, 'precision@30': 0.003513174757827073, 'precision@50': 0.0035326664452441036, 'ndcg@10': 0.0014841948415657196, 'ndcg@30': 0.0021040971487915764, 'ndcg@50': 0.0027108131942804904, 'ndcg@all': 0.22278756221135457})
valid defaultdict(<class 'list'>, {'precision@10': 0.0029736842598619036, 'precision@30': 0.0033289475188786654, 'precision@50': 0.0034157892713617337, 'ndcg@10': 0.0011749849666222488, 'ndcg@30': 0.0018844154592922056, 'ndcg@50': 0.0025111112856967864, 'ndcg@all': 0.22343928053190834})
Epoch 151 4.6346301896846856e-05 4.606417659440347e-05
Epoch 152 4.637233311465631e-05 4.608104963584705e-05
Epoch 153 4.6372503275051715e-05 4.594482934886688e-05
Epoch 154 4.6371624339371916e-05 4.600543003394514e-05
Epoch 155 4.637366262613796e-05 4.6087720444407556e-05
Epoch 156 4.6357286167525064e-05 4.60361

In [15]:
train_config["lr"] = 0.05
train_config["criterion"] = "MultiLabelMarginLossCustomV:1"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 13.71753002166748 11.747525616696006

train defaultdict(<class 'list'>, {'precision@10': 0.1448523796101411, 'precision@30': 0.11674206783374151, 'precision@50': 0.10382628018657367, 'ndcg@10': 0.0678451406210661, 'ndcg@30': 0.07896385677158832, 'ndcg@50': 0.08977257971962294, 'ndcg@all': 0.3478132019440333})
valid defaultdict(<class 'list'>, {'precision@10': 0.13882894637553314, 'precision@30': 0.11185088126282942, 'precision@50': 0.1000394687840813, 'ndcg@10': 0.06510185136606819, 'ndcg@30': 0.07557005788150586, 'ndcg@50': 0.08597030137714587, 'ndcg@all': 0.34274556527012273})
Epoch 1 9.940874729156494 9.86156712080303
Epoch 2 8.2884051322937 8.881001096022757
Epoch 3 7.256719913482666 8.360233570400037
Epoch 4 6.588723510106405 7.912931680679321
Epoch 5 6.047731574376424 7.699429712797466
Epoch 6 5.637971601486206 7.298065812964189
Epoch 7 5.259325451850891 7.207107493751927
Epoch 8 5.068968346913656 7.0529794065575855
Epoch 9 4.784829640388489 6.926007622166684
Epoch 10 4.5

Epoch 81 2.8379601939519246 6.118387711675544
Epoch 82 2.808279965718587 6.00675734720732
Epoch 83 2.7747733402252197 6.084813619914808
Epoch 84 2.855454343954722 6.063089408372578
Epoch 85 2.840175943374634 6.042284739644904
Epoch 86 2.8396936432520548 6.070204872834055
Epoch 87 2.868510874112447 6.034474749314158
Epoch 88 2.854540737469991 6.074427617223639
Epoch 89 2.8758242575327557 6.120838491540206
Epoch 90 2.8218638483683267 6.032307750300357

train defaultdict(<class 'list'>, {'precision@10': 0.3763004744052887, 'precision@30': 0.2671525476376216, 'precision@50': 0.2239939883351326, 'ndcg@10': 0.2864928802847862, 'ndcg@30': 0.29159338732560475, 'ndcg@50': 0.31068972130616507, 'ndcg@all': 0.5407363569736481})
valid defaultdict(<class 'list'>, {'precision@10': 0.34156579014502075, 'precision@30': 0.23373684953702123, 'precision@50': 0.1925841989485841, 'ndcg@10': 0.26927326972547333, 'ndcg@30': 0.26556746465595144, 'ndcg@50': 0.27811688733728307, 'ndcg@all': 0.5034250052351701})


Epoch 164 3.012704749107361 6.157153543673064
Epoch 165 3.0595498100916543 6.1567408160159465
Epoch 166 3.124103930791219 6.205024995301899
Epoch 167 3.1517813936869303 6.2000369021767066
Epoch 168 3.1262690671284994 6.225940265153584
Epoch 169 3.129041846593221 6.2328023785039
Epoch 170 3.1835367409388224 6.255038324155305

train defaultdict(<class 'list'>, {'precision@10': 0.3490799967447917, 'precision@30': 0.2501735000809034, 'precision@50': 0.209907036225001, 'ndcg@10': 0.2620575993259748, 'ndcg@30': 0.2696734874447187, 'ndcg@50': 0.2879592784245809, 'ndcg@all': 0.5208128309249878})
valid defaultdict(<class 'list'>, {'precision@10': 0.320421051037939, 'precision@30': 0.2217017600410863, 'precision@50': 0.18312104281626249, 'ndcg@10': 0.24754301497810766, 'ndcg@30': 0.24700857777344554, 'ndcg@50': 0.25998362075341375, 'ndcg@all': 0.48840627623231786})
Epoch 171 3.191272662480672 6.110293200141506
Epoch 172 3.0713325262069704 6.082056735691271
Epoch 173 3.105040284792582 6.160983261

In [16]:
train_config["lr"] = 0.1
train_config["criterion"] = "MultiLabelMarginLossCustom:1"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 14.209401200612387 11.921100691745156

train defaultdict(<class 'list'>, {'precision@10': 0.028013809515784183, 'precision@30': 0.031507778838276865, 'precision@50': 0.037489235711594425, 'ndcg@10': 0.012360263401642441, 'ndcg@30': 0.019768902839471896, 'ndcg@50': 0.029099697458247344, 'ndcg@all': 0.30494828005631763})
valid defaultdict(<class 'list'>, {'precision@10': 0.027289473770284338, 'precision@30': 0.030618422713718917, 'precision@50': 0.03646052403277472, 'ndcg@10': 0.012050367119771085, 'ndcg@30': 0.01922771503756705, 'ndcg@50': 0.028229637169524244, 'ndcg@all': 0.3014717776524393})
Epoch 1 9.633870992660523 9.818745387227912
Epoch 2 7.8508179601033525 8.699242240504214
Epoch 3 6.758927491505941 8.218266838475278
Epoch 4 5.978537073135376 7.730007159082513
Epoch 5 5.419891449610392 7.149332071605482
Epoch 6 4.87075909614563 6.9464043692538615
Epoch 7 4.489374645551046 6.7364720545317
Epoch 8 4.201104132334391 6.66179074739155
Epoch 9 4.011140960057577 6.41086942271182

Epoch 81 1.9078827087084453 5.532153142125983
Epoch 82 1.9013243261973063 5.50444874010588
Epoch 83 1.9425671354929606 5.5050294399261475
Epoch 84 1.9385208010673523 5.530269346739116
Epoch 85 1.9717535734176637 5.51842935461747
Epoch 86 1.9588293313980103 5.550775691082603
Epoch 87 1.9394066182772318 5.558581691039236
Epoch 88 1.993358789285024 5.606884366587589
Epoch 89 2.0045449805259703 5.6235343155108
Epoch 90 2.0417643785476685 5.65552702702974

train defaultdict(<class 'list'>, {'precision@10': 0.28895618855953215, 'precision@30': 0.19317857762177784, 'precision@50': 0.17144951512416204, 'ndcg@10': 0.2392070553700129, 'ndcg@30': 0.23245599875847497, 'ndcg@50': 0.2534734601775805, 'ndcg@all': 0.509652558962504})
valid defaultdict(<class 'list'>, {'precision@10': 0.2563289470578495, 'precision@30': 0.1646842156585894, 'precision@50': 0.14172104491215004, 'ndcg@10': 0.2243887617399818, 'ndcg@30': 0.2104385479500419, 'ndcg@50': 0.22358694084380804, 'ndcg@all': 0.46888182429890884})


Epoch 164 2.1674746878941855 5.707745840674953
Epoch 165 2.1051562547683718 5.63786768913269
Epoch 166 2.064260419209798 5.637259081790321
Epoch 167 2.0630112067858377 5.587294478165476
Epoch 168 2.0798270678520203 5.62544604351646
Epoch 169 2.084765061537425 5.642752057627628
Epoch 170 2.101597611109416 5.756200275923076

train defaultdict(<class 'list'>, {'precision@10': 0.20138666669527688, 'precision@30': 0.15130889371037484, 'precision@50': 0.1388414221505324, 'ndcg@10': 0.16781453390916187, 'ndcg@30': 0.1763105914990107, 'ndcg@50': 0.19722550263007482, 'ndcg@all': 0.46961864908536277})
valid defaultdict(<class 'list'>, {'precision@10': 0.184947365208676, 'precision@30': 0.13171053403302244, 'precision@50': 0.11769999406839672, 'ndcg@10': 0.15947744015016055, 'ndcg@30': 0.16184553896125994, 'ndcg@50': 0.17680936698850833, 'ndcg@all': 0.4343314523759641})
Epoch 171 2.081030170917511 5.608124946293078
Epoch 172 2.076095910867055 5.62073688758047
Epoch 173 2.094781726996104 5.6126422

In [17]:
train_config["lr"] = 1e-4
train_config["criterion"] = "MSE"
train_decoder(train_config)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

Epoch 0 9084.32693359375 7506.067941766036

train defaultdict(<class 'list'>, {'precision@10': 0.019196190604319176, 'precision@30': 0.028006509132683276, 'precision@50': 0.03138818882405758, 'ndcg@10': 0.006948898893315345, 'ndcg@30': 0.012902229555572072, 'ndcg@50': 0.0184006313358744, 'ndcg@all': 0.2706452848513921})
valid defaultdict(<class 'list'>, {'precision@10': 0.02092105265412676, 'precision@30': 0.029241229211421388, 'precision@50': 0.031147366340615247, 'ndcg@10': 0.008266179356724024, 'ndcg@30': 0.014071620878224311, 'ndcg@50': 0.01882112141404497, 'ndcg@all': 0.2704323702736905})
Epoch 1 8542.281866861978 7454.758191560444
Epoch 2 8481.673940429688 7379.46580103824
Epoch 3 8425.296920572917 7367.864720394737
Epoch 4 8424.752960611979 7400.513382761102
Epoch 5 8355.311637369792 7332.1424946032075
Epoch 6 8319.92216796875 7486.597360711348
Epoch 7 8293.31796061198 7262.497789884868
Epoch 8 8267.855436197917 7323.0301320929275
Epoch 9 8250.746150716146 7295.203073601973
Epoc

Epoch 81 7482.15400390625 7182.218849583676
Epoch 82 7478.5291715494795 7170.451913934005
Epoch 83 7482.513863932291 7208.6178364000825
Epoch 84 7465.327137044271 7183.745335629112
Epoch 85 7491.211402994792 7290.060829564145
Epoch 86 7455.376783854167 7183.089856599507
Epoch 87 7453.21544921875 7246.110910516036
Epoch 88 7434.468694661458 7331.766999897204
Epoch 89 7433.03364420573 7166.224416632402
Epoch 90 7431.492709960938 7178.76072291324

train defaultdict(<class 'list'>, {'precision@10': 0.08086619026958942, 'precision@30': 0.0649734945098559, 'precision@50': 0.06016209200024605, 'ndcg@10': 0.052269386028250056, 'ndcg@30': 0.057618864700198176, 'ndcg@50': 0.06457408209641774, 'ndcg@all': 0.3141645028193792})
valid defaultdict(<class 'list'>, {'precision@10': 0.07627631518009462, 'precision@30': 0.06224561639522251, 'precision@50': 0.05739736468776276, 'ndcg@10': 0.04501373727658862, 'ndcg@30': 0.04940443819290713, 'ndcg@50': 0.05523989741739474, 'ndcg@all': 0.30485696541635615})

Epoch 164 7228.807076822916 7169.984680175781
Epoch 165 7227.604236653646 7169.84861353824
Epoch 166 7219.265241699219 7176.350168328536
Epoch 167 7243.906317545573 7232.382176449424
Epoch 168 7214.375646972656 7166.928685238487
Epoch 169 7218.79157796224 7192.073775442023
Epoch 170 7228.718676757812 7828.8902009662825

train defaultdict(<class 'list'>, {'precision@10': 0.04056095233807961, 'precision@30': 0.05093698613345623, 'precision@50': 0.05144856842855612, 'ndcg@10': 0.027270501144230366, 'ndcg@30': 0.04165420938283205, 'ndcg@50': 0.050404883809387686, 'ndcg@all': 0.3078404192129771})
valid defaultdict(<class 'list'>, {'precision@10': 0.038986842098988984, 'precision@30': 0.0487412303490074, 'precision@50': 0.048989470361879, 'ndcg@10': 0.021902319632078473, 'ndcg@30': 0.034241106018031896, 'ndcg@50': 0.04175122812586395, 'ndcg@all': 0.2992378332112965})
Epoch 171 7207.431310221355 7148.288667377673
Epoch 172 7274.8172924804685 7322.199315121299
Epoch 173 7206.533920898438 7214.