In [10]:
import os
import argparse
import sys
from gensim.models import Word2Vec
import re

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict,OrderedDict
from contextualized_topic_models.models.ctm import ZeroShotTM

sys.path.append("../")

from utils.data_processing import get_process_data
from utils.data_loader import load_document
from utils.loss import ListNet
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all
from utils.toolbox import same_seeds, get_freer_gpu

In [2]:
device = f"cuda:{get_freer_gpu()}"
config = {
    "topk":[10,30,50]
}

In [3]:
seed = 123
seed_everything(123)

Global seed set to 123


123

In [4]:
dataset = "IMDB"
docvec = np.load("../data/docvec_IMDB_SBERT_768d.npy")
dim = 768
raw_documents = load_document(dataset)["documents"]

Reusing dataset imdb (/dhome/roytsai/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


In [5]:
# get TF-IDF score
vectorizer = TfidfVectorizer(min_df=10,stop_words="english")
importance_score = vectorizer.fit_transform(raw_documents).todense()

In [6]:
vocab_size = len(vectorizer.vocabulary_)
print(f"Vocab size:{vocab_size}")

Vocab size:25768


In [7]:
class IDEDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [8]:
dataset = IDEDataset(docvec, importance_score)
train_length = int(len(dataset)*0.6)
valid_length = int(len(dataset)*0.2)
test_length = len(dataset) - train_length - valid_length

full_loader = DataLoader(dataset, batch_size=128)
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, lengths=[train_length, valid_length,test_length],
    generator=torch.Generator().manual_seed(42)
    )

In [9]:
train_loader = DataLoader(
    train_dataset, batch_size=128, 
    shuffle=True, pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset, batch_size=128, shuffle=False, pin_memory=True,drop_last=True)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False)

In [24]:
def evaluate_Decoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(target,doc_embs)
        pred = pred[-2]
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [54]:
# decoder = Decoder(input_dim=dim,hidden_dim=1024,output_dim=vocab_size)
model = ZeroShotTM(bow_size=vocab_size, contextual_size=768,n_components=1024,hidden_sizes=(1024,1024),activation="relu")
decodernet = model.model
# optimizer = torch.optim.Adam(decodernet.parameters(), lr = 1e-3)
optimizer = model.optimizer
# initialize parameters
for p in decodernet.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# decoder.load_pretrianed(pretrain_wordembedding)
decodernet = decodernet.to(device)

In [55]:
# early stop settings 
stop_rounds = 300
no_improvement = 0
best_score = None 


