## Siamese network 
Steps:
1. load word embeding and document embedding
2. create pytorch dataset and dataloader
3. Try Contrastive loss and triplet loss
4. further improve negative sampling (e.g. hard negative or word2vec negative sampling)

In [1]:
import numpy as np 
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
# load word embedding
word2embedding = dict()
embedding_file = "../data/glove.6B.100d.txt"
word_dim = int(re.findall(r".(\d+)d",embedding_file)[0])
with open(embedding_file,"r") as f:
    for line in f:
        line = line.strip().split()
        word = line[0]
        embedding = list(map(float,line[1:]))
        word2embedding[word] = embedding

print("Number of words:%d" % len(word2embedding))

Number of words:400000


In [3]:
class Vocabulary:
    def __init__(self, freq_threshold, word2embedding):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        return text.strip().split()

    def build_vocabulary(self, sentence_list):
        self.frequencies = {}
        self.word_vectors = [[0]*word_dim] # init zero padding
        idx = 1
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            for word in self.tokenizer_eng(sentence):
                if word not in word2embedding:
                    continue
                if word not in self.frequencies:
                    self.frequencies[word] = 1

                else:
                    self.frequencies[word] += 1

                if self.frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    self.word_vectors.append(word2embedding[word])
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [4]:
class CBowDataset(Dataset):
    def __init__(self, 
                 data_file_path,
                 word2embedding,
                 freq_threshold=20,
                 skip_header = False,
                 max_length = None,
                 ):
        # read data
        self.document_vectors = []
        docEmb_file = open("../data/docvector.txt","r")
        with open(data_file_path,'r',encoding='utf-8') as f:
            if skip_header:
                f.readline()
            self.documents = []
            for line in tqdm(f, desc="Loading documents"):
                if max_length is not None and len(self.documents) >= max_length:
                    break
                self.documents.append(line.strip("\n"))
                doc_vec = docEmb_file.readline().strip().split()
                doc_vec = list(map(float, doc_vec))
                self.document_vectors.append(doc_vec)
        
        # build vocabulary
        self.vocab = Vocabulary(freq_threshold,word2embedding)
        self.vocab.build_vocabulary(self.documents)
        self.vocab_size = len(self.vocab)
        self.words_tokenized = [self.vocab.numericalize(text) for text in self.documents]
        
        # train-test split
        # training
        self.train_length = int(len(self.words_tokenized)*0.8)
        self.train_vectors = self.document_vectors[:self.train_length]
        self.train_words = self.words_tokenized[:self.train_length]
        self.document_ids = list(range(self.train_length))
        self.generator = cycle(self.context_target_generator())
        self.dataset_size = sum([len(s) for s in self.train_words])
        
        # testing
        self.test_vectors = self.document_vectors[self.train_length:]
        self.test_words = self.words_tokenized[self.train_length:]



    def context_target_generator(self):
        np.random.shuffle(self.document_ids) # inplace shuffle

        # randomly select a document and create its training example
        for document_id in self.document_ids: 
            word_list = set(self.train_words[document_id])
            negative_sample_space = list(set(range(self.vocab_size)) - word_list)
            negative_samples = np.random.choice(negative_sample_space,size=len(word_list),replace = False)
            for word_id, negative_wordID in zip(word_list, negative_samples):
                yield [document_id, word_id, negative_wordID]
                
    def __getitem__(self, idx):
        doc_id, word_id, negative_wordID = next(self.generator)
        doc_id = torch.FloatTensor(self.document_vectors[doc_id])
        word_id = torch.FloatTensor(self.vocab.word_vectors[word_id])
        negative_word = torch.FloatTensor(self.vocab.word_vectors[negative_wordID])

        return doc_id, word_id, negative_word

    def __len__(self):
        return self.dataset_size 


In [5]:
# load and build torch dataset
data_file_path = '../data/IMDB.txt'
# checkpoint_path = "doc2vecC_lr0.001.pt"
print("Building dataset....")
dataset = CBowDataset(
                    data_file_path=data_file_path,
                    word2embedding=word2embedding,
                    max_length=None,
                    freq_threshold=20,
                    skip_header=False
                    )
