In [1]:
import itertools
from collections import OrderedDict 
import re
import nltk
from nltk.corpus import brown, gutenberg
from nltk.probability import FreqDist
from nltk.corpus import stopwords

# corpus

In [2]:
gutenberg.fileids()

['austen-emma.txt',
 'austen-persuasion.txt',
 'austen-sense.txt',
 'bible-kjv.txt',
 'blake-poems.txt',
 'bryant-stories.txt',
 'burgess-busterbrown.txt',
 'carroll-alice.txt',
 'chesterton-ball.txt',
 'chesterton-brown.txt',
 'chesterton-thursday.txt',
 'edgeworth-parents.txt',
 'melville-moby_dick.txt',
 'milton-paradise.txt',
 'shakespeare-caesar.txt',
 'shakespeare-hamlet.txt',
 'shakespeare-macbeth.txt',
 'whitman-leaves.txt']

### corpus preprocessing

In [4]:
samples  =gutenberg.sents(gutenberg.fileids()[0])
pattern = re.compile("[A-Za-z]+")
stop_w =  set(stopwords.words('english'))
corpus = []
for sent in samples:
    sent = [w.lower() for w in sent]
    sent = [w for w in sent if w not in stop_w]
    sent = [w.replace('\n', ' ') for w in sent]
    sent = [w for w in sent if pattern.fullmatch(w)]
    if len(sent) > 5:
        corpus.append(sent)

In [5]:
fre_dist = FreqDist()
for sent in corpus:
    fre_dist.update(sent)
fre_dist = {k : v for k, v in fre_dist.items() if v > 5}

In [6]:
vocab_size = len(fre_dist)
idx_to_word = {idx: word for idx,  word in enumerate(fre_dist.keys())}
word_to_idx = {word: idx for idx, word in idx_to_word.items()}


### convert word to index 

In [7]:
corpus_indexed = [[word_to_idx[word] for word in sent if word in word_to_idx]for sent in corpus]
corpus_indexed = [sent for sent in corpus_indexed if len(sent) > 5]
fre_dist_indexed = {word_to_idx[w]: f for w, f in fre_dist.items()}

## Huffman Tree

In [9]:
import numpy as np

In [10]:
class HuffmanNode:
    def __init__(self, is_leaf, value=None, fre=0, left=None, right=None):
        self.is_leaf = is_leaf
        self.value = value
        self.fre = fre
        self.code = []
        self.code_len = 0
        self.node_path = []
        self.left = left
        self.right = right

The way builing huffman tree refer to c's original implement

