In [1]:
import os
import argparse
import sys
from gensim.models import Word2Vec
import re

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict

sys.path.append("../")

from utils.data_processing import get_process_data
from utils.data_loader import load_document
from utils.loss import ListNet
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all

# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))

In [2]:
device = f"cuda:{get_freer_gpu()}"
config = {
    "topk":[10,30,50]
}

In [3]:
seed = 123
seed_everything(123)

Global seed set to 123


123

In [4]:
dataset = "IMDB"
docvec = np.load("../data/docvec_IMDB_SBERT_768d.npy")
dim = 768
raw_documents = load_document(dataset)["documents"]

Reusing dataset imdb (/dhome/roytsai/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


In [5]:
# get TF-IDF score
vectorizer = TfidfVectorizer(min_df=10,stop_words="english")
importance_score = vectorizer.fit_transform(raw_documents).todense()

In [6]:
vocab_size = len(vectorizer.vocabulary_)
print(f"Vocab size:{vocab_size}")

Vocab size:25768


In [8]:
# # using word2vec as pretrained word embedding
# PATTERN = r"(?u)\b\w\w+\b"
# documents = [re.findall(PATTERN, i.lower()) for i in raw_documents]
# model = Word2Vec(sentences=documents, vector_size=512, window=5, min_count=10, workers=4,epochs=10)

# # get embedding
# valid_words = set(vocabulary) & set(model.wv.index_to_key)
# assert len(valid_words) == len(vocabulary)
# pretrain_wordembedding = np.array([model.wv[w] for w in valid_words]).T

In [9]:
class IDEDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [10]:
dataset = IDEDataset(docvec, importance_score)
train_length = int(len(dataset)*0.6)
valid_length = int(len(dataset)*0.2)
test_length = len(dataset) - train_length - valid_length

full_loader = DataLoader(dataset, batch_size=128)
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, lengths=[train_length, valid_length,test_length],
    generator=torch.Generator().manual_seed(42)
    )

In [12]:
train_loader = DataLoader(
    train_dataset, batch_size=128, 
    shuffle=True, pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset, batch_size=128, shuffle=False, pin_memory=True,drop_last=True)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False)

