## Siamese network 
Steps:
1. load word embeding and document embedding
2. create pytorch dataset and dataloader
3. Try Contrastive loss and triplet loss
4. further improve negative sampling (e.g. hard negative or word2vec negative sampling)

In [1]:
import numpy as np 
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from tqdm import tqdm

In [3]:
# load word embedding
word2embedding = dict()
embedding_file = "../data/glove.6B.100d.txt"
word_dim = int(re.findall(r".(\d+)d",embedding_file)[0])
with open(embedding_file,"r") as f:
    for line in f:
        line = line.strip().split()
        word = line[0]
        embedding = list(map(float,line[1:]))
        word2embedding[word] = embedding

print("Number of words:%d" % len(word2embedding))

Number of words:400000


In [4]:
class Vocabulary:
    def __init__(self, freq_threshold, word2embedding):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        return text.strip().split()

    def build_vocabulary(self, sentence_list):
        self.frequencies = {}
        self.word_vectors = [[0]*word_dim] # init zero padding
        idx = 1
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            for word in self.tokenizer_eng(sentence):
                if word not in word2embedding:
                    continue
                if word not in self.frequencies:
                    self.frequencies[word] = 1

                else:
                    self.frequencies[word] += 1

                if self.frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    self.word_vectors.append(word2embedding[word])
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [5]:
class CBowDataset(Dataset):
    def __init__(self, 
                 data_file_path,
                 word2embedding,
                 freq_threshold=20,
                 skip_header = False,
                 max_length = None,
                 ):
        # read data
        self.document_vectors = []
        docEmb_file = open("../data/docvector.txt","r")
        with open(data_file_path,'r',encoding='utf-8') as f:
            if skip_header:
                f.readline()
            self.documents = []
            for line in tqdm(f, desc="Loading documents"):
                if max_length is not None and len(self.documents) >= max_length:
                    break
                self.documents.append(line.strip("\n"))
                doc_vec = docEmb_file.readline().strip().split()
                doc_vec = list(map(float, doc_vec))
                self.document_vectors.append(doc_vec)
        
        # build vocabulary
        self.vocab = Vocabulary(freq_threshold,word2embedding)
        self.vocab.build_vocabulary(self.documents)
        self.vocab_size = len(self.vocab)
        self.words_tokenized = [self.vocab.numericalize(text) for text in self.documents]
        
        # train-test split
        # training
        self.train_length = int(len(self.words_tokenized)*0.8)
        self.train_vectors = self.document_vectors[:self.train_length]
        self.train_words = self.words_tokenized[:self.train_length]
        self.document_ids = list(range(self.train_length))
        self.generator = cycle(self.context_target_generator())
        self.dataset_size = sum([len(s) for s in self.train_words])
        
        # testing
        self.test_vectors = self.document_vectors[self.train_length:]
        self.test_words = self.words_tokenized[self.train_length:]



    def context_target_generator(self):
        np.random.shuffle(self.document_ids) # inplace shuffle

        # randomly select a document and create its training example
        for document_id in self.document_ids: 
            word_list = set(self.train_words[document_id])
            negative_sample_space = list(set(range(self.vocab_size)) - word_list)
            negative_samples = np.random.choice(negative_sample_space,size=len(word_list),replace = False)
            for word_id, negative_wordID in zip(word_list, negative_samples):
                yield [document_id, word_id, negative_wordID]
                
    def __getitem__(self, idx):
        doc_id, word_id, negative_wordID = next(self.generator)
        doc_id = torch.FloatTensor(self.document_vectors[doc_id])
        word_id = torch.FloatTensor(self.vocab.word_vectors[word_id])
        negative_word = torch.FloatTensor(self.vocab.word_vectors[negative_wordID])

        return doc_id, word_id, negative_word

    def __len__(self):
        return self.dataset_size 


