# Demo baseline

### document embedding decoder
1. demo utils
2. demo loss
3. demo evaluation

In [22]:
import os
import sys
from collections import defaultdict
import numpy as np 
import pandas as pd

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

sys.path.append('../')
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all, semantic_precision_all
from utils.loss import *
from utils.data_loader import load_document
from utils.toolbox import preprocess_document, get_preprocess_document, get_preprocess_document_embs,\
                          get_preprocess_document_labels, get_word_embs

## Data preprocess
1. filter special characters, punctuation (remain english & number character)
2. filter stopwords
3. filter by term frequency
4. pos tagging

## Parameters

### preprocess parameters:
1. min word frequency
2. max word frequency(max_df)
3. min word per doc(min_words)
4. pos tagging select

### training parameters:
1. decoder label
2. model parameters

## Load Data, Label
label -> bow, tf-idf, keybert, classification

In [2]:
config = {}
config['label_type'] = 'tf-idf'

In [3]:
dataset_name='agnews'
min_df=10
max_df=1.0
vocabulary_size=None
min_doc_word=15
model_name = 'average'
seed = 33

preprocess_config = {}
preprocess_config['dataset_name'] = dataset_name
preprocess_config['min_df'] = min_df
preprocess_config['max_df'] = max_df
preprocess_config['vocabulary_size'] = vocabulary_size
preprocess_config['min_doc_word'] = min_doc_word

unpreprocessed_docs ,preprocessed_docs = get_preprocess_document(**preprocess_config)
print('doc num', len(preprocessed_docs))

Getting preprocess documents: agnews
min_df: 1 max_df: 1.0 vocabulary_size: 2000 min_doc_word: 15