In [11]:
class HuffmanTree:
    def __init__(self, fre_dict):
        self.root = None
        freq_dict = sorted(fre_dict.items(), key=lambda x:x[1], reverse=True)
        self.vocab_size = len(freq_dict)
        self.node_dict = {}
        self._build_tree(freq_dict)
    
    def _build_tree(self, freq_dict):
        # freq_dict is in decent order
        # node_list: two part: [leaf node :: internal node]
        # leaf node is in decent order; 
        node_list = [HuffmanNode(is_leaf=True, value=w, fre=fre) for w, fre in freq_dict]
        node_list += [HuffmanNode(is_leaf=False, fre=1e10) for i in range(self.vocab_size)]

        parentNode = [0] * (self.vocab_size * 2)  # only 2 * vocab_size - 2 be use
        binary = [0] * (self.vocab_size * 2)
        
        pos1 = self.vocab_size - 1
        pos2 = self.vocab_size
        
        # min2i is always larger than min1i
        min1i = 0
        min2i = 0
        for a in range(self.vocab_size - 1):
            if pos1 >= 0:
                if node_list[pos1].fre < node_list[pos2].fre:
                    min1i = pos1
                    pos1 -= 1
                else:
                    min1i = pos2
                    pos2 += 1
            else:
                min1i = pos2
                pos2 += 1
            
            if pos1 >= 0:
                if node_list[pos1].fre < node_list[pos2].fre:
                    min2i = pos1
                    pos1 -= 1
                else:
                    min2i = pos2
                    pos2 += 1
            else:
                min2i = pos2
                pos2 += 1
            
            node_list[self.vocab_size + a].fre = node_list[min1i].fre + node_list[min2i].fre
            node_list[self.vocab_size + a].left = node_list[min1i]
            node_list[self.vocab_size + a].right = node_list[min2i]
            
            parentNode[min1i] = self.vocab_size + a  # max index = 2 * vocab_size - 2
            parentNode[min2i] = self.vocab_size + a
            binary[min2i] = 1
        
        # generate huffman code
        for a in range(self.vocab_size):
            b = a
            i = 0
            code = []
            point = []
            # backtrace node from leaf to root
            while b != self.vocab_size * 2 - 2:   # trace path from current node until root node . 'root node index' = 2 * vocab_size - 2 
                code.append(binary[b])  
                b = parentNode[b]
                # point recording the path index from leaf node to root, the length of point is less one than the length of code
                point.append(b)
            
            
            node_list[a].code_len = len(code)
            node_list[a].code = list(reversed(code))
            
            # recording the index from root to leaf node, the actually index value should be shifted by self.vocab_size
            # In case of full binary tree, the number of non leaf node always is vocab_size - 1, 
            # The root node of BST in node_list is 2 * vocab_size - 2, and we shift vocab_size to get the actual index of root node: vocab_size - 2
            node_list[a].node_path = list(reversed([p - self.vocab_size for p in point]))
            
            self.node_dict[node_list[a].value] = node_list[a]
            
        self.root = node_list[2 * vocab_size - 2]
        
        
        

## CBOW + HS

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

### create dataset 

In [14]:
class CBOWDataset(torch.utils.data.Dataset):
    def __init__(self, corpus, windows_size=2, sentence_length_threshold=5):
        self.windows_size = windows_size
        self.sentence_length_threshold = sentence_length_threshold
        self.contexts, self.centers = self._generate_pairs(corpus, windows_size)
        
    def _generate_pairs(self, corpus, windows_size):
        contexts = []
        centers = []
        
        for sent in corpus:
            if len(sent) < self.sentence_length_threshold:
                continue
            
            for center_word_pos in range(len(sent)):
                context = []
                for w in range(-windows_size, windows_size + 1):
                    context_word_pos = center_word_pos + w
                    if(0 <= context_word_pos < len(sent) and context_word_pos != center_word_pos):
                        context.append(sent[context_word_pos])
                if(len(context) == 2 * self.windows_size):
                    contexts.append(context)
                    centers.append(sent[center_word_pos])
        return contexts, centers
    
    def __len__(self):
        return len(self.centers)
    
    def __getitem__(self, index):
        return np.array(self.contexts[index]), np.array([self.centers[index]])

### define network

In [15]:
class HierarchicalSoftmaxLayer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, freq_dict):
        super().__init__()
        ## in w2v c implement, syn1 initial with all zero
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.syn1 = nn.Embedding(
            num_embeddings=vocab_size + 1,
            embedding_dim=embedding_dim,
            padding_idx=vocab_size
            
        )
        torch.nn.init.constant_(self.syn1.weight.data, val=0)
        self.huffman_tree = HuffmanTree(freq_dict)

    def forward(self, neu1, target):
        # neu1: [b_size, embedding_dim]
        # target: [b_size, 1]
        
        # turns:[b_size, max_code_len_in_batch]
        # paths: [b_size, max_code_len_in_batch]
        turns, paths = self._get_turns_and_paths(target)
        paths_emb = self.syn1(paths) # [b_size, max_code_len_in_batch, embedding_dim]

   
        loss = -F.logsigmoid(
            (turns.unsqueeze(2) * paths_emb * neu1.unsqueeze(1)).sum(2)).sum(1).mean()
        return loss



    
    def _get_turns_and_paths(self, target):
        turns = []  # turn right(1) or turn left(-1) in huffman tree
        paths = []
        max_len = 0
        for n in target:
            n = n.item()
            node = self.huffman_tree.node_dict[n]
            
            code = target.new_tensor(node.code).int()  # in code, left node is 0; right node is 1
            turn = torch.where(code == 1, code, -torch.ones_like(code))
            
            turns.append(turn)
            paths.append(target.new_tensor(node.node_path))
            
            if node.code_len > max_len:
                max_len = node.code_len
        
        
        turns = [F.pad(t, pad=(0, max_len - len(t)), mode='constant', value=0) for t in turns] 
        paths = [F.pad(p, pad=(0, max_len - p.shape[0]), mode='constant', value=net.hs.vocab_size) for p in paths]
        return torch.stack(turns).int(), torch.stack(paths).long()
    

