In [1]:
from collections import Counter, defaultdict
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import itertools
import math
from pathlib import Path
from typing import List, Tuple, Union

from jupyterthemes import jtplot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils

from fastai.basic_data import DataBunch
from fastai.basic_train import Learner
from fastai.train import lr_find, fit_one_cycle
from fastprogress import progress_bar

from utils import download_wikitext, prepare_data, load_glove
from nn_toolkit.vocab import Vocab, VocabEncoder

jtplot.style()
%load_ext autoreload
%matplotlib notebook
%autoreload 2

### Helpers

#### Cooccurence Matrix

In [2]:
class Cooccurence:
    def __init__(self, vocab: Vocab, window: int, distance_mode: str, num_workers: int) -> None:
        self.vocab = vocab
        self.window = window
        self.distance_mode = distance_mode
        self.counts = defaultdict(lambda: defaultdict(float))
        self.num_workers = num_workers

    def dict_to_df(self):
        data = []
        for token1, token_counts in self.counts.items():
            for token2, count in token_counts.items():
                data.append({'token1': token1, 'token2': token2, 'count': count})
        del self.counts
        df = pd.DataFrame(data)
        index = pd.MultiIndex.from_frame(df[['token1', 'token2']])
        return pd.DataFrame(df['count'].values, index=index, columns=['count'])

    def update(self, documents: List[List[str]]) -> None:
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            futures = [executor.submit(self.add_doc, doc) for doc in documents]
            pbar = progress_bar(concurrent.futures.as_completed(futures), total=len(documents))
            for f in pbar:
                f.result()

    def add_doc(self, doc: List[str]) -> None:
        idx = 0
        while idx < len(doc):
            left = self.get_left_context(doc, idx)
            right = self.get_right_context(doc, idx)
            self.add_context(left, doc[idx], 'left')
            self.add_context(right, doc[idx], 'right')
            idx += 1
            assert len(self.counts) <= self.vocab.size

    def add_context(self, context: List[str], target: str, mode: str) -> None:
        if target not in self.vocab: target = '<unk>'
        for i, token in enumerate(context):
            if token not in self.vocab: token = '<unk>'
            d = (i + 1) if mode == 'right' else (self.window - i)
            self.counts[target][token] += self.distance_weight(d)
        return

    def get_right_context(self, doc: List[str], idx: int) -> List[str]:
        right_edge = min(len(doc), idx+self.window+1)
        right = doc[idx+1: right_edge]
        return right

    def get_left_context(self, doc: List[str], idx: int) -> List[str]:
        left_edge = max(0, idx-self.window)
        left = doc[left_edge: idx]
        return left

    def distance_weight(self, d: int):
        if self.distance_mode == 'inverse':
            return 1. / d
        return 1.

    def _get_from_str(self, key: str):
        return self.counts[key]

    def _get_from_int(self, key: int):
        key = self.vocab.int_to_token[key]
        return self._get_from_str(key)

    def __getitem__(self, key: Union[str, int]):
        if isinstance(key, int):
            return self._get_from_int(key)
        elif isinstance(key, str):
            return self._get_from_str(key)
        else:
            t = type(key)
            raise TypeError(f"Can not look up using type {t}")

#### Dataset

In [3]:
class CooDataset(utils.data.Dataset):
    def __init__(self, coo: Cooccurence) -> None:
        self.coo = coo
        self.Xij = coo.dict_to_df()
        self.vocab = coo.vocab
        
        self._prep_df()
        
    def _prep_df(self):
        token1 = pd.Series(self.Xij.index.get_level_values('token1'))
        token2 = pd.Series(self.Xij.index.get_level_values('token2'))
        self.i = [self.vocab[t] for t in token1]
        self.j = [self.vocab[t] for t in token2]
        self.counts = self.Xij['count'].tolist()

    def __getitem__(self, idx: int):
        return self.i[idx], self.j[idx], self.counts[idx]

    def __len__(self):
        return self.Xij.shape[0]

def collate_batch(batch) -> Tuple[torch.Tensor]:
    N = len(batch)
    i = torch.empty(N, dtype=torch.int64)
    j = torch.empty(N, dtype=torch.int64)
    Xij = torch.empty(N, dtype=torch.float32)
    for k, sample in enumerate(batch):
        i[k], j[k], Xij[k] = sample
    return (i, j), Xij

#### Model