Using custom data configuration default
Reusing dataset ag_news (/home/chrisliu/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


In [4]:
doc_embs, doc_model = get_preprocess_document_embs(preprocessed_docs, model_name)
print('doc_embs', doc_embs.shape)

Getting preprocess documents embeddings
Using cuda 0 for training...


Batches:   0%|          | 0/635 [00:00<?, ?it/s]

In [5]:
labels, vocabularys = get_preprocess_document_labels(preprocessed_docs)

Getting preprocess documents labels


In [6]:
targets = labels[config['label_type']] 
word_embs = get_word_embs(vocabularys['tf-idf'])
word_embs.shape

Getting word embeddings


0it [00:00, ?it/s]

Number of words:400001


(1811, 300)

In [7]:
word_embs_tensor = torch.FloatTensor(word_embs)

In [8]:
training_config = {}
training_config["topk"] = [5, 10, 15]

## MLP Decoder

In [9]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [10]:
class DNNDecoderDataset(Dataset):
    def __init__(self, doc_embs, targets):
        
        assert len(doc_embs) == len(targets)

        self.doc_embs = torch.FloatTensor(doc_embs)
        self.targets = torch.FloatTensor(targets)        
        self.targets_rank = torch.argsort(self.targets, dim=1, descending=True)
        self.topk = torch.sum(self.targets > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_embs[idx], self.targets[idx], self.targets_rank[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_embs)

In [11]:
def prepare_dataloader(doc_embs, targets, batch_size=100, train_valid_test_ratio=[0.7, 0.1, 0.2],\
                       target_normalize=False, seed=123):
    train_size = int(len(doc_embs) * train_valid_test_ratio[0])
    valid_size = int(len(doc_embs) * (train_valid_test_ratio[0] + train_valid_test_ratio[1])) - train_size
    test_size = len(doc_embs) - train_size - valid_size
    
    print('Preparing dataloader')
    print('train size', train_size)
    print('valid size', valid_size)
    print('test size', test_size)

    if target_normalize:
        # normalize target summation of each document to 1 
        norm = targets.sum(axis=1).reshape(-1, 1)
        targets = (targets / norm)
        # normalize target L2 norm of each document to 1
        # norm = np.linalg.norm(targets, axis=1).reshape(-1, 1)
        # targets = (targets / norm)

    # shuffle
    randomize = np.arange(len(doc_embs))
    np.random.seed(seed)
    np.random.shuffle(randomize)
    doc_embs = doc_embs[randomize]
    targets = targets[randomize]
    
    # dataloader
    train_dataset = DNNDecoderDataset(doc_embs[:train_size], targets[:train_size])
    train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    valid_dataset = DNNDecoderDataset(doc_embs[train_size:train_size+valid_size], targets[train_size:train_size+valid_size])
    valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    test_dataset = DNNDecoderDataset(doc_embs[train_size+valid_size:], targets[train_size+valid_size:])
    test_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    
    return train_loader, valid_loader, test_loader


In [12]:
# prepare dataloader
train_loader, valid_loader, test_loader = prepare_dataloader(doc_embs, labels['tf-idf'], batch_size=100,\
                                                             train_valid_test_ratio=[0.7, 0.1, 0.2],target_normalize=False,\
                                                             seed=seed)

Preparing dataloader
train size 88892
valid size 12699
test size 25398


In [13]:
class DNNDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(doc_emb_dim, h_dim),
            # nn.Dropout(p=0.5),
            nn.Tanh(),
            nn.Linear(h_dim, h_dim),
            # nn.Dropout(p=0.5),
            nn.Tanh(),
            nn.Linear(h_dim, num_words),
            # nn.Dropout(p=0.5),
            # nn.Sigmoid(),
        )
    def forward(self, x):
        return self.decoder(x)

In [14]:
def evaluate_DNNDecoder(model, data_loader, config):
    results = defaultdict(list)
    model.eval()
    
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # Semantic Precision
#         semantic_precision_scores, word_result = semantic_precision_all(pred, target, word_embs_tensor, vocabularys['tf-idf'],\
#                                                                         k=config["topk"], th=0.7, display_word_result=False)
#         for k, v in semantic_precision_scores.items():
#             results['semantic_precision@{}'.format(k)].append(v)
            
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [15]:
def calculate_loss(config, criterion, pred, target, target_rank, target_topk):
    if config["criterion"] == "MultiLabelMarginLoss":
        target_rank[:, config["topk"]] = -1
        loss = criterion(pred, target_rank)
    elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        loss = criterion(pred, target_rank, target_topk)
    elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
        loss = criterion(pred, target_rank, config["topk"])
    else:
        loss = criterion(pred, target)
        
    return loss

def train_decoder(doc_embs, targets, config):
    model = DNNDecoder(doc_emb_dim=doc_embs.shape[1], num_words=targets.shape[1],\
                       h_dim=config["h_dim"]).to(device)
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"],
                          weight_decay=config["weight_decay"])
    # prepare loss
    if config["criterion"] == "MultiLabelMarginLoss":
        criterion = nn.MultiLabelMarginLoss(reduction='mean')
    elif config["criterion"] == "MultiLabel":
        criterion = nn.BCEWithLogitsLoss(reduction='mean')
    elif config["criterion"].startswith("MultiLabelMarginLossCustomV"):
        def criterion(a, b, c): return MultiLabelMarginLossCustomV(
            a, b, c, float(config["criterion"].split(':')[-1]))
    elif config["criterion"].startswith("MultiLabelMarginLossCustom"):
        def criterion(a, b, c): return MultiLabelMarginLossCustom(
            a, b, c, float(config["criterion"].split(':')[-1]))
    else:
        criterion = eval(config["criterion"])

    results = []
    n_epoch = config["n_epoch"]
    valid_epoch = config["valid_epoch"]
    valid_verbose = config["valid_verbose"]

    for epoch in tqdm(range(n_epoch)):
        train_loss_his = []
        valid_loss_his = []

        model.train()

        for data in train_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            loss = calculate_loss(config, criterion, pred, target, target_rank, target_topk)
            train_loss_his.append(loss.item())

            # Model backwarding
            model.zero_grad()
            loss.backward()
            opt.step()

        model.eval()
        for data in valid_loader:
            doc_embs, target, target_rank, target_topk = data
            doc_embs = doc_embs.to(device)
            target = target.to(device)
            target_rank = target_rank.to(device)
            target_topk = target_topk.to(device)

            # loss
            pred = model(doc_embs)
            loss = calculate_loss(config, criterion, pred, target, target_rank, target_topk)
            valid_loss_his.append(loss.item())

        print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))

        # show decoder result
        if valid_epoch > 0 and epoch % valid_epoch == 0:
            res = {}
            res['epoch'] = epoch

            train_res_ndcg = evaluate_DNNDecoder(model, train_loader, config)
            valid_res_ndcg = evaluate_DNNDecoder(model, valid_loader, config)

            res.update(valid_res_ndcg)
            results.append(res)

            if valid_verbose:
                print()
                print('train', train_res_ndcg)
                print('valid', valid_res_ndcg)