In [13]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.word_embedding = nn.Parameter(torch.randn(hidden_dim, output_dim))
        self.transform = nn.Sequential(
            nn.Linear(input_dim,hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Sigmoid(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        docvec = self.transform(x)
        decoded = torch.sigmoid(torch.matmul(docvec,self.word_embedding))
        return decoded
    
    def load_pretrianed(self,word_embedding):
        self.word_embedding = nn.Parameter(torch.FloatTensor(word_embedding))

In [14]:
def evaluate_Decoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [15]:
decoder = Decoder(input_dim=dim,hidden_dim=1024,output_dim=vocab_size)
optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-4)
# initialize parameters
for p in decoder.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# decoder.load_pretrianed(pretrain_wordembedding)
decoder = decoder.to(device)

In [16]:
def Mse_loss(pred,target):
    mse_loss = F.mse_loss(decoded,target,reduction="none")
    mask = target > 0
    loss = (mse_loss * mask).sum()
    return loss

In [17]:
# early stop settings 
stop_rounds = 3
no_improvement = 0
best_score = None 


for epoch in range(100):
    # Training
    decoder.train()
    train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        batch = [i.to(device) for i in batch]
        doc_embs, target, _, _ = batch
        target = torch.nn.functional.normalize(target.to(device), dim=1)
        decoded = torch.nn.functional.normalize(decoder(doc_embs), dim=1)
        loss = ListNet(decoded, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

    print(f"[Epoch {epoch+1:02d}]")
    res = evaluate_Decoder(decoder, valid_loader)
    for key,val in res.items():
        print(f"{key}:{val:.4f}")
        
     # early stopping 
    current_score = res["precision@10"]
    if best_score == None:
        best_score = current_score
        continue
    if current_score < best_score:
        no_improvement += 1
    if no_improvement >= stop_rounds:
        print("Early stopping...")
        break 
    if current_score > best_score:
        no_improvement = 0
        best_score = current_score

Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 01]
precision@10:0.3936
precision@30:0.2754
precision@50:0.2267
ndcg@10:0.1663
ndcg@30:0.1459
ndcg@50:0.1431
ndcg@all:0.4361


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 02]
precision@10:0.4227
precision@30:0.2800
precision@50:0.2285
ndcg@10:0.1777
ndcg@30:0.1527
ndcg@50:0.1484
ndcg@all:0.4475


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 03]
precision@10:0.4189
precision@30:0.3037
precision@50:0.2524
ndcg@10:0.1801
ndcg@30:0.1683
ndcg@50:0.1682
ndcg@all:0.4590


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 04]
precision@10:0.4643
precision@30:0.3240
precision@50:0.2642
ndcg@10:0.2071
ndcg@30:0.1893
ndcg@50:0.1861
ndcg@all:0.4714


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 05]
precision@10:0.4753
precision@30:0.3329
precision@50:0.2704
ndcg@10:0.2238
ndcg@30:0.2040
ndcg@50:0.2001
ndcg@all:0.4814


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 06]
precision@10:0.4832
precision@30:0.3391
precision@50:0.2757
ndcg@10:0.2462
ndcg@30:0.2222
ndcg@50:0.2166
ndcg@all:0.4939


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 07]
precision@10:0.4969
precision@30:0.3448
precision@50:0.2802
ndcg@10:0.2628
ndcg@30:0.2349
ndcg@50:0.2283
ndcg@all:0.5023


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 08]
precision@10:0.5065
precision@30:0.3502
precision@50:0.2847
ndcg@10:0.2768
ndcg@30:0.2462
ndcg@50:0.2387
ndcg@all:0.5097


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 09]
precision@10:0.5148
precision@30:0.3576
precision@50:0.2894
ndcg@10:0.3004
ndcg@30:0.2644
ndcg@50:0.2548
ndcg@all:0.5225


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 10]
precision@10:0.5231
precision@30:0.3616
precision@50:0.2926
ndcg@10:0.3198
ndcg@30:0.2790
ndcg@50:0.2679
ndcg@all:0.5320


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 11]
precision@10:0.5318
precision@30:0.3669
precision@50:0.2970
ndcg@10:0.3359
ndcg@30:0.2911
ndcg@50:0.2789
ndcg@all:0.5401


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 12]
precision@10:0.5391
precision@30:0.3696
precision@50:0.2983
ndcg@10:0.3505
ndcg@30:0.3027
ndcg@50:0.2892
ndcg@all:0.5474


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 13]
precision@10:0.5446
precision@30:0.3737
precision@50:0.3015
ndcg@10:0.3672
ndcg@30:0.3150
ndcg@50:0.3001
ndcg@all:0.5554


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 14]
precision@10:0.5540
precision@30:0.3766
precision@50:0.3030
ndcg@10:0.3759
ndcg@30:0.3219
ndcg@50:0.3058
ndcg@all:0.5599


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 15]
precision@10:0.5594
precision@30:0.3789
precision@50:0.3051
ndcg@10:0.3894
ndcg@30:0.3313
ndcg@50:0.3141
ndcg@all:0.5660


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 16]
precision@10:0.5618
precision@30:0.3810
precision@50:0.3068
ndcg@10:0.3974
ndcg@30:0.3374
ndcg@50:0.3195
ndcg@all:0.5700


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 17]
precision@10:0.5653
precision@30:0.3834
precision@50:0.3078
ndcg@10:0.4059
ndcg@30:0.3438
ndcg@50:0.3250
ndcg@all:0.5740


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 18]
precision@10:0.5680
precision@30:0.3853
precision@50:0.3098
ndcg@10:0.4115
ndcg@30:0.3478
ndcg@50:0.3288
ndcg@all:0.5769


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 19]
precision@10:0.5763
precision@30:0.3889
precision@50:0.3121
ndcg@10:0.4176
ndcg@30:0.3525
ndcg@50:0.3330
ndcg@all:0.5799


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 20]
precision@10:0.5776
precision@30:0.3904
precision@50:0.3127
ndcg@10:0.4230
ndcg@30:0.3569
ndcg@50:0.3368
ndcg@all:0.5827


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 21]
precision@10:0.5832
precision@30:0.3915
precision@50:0.3126
ndcg@10:0.4288
ndcg@30:0.3611
ndcg@50:0.3402
ndcg@all:0.5852


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 22]
precision@10:0.5820
precision@30:0.3936
precision@50:0.3152
ndcg@10:0.4302
ndcg@30:0.3627
ndcg@50:0.3419
ndcg@all:0.5862


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 23]
precision@10:0.5840
precision@30:0.3933
precision@50:0.3151
ndcg@10:0.4361
ndcg@30:0.3666
ndcg@50:0.3456
ndcg@all:0.5889


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 24]
precision@10:0.5861
precision@30:0.3940
precision@50:0.3153
ndcg@10:0.4374
ndcg@30:0.3678
ndcg@50:0.3466
ndcg@all:0.5896


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 25]
precision@10:0.5876
precision@30:0.3929
precision@50:0.3140
ndcg@10:0.4419
ndcg@30:0.3709
ndcg@50:0.3491
ndcg@all:0.5918


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 26]
precision@10:0.5904
precision@30:0.3959
precision@50:0.3169
ndcg@10:0.4422
ndcg@30:0.3711
ndcg@50:0.3496
ndcg@all:0.5920


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 27]
precision@10:0.5898
precision@30:0.3959
precision@50:0.3168
ndcg@10:0.4456
ndcg@30:0.3738
ndcg@50:0.3520
ndcg@all:0.5936


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 28]
precision@10:0.5925
precision@30:0.3972
precision@50:0.3167
ndcg@10:0.4473
ndcg@30:0.3755
ndcg@50:0.3533
ndcg@all:0.5946


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 29]
precision@10:0.5904
precision@30:0.3957
precision@50:0.3164
ndcg@10:0.4500
ndcg@30:0.3772
ndcg@50:0.3549
ndcg@all:0.5957


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 30]
precision@10:0.5926
precision@30:0.3962
precision@50:0.3166
ndcg@10:0.4504
ndcg@30:0.3773
ndcg@50:0.3549
ndcg@all:0.5959


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 31]
precision@10:0.5945
precision@30:0.3978
precision@50:0.3179
ndcg@10:0.4522
ndcg@30:0.3790
ndcg@50:0.3569
ndcg@all:0.5968


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 32]
precision@10:0.5988
precision@30:0.3999
precision@50:0.3180
ndcg@10:0.4522
ndcg@30:0.3794
ndcg@50:0.3567
ndcg@all:0.5969


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 33]
precision@10:0.5946
precision@30:0.3979
precision@50:0.3179
ndcg@10:0.4533
ndcg@30:0.3802
ndcg@50:0.3576
ndcg@all:0.5975


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 34]
precision@10:0.5958
precision@30:0.3984
precision@50:0.3178
ndcg@10:0.4516
ndcg@30:0.3785
ndcg@50:0.3560
ndcg@all:0.5966


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 35]
precision@10:0.5972
precision@30:0.3974
precision@50:0.3168
ndcg@10:0.4546
ndcg@30:0.3812
ndcg@50:0.3585
ndcg@all:0.5980
Early stopping...


In [18]:
print("Testing...")
res = evaluate_Decoder(decoder, test_loader)
for key,val in res.items():
    print(f"{key}:{val:.4f}")

Testing...
precision@10:0.5960
precision@30:0.3988
precision@50:0.3179
ndcg@10:0.4521
ndcg@30:0.3801
ndcg@50:0.3576
ndcg@all:0.5981