In [4]:
class Glove(nn.Module):
    def __init__(self, vocab: Vocab, embedding_dim: int) -> None:
        super().__init__()
        self.vocab = vocab
        self.embedding_dim = embedding_dim
        lim = embedding_dim
        self.W = nn.Parameter(self._W_init())
        self.Wt = nn.Parameter(self._W_init())
        self.b = nn.Parameter(self._b_init())
        self.bt = nn.Parameter(self._b_init())

    def forward(self, i: torch.LongTensor, j: torch.LongTensor) -> torch.Tensor:
        wi = F.embedding(i, self.W, padding_idx=self.vocab.pad_index)  # (N, e)
        bi = F.embedding(i, self.b, padding_idx=self.vocab.pad_index).squeeze()  # (N,)
        wj = F.embedding(j, self.Wt, padding_idx=self.vocab.pad_index)  # (N, e)
        bj = F.embedding(j, self.bt, padding_idx=self.vocab.pad_index).squeeze()  # (N,)

        dot = (wi * wj).sum(1)  # (N,)
        p = dot + bi + bj
        return p

    def _W_init(self) -> torch.Tensor:
        size = [self.vocab.size, self.embedding_dim]
        W = torch.empty(*size).uniform_() - 0.5
        W = W / math.sqrt(self.embedding_dim)
        return W

    def _b_init(self) -> torch.Tensor:
        size = [self.vocab.size, 1]
        b = torch.empty(*size).uniform_() - 0.5
        b = b / math.sqrt(self.embedding_dim)
        return b

    def __getitem__(self, token: str):
        idx = self.vocab[token]
        return self.W[idx]

    def similarity(self, token1: str, token2: str) -> float:
        with torch.no_grad():
            vec1 = self[token1]
            vec2 = self[token2]
            norm = torch.norm(vec1) * torch.norm(vec2)
            dot = (vec1 * vec2).sum() / norm
        return dot.item()
    
    def analogy(self, token1: str, token2: str, token3: str, k: int = 5) -> str:
        """Token1 - Token2 + Token3"""
        with torch.no_grad():
            vec1 = self[token1]
            vec2 = self[token2]
            vec3 = self[token3]
            diff = vec1 - vec2 + vec3
            diff_norm = torch.norm(diff)
            w_norm = torch.norm(self.W, dim=1, keepdim=True)
            similarity  = self.W @ diff_norm / (w_norm * diff_norm)
            topk = torch.topk(similarity, k)
            idx = topk.indices.tolist()[0:k]
            vals = topk.values.tolist()[0:k]
        words = [self.vocab.get(i, reverse=True) for i in idx]
        return list(zip(words, vals))

    def most_similar(self, token: str, k: int = 10) -> List[Tuple[str, float]]:
        with torch.no_grad():
            vec = self[token].unsqueeze(0)  # (1, e)
            dot = (vec @ self.W.t()).squeeze()  # (N, )
            norm = torch.norm(vec, dim=1) * torch.norm(self.W, dim=1)
            norm = torch.clamp(norm, 1e-9, float('inf'))
            similarity = dot / norm
            topk = torch.topk(similarity, k)
        idxs = topk.indices.tolist()[0:k]
        vals = topk.values.tolist()[0:k]
        words = [self.vocab.get(i, reverse=True) for i in idxs]
        return list(zip(words, vals))
    
    def to_text(self) -> str:
        """Represent model as glove txt format."""
        lines = []
        for token in self.vocab:
            vector = self[token].tolist()
            vec_str = ' '.join(map(str, vector))
            line = f'{token} {vec_str}'
            lines.append(line)
        return '\n'.join(lines)
    
    def save_as_text(self, path: Union[Path, str]) -> None:
        with open(path, 'w') as fw:
            fw.write(self.to_text())
        return

#### Loss

In [5]:
class GloveLoss(nn.modules.loss._Loss):
    """
    Weighted MSE summed, not averaged, over the vocabulary.
    
    Defaults are set to values in the paper.
    https://github.com/stanfordnlp/GloVe/blob/master/src/glove.c#L55
    https://nlp.stanford.edu/pubs/glove.pdf
    
    Parameters
    ----------
    xmax : int
        threshold to apply maximum weighting in the loss function, by default 100
    alpha : float
        power in weighting function, by default 0.75
    """

    def __init__(self, xmax: int = 100, alpha: float = 0.75) -> None:
        super().__init__()
        self.xmax = xmax
        self.alpha = alpha
        
    def forward(self, y_hat, Xij) -> torch.Tensor:
        w = self.f(Xij)
        mse = self.mse(y_hat, Xij)
        assert mse.size(0) == y_hat.size(0)
        return (w * mse).sum()
        
    def f(self, Xij:torch.LongTensor) -> torch.FloatTensor:
        w = Xij / self.xmax
        return torch.clamp(w, 0., 1.).pow(self.alpha)
    
    def mse(self, y_hat: torch.FloatTensor, Xij: torch.LongTensor) -> torch.FloatTensor:
        Xij = Xij.to(torch.float32)
        return F.mse_loss(y_hat, torch.log(Xij + 1.), reduction='none')