In [None]:
train_config = {
    "lr": 0.5,
    "momentum": 0.0,
    "weight_decay": 0.0,
    
    "n_epoch": 500,
    "valid_epoch": 10,
    "valid_verbose": True,
    
    "topk": [5, 10, 15],
    
    "h_dim": 3000,
    "criterion": "ListNet_sigmoid_L1"
}

In [18]:
train_config["lr"] = 1
train_decoder(doc_embs, targets, train_config)

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch 0 7.501746106871872 7.5017141657551445

train defaultdict(<class 'list'>, {'precision@5': 0.010711302524547777, 'precision@10': 0.010255929987240663, 'precision@15': 0.009916662997005693, 'ndcg@5': 0.007408029777572739, 'ndcg@10': 0.008628657880198332, 'ndcg@15': 0.010290466704960028, 'ndcg@all': 0.2235390706191047})
valid defaultdict(<class 'list'>, {'precision@5': 0.011811182804875017, 'precision@10': 0.010945200109781008, 'precision@15': 0.010336488723666885, 'ndcg@5': 0.008094578199491902, 'ndcg@10': 0.009252907235292703, 'ndcg@15': 0.01080364715601281, 'ndcg@all': 0.22394382320051118})
Epoch 1 7.501708631440411 7.501704369942973
Epoch 2 7.501701295442871 7.501698505221389
Epoch 3 7.501696097971484 7.5016938757708695
Epoch 4 7.501692012404966 7.501690181221549
Epoch 5 7.50168861715276 7.501687079902709
Epoch 6 7.5016857298474475 7.501684365310068
Epoch 7 7.501683181530981 7.50168205246212
Epoch 8 7.50168101833159 7.501679983664685
Epoch 9 7.501679034817742 7.501678155163142
E

Epoch 81 7.501642486241859 7.501642531297338
Epoch 82 7.501642309774564 7.50164230601994
Epoch 83 7.501642127943522 7.501642182117372
Epoch 84 7.501641962740097 7.501642024423194
Epoch 85 7.5016418066550425 7.501641836692029
Epoch 86 7.501641634478746 7.501641694016344
Epoch 87 7.501641464447949 7.501641540076789
Epoch 88 7.501641309435644 7.5016413861372335
Epoch 89 7.501641139941221 7.501641190896823
Epoch 90 7.501640992974537 7.501641063239631

train defaultdict(<class 'list'>, {'precision@5': 0.48318374009448845, 'precision@10': 0.3286851342469957, 'precision@15': 0.2530011664091267, 'ndcg@5': 0.4217153193913107, 'ndcg@10': 0.38939709399897016, 'ndcg@15': 0.3957393735613678, 'ndcg@all': 0.5668777594818442})
valid defaultdict(<class 'list'>, {'precision@5': 0.478463213509462, 'precision@10': 0.326055911347622, 'precision@15': 0.25117893669548935, 'ndcg@5': 0.41866333512809334, 'ndcg@10': 0.386314571373106, 'ndcg@15': 0.3922078560187122, 'ndcg@all': 0.5637899449491125})
Epoch 91 7.50

Epoch 167 7.501633043900652 7.501633144739106
Epoch 168 7.501632972562809 7.501633118456743
Epoch 169 7.501632887815598 7.501633017081914
Epoch 170 7.501632856705862 7.501632987044927

train defaultdict(<class 'list'>, {'precision@5': 0.6036358324829719, 'precision@10': 0.4289533378906272, 'precision@15': 0.3347198557531874, 'ndcg@5': 0.5214876705870719, 'ndcg@10': 0.4955858580843536, 'ndcg@15': 0.5089558793066323, 'ndcg@all': 0.6584237006705577})
valid defaultdict(<class 'list'>, {'precision@5': 0.5960162159964795, 'precision@10': 0.4235075939828017, 'precision@15': 0.33131881914739536, 'ndcg@5': 0.515907681833102, 'ndcg@10': 0.48974539937935474, 'ndcg@15': 0.5031761940070024, 'ndcg@all': 0.6538298571203637})
Epoch 171 7.501632777858773 7.501632923216332
Epoch 172 7.501632719930299 7.501632889424722
Epoch 173 7.501632659856327 7.501632799313763
Epoch 174 7.501632587982109 7.501632724221297
Epoch 175 7.501632516644267 7.501632701693557
Epoch 176 7.50163246407954 7.501632600318729
Epoch