for epoch in range(100):
    # Training
    decodernet.train()
    train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        batch = [i.to(device) for i in batch]
        doc_embs, target, _, _ = batch
        prior_mean, prior_variance, posterior_mean, posterior_variance,\
            posterior_log_variance, word_dists,_ = decodernet(target,doc_embs)
        kl_loss, rl_loss = model._loss(target,word_dists, prior_mean, prior_variance,
                        posterior_mean, posterior_variance, posterior_log_variance)
        loss = 1e-3*kl_loss + rl_loss
        loss = loss.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

    print(f"[Epoch {epoch+1:02d}]")
    res = evaluate_Decoder(decodernet, valid_loader)
    for key,val in res.items():
        print(f"{key}:{val:.4f}")
        
     # early stopping 
    current_score = res["precision@10"]
    if best_score == None:
        best_score = current_score
        continue
    if current_score < best_score:
        no_improvement += 1
    if no_improvement >= stop_rounds:
        print("Early stopping...")
        break 
    if current_score > best_score:
        no_improvement = 0
        best_score = current_score

Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 01]
precision@10:0.0550
precision@30:0.0475
precision@50:0.0437
ndcg@10:0.0234
ndcg@30:0.0262
ndcg@50:0.0286
ndcg@all:0.3180


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 02]
precision@10:0.0845
precision@30:0.0661
precision@50:0.0588
ndcg@10:0.0383
ndcg@30:0.0393
ndcg@50:0.0411
ndcg@all:0.3310


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 03]
precision@10:0.1058
precision@30:0.0807
precision@50:0.0711
ndcg@10:0.0494
ndcg@30:0.0501
ndcg@50:0.0519
ndcg@all:0.3415


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 04]
precision@10:0.1219
precision@30:0.0914
precision@50:0.0796
ndcg@10:0.0606
ndcg@30:0.0602
ndcg@50:0.0619
ndcg@all:0.3531


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 05]
precision@10:0.1440
precision@30:0.1053
precision@50:0.0908
ndcg@10:0.0748
ndcg@30:0.0728
ndcg@50:0.0740
ndcg@all:0.3660


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 06]
precision@10:0.1498
precision@30:0.1094
precision@50:0.0939
ndcg@10:0.0832
ndcg@30:0.0800
ndcg@50:0.0810
ndcg@all:0.3718


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 07]
precision@10:0.1453
precision@30:0.1052
precision@50:0.0896
ndcg@10:0.0865
ndcg@30:0.0828
ndcg@50:0.0833
ndcg@all:0.3717


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 08]
precision@10:0.1456
precision@30:0.1043
precision@50:0.0884
ndcg@10:0.0929
ndcg@30:0.0877
ndcg@50:0.0873
ndcg@all:0.3754


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 09]
precision@10:0.1512
precision@30:0.1067
precision@50:0.0897
ndcg@10:0.0999
ndcg@30:0.0935
ndcg@50:0.0926
ndcg@all:0.3780


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 10]
precision@10:0.1553
precision@30:0.1094
precision@50:0.0915
ndcg@10:0.1076
ndcg@30:0.0999
ndcg@50:0.0983
ndcg@all:0.3800


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 11]
precision@10:0.1707
precision@30:0.1197
precision@50:0.1003
ndcg@10:0.1189
ndcg@30:0.1100
ndcg@50:0.1086
ndcg@all:0.3928


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 12]
precision@10:0.1663
precision@30:0.1147
precision@50:0.0948
ndcg@10:0.1215
ndcg@30:0.1111
ndcg@50:0.1083
ndcg@all:0.3859


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 13]
precision@10:0.1797
precision@30:0.1239
precision@50:0.1024
ndcg@10:0.1326
ndcg@30:0.1210
ndcg@50:0.1181
ndcg@all:0.3975


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 14]
precision@10:0.1798
precision@30:0.1225
precision@50:0.1010
ndcg@10:0.1365
ndcg@30:0.1233
ndcg@50:0.1199
ndcg@all:0.3955


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 15]
precision@10:0.1855
precision@30:0.1257
precision@50:0.1038
ndcg@10:0.1416
ndcg@30:0.1274
ndcg@50:0.1237
ndcg@all:0.3989


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 16]
precision@10:0.1918
precision@30:0.1312
precision@50:0.1086
ndcg@10:0.1469
ndcg@30:0.1328
ndcg@50:0.1293
ndcg@all:0.4056


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 17]
precision@10:0.1905
precision@30:0.1274
precision@50:0.1044
ndcg@10:0.1511
ndcg@30:0.1349
ndcg@50:0.1300
ndcg@all:0.4014


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 18]
precision@10:0.1926
precision@30:0.1290
precision@50:0.1052
ndcg@10:0.1557
ndcg@30:0.1386
ndcg@50:0.1335
ndcg@all:0.4054


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 19]
precision@10:0.1922
precision@30:0.1280
precision@50:0.1036
ndcg@10:0.1580
ndcg@30:0.1401
ndcg@50:0.1344
ndcg@all:0.4034


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 20]
precision@10:0.1986
precision@30:0.1323
precision@50:0.1075
ndcg@10:0.1622
ndcg@30:0.1439
ndcg@50:0.1382
ndcg@all:0.4084


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 21]
precision@10:0.1973
precision@30:0.1301
precision@50:0.1052
ndcg@10:0.1642
ndcg@30:0.1451
ndcg@50:0.1389
ndcg@all:0.4068


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 22]
precision@10:0.1997
precision@30:0.1316
precision@50:0.1063
ndcg@10:0.1665
ndcg@30:0.1466
ndcg@50:0.1400
ndcg@all:0.4082


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 23]
precision@10:0.2072
precision@30:0.1366
precision@50:0.1110
ndcg@10:0.1728
ndcg@30:0.1519
ndcg@50:0.1453
ndcg@all:0.4146


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 24]
precision@10:0.2064
precision@30:0.1354
precision@50:0.1095
ndcg@10:0.1729
ndcg@30:0.1520
ndcg@50:0.1452
ndcg@all:0.4129


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 25]
precision@10:0.2065
precision@30:0.1369
precision@50:0.1109
ndcg@10:0.1735
ndcg@30:0.1526
ndcg@50:0.1458
ndcg@all:0.4133


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 26]
precision@10:0.2110
precision@30:0.1387
precision@50:0.1115
ndcg@10:0.1777
ndcg@30:0.1558
ndcg@50:0.1485
ndcg@all:0.4146


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 27]
precision@10:0.2099
precision@30:0.1378
precision@50:0.1111
ndcg@10:0.1787
ndcg@30:0.1566
ndcg@50:0.1495
ndcg@all:0.4157


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 28]
precision@10:0.2101
precision@30:0.1371
precision@50:0.1104
ndcg@10:0.1793
ndcg@30:0.1567
ndcg@50:0.1494
ndcg@all:0.4147


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 29]
precision@10:0.2098
precision@30:0.1374
precision@50:0.1104
ndcg@10:0.1809
ndcg@30:0.1583
ndcg@50:0.1506
ndcg@all:0.4146


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 30]
precision@10:0.2123
precision@30:0.1383
precision@50:0.1112
ndcg@10:0.1825
ndcg@30:0.1593
ndcg@50:0.1516
ndcg@all:0.4163


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 31]
precision@10:0.2183
precision@30:0.1424
precision@50:0.1145
ndcg@10:0.1861
ndcg@30:0.1626
ndcg@50:0.1550
ndcg@all:0.4197


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 32]
precision@10:0.2190
precision@30:0.1431
precision@50:0.1149
ndcg@10:0.1872
ndcg@30:0.1636
ndcg@50:0.1556
ndcg@all:0.4199


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 33]
precision@10:0.2171
precision@30:0.1401
precision@50:0.1124
ndcg@10:0.1879
ndcg@30:0.1633
ndcg@50:0.1551
ndcg@all:0.4195


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 34]
precision@10:0.2167
precision@30:0.1403
precision@50:0.1127
ndcg@10:0.1878
ndcg@30:0.1630
ndcg@50:0.1551
ndcg@all:0.4175


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 35]
precision@10:0.2164
precision@30:0.1386
precision@50:0.1104
ndcg@10:0.1899
ndcg@30:0.1642
ndcg@50:0.1555
ndcg@all:0.4164


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 36]
precision@10:0.2186
precision@30:0.1417
precision@50:0.1130
ndcg@10:0.1895
ndcg@30:0.1648
ndcg@50:0.1561
ndcg@all:0.4185


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 37]
precision@10:0.2219
precision@30:0.1424
precision@50:0.1140
ndcg@10:0.1935
ndcg@30:0.1674
ndcg@50:0.1591
ndcg@all:0.4210


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 38]
precision@10:0.2212
precision@30:0.1432
precision@50:0.1143
ndcg@10:0.1914
ndcg@30:0.1663
ndcg@50:0.1576
ndcg@all:0.4207


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 39]
precision@10:0.2219
precision@30:0.1427
precision@50:0.1137
ndcg@10:0.1940
ndcg@30:0.1681
ndcg@50:0.1591
ndcg@all:0.4210


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 40]
precision@10:0.2247
precision@30:0.1439
precision@50:0.1150
ndcg@10:0.1947
ndcg@30:0.1686
ndcg@50:0.1599
ndcg@all:0.4218


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 41]
precision@10:0.2250
precision@30:0.1445
precision@50:0.1151
ndcg@10:0.1975
ndcg@30:0.1709
ndcg@50:0.1618
ndcg@all:0.4228


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 42]
precision@10:0.2217
precision@30:0.1424
precision@50:0.1128
ndcg@10:0.1952
ndcg@30:0.1688
ndcg@50:0.1594
ndcg@all:0.4200


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 43]
precision@10:0.2276
precision@30:0.1459
precision@50:0.1163
ndcg@10:0.1981
ndcg@30:0.1713
ndcg@50:0.1622
ndcg@all:0.4240


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 44]
precision@10:0.2243
precision@30:0.1445
precision@50:0.1150
ndcg@10:0.1963
ndcg@30:0.1700
ndcg@50:0.1611
ndcg@all:0.4214


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 45]
precision@10:0.2248
precision@30:0.1440
precision@50:0.1141
ndcg@10:0.1970
ndcg@30:0.1703
ndcg@50:0.1609
ndcg@all:0.4208


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 46]
precision@10:0.2256
precision@30:0.1449
precision@50:0.1151
ndcg@10:0.1983
ndcg@30:0.1717
ndcg@50:0.1622
ndcg@all:0.4228


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 47]
precision@10:0.2256
precision@30:0.1447
precision@50:0.1150
ndcg@10:0.1984
ndcg@30:0.1715
ndcg@50:0.1621
ndcg@all:0.4232


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 48]
precision@10:0.2202
precision@30:0.1407
precision@50:0.1112
ndcg@10:0.1962
ndcg@30:0.1693
ndcg@50:0.1598
ndcg@all:0.4191


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 49]
precision@10:0.2278
precision@30:0.1464
precision@50:0.1160
ndcg@10:0.2008
ndcg@30:0.1735
ndcg@50:0.1638
ndcg@all:0.4232


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 50]
precision@10:0.2271
precision@30:0.1449
precision@50:0.1147
ndcg@10:0.2010
ndcg@30:0.1733
ndcg@50:0.1635
ndcg@all:0.4233


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 51]
precision@10:0.2291
precision@30:0.1466
precision@50:0.1167
ndcg@10:0.2023
ndcg@30:0.1745
ndcg@50:0.1651
ndcg@all:0.4255


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 52]
precision@10:0.2297
precision@30:0.1459
precision@50:0.1158
ndcg@10:0.2012
ndcg@30:0.1733
ndcg@50:0.1639
ndcg@all:0.4228


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 53]
precision@10:0.2265
precision@30:0.1444
precision@50:0.1147
ndcg@10:0.2011
ndcg@30:0.1734
ndcg@50:0.1639
ndcg@all:0.4238


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 54]
precision@10:0.2249
precision@30:0.1431
precision@50:0.1129
ndcg@10:0.2003
ndcg@30:0.1727
ndcg@50:0.1629
ndcg@all:0.4208


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 55]
precision@10:0.2299
precision@30:0.1464
precision@50:0.1154
ndcg@10:0.2043
ndcg@30:0.1761
ndcg@50:0.1658
ndcg@all:0.4240


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 56]
precision@10:0.2343
precision@30:0.1502
precision@50:0.1187
ndcg@10:0.2043
ndcg@30:0.1771
ndcg@50:0.1672
ndcg@all:0.4259


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 57]
precision@10:0.2275
precision@30:0.1443
precision@50:0.1140
ndcg@10:0.2042
ndcg@30:0.1756
ndcg@50:0.1654
ndcg@all:0.4242


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 58]
precision@10:0.2316
precision@30:0.1470
precision@50:0.1163
ndcg@10:0.2055
ndcg@30:0.1767
ndcg@50:0.1668
ndcg@all:0.4257


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 59]
precision@10:0.2300
precision@30:0.1458
precision@50:0.1158
ndcg@10:0.2055
ndcg@30:0.1767
ndcg@50:0.1668
ndcg@all:0.4254


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 60]
precision@10:0.2288
precision@30:0.1448
precision@50:0.1147
ndcg@10:0.2056
ndcg@30:0.1766
ndcg@50:0.1667
ndcg@all:0.4250


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 61]
precision@10:0.2338
precision@30:0.1482
precision@50:0.1170
ndcg@10:0.2067
ndcg@30:0.1777
ndcg@50:0.1675
ndcg@all:0.4239


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 62]
precision@10:0.2349
precision@30:0.1493
precision@50:0.1180
ndcg@10:0.2078
ndcg@30:0.1788
ndcg@50:0.1686
ndcg@all:0.4259


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 63]
precision@10:0.2338
precision@30:0.1492
precision@50:0.1184
ndcg@10:0.2071
ndcg@30:0.1784
ndcg@50:0.1685
ndcg@all:0.4269


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 64]
precision@10:0.2342
precision@30:0.1494
precision@50:0.1182
ndcg@10:0.2074
ndcg@30:0.1787
ndcg@50:0.1686
ndcg@all:0.4269


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 65]
precision@10:0.2337
precision@30:0.1486
precision@50:0.1178
ndcg@10:0.2077
ndcg@30:0.1786
ndcg@50:0.1685
ndcg@all:0.4259


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 66]
precision@10:0.2344
precision@30:0.1485
precision@50:0.1173
ndcg@10:0.2077
ndcg@30:0.1786
ndcg@50:0.1684
ndcg@all:0.4257


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 67]
precision@10:0.2274
precision@30:0.1434
precision@50:0.1127
ndcg@10:0.2068
ndcg@30:0.1770
ndcg@50:0.1662
ndcg@all:0.4229


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 68]
precision@10:0.2334
precision@30:0.1475
precision@50:0.1165
ndcg@10:0.2087
ndcg@30:0.1790
ndcg@50:0.1686
ndcg@all:0.4260


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 69]
precision@10:0.2366
precision@30:0.1501
precision@50:0.1184
ndcg@10:0.2111
ndcg@30:0.1812
ndcg@50:0.1704
ndcg@all:0.4286


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 70]
precision@10:0.2371
precision@30:0.1501
precision@50:0.1190
ndcg@10:0.2093
ndcg@30:0.1798
ndcg@50:0.1697
ndcg@all:0.4291


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 71]
precision@10:0.2327
precision@30:0.1461
precision@50:0.1152
ndcg@10:0.2087
ndcg@30:0.1786
ndcg@50:0.1681
ndcg@all:0.4241


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 72]
precision@10:0.2364
precision@30:0.1498
precision@50:0.1180
ndcg@10:0.2110
ndcg@30:0.1810
ndcg@50:0.1704
ndcg@all:0.4280


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 73]
precision@10:0.2325
precision@30:0.1471
precision@50:0.1160
ndcg@10:0.2086
ndcg@30:0.1793
ndcg@50:0.1687
ndcg@all:0.4253


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 74]
precision@10:0.2361
precision@30:0.1494
precision@50:0.1175
ndcg@10:0.2117
ndcg@30:0.1815
ndcg@50:0.1706
ndcg@all:0.4279


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 75]
precision@10:0.2373
precision@30:0.1494
precision@50:0.1180
ndcg@10:0.2118
ndcg@30:0.1814
ndcg@50:0.1708
ndcg@all:0.4285


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 76]
precision@10:0.2379
precision@30:0.1507
precision@50:0.1187
ndcg@10:0.2107
ndcg@30:0.1812
ndcg@50:0.1707
ndcg@all:0.4281


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 77]
precision@10:0.2340
precision@30:0.1483
precision@50:0.1165
ndcg@10:0.2104
ndcg@30:0.1804
ndcg@50:0.1697
ndcg@all:0.4267


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 78]
precision@10:0.2341
precision@30:0.1487
precision@50:0.1175
ndcg@10:0.2109
ndcg@30:0.1810
ndcg@50:0.1706
ndcg@all:0.4271


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 79]
precision@10:0.2363
precision@30:0.1497
precision@50:0.1176
ndcg@10:0.2115
ndcg@30:0.1816
ndcg@50:0.1706
ndcg@all:0.4278


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 80]
precision@10:0.2380
precision@30:0.1500
precision@50:0.1185
ndcg@10:0.2119
ndcg@30:0.1820
ndcg@50:0.1714
ndcg@all:0.4280


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 81]
precision@10:0.2353
precision@30:0.1485
precision@50:0.1169
ndcg@10:0.2125
ndcg@30:0.1819
ndcg@50:0.1711
ndcg@all:0.4272


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 82]
precision@10:0.2352
precision@30:0.1477
precision@50:0.1163
ndcg@10:0.2124
ndcg@30:0.1813
ndcg@50:0.1704
ndcg@all:0.4265


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 83]
precision@10:0.2380
precision@30:0.1495
precision@50:0.1178
ndcg@10:0.2139
ndcg@30:0.1829
ndcg@50:0.1720
ndcg@all:0.4281


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 84]
precision@10:0.2347
precision@30:0.1482
precision@50:0.1168
ndcg@10:0.2121
ndcg@30:0.1815
ndcg@50:0.1707
ndcg@all:0.4273


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 85]
precision@10:0.2356
precision@30:0.1474
precision@50:0.1162
ndcg@10:0.2137
ndcg@30:0.1823
ndcg@50:0.1714
ndcg@all:0.4267


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 86]
precision@10:0.2372
precision@30:0.1490
precision@50:0.1172
ndcg@10:0.2127
ndcg@30:0.1818
ndcg@50:0.1709
ndcg@all:0.4266


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 87]
precision@10:0.2358
precision@30:0.1478
precision@50:0.1160
ndcg@10:0.2135
ndcg@30:0.1822
ndcg@50:0.1711
ndcg@all:0.4273


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 88]
precision@10:0.2380
precision@30:0.1503
precision@50:0.1179
ndcg@10:0.2144
ndcg@30:0.1836
ndcg@50:0.1723
ndcg@all:0.4286


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 89]
precision@10:0.2407
precision@30:0.1513
precision@50:0.1196
ndcg@10:0.2158
ndcg@30:0.1844
ndcg@50:0.1736
ndcg@all:0.4307


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 90]
precision@10:0.2348
precision@30:0.1477
precision@50:0.1164
ndcg@10:0.2136
ndcg@30:0.1824
ndcg@50:0.1716
ndcg@all:0.4266


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 91]
precision@10:0.2371
precision@30:0.1498
precision@50:0.1181
ndcg@10:0.2142
ndcg@30:0.1835
ndcg@50:0.1728
ndcg@all:0.4295


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 92]
precision@10:0.2360
precision@30:0.1481
precision@50:0.1159
ndcg@10:0.2135
ndcg@30:0.1823
ndcg@50:0.1710
ndcg@all:0.4251


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 93]
precision@10:0.2403
precision@30:0.1514
precision@50:0.1195
ndcg@10:0.2161
ndcg@30:0.1850
ndcg@50:0.1743
ndcg@all:0.4306


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 94]
precision@10:0.2383
precision@30:0.1499
precision@50:0.1182
ndcg@10:0.2151
ndcg@30:0.1837
ndcg@50:0.1728
ndcg@all:0.4290


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 95]
precision@10:0.2399
precision@30:0.1509
precision@50:0.1192
ndcg@10:0.2164
ndcg@30:0.1849
ndcg@50:0.1740
ndcg@all:0.4309


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 96]
precision@10:0.2402
precision@30:0.1517
precision@50:0.1193
ndcg@10:0.2167
ndcg@30:0.1856
ndcg@50:0.1743
ndcg@all:0.4301


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 97]
precision@10:0.2418
precision@30:0.1523
precision@50:0.1201
ndcg@10:0.2172
ndcg@30:0.1855
ndcg@50:0.1746
ndcg@all:0.4302


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 98]
precision@10:0.2393
precision@30:0.1513
precision@50:0.1185
ndcg@10:0.2164
ndcg@30:0.1854
ndcg@50:0.1741
ndcg@all:0.4289


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 99]
precision@10:0.2383
precision@30:0.1498
precision@50:0.1178
ndcg@10:0.2156
ndcg@30:0.1843
ndcg@50:0.1732
ndcg@all:0.4286


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 100]
precision@10:0.2422
precision@30:0.1530
precision@50:0.1208
ndcg@10:0.2178
ndcg@30:0.1863
ndcg@50:0.1755
ndcg@all:0.4320


In [None]:
print("Testing...")
res = evaluate_Decoder(decoder, test_loader)
for key,val in res.items():
    print(f"{key}:{val:.4f}")