In [6]:
# load and build torch dataset
data_file_path = '../data/IMDB.txt'
# checkpoint_path = "doc2vecC_lr0.001.pt"
print("Building dataset....")
dataset = CBowDataset(
                    data_file_path=data_file_path,
                    word2embedding=word2embedding,
                    max_length=None,
                    freq_threshold=20,
                    skip_header=False
                    )
print("Finish building dataset!")
print(f"Number of documents:{len(dataset.documents)}")
print(f"Number of words:{dataset.vocab_size}")

Building dataset....


Loading documents: 100000it [00:04, 23543.61it/s]
Preprocessing documents: 100%|██████████| 100000/100000 [00:11<00:00, 8660.21it/s]


Finish building dataset!
Number of documents:100000
Number of words:27961


In [7]:
class TestDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 ans_words,
                 ):
        self.doc_vectors = doc_vectors
        self.ans_words = ans_words
        assert len(doc_vectors) == len(ans_words)
        
    def __getitem__(self, idx):
        doc_vec = torch.FloatTensor(self.doc_vectors[idx])
        ans_w = torch.tensor(list(set(self.ans_words[idx])))
        return doc_vec, ans_w

    def collate_fn(self,batch):
        # Batch: List of tuples [(batch1), (batch2)]
        
        doc_vec = torch.cat([item[0].unsqueeze(0) for item in batch], dim=0)
        ans_w = [item[1] for item in batch]
        ans_w = pad_sequence(ans_w, batch_first=True, padding_value=-1)
        
        return doc_vec, ans_w 

    def __len__(self):
        return len(self.doc_vectors)