Epoch 251 7.501628861786679 7.501629022162731
Epoch 252 7.501628801712705 7.501628984616497
Epoch 253 7.501628778112217 7.5016289921257435
Epoch 254 7.501628743784232 7.5016289245425245
Epoch 255 7.501628696046879 7.501628894505538
Epoch 256 7.501628662791644 7.501628841940812
Epoch 257 7.501628644554902 7.501628841940812
Epoch 258 7.501628602717671 7.5016288306769425
Epoch 259 7.5016285614168146 7.50162870301975
Epoch 260 7.501628518506834 7.5016287480752295

train defaultdict(<class 'list'>, {'precision@5': 0.6664036710163069, 'precision@10': 0.4821025025187515, 'precision@15': 0.37861940341805717, 'ndcg@5': 0.5736492060971341, 'ndcg@10': 0.5512617100344466, 'ndcg@15': 0.5685363087396654, 'ndcg@all': 0.706314260498343})
valid defaultdict(<class 'list'>, {'precision@5': 0.6586023892943315, 'precision@10': 0.4765181684587884, 'precision@15': 0.37441290997144744, 'ndcg@5': 0.5679737249697288, 'ndcg@10': 0.545206086372766, 'ndcg@15': 0.5622937979660635, 'ndcg@all': 0.7016501163873147})
E

Epoch 336 7.501626342434449 7.501626532847487
Epoch 337 7.50162631668846 7.501626491546631
Epoch 338 7.501626300597217 7.501626480282761
Epoch 339 7.501626264660109 7.501626435227282
Epoch 340 7.501626257150862 7.501626465264268

train defaultdict(<class 'list'>, {'precision@5': 0.6992054510572495, 'precision@10': 0.5108887766088758, 'precision@15': 0.4018042229694242, 'ndcg@5': 0.6014388392693012, 'ndcg@10': 0.5811774278786641, 'ndcg@15': 0.6001314235216318, 'ndcg@all': 0.7311749513291401})
valid defaultdict(<class 'list'>, {'precision@5': 0.6920214760021901, 'precision@10': 0.5053550413274389, 'precision@15': 0.3975378558860989, 'ndcg@5': 0.596070403658499, 'ndcg@10': 0.5748452933754508, 'ndcg@15': 0.5935836068288548, 'ndcg@all': 0.7263575656207528})
Epoch 341 7.50162621853188 7.501626438981905
Epoch 342 7.50162619707689 7.501626390171802
Epoch 343 7.501626163821655 7.501626382662556
Epoch 344 7.501626164894405 7.5016263789079325
Epoch 345 7.501626117693426 7.5016262963062195
Epoch 3

Epoch 421 7.501624541824392 7.501624783193033
Epoch 422 7.501624548797264 7.5016248207392655
Epoch 423 7.5016245262695245 7.501624764419916
Epoch 424 7.501624498378037 7.501624749401423
Epoch 425 7.501624485505043 7.501624715609814
Epoch 426 7.50162445922268 7.501624693082073
Epoch 427 7.501624442058688 7.501624696836696
Epoch 428 7.501624445813311 7.501624693082073
Epoch 429 7.501624401830581 7.5016246442719705
Epoch 430 7.5016243991487075 7.501624636762724

train defaultdict(<class 'list'>, {'precision@5': 0.7247729158642858, 'precision@10': 0.5334829461118636, 'precision@15': 0.42001690280584975, 'ndcg@5': 0.6230391025945494, 'ndcg@10': 0.6045941658846037, 'ndcg@15': 0.6248741633980561, 'ndcg@all': 0.7499385144230485})
valid defaultdict(<class 'list'>, {'precision@5': 0.7176004113174799, 'precision@10': 0.5279316695656363, 'precision@15': 0.41559660927517206, 'ndcg@5': 0.617548103407612, 'ndcg@10': 0.5981208301904634, 'ndcg@15': 0.6179929309942591, 'ndcg@all': 0.7449476939486707})
E

