## Siamese network 
Steps:
1. load word embeding and document embedding
2. create pytorch dataset and dataloader
3. Try Contrastive loss and triplet loss
4. further improve negative sampling (e.g. hard negative or word2vec negative sampling)

In [78]:
import numpy as np 
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

In [92]:
# load word embedding
word2embedding = dict()
embedding_file = "../../word/glove.6B.100d.txt"
word_dim = int(re.findall(r".(\d+)d",embedding_file)[0])
with open(embedding_file,"r") as f:
    for line in f:
        line = line.strip().split()
        word = line[0]
        embedding = list(map(float,line[1:]))
        word2embedding[word] = embedding
        word_embedding_matrix.append(embedding)
        stoi[word] = index
        itos[index] = word
        index += 1

print("Number of words:%d" % len(word2embedding))

Number of words:400000


In [93]:
word_dim

100

In [63]:
class Vocabulary:
    def __init__(self, freq_threshold, word2embedding):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        return text.strip().split()

    def build_vocabulary(self, sentence_list):
        self.frequencies = {}
        self.word_vectors = [[0]*word_dim] # init zero padding
        idx = 1
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            for word in self.tokenizer_eng(sentence):
                if word not in word2embedding:
                    continue
                if word not in self.frequencies:
                    self.frequencies[word] = 1

                else:
                    self.frequencies[word] += 1

                if self.frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    self.word_vectors.append(word2embedding[word])
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [105]:
class CBowDataset(Dataset):
    def __init__(self, 
                 data_file_path,
                 word2embedding,
                 freq_threshold=20,
                 skip_header = False,
                 max_length = None,
                 ):
        # read data
        self.document_vectors = []
        docEmb_file = open("../data/docvector.txt","r")
        with open(data_file_path,'r',encoding='utf-8') as f:
            if skip_header:
                f.readline()
            self.documents = []
            for line in tqdm(f, desc="Loading documents"):
                if max_length is not None and len(self.documents) >= max_length:
                    break
                self.documents.append(line.strip("\n"))
                doc_vec = docEmb_file.readline().strip().split()
                doc_vec = list(map(float, doc_vec))
                self.document_vectors.append(doc_vec)
        
        #build vocabulary
        self.vocab = Vocabulary(freq_threshold,word2embedding)
        self.vocab.build_vocabulary(self.documents)
        self.vocab_size = len(self.vocab)
        self.words_tokenized = [self.vocab.numericalize(text) for text in self.documents]
        self.document_ids = list(range(len(self.words_tokenized)))
        self.generator = cycle(self.context_target_generator())


    def context_target_generator(self):
        np.random.shuffle(self.document_ids) # inplace shuffle

        # randomly select a document and create its training example
        for document_id in self.document_ids: 
            word_list = set(self.words_tokenized[document_id])
            negative_sample_space = list(set(range(self.vocab_size)) - word_list)
            negative_samples = np.random.choice(negative_sample_space,size=len(word_list),replace = False)
            for word_id, negative_wordID in zip(word_list, negative_samples):
                yield [document_id, word_id, negative_wordID]
                
    def __getitem__(self, idx):
        doc_id, word_id, negative_wordID = next(self.generator)
        doc_id = torch.tensor(doc_id)
        word_id = torch.tensor(word_id)
        negative_word = torch.tensor(negative_wordID)

        return doc_id, word_id, negative_word

    def __len__(self):
        return 2**20 


In [106]:
# load and build torch dataset
data_file_path = '../data/IMDB.txt'
# checkpoint_path = "doc2vecC_lr0.001.pt"
print("Building dataset....")
dataset = CBowDataset(
                    data_file_path=data_file_path,
                    word2embedding=word2embedding,
                    max_length=None,
                    freq_threshold=20,
                    skip_header=False
                    )
print("Finish building dataset!")
print(f"Number of documents:{len(dataset.documents)}")
print(f"Number of words:{dataset.vocab_size}")

Building dataset....


Loading documents: 0it [00:00, ?it/s]

Preprocessing documents:   0%|          | 0/100000 [00:00<?, ?it/s]

Finish building dataset!
Number of documents:100000
Number of words:27961


In [107]:
BATCH_SIZE = 16

In [108]:
data_loader = DataLoader(
                        dataset, 
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        shuffle=True,
                        )

In [111]:
for doc_id,pos_w,neg_w in data_loader:
    print(doc_id,pos_w,neg_w)
    break

tensor([70663, 70663, 70663, 70663, 70663, 70663, 70663, 70663, 70663, 70663,
        70663, 70663, 70663, 70663, 70663, 70663]) tensor([   0,    1,    3,    4,    5,    6,    8,    9,   10,   11,   12,  526,
         527, 7183,   18,   19]) tensor([ 9370, 15830, 10243,  1979,  8193, 21789, 13163, 26057, 12130,  1440,
        17503, 21588, 12861, 11286,  2515, 11441])