In [19]:
class CBOWHierarchicalSoftmax(nn.Module):
    def __init__(self, vocab_size, embedding_dim, fre_dict):
        super().__init__()
        self.syn0 = nn.Embedding(vocab_size, embedding_dim)
        self.hs = HierarchicalSoftmaxLayer(vocab_size, embedding_dim, fre_dict)
        torch.nn.init.xavier_uniform_(self.syn0.weight.data)
    
    def forward(self, context, target):
        # context: [b_size, 2 * window_size]
        # target: [b_size]
        neu1 = self.syn0(context.long()).mean(dim=1)  # [b_size, embedding_dim]
        loss = self.hs(neu1, target.long())
        return loss
    

## training

In [18]:
data_set = CBOWDataset(corpus_indexed)
data_loader = DataLoader(data_set, batch_size=100, num_workers=0)

In [20]:
embedding_dim = 50
net = CBOWHierarchicalSoftmax(vocab_size, embedding_dim, fre_dist_indexed)
optimizer = optim.Adam(net.parameters(), lr=0.001,  weight_decay=1e-6)

In [21]:
log_interval = 100
for epoch_i in range(20):
    total_loss = 0
    net.train()
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (context, center) in enumerate(tk0):

        loss = net(context, center)
        net.zero_grad()
        loss.backward()
        
        optimizer.step()

        total_loss += loss.item()
        if(i + 1) % log_interval == 0:
            tk0.set_postfix(loss = total_loss/log_interval)
            total_loss = 0
        

100%|██████████| 361/361 [00:06<00:00, 56.96it/s, loss=8.98]
100%|██████████| 361/361 [00:05<00:00, 69.30it/s, loss=8.86]
100%|██████████| 361/361 [00:05<00:00, 69.50it/s, loss=8.77]
100%|██████████| 361/361 [00:05<00:00, 69.73it/s, loss=8.7]
100%|██████████| 361/361 [00:05<00:00, 70.27it/s, loss=8.63]
100%|██████████| 361/361 [00:05<00:00, 71.61it/s, loss=8.56]
100%|██████████| 361/361 [00:05<00:00, 60.62it/s, loss=8.49]
100%|██████████| 361/361 [00:07<00:00, 48.11it/s, loss=8.42]
100%|██████████| 361/361 [00:07<00:00, 50.28it/s, loss=8.35]
100%|██████████| 361/361 [00:05<00:00, 70.87it/s, loss=8.29]
100%|██████████| 361/361 [00:05<00:00, 65.54it/s, loss=8.22]
100%|██████████| 361/361 [00:06<00:00, 57.14it/s, loss=8.16]
100%|██████████| 361/361 [00:06<00:00, 55.74it/s, loss=8.1]
100%|██████████| 361/361 [00:05<00:00, 67.17it/s, loss=8.04]
100%|██████████| 361/361 [00:05<00:00, 68.77it/s, loss=7.98]
100%|██████████| 361/361 [00:06<00:00, 57.39it/s, loss=7.93]
100%|██████████| 361/361 [