### Run the Training

#### Load data

In [6]:
seed = 0
text_file = Path('data/raw/wikitext-2-v1.zip')
download_wikitext(text_file)

train_sents = prepare_data(text_file, mode='train', sampling_rate=1., seed=seed)
val_sents = prepare_data(text_file, mode='valid', sampling_rate=1., seed=seed)

HBox(children=(IntProgress(value=0, max=78107), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8196), HTML(value='')))




In [7]:
train_tokens = Counter([t for sent in train_sents for t in sent])
max_vocab = 60000
vocab = Vocab([t for t, _ in train_tokens.most_common(max_vocab)])
print('Vocab size:', vocab.size)

Vocab size: 28914


##### Cooccurence Counts

In [8]:
window = 15
num_workers = 6
d_mode = 'inverse'

train_coo = Cooccurence(vocab=vocab, window=window, distance_mode=d_mode, num_workers=num_workers)
val_coo = Cooccurence(vocab=vocab, window=window, distance_mode=d_mode, num_workers=num_workers)
train_coo.update(train_sents)
val_coo.update(val_sents)

##### Init Dataset

In [9]:
train_ds = CooDataset(train_coo)
val_ds = CooDataset(val_coo)

In [10]:
bs = 128
pm = False

train_dl = utils.data.DataLoader(
    train_ds, batch_size=bs, shuffle=True, num_workers=num_workers, pin_memory=pm
)
val_dl = utils.data.DataLoader(
    val_ds, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=pm
)

data = DataBunch(
    train_dl=train_dl, valid_dl=val_dl,
    collate_fn=collate_batch,
    device='cuda:0'
)

print('If distance mode is not inverse, the following values are the number of words.')
print('Total weight in train:', train_ds.Xij['count'].sum())
print('Total weight in val:', val_ds.Xij['count'].sum())

If distance mode is not inverse, the following values are the number of words.
Total weight in train: 9359889.410661587
Total weight in val: 973996.0327478057


#### Init model

In [11]:
model = Glove(vocab, 300)
model.to(data.device)
learner = Learner(
    data,
    model,
    loss_func=GloveLoss(xmax=100, alpha=0.75),
    opt_func=optim.AdamW
)

### Training

#### Training loop

In [None]:
# glove code hyperparams
# lr = 0.05
# num_epochs = 25
# glove code uses lr of 0.05 
# https://github.com/stanfordnlp/GloVe/blob/master/src/glove.c
maxlr = 0.05
minlr = 5e-5
cycles = 2
epochs = 25

learner.fit_one_cycle(cycles, max_lr=maxlr)
learner.recorder.plot_lr()
learner.fit(epochs, lr=minlr)

epoch,train_loss,valid_loss,time
0,3.660849,4.832432,06:16


epoch,train_loss,valid_loss,time
0,5.297681,4.003604,06:14
1,3.477157,3.811065,06:12
2,2.780805,3.686398,06:16
3,2.172549,3.560981,06:14


#### Qualitative Evaluation

In [16]:
print(model.similarity('boy', 'ball'))
print(model.similarity('girl', 'ball'))

0.010557662695646286
0.04428711533546448


In [18]:
model.most_similar('king', 5)

[('king', 0.9999998807907104),
 ('henry', 0.45483022928237915),
 ('william', 0.39781129360198975),
 ('charles', 0.3856470584869385),
 ('john', 0.37961480021476746)]

In [19]:
model.most_similar('queen', 5)

[('queen', 1.0),
 ('holy', 0.3933846652507782),
 ('dominican', 0.3632296025753021),
 ('lack', 0.35511961579322815),
 ('ultimately', 0.35221707820892334)]

In [20]:
model.most_similar('girl', 5)