In [8]:
class TripletNet(nn.Module):
    def __init__(self, hdim):
        super(TripletNet, self).__init__()
        self.fc = nn.Sequential(nn.Linear(hdim, 256),
                        nn.PReLU(),
                        nn.Linear(256, 256),
                        nn.PReLU(),
                        nn.Linear(256, 2)
                        )


    def forward(self, x1, x2, x3):
        output1 = self.fc(x1)
        output2 = self.fc(x2)
        output3 = self.fc(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.fc(x)

In [9]:
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

In [10]:
margin = 1.
BATCH_SIZE = 1024
EPOCH = 300

device = "cuda:0"
model = TripletNet(word_dim).to(device)
loss_fn = TripletLoss(margin).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

In [11]:
train_loader = DataLoader(
                        dataset, 
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        shuffle=True,
                        )

In [12]:
test_docvec = dataset.test_vectors
test_ans = dataset.test_words
test_dataset = TestDataset(test_docvec,test_ans)
test_loader = DataLoader(test_dataset,                         
                         batch_size=BATCH_SIZE,
                         num_workers=4,
                         collate_fn=test_dataset.collate_fn)
word_embedding_tensor = torch.FloatTensor(dataset.vocab.word_vectors).to(device)

In [13]:
def evaluate(test_word_emb, loader,Ks = [50,100,150,200]):
    avg_precision, avg_recall = [], []
    for batch in test_loader:
        batch = [item.to(device) for item in batch]
        emb, ans = batch
        emb = model.get_embedding(emb)
        scores = torch.cdist(emb, test_word_emb)
        ans_length = torch.sum((~ans.eq(-1)).float(), dim=-1)
        mask = ~ans.eq(-1).unsqueeze(-1)
        
        # calculate precision and recall
        tmp_pr, tmp_re = [],[]
        for K in Ks:
            top_indices = torch.argsort(scores,dim=1)[:,:K]
            hit = top_indices.unsqueeze(-2) == ans.unsqueeze(-1)
            hit = torch.sum((hit * mask).flatten(1),dim=-1)
            precision = hit / K
            recall = hit / ans_length
            tmp_pr.append(precision)
            tmp_re.append(recall)
        tmp_pr = torch.stack(tmp_pr).T.detach().cpu().numpy().tolist()
        tmp_re = torch.stack(tmp_re).T.detach().cpu().numpy().tolist()
        avg_precision.extend(tmp_pr)
        avg_recall.extend(tmp_re)
        
    avg_precision = np.mean(avg_precision,axis=0)
    avg_recall = np.mean(avg_recall, axis=0)
    for idx, kval in enumerate(Ks):
        print(f"[K={kval}] Precision:{avg_precision[idx]:.4f} Recall:{avg_recall[idx]:.4f}")
    return avg_precision, avg_recall

In [14]:
for epoch in range(EPOCH):
    avg_loss = []
    model.train()
    for batch in tqdm(train_loader):
        batch = [item.to(device) for item in batch]
        doc_id,pos_w,neg_w = batch
        optimizer.zero_grad()
        loss = loss_fn(*model(doc_id,pos_w,neg_w))
        loss.backward()
        optimizer.step()
        avg_loss.append(loss.item())
    avg_loss = np.mean(avg_loss)
    print(f"Loss:{avg_loss:4f}")
    
    # evaluate
    model.eval()
    test_word_emb = model.get_embedding(word_embedding_tensor)
    res = evaluate(test_word_emb,test_loader)

100%|██████████| 18094/18094 [04:59<00:00, 60.37it/s]

Loss:0.193487





[K=50] Precision:0.4412 Recall:0.1846
[K=100] Precision:0.3930 Recall:0.3249
[K=150] Precision:0.3234 Recall:0.3952
[K=200] Precision:0.2779 Recall:0.4505


100%|██████████| 18094/18094 [05:07<00:00, 58.89it/s]

Loss:0.190775





[K=50] Precision:0.5442 Recall:0.2313
[K=100] Precision:0.4059 Recall:0.3371
[K=150] Precision:0.3312 Recall:0.4072
[K=200] Precision:0.2820 Recall:0.4582


100%|██████████| 18094/18094 [05:12<00:00, 57.86it/s]

Loss:0.198980





[K=50] Precision:0.5568 Recall:0.2391
[K=100] Precision:0.4091 Recall:0.3401
[K=150] Precision:0.3350 Recall:0.4122
[K=200] Precision:0.2812 Recall:0.4570


100%|██████████| 18094/18094 [05:16<00:00, 57.14it/s]

Loss:0.189145





[K=50] Precision:0.4947 Recall:0.2106
[K=100] Precision:0.4031 Recall:0.3345
[K=150] Precision:0.3298 Recall:0.4048
[K=200] Precision:0.2796 Recall:0.4535


100%|██████████| 18094/18094 [05:10<00:00, 58.18it/s]

Loss:0.186144





[K=50] Precision:0.4765 Recall:0.2002
[K=100] Precision:0.3932 Recall:0.3254
[K=150] Precision:0.3262 Recall:0.4002
[K=200] Precision:0.2780 Recall:0.4507


100%|██████████| 18094/18094 [05:15<00:00, 57.31it/s]

Loss:0.183112





[K=50] Precision:0.4795 Recall:0.2028
[K=100] Precision:0.3989 Recall:0.3305
[K=150] Precision:0.3278 Recall:0.4021
[K=200] Precision:0.2773 Recall:0.4496


100%|██████████| 18094/18094 [05:21<00:00, 56.20it/s]

Loss:0.201365





[K=50] Precision:0.4838 Recall:0.2043
[K=100] Precision:0.4069 Recall:0.3370
[K=150] Precision:0.3306 Recall:0.4052
[K=200] Precision:0.2789 Recall:0.4524


100%|██████████| 18094/18094 [05:08<00:00, 58.70it/s]

Loss:0.184679





[K=50] Precision:0.4944 Recall:0.2100
[K=100] Precision:0.4086 Recall:0.3384
[K=150] Precision:0.3318 Recall:0.4068
[K=200] Precision:0.2810 Recall:0.4556


100%|██████████| 18094/18094 [05:13<00:00, 57.72it/s]

Loss:0.182497





[K=50] Precision:0.2530 Recall:0.1011
[K=100] Precision:0.2743 Recall:0.2246
[K=150] Precision:0.2564 Recall:0.3139
[K=200] Precision:0.2266 Recall:0.3661


100%|██████████| 18094/18094 [05:11<00:00, 58.12it/s]

Loss:0.183641





[K=50] Precision:0.5073 Recall:0.2151
[K=100] Precision:0.4080 Recall:0.3382
[K=150] Precision:0.3320 Recall:0.4073
[K=200] Precision:0.2805 Recall:0.4547


100%|██████████| 18094/18094 [05:13<00:00, 57.68it/s]

Loss:0.180014





[K=50] Precision:0.4766 Recall:0.2023
[K=100] Precision:0.3989 Recall:0.3300
[K=150] Precision:0.3287 Recall:0.4023
[K=200] Precision:0.2785 Recall:0.4510


100%|██████████| 18094/18094 [05:11<00:00, 58.11it/s]

Loss:0.179828





[K=50] Precision:0.4689 Recall:0.1979
[K=100] Precision:0.3998 Recall:0.3304
[K=150] Precision:0.3275 Recall:0.4007
[K=200] Precision:0.2780 Recall:0.4501


100%|██████████| 18094/18094 [05:10<00:00, 58.36it/s]

Loss:0.179244





[K=50] Precision:0.4828 Recall:0.2030
[K=100] Precision:0.4060 Recall:0.3358
[K=150] Precision:0.3310 Recall:0.4061
[K=200] Precision:0.2790 Recall:0.4526


100%|██████████| 18094/18094 [05:09<00:00, 58.52it/s]

Loss:0.204652





[K=50] Precision:0.4602 Recall:0.1939
[K=100] Precision:0.3943 Recall:0.3272
[K=150] Precision:0.3247 Recall:0.3983
[K=200] Precision:0.2771 Recall:0.4498


100%|██████████| 18094/18094 [05:09<00:00, 58.48it/s]

Loss:0.206862





[K=50] Precision:0.4628 Recall:0.1963
[K=100] Precision:0.3946 Recall:0.3278
[K=150] Precision:0.3259 Recall:0.3989
[K=200] Precision:0.2771 Recall:0.4489


100%|██████████| 18094/18094 [05:12<00:00, 57.85it/s]

Loss:0.184887





[K=50] Precision:0.4212 Recall:0.1789
[K=100] Precision:0.3771 Recall:0.3138
[K=150] Precision:0.3206 Recall:0.3926
[K=200] Precision:0.2722 Recall:0.4407


100%|██████████| 18094/18094 [05:07<00:00, 58.80it/s]

Loss:0.181621





[K=50] Precision:0.4819 Recall:0.2055
[K=100] Precision:0.4013 Recall:0.3324
[K=150] Precision:0.3308 Recall:0.4045
[K=200] Precision:0.2786 Recall:0.4506


100%|██████████| 18094/18094 [05:09<00:00, 58.56it/s]

Loss:0.204388





[K=50] Precision:0.4155 Recall:0.1742
[K=100] Precision:0.3766 Recall:0.3126
[K=150] Precision:0.3185 Recall:0.3901
[K=200] Precision:0.2713 Recall:0.4394


100%|██████████| 18094/18094 [05:10<00:00, 58.29it/s]

Loss:0.185851





[K=50] Precision:0.4662 Recall:0.1987
[K=100] Precision:0.3971 Recall:0.3291
[K=150] Precision:0.3314 Recall:0.4063
[K=200] Precision:0.2799 Recall:0.4539


100%|██████████| 18094/18094 [05:33<00:00, 54.26it/s]

Loss:0.186749





[K=50] Precision:0.4996 Recall:0.2135
[K=100] Precision:0.4060 Recall:0.3361
[K=150] Precision:0.3311 Recall:0.4050
[K=200] Precision:0.2802 Recall:0.4545


100%|██████████| 18094/18094 [05:29<00:00, 54.91it/s]

Loss:0.186609





[K=50] Precision:0.4299 Recall:0.1825
[K=100] Precision:0.3872 Recall:0.3209
[K=150] Precision:0.3285 Recall:0.4023
[K=200] Precision:0.2775 Recall:0.4502


100%|██████████| 18094/18094 [05:22<00:00, 56.07it/s]

Loss:0.193406





[K=50] Precision:0.4096 Recall:0.1750
[K=100] Precision:0.3844 Recall:0.3188
[K=150] Precision:0.3274 Recall:0.4009
[K=200] Precision:0.2785 Recall:0.4514


100%|██████████| 18094/18094 [06:01<00:00, 50.05it/s]

Loss:0.183956





[K=50] Precision:0.4530 Recall:0.1921
[K=100] Precision:0.3896 Recall:0.3226
[K=150] Precision:0.3280 Recall:0.4016
[K=200] Precision:0.2787 Recall:0.4516


100%|██████████| 18094/18094 [05:08<00:00, 58.62it/s]

Loss:0.186776





[K=50] Precision:0.2994 Recall:0.1240
[K=100] Precision:0.3028 Recall:0.2510
[K=150] Precision:0.2809 Recall:0.3459
[K=200] Precision:0.2541 Recall:0.4128


100%|██████████| 18094/18094 [05:08<00:00, 58.64it/s]

Loss:0.184965





[K=50] Precision:0.4866 Recall:0.2064
[K=100] Precision:0.3962 Recall:0.3287
[K=150] Precision:0.3303 Recall:0.4055
[K=200] Precision:0.2798 Recall:0.4538


100%|██████████| 18094/18094 [05:14<00:00, 57.58it/s]

Loss:0.215963





[K=50] Precision:0.4543 Recall:0.1906
[K=100] Precision:0.3946 Recall:0.3268
[K=150] Precision:0.3283 Recall:0.4028
[K=200] Precision:0.2795 Recall:0.4530


100%|██████████| 18094/18094 [05:11<00:00, 58.12it/s]

Loss:0.190888





[K=50] Precision:0.4801 Recall:0.2039
[K=100] Precision:0.3925 Recall:0.3275
[K=150] Precision:0.3267 Recall:0.4020
[K=200] Precision:0.2801 Recall:0.4543


100%|██████████| 18094/18094 [10:39<00:00, 28.31it/s]

Loss:0.181396





[K=50] Precision:0.4809 Recall:0.2028
[K=100] Precision:0.3991 Recall:0.3297
[K=150] Precision:0.3281 Recall:0.4016
[K=200] Precision:0.2788 Recall:0.4518


100%|██████████| 18094/18094 [21:43<00:00, 13.89it/s]

Loss:0.183599





[K=50] Precision:0.5044 Recall:0.2140
[K=100] Precision:0.4059 Recall:0.3358
[K=150] Precision:0.3310 Recall:0.4057
[K=200] Precision:0.2801 Recall:0.4538


100%|██████████| 18094/18094 [21:26<00:00, 14.06it/s]

Loss:0.185029





[K=50] Precision:0.4743 Recall:0.1996
[K=100] Precision:0.4060 Recall:0.3363
[K=150] Precision:0.3309 Recall:0.4058
[K=200] Precision:0.2804 Recall:0.4547


100%|██████████| 18094/18094 [21:44<00:00, 13.87it/s]

Loss:0.191136





[K=50] Precision:0.4380 Recall:0.1848
[K=100] Precision:0.3929 Recall:0.3267
[K=150] Precision:0.3290 Recall:0.4040
[K=200] Precision:0.2801 Recall:0.4545


100%|██████████| 18094/18094 [21:30<00:00, 14.02it/s]

Loss:0.185452





[K=50] Precision:0.5094 Recall:0.2181
[K=100] Precision:0.4045 Recall:0.3354
[K=150] Precision:0.3318 Recall:0.4074
[K=200] Precision:0.2776 Recall:0.4505


 80%|███████▉  | 14443/18094 [15:32<03:55, 15.49it/s]


KeyboardInterrupt: 