In [20]:
train_config["lr"] = 5
train_decoder(doc_embs, targets, train_config)

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch 0 7.501709296008733 7.501690526646892

train defaultdict(<class 'list'>, {'precision@5': 0.022337360059291426, 'precision@10': 0.019658091666095202, 'precision@15': 0.018254153861028338, 'ndcg@5': 0.017015773577236385, 'ndcg@10': 0.018378217660594644, 'ndcg@15': 0.02078088575867251, 'ndcg@all': 0.23687931625392494})
valid defaultdict(<class 'list'>, {'precision@5': 0.021403483643130525, 'precision@10': 0.018907182057952786, 'precision@15': 0.01764458327693498, 'ndcg@5': 0.015868491485832244, 'ndcg@10': 0.01722623937301279, 'ndcg@15': 0.019593687126720983, 'ndcg@all': 0.23594586316525468})
Epoch 1 7.501683787634456 7.50167838044054
Epoch 2 7.501674499769178 7.501671092716728
Epoch 3 7.501668371151215 7.501665945128193
Epoch 4 7.501663847902897 7.501661976491373
Epoch 5 7.501660325529873 7.501658893945649
Epoch 6 7.501657424278817 7.501656186862255
Epoch 7 7.5016549419364456 7.501653900296669
Epoch 8 7.501652814137788 7.501651936628687
Epoch 9 7.501650975981499 7.501650243293582
Ep

Epoch 81 7.50162484487613 7.501625049771286
Epoch 82 7.501624749937798 7.501624985942691
Epoch 83 7.501624690936574 7.501624839512382
Epoch 84 7.501624575616002 7.50162472311906
Epoch 85 7.501624479068546 7.501624640517347
Epoch 86 7.501624379302841 7.501624535387895
Epoch 87 7.501624309574123 7.5016244828231695
Epoch 88 7.501624208735668 7.50162438144834
Epoch 89 7.501624146516197 7.501624287582758
Epoch 90 7.501624055332488 7.501624291337381

train defaultdict(<class 'list'>, {'precision@5': 0.7305898085085694, 'precision@10': 0.5394043059196193, 'precision@15': 0.4246666476348388, 'ndcg@5': 0.6277407942123971, 'ndcg@10': 0.6102030126396797, 'ndcg@15': 0.6306240109410409, 'ndcg@all': 0.7539232413465716})
valid defaultdict(<class 'list'>, {'precision@5': 0.7235203901613791, 'precision@10': 0.5328621568642263, 'precision@15': 0.4194147889539013, 'ndcg@5': 0.6226497541262409, 'ndcg@10': 0.603596143365845, 'ndcg@15': 0.6235418709244315, 'ndcg@all': 0.7491142876504913})
Epoch 91 7.5016239

Epoch 167 7.501620534568589 7.5016207845192255
Epoch 168 7.501620497022357 7.50162072069063
Epoch 169 7.501620462157997 7.501620694408267
Epoch 170 7.501620443921255 7.50162069816289

train defaultdict(<class 'list'>, {'precision@5': 0.787073296020216, 'precision@10': 0.591139620482184, 'precision@15': 0.46450323493327816, 'ndcg@5': 0.6762056010780506, 'ndcg@10': 0.6629712680193383, 'ndcg@15': 0.684763642076164, 'ndcg@all': 0.7922941619955649})
valid defaultdict(<class 'list'>, {'precision@5': 0.7795891029628244, 'precision@10': 0.5842887022363858, 'precision@15': 0.459163566743295, 'ndcg@5': 0.6709150859690088, 'ndcg@10': 0.6563125933249165, 'ndcg@15': 0.6777562057878089, 'ndcg@all': 0.7876165757967731})
Epoch 171 7.501620424611764 7.501620668125904
Epoch 172 7.501620417638892 7.501620664371281
Epoch 173 7.501620385456407 7.501620600542684
Epoch 174 7.501620332891681 7.501620578014944
Epoch 175 7.501620311436691 7.501620559241828
Epoch 176 7.50162029105445 7.501620514186349
Epoch 177 