[('girl', 0.9999998807907104),
 ('ellington', 0.3397264778614044),
 ('pdr', 0.30239516496658325),
 ('boy', 0.29481378197669983),
 ('woman', 0.28253433108329773)]

In [21]:
model.most_similar('boy', 5)

[('boy', 1.0),
 ('child', 0.3289439380168915),
 ('cold', 0.3127919137477875),
 ('lizard', 0.29986217617988586),
 ('girl', 0.29481378197669983)]

In [28]:
model.analogy('man', 'woman', 'king')

[('king', 0.0034117086324840784),
 ('man', 0.002597410697489977),
 ('filmography', 0.0017281401669606566),
 ('john', 0.0015856641111895442),
 ('lottery', 0.0015129103558138013)]

#### Quantitative Evaluation

In [23]:
from evaluate.evaluate import main

In [24]:
model_name = f'torch_glove.{model.embedding_dim}d.txt'
model.save_as_text(f'models/{model_name}')

In [25]:
main(f'models/{model_name}')

  W_norm = (W.T / d).T


capital-common-countries.txt:
ACCURACY TOP1: 0.00% (0/420)
capital-world.txt:
ACCURACY TOP1: 0.00% (0/758)
currency.txt:
ACCURACY TOP1: 0.00% (0/70)
city-in-state.txt:
ACCURACY TOP1: 0.00% (0/1114)
family.txt:
ACCURACY TOP1: 0.00% (0/342)
gram1-adjective-to-adverb.txt:
ACCURACY TOP1: 0.00% (0/702)
gram2-opposite.txt:
ACCURACY TOP1: 0.00% (0/272)
gram3-comparative.txt:
ACCURACY TOP1: 0.00% (0/1056)
gram4-superlative.txt:
ACCURACY TOP1: 0.00% (0/506)
gram5-present-participle.txt:
ACCURACY TOP1: 0.00% (0/870)
gram6-nationality-adjective.txt:
ACCURACY TOP1: 0.00% (0/1160)
gram7-past-tense.txt:
ACCURACY TOP1: 0.00% (0/1482)
gram8-plural.txt:
ACCURACY TOP1: 0.00% (0/756)
gram9-plural-verbs.txt:
ACCURACY TOP1: 0.00% (0/702)
Questions seen/total: 52.24% (10210/19544)
Semantic accuracy: 0.00%  (0/2704)
Syntactic accuracy: 0.00%  (0/7506)
Total accuracy: 0.00%  (0/10210)


In [26]:
main('data/glove/glove.6B.300d.txt')

capital-common-countries.txt:
ACCURACY TOP1: 0.00% (0/506)
capital-world.txt:
ACCURACY TOP1: 0.00% (0/4524)
currency.txt:
ACCURACY TOP1: 0.00% (0/866)
city-in-state.txt:
ACCURACY TOP1: 0.00% (0/2467)
family.txt:
ACCURACY TOP1: 0.00% (0/506)
gram1-adjective-to-adverb.txt:
ACCURACY TOP1: 0.00% (0/992)
gram2-opposite.txt:
ACCURACY TOP1: 0.00% (0/812)
gram3-comparative.txt:
ACCURACY TOP1: 0.00% (0/1332)
gram4-superlative.txt:
ACCURACY TOP1: 0.00% (0/1122)
gram5-present-participle.txt:
ACCURACY TOP1: 0.00% (0/1056)
gram6-nationality-adjective.txt:
ACCURACY TOP1: 0.00% (0/1599)
gram7-past-tense.txt:
ACCURACY TOP1: 0.00% (0/1560)
gram8-plural.txt:
ACCURACY TOP1: 0.00% (0/1332)
gram9-plural-verbs.txt:
ACCURACY TOP1: 0.00% (0/870)
Questions seen/total: 100.00% (19544/19544)
Semantic accuracy: 0.00%  (0/8869)
Syntactic accuracy: 0.00%  (0/10675)
Total accuracy: 0.00%  (0/19544)


### Pretrained glove

In [None]:
tokens, W = load_glove('data/glove/glove.6B.300d.txt')
glove_vocab = Vocab(tokens, add_specials=False)

In [None]:
glove = Glove(glove_vocab, W.size(1))

In [None]:
glove.most_similar('king')

In [None]:
glove.most_similar('queen')

In [None]:
glove.most_similar('girl')

In [None]:
glove.most_similar('boy')

In [None]:
glove.analogy('king', 'queen', 'man')

In [None]:
torch.norm(glove[''])