In [1]:
import os
import argparse
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict

sys.path.append("../")

from utils.data_processing import get_process_data
from utils.data_loader import load_document
from utils.loss import ListNet
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all

# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))

In [2]:
device = f"cuda:{get_freer_gpu()}"
config = {
    "topk":[10,30,50]
}

In [3]:
seed = 123
seed_everything(123)

Global seed set to 123


123

In [4]:
dataset = "IMDB"
docvec = np.load("../data/docvec_IMDB_SBERT_768d.npy")
dim = 768
raw_documents = load_document(dataset)["documents"]

Reusing dataset imdb (/dhome/roytsai/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


In [5]:
# get TF-IDF score
vectorizer = TfidfVectorizer(min_df=10,stop_words="english")
importance_score = vectorizer.fit_transform(raw_documents).todense()

In [6]:
vocab_size = len(vectorizer.vocabulary_)

In [7]:
class IDEDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [8]:
dataset = IDEDataset(docvec, importance_score)
train_length = int(len(dataset)*0.6)
valid_length = int(len(dataset)*0.2)
test_length = len(dataset) - train_length - valid_length

full_loader = DataLoader(dataset, batch_size=128)
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, lengths=[train_length, valid_length,test_length],
    generator=torch.Generator().manual_seed(42)
    )

In [9]:
train_loader = DataLoader(
    train_dataset, batch_size=128, 
    shuffle=True, pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset, batch_size=128, shuffle=False, pin_memory=True)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False)

In [15]:
class Decoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.Tanh(),  
            nn.Linear(1024, 4096),
            nn.Tanh(),
            nn.Linear(4096, output_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        return self.decoder(x)

In [16]:
def evaluate_Decoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [17]:
decoder = Decoder(input_dim=dim,output_dim=vocab_size).to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-4)

In [18]:
# early stop settings 
stop_rounds = 3
no_improvement = 0
best_score = None 