print("Finish building dataset!")
print(f"Number of documents:{len(dataset.documents)}")
print(f"Number of words:{dataset.vocab_size}")

Loading documents: 2519it [00:00, 25182.58it/s]

Building dataset....


Loading documents: 100000it [00:04, 24255.79it/s]
Preprocessing documents: 100%|██████████| 100000/100000 [00:10<00:00, 9279.95it/s]


Finish building dataset!
Number of documents:100000
Number of words:27961


In [6]:
class TestDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 ans_words,
                 ):
        self.doc_vectors = doc_vectors
        self.ans_words = ans_words
        assert len(doc_vectors) == len(ans_words)
        
    def __getitem__(self, idx):
        doc_vec = torch.FloatTensor(self.doc_vectors[idx])
        ans_w = torch.tensor(list(set(self.ans_words[idx])))
        return doc_vec, ans_w

    def collate_fn(self,batch):
        # Batch: List of tuples [(batch1), (batch2)]
        
        doc_vec = torch.cat([item[0].unsqueeze(0) for item in batch], dim=0)
        ans_w = [item[1] for item in batch]
        ans_w = pad_sequence(ans_w, batch_first=True, padding_value=-1)
        
        return doc_vec, ans_w 

    def __len__(self):
        return len(self.doc_vectors)


In [7]:
class TripletNet(nn.Module):
    def __init__(self, hdim):
        super(TripletNet, self).__init__()
        self.fc = nn.Sequential(nn.Linear(hdim, 512),
                        nn.PReLU(),
                        nn.Linear(512, 512),
                        nn.PReLU(),
                        nn.Linear(512, 2)
                        )

    def forward(self, x1, x2, x3):
        output1 = self.fc(x1)
        output2 = self.fc(x2)
        output3 = self.fc(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.fc(x)

In [8]:
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

In [9]:
margin = 1.
BATCH_SIZE = 2048
EPOCH = 300

device = "cuda:0"
model = TripletNet(word_dim).to(device)
loss_fn = TripletLoss(margin).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)

In [10]:
train_loader = DataLoader(
                        dataset, 
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        shuffle=True,
                        )

In [11]:
test_docvec = dataset.test_vectors
test_ans = dataset.test_words
test_dataset = TestDataset(test_docvec,test_ans)
test_loader = DataLoader(test_dataset,                         
                         batch_size=BATCH_SIZE,
                         num_workers=4,
                         collate_fn=test_dataset.collate_fn)
word_embedding_tensor = torch.FloatTensor(dataset.vocab.word_vectors).to(device)

In [12]:
def evaluate(test_word_emb, loader,Ks = [50,100,150,200]):
    avg_precision, avg_recall = [], []
    for batch in test_loader:
        batch = [item.to(device) for item in batch]
        emb, ans = batch
        emb = model.get_embedding(emb)
        scores = torch.cdist(emb, test_word_emb)
        ans_length = torch.sum((~ans.eq(-1)).float(), dim=-1)
        mask = ~ans.eq(-1).unsqueeze(-1)
        
        # calculate precision and recall
        tmp_pr, tmp_re = [],[]
        for K in Ks:
            top_indices = torch.argsort(scores,dim=1)[:,:K]
            hit = top_indices.unsqueeze(-2) == ans.unsqueeze(-1)
            hit = torch.sum((hit * mask).flatten(1),dim=-1)
            precision = hit / torch.min(ans_length,torch.tensor(K).to(device))
            recall = hit / ans_length
            tmp_pr.append(precision)
            tmp_re.append(recall)
        tmp_pr = torch.stack(tmp_pr).T.detach().cpu().numpy().tolist()
        tmp_re = torch.stack(tmp_re).T.detach().cpu().numpy().tolist()
        avg_precision.extend(tmp_pr)
        avg_recall.extend(tmp_re)
        
    avg_precision = np.mean(avg_precision,axis=0)
    avg_recall = np.mean(avg_recall, axis=0)
    for idx, kval in enumerate(Ks):
        print(f"[K={kval}] Precision:{avg_precision[idx]:.4f} Recall:{avg_recall[idx]:.4f}")
    return avg_precision, avg_recall