Epoch 251 7.501618858933851 7.501619203822819
Epoch 252 7.50161884981548 7.501619132484977
Epoch 253 7.501618852497354 7.501619106202614
Epoch 254 7.501618823533117 7.50161910244799
Epoch 255 7.501618811196498 7.501619094938744
Epoch 256 7.5016187875960085 7.501619109957237
Epoch 257 7.501618780086762 7.501619068656381
Epoch 258 7.501618777404889 7.501619046128641
Epoch 259 7.501618750049776 7.501619027355525
Epoch 260 7.501618748977026 7.501619046128641

train defaultdict(<class 'list'>, {'precision@5': 0.8160793125562378, 'precision@10': 0.6192800833752506, 'precision@15': 0.48590242819120966, 'ndcg@5': 0.7017009220724031, 'ndcg@10': 0.6914295319497116, 'ndcg@15': 0.7135996333123327, 'ndcg@all': 0.8113101419099136})
valid defaultdict(<class 'list'>, {'precision@5': 0.8088802798526493, 'precision@10': 0.6123233841160151, 'precision@15': 0.4807737887374998, 'ndcg@5': 0.6962866553171413, 'ndcg@10': 0.684627094137387, 'ndcg@15': 0.7068443551776916, 'ndcg@all': 0.8067340925922544})
Epoch 

Epoch 336 7.501618013607235 7.5016183064678525
Epoch 337 7.501617992152245 7.501618298958606
Epoch 338 7.501618001806991 7.50161826892162
Epoch 339 7.501618008243488 7.501618231375386
Epoch 340 7.501618001806991 7.501618238884633

train defaultdict(<class 'list'>, {'precision@5': 0.8314559401355525, 'precision@10': 0.6350278162446682, 'precision@15': 0.4975674161760826, 'ndcg@5': 0.7153381418994093, 'ndcg@10': 0.7070143324049573, 'ndcg@15': 0.7292161527581102, 'ndcg@all': 0.82128196384829})
valid defaultdict(<class 'list'>, {'precision@5': 0.8244099077277296, 'precision@10': 0.6284573186100937, 'precision@15': 0.49245739569814184, 'ndcg@5': 0.7102739942355418, 'ndcg@10': 0.7008792410685322, 'ndcg@15': 0.7228522338266448, 'ndcg@all': 0.8171091891649201})
Epoch 341 7.501617967479007 7.501618257657749
Epoch 342 7.501617973379129 7.501618238884633
Epoch 343 7.501617951387764 7.501618227620763
Epoch 344 7.5016179535332626 7.501618197583777
Epoch 345 7.50161794763314 7.501618253903126
Epoch 


train defaultdict(<class 'list'>, {'precision@5': 0.8419647637851059, 'precision@10': 0.6458759600856039, 'precision@15': 0.5055136171717477, 'ndcg@5': 0.7247931868743038, 'ndcg@10': 0.7178143864273354, 'ndcg@15': 0.7399126753227783, 'ndcg@all': 0.8280403947937475})
valid defaultdict(<class 'list'>, {'precision@5': 0.8352014403643571, 'precision@10': 0.6390047617784635, 'precision@15': 0.5002133280742825, 'ndcg@5': 0.7197689155894001, 'ndcg@10': 0.7115723598660446, 'ndcg@15': 0.7333772135531809, 'ndcg@all': 0.8238641802720198})
Epoch 421 7.5016174986874695 7.5016178183668245
Epoch 422 7.50161748849635 7.501617799593708
Epoch 423 7.501617457922988 7.501617758292851
Epoch 424 7.501617479914353 7.501617750783605
Epoch 425 7.501617468650483 7.501617762047474
Epoch 426 7.50161745363199 7.501617769556721
Epoch 427 7.501617463286736 7.501617758292851
Epoch 428 7.501617462213986 7.501617773311344
Epoch 429 7.501617456850239 7.501617765802098
Epoch 430 7.5016174402226214 7.501617713237372

tra

In [None]:
# train_config["lr"] = 0.05
# train_config["criterion"] = "MultiLabelMarginLossCustomV:1"
# train_decoder(doc_embs, targets, train_config)

In [None]:
# train_config["lr"] = 0.1
# train_config["criterion"] = "MultiLabelMarginLossCustom:1"
# train_decoder(doc_embs, targets, train_config)