for epoch in range(100):
    # Training
    decoder.train()
    train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        batch = [i.to(device) for i in batch]
        doc_embs, target, _, _ = batch
        target = torch.nn.functional.normalize(target.to(device), dim=1)
        decoded = torch.nn.functional.normalize(decoder(doc_embs), dim=1)
        loss = ListNet(decoded, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

    print(f"[Epoch {epoch+1:02d}]")
    res = evaluate_Decoder(decoder, valid_loader)
    for key,val in res.items():
        print(f"{key}:{val:.4f}")
        
     # early stopping 
    current_score = res["precision@10"]
    if best_score == None:
        best_score = current_score
        continue
    if current_score < best_score:
        no_improvement += 1
    if no_improvement >= stop_rounds:
        print("Early stopping...")
        break 
    if current_score > best_score:
        no_improvement = 0
        best_score = current_score

Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 01]
precision@10:0.4172
precision@30:0.2762
precision@50:0.2260
ndcg@10:0.1803
ndcg@30:0.1539
ndcg@50:0.1491
ndcg@all:0.4379


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 02]
precision@10:0.4208
precision@30:0.2789
precision@50:0.2285
ndcg@10:0.1832
ndcg@30:0.1563
ndcg@50:0.1516
ndcg@all:0.4444


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 03]
precision@10:0.4260
precision@30:0.2840
precision@50:0.2413
ndcg@10:0.1843
ndcg@30:0.1591
ndcg@50:0.1623
ndcg@all:0.4542


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 04]
precision@10:0.4345
precision@30:0.3044
precision@50:0.2469
ndcg@10:0.1884
ndcg@30:0.1774
ndcg@50:0.1752
ndcg@all:0.4634


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 05]
precision@10:0.4514
precision@30:0.3081
precision@50:0.2507
ndcg@10:0.2031
ndcg@30:0.1881
ndcg@50:0.1855
ndcg@all:0.4711


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 06]
precision@10:0.4543
precision@30:0.3132
precision@50:0.2551
ndcg@10:0.2169
ndcg@30:0.1995
ndcg@50:0.1964
ndcg@all:0.4791


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 07]
precision@10:0.4574
precision@30:0.3167
precision@50:0.2582
ndcg@10:0.2296
ndcg@30:0.2107
ndcg@50:0.2068
ndcg@all:0.4861


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 08]
precision@10:0.4654
precision@30:0.3213
precision@50:0.2614
ndcg@10:0.2441
ndcg@30:0.2228
ndcg@50:0.2176
ndcg@all:0.4934


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 09]
precision@10:0.4713
precision@30:0.3253
precision@50:0.2641
ndcg@10:0.2680
ndcg@30:0.2403
ndcg@50:0.2332
ndcg@all:0.5049


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 10]
precision@10:0.4764
precision@30:0.3277
precision@50:0.2661
ndcg@10:0.2827
ndcg@30:0.2516
ndcg@50:0.2431
ndcg@all:0.5116


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 11]
precision@10:0.4831
precision@30:0.3305
precision@50:0.2691
ndcg@10:0.2955
ndcg@30:0.2610
ndcg@50:0.2517
ndcg@all:0.5175


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 12]
precision@10:0.4906
precision@30:0.3360
precision@50:0.2719
ndcg@10:0.3061
ndcg@30:0.2695
ndcg@50:0.2589
ndcg@all:0.5224


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 13]
precision@10:0.4944
precision@30:0.3387
precision@50:0.2738
ndcg@10:0.3166
ndcg@30:0.2779
ndcg@50:0.2663
ndcg@all:0.5273


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 14]
precision@10:0.5034
precision@30:0.3420
precision@50:0.2757
ndcg@10:0.3258
ndcg@30:0.2846
ndcg@50:0.2721
ndcg@all:0.5312


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 15]
precision@10:0.5029
precision@30:0.3395
precision@50:0.2741
ndcg@10:0.3342
ndcg@30:0.2908
ndcg@50:0.2776
ndcg@all:0.5347


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 16]
precision@10:0.5051
precision@30:0.3405
precision@50:0.2737
ndcg@10:0.3696
ndcg@30:0.3153
ndcg@50:0.2986
ndcg@all:0.5528


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 17]
precision@10:0.5074
precision@30:0.3416
precision@50:0.2751
ndcg@10:0.3801
ndcg@30:0.3224
ndcg@50:0.3049
ndcg@all:0.5576


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 18]
precision@10:0.5097
precision@30:0.3446
precision@50:0.2781
ndcg@10:0.3871
ndcg@30:0.3276
ndcg@50:0.3096
ndcg@all:0.5611


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 19]
precision@10:0.5125
precision@30:0.3434
precision@50:0.2762
ndcg@10:0.3956
ndcg@30:0.3331
ndcg@50:0.3142
ndcg@all:0.5644


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 20]
precision@10:0.5080
precision@30:0.3408
precision@50:0.2740
ndcg@10:0.4008
ndcg@30:0.3371
ndcg@50:0.3173
ndcg@all:0.5669


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 21]
precision@10:0.5128
precision@30:0.3417
precision@50:0.2747
ndcg@10:0.4055
ndcg@30:0.3397
ndcg@50:0.3200
ndcg@all:0.5688


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 22]
precision@10:0.5128
precision@30:0.3432
precision@50:0.2755
ndcg@10:0.4073
ndcg@30:0.3414
ndcg@50:0.3212
ndcg@all:0.5698


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 23]
precision@10:0.5146
precision@30:0.3431
precision@50:0.2753
ndcg@10:0.4109
ndcg@30:0.3438
ndcg@50:0.3235
ndcg@all:0.5713


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 24]
precision@10:0.5188
precision@30:0.3440
precision@50:0.2760
ndcg@10:0.4131
ndcg@30:0.3449
ndcg@50:0.3244
ndcg@all:0.5723


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 25]
precision@10:0.5130
precision@30:0.3430
precision@50:0.2758
ndcg@10:0.4139
ndcg@30:0.3456
ndcg@50:0.3252
ndcg@all:0.5726


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 26]
precision@10:0.5149
precision@30:0.3432
precision@50:0.2762
ndcg@10:0.4156
ndcg@30:0.3471
ndcg@50:0.3267
ndcg@all:0.5737


Training:   0%|          | 0/235 [00:00<?, ?it/s]

[Epoch 27]
precision@10:0.5153
precision@30:0.3456
precision@50:0.2772
ndcg@10:0.4156
ndcg@30:0.3474
ndcg@50:0.3267
ndcg@all:0.5739
Early stopping...


In [19]:
print("Testing...")
res = evaluate_Decoder(decoder, test_loader)
for key,val in res.items():
    print(f"{key}:{val:.4f}")

Testing...
precision@10:0.5181
precision@30:0.3473
precision@50:0.2792
ndcg@10:0.4156
ndcg@30:0.3470
ndcg@50:0.3266
ndcg@all:0.5746