In [13]:
for epoch in range(EPOCH):
    avg_loss = []
    model.train()
    for batch in tqdm(train_loader):
        batch = [item.to(device) for item in batch]
        doc_id,pos_w,neg_w = batch
        optimizer.zero_grad()
        loss = loss_fn(*model(doc_id,pos_w,neg_w))
        loss.backward()
        optimizer.step()
        avg_loss.append(loss.item())
    avg_loss = np.mean(avg_loss)
    print(f"Loss:{avg_loss:4f}")
    
    # evaluate
    model.eval()
    test_word_emb = model.get_embedding(word_embedding_tensor)
    res = evaluate(test_word_emb,test_loader)

100%|██████████| 9047/9047 [06:30<00:00, 23.18it/s]

Loss:0.189930



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.2319 Recall:0.0955
[K=100] Precision:0.2493 Recall:0.1869
[K=150] Precision:0.2997 Recall:0.2711
[K=200] Precision:0.3607 Recall:0.3471


100%|██████████| 9047/9047 [05:54<00:00, 25.51it/s]

Loss:0.183412



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5651 Recall:0.2413
[K=100] Precision:0.4574 Recall:0.3459
[K=150] Precision:0.4553 Recall:0.4120
[K=200] Precision:0.4752 Recall:0.4571


100%|██████████| 9047/9047 [05:50<00:00, 25.79it/s]

Loss:0.183480



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4457 Recall:0.1866
[K=100] Precision:0.4238 Recall:0.3201
[K=150] Precision:0.4264 Recall:0.3857
[K=200] Precision:0.4631 Recall:0.4457


100%|██████████| 9047/9047 [05:43<00:00, 26.37it/s]

Loss:0.181225



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5474 Recall:0.2327
[K=100] Precision:0.4530 Recall:0.3427
[K=150] Precision:0.4463 Recall:0.4041
[K=200] Precision:0.4720 Recall:0.4542


100%|██████████| 9047/9047 [05:54<00:00, 25.54it/s]

Loss:0.180449



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5637 Recall:0.2388
[K=100] Precision:0.4643 Recall:0.3513
[K=150] Precision:0.4527 Recall:0.4098
[K=200] Precision:0.4772 Recall:0.4591


100%|██████████| 9047/9047 [05:49<00:00, 25.88it/s]

Loss:0.179355



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5574 Recall:0.2360
[K=100] Precision:0.4609 Recall:0.3481
[K=150] Precision:0.4540 Recall:0.4107
[K=200] Precision:0.4768 Recall:0.4586


100%|██████████| 9047/9047 [05:50<00:00, 25.83it/s]

Loss:0.178145



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5553 Recall:0.2353
[K=100] Precision:0.4622 Recall:0.3494
[K=150] Precision:0.4546 Recall:0.4112
[K=200] Precision:0.4780 Recall:0.4598


 64%|██████▍   | 5801/9047 [03:49<01:43, 31.45it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:45<00:00, 26.19it/s]

Loss:0.179775



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4964 Recall:0.2097
[K=100] Precision:0.4410 Recall:0.3324
[K=150] Precision:0.4446 Recall:0.4016
[K=200] Precision:0.4677 Recall:0.4496


 49%|████▉     | 4448/9047 [03:01<02:24, 31.76it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:52<00:00, 25.69it/s]

Loss:0.181444



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5004 Recall:0.2123
[K=100] Precision:0.4438 Recall:0.3340
[K=150] Precision:0.4441 Recall:0.4011
[K=200] Precision:0.4642 Recall:0.4462


 35%|███▍      | 3163/9047 [02:14<03:33, 27.60it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [06:35<00:00, 22.88it/s]

Loss:0.191374



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4923 Recall:0.2078
[K=100] Precision:0.4399 Recall:0.3320
[K=150] Precision:0.4461 Recall:0.4031
[K=200] Precision:0.4680 Recall:0.4500


  8%|▊         | 733/9047 [00:34<19:02,  7.28it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:45<00:00, 26.15it/s]

Loss:0.180400



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4945 Recall:0.2089
[K=100] Precision:0.4427 Recall:0.3341
[K=150] Precision:0.4469 Recall:0.4039
[K=200] Precision:0.4694 Recall:0.4514


 24%|██▍       | 2149/9047 [01:32<03:41, 31.15it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:52<00:00, 25.67it/s]

Loss:0.176195



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5005 Recall:0.2124
[K=100] Precision:0.4423 Recall:0.3338
[K=150] Precision:0.4441 Recall:0.4014
[K=200] Precision:0.4658 Recall:0.4478


 93%|█████████▎| 8377/9047 [05:20<00:20, 32.36it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:53<00:00, 25.59it/s]

Loss:0.177499



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4865 Recall:0.2053
[K=100] Precision:0.4404 Recall:0.3323
[K=150] Precision:0.4457 Recall:0.4028
[K=200] Precision:0.4672 Recall:0.4491


  7%|▋         | 649/9047 [00:30<04:30, 31.09it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:46<00:00, 26.11it/s]

Loss:0.174651



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4610 Recall:0.1942
[K=100] Precision:0.4300 Recall:0.3242
[K=150] Precision:0.4384 Recall:0.3959
[K=200] Precision:0.4638 Recall:0.4459


 53%|█████▎    | 4831/9047 [03:13<02:12, 31.90it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 65%|██████▍   | 5850/9047 [04:01<01:37, 32.86it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:50<00:00, 25.78it/s]

Loss:0.178876



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5039 Recall:0.2137
[K=100] Precision:0.4412 Recall:0.3330
[K=150] Precision:0.4461 Recall:0.4032
[K=200] Precision:0.4676 Recall:0.4496


 54%|█████▍    | 4877/9047 [03:17<02:10, 32.04it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 28%|██▊       | 2538/9047 [01:46<03:29, 31.14it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:43<00:00, 26.35it/s]

Loss:0.174499



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4882 Recall:0.2060
[K=100] Precision:0.4379 Recall:0.3303
[K=150] Precision:0.4433 Recall:0.4007
[K=200] Precision:0.4646 Recall:0.4467


 20%|█▉        | 1788/9047 [01:17<03:55, 30.77it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 90%|████████▉ | 8125/9047 [05:15<00:28, 31.80it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:54<00:00, 25.55it/s]

Loss:0.174004



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4940 Recall:0.2081
[K=100] Precision:0.4464 Recall:0.3366
[K=150] Precision:0.4485 Recall:0.4053
[K=200] Precision:0.4693 Recall:0.4512


 81%|████████  | 7332/9047 [04:53<00:54, 31.24it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 53%|█████▎    | 4818/9047 [03:15<02:20, 30.18it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:54<00:00, 25.51it/s]

Loss:0.176838



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4598 Recall:0.1928
[K=100] Precision:0.4204 Recall:0.3168
[K=150] Precision:0.4358 Recall:0.3939
[K=200] Precision:0.4595 Recall:0.4419


 44%|████▍     | 4014/9047 [02:44<11:37,  7.22it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [06:14<00:00, 24.13it/s]

Loss:0.172757



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5016 Recall:0.2117
[K=100] Precision:0.4400 Recall:0.3318
[K=150] Precision:0.4417 Recall:0.3992
[K=200] Precision:0.4629 Recall:0.4451


 17%|█▋        | 1518/9047 [01:06<03:52, 32.42it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:47<00:00, 26.01it/s]

Loss:0.177057



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4522 Recall:0.1890
[K=100] Precision:0.4299 Recall:0.3236
[K=150] Precision:0.4368 Recall:0.3946
[K=200] Precision:0.4587 Recall:0.4410


  9%|▊         | 777/9047 [00:36<04:20, 31.73it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 9047/9047 [05:48<00:00, 25.98it/s]

Loss:0.172568



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5073 Recall:0.2142
[K=100] Precision:0.4441 Recall:0.3347
[K=150] Precision:0.4473 Recall:0.4042
[K=200] Precision:0.4679 Recall:0.4499


100%|██████████| 9047/9047 [05:52<00:00, 25.69it/s]

Loss:0.173663



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5093 Recall:0.2150
[K=100] Precision:0.4440 Recall:0.3346
[K=150] Precision:0.4478 Recall:0.4047
[K=200] Precision:0.4674 Recall:0.4495


100%|██████████| 9047/9047 [05:47<00:00, 26.06it/s]

Loss:0.183238



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4967 Recall:0.2094
[K=100] Precision:0.4427 Recall:0.3337
[K=150] Precision:0.4470 Recall:0.4040
[K=200] Precision:0.4675 Recall:0.4495


100%|██████████| 9047/9047 [05:52<00:00, 25.69it/s]

Loss:0.175671



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5021 Recall:0.2117
[K=100] Precision:0.4461 Recall:0.3365
[K=150] Precision:0.4479 Recall:0.4048
[K=200] Precision:0.4699 Recall:0.4519


100%|██████████| 9047/9047 [05:52<00:00, 25.65it/s]

Loss:0.172573



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.4983 Recall:0.2099
[K=100] Precision:0.4381 Recall:0.3303
[K=150] Precision:0.4428 Recall:0.4001
[K=200] Precision:0.4645 Recall:0.4466


100%|██████████| 9047/9047 [05:51<00:00, 25.76it/s]

Loss:0.185931



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.3978 Recall:0.1644
[K=100] Precision:0.4009 Recall:0.3017
[K=150] Precision:0.4268 Recall:0.3856
[K=200] Precision:0.4454 Recall:0.4282


100%|██████████| 9047/9047 [05:52<00:00, 25.68it/s]

Loss:0.178193



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5029 Recall:0.2125
[K=100] Precision:0.4418 Recall:0.3331
[K=150] Precision:0.4452 Recall:0.4024
[K=200] Precision:0.4675 Recall:0.4496


100%|██████████| 9047/9047 [05:54<00:00, 25.52it/s]

Loss:0.175348



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.5006 Recall:0.2106
[K=100] Precision:0.4384 Recall:0.3300
[K=150] Precision:0.4453 Recall:0.4023
[K=200] Precision:0.4677 Recall:0.4497


100%|██████████| 9047/9047 [05:58<00:00, 25.26it/s]

Loss:0.172986



  0%|          | 0/9047 [00:00<?, ?it/s]

[K=50] Precision:0.3665 Recall:0.1531
[K=100] Precision:0.3538 Recall:0.2677
[K=150] Precision:0.3849 Recall:0.3490
[K=200] Precision:0.4163 Recall:0.4009


 95%|█████████▌| 8637/9047 [05:37<00:16, 25.56it/s]


KeyboardInterrupt: 

In [14]:
validTrain_docvec = dataset.train_vectors
validTrain_ans = dataset.train_words
validTrain_dataset = TestDataset(validTrain_docvec,validTrain_ans)
validTrain_loader = DataLoader(validTrain_dataset,                         
                         batch_size=BATCH_SIZE,
                         num_workers=4,
                         collate_fn=validTrain_dataset.collate_fn)

In [15]:
# evaluate
model.eval()
test_word_emb = model.get_embedding(word_embedding_tensor)
res = evaluate(test_word_emb,validTrain_loader)

[K=50] Precision:0.4818 Recall:0.2027
[K=100] Precision:0.4392 Recall:0.3311
[K=150] Precision:0.4447 Recall:0.4017
[K=200] Precision:0.4674 Recall:0.4494


In [17]:
!nvidia-htop.py --color

Mon Aug 23 06:11:38 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 455.45.01    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
[31m|   0  GeForce RTX 3090    On   | 00000000:1D:00.0 Off |                  N/A |[0m
[31m| 76%   62C    P2   333W / 350W |  21435MiB / 24268MiB |    100%      Default |[0m
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
[31m|   1  GeForce RTX 3090    On   | 00000000:B4:00.0 Off |                  N/A |[0m
[31m| 58%   60C    P2   331W / 350W |  14178MiB / 24268MiB |    100