In [1]:
from collections import Counter
import math
from pathlib import Path
from typing import List
import re
import string
import wget
import zipfile

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils

from fastai.basic_train import Learner
from fastai.train import lr_find
from fastai.basic_data import DataBunch
from fastai.metrics import accuracy, accuracy_thresh
from fastprogress import fastprogress

from nn_toolkit.vocab import Vocab, VocabEncoder

%load_ext autoreload
%autoreload 2
%matplotlib notebook

### Get data

In [2]:
def download_wikitext(local_path: Path) -> None:
    if local_path.exists(): return
    url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
    print(f'Downloading {url}')
    print(f'Saving to {local_path}')
    list(local_path.parents)[0].mkdir(exist_ok=True, parents=True)
    wget.download(
        url,
        str(local_path)
    )

def get_token_string(zip_path, mode='train'):
    assert mode in ('train', 'valid', 'test')
    with zipfile.ZipFile(zip_path) as zfo:
        dirname = zfo.namelist()[0]
        path = Path(f'{dirname}') / f'wiki.{mode}.tokens'
        tokens = zfo.open(str(path)).read().decode()
        return tokens
    
def clean_token_string(text: str) -> str:
    text = text.lower()
    text = re.sub('\n( =){1,}.*?( =){1,} \n', '', text)
    #text = re.sub('(\n =|= \n)', '', text)
    text = re.sub(r'[ \n]{1,}', ' ', text)
    return text

def tokenize(text: str) -> List[str]:
    return text.split()

def remove_stopwords(tokens: List[str]) -> List[str]:
    return [t for t in tokens if t not in set(stopwords.words('english'))]

def prepare_data(zip_path, mode):
    token_string = get_token_string(zip_path, mode)
    token_string = clean_token_string(token_string)
    sents = [remove_stopwords(tokenize(sent) for sent in sent_tokenize(token_string)]
    return sents

def token_pairs(tokens, window: int):
    slow, N = 0, len(tokens)
    while slow < len(tokens):
        target_token = tokens[slow]
        left_edge = max(0, slow-window)
        right_edge = min(N, slow+window)
        context = tokens[left_edge: slow] + tokens[slow+1:right_edge]
        for context_token in context:
            yield target_token, context_token
        slow += 1

### Dataset

In [99]:
class Dataset(utils.data.Dataset):
    def __init__(self, sents, vocab_encoder, window: int=2, cbow: bool = False, t: float=1e-3, ns: int=5) -> None:
        self.sents = sents
        self.vocab_encoder = vocab_encoder
        self.window = window
        self.cbow = cbow
        self.t = t
        self.ns = ns
        self._prepare_data()

    def _prepare_data(self) -> None:
        self.encoded_sents = [self.vocab_encoder.encode_sequence(sent, add_specials=False) for sent in self.sents]
        self.tokens = [t for sent in self.encoded_sents for t in sent]
        #self.char_tokens = [list(t) for sent in self.sents for t in sent]
        self._build_p_table()
        self.sent_lengths = list(map(len, self.encoded_sents))
                
    def _build_p_table(self):
        self.token_freq = Counter(self.tokens)
        self.N = sum(self.token_freq.values())
        self.token_p = {k: self.get_p(v/self.N) for k, v in self.token_freq.items()}
        self.N_pow = sum([c**0.75 for c in self.token_freq.values()])
        self.probs = torch.tensor([
            self.token_freq[i]**0.75/self.N_pow for i in range(self.vocab_encoder.size)
        ])
        
    def get_p(self, p: float):
        return min( (math.sqrt(p/self.t) + 1) * self.t/p, 1)
    
    def get_context(self, seq, idx, window):
        left_edge = max(0, idx-window)
        left = seq[left_edge: idx]
        right_edge = min(len(seq), idx+window+1)
        right = seq[idx+1: right_edge]
        return left+right
    
    def drop_context(self, seq: List[int]) -> List[int]:
        return [idx for idx in seq if np.random.rand()<=self.token_p[idx]]
    
    def ns_mask(self, label):
        mask = torch.empty(self.ns+1, dtype=torch.int64)
        old_p = self.probs[label]
        self.probs[label] = 0
        mask[:-1] = torch.multinomial(self.probs, self.ns, replacement=False)
        mask[-1] = label
        self.probs[label] = old_p
        return mask
    
    def _negative_sample(self, sample: dict) -> dict:
        target, context = sample['target'], sample['context']
        sample['context'] = context.repeat(1+self.ns, 1)
        fake_targets = self.get_random_target(self.ns * target.size(0)).view(self.ns, -1)
        sample['target'] = torch.cat([target.unsqueeze(0), fake_targets])
        sample['label'] = torch.zeros(1+self.ns)
        sample['label'][0] = 1
        return sample
    
    def get_random_target(self, n):
        return torch.multinomial(self.probs, n, replacement=False)
        
    def _cbow(self, target, context):
        return {'context': context, 'target': target}
    
    def _skipgram(self, target: int, context: List[int]):
        return {'context': target, 'target': context}
    
    def _pad(self, seq: List[int], maxlen: int) -> List[int]:
        assert len(seq) <= maxlen
        diff = maxlen - len(seq)
        return seq + [self.vocab_encoder.pad_index] * diff
    
    def __getitem__(self, idx):
        target = torch.tensor([self.tokens[idx]])
        context = self.get_context(self.tokens, idx, self.window)
        #context = self.drop_context(context)
        context = self._pad(context, 2*self.window)
        context = torch.tensor(context)
        if self.cbow: sample = self._cbow(target, context)
        else: sample = self._skipgram(target, context)
        return self._negative_sample(sample)
    
    def __len__(self):
        if getattr(self, '_length', None) is None:
            self._length = sum(map(len, self.sents))
        return self._length

### Sampler

In [4]:
class Sampler(utils.data.Sampler):
    def __init__(self, dataset, shuffle: bool=False) -> None:
        self.t = dataset.t
        self.counter = dataset.token_freq
        self.N = sum(self.counter.values())
        self.f_table = {i: self.counter[t]/self.N for i, t in enumerate(dataset.tokens)}
        self.p = [self.get_p(f) for f in self.f_table.values()]
        self.shuffle = shuffle
        
    def get_p(self, f: float):
        return min( (math.sqrt(f/self.t) + 1) * self.t/f, 1)
    
    def __iter__(self):
        idxs = np.arange(len(self))
        if self.shuffle: np.random.shuffle(idxs)
        for i in idxs:
            if self.shuffle and np.random.rand() >= self.p[i]:
                continue
            else:
                yield i

    def __len__(self):
        return len(self.p)

In [188]:
class TokenDistribution:
    def __init__(self, tokens, vocab):
        self.counter = Counter(tokens)
        self.vocab = vocab
        self.vocab_counts = [self.counter[self.vocab.int_to_token[i]] for i in range(self.vocab.size)]
        self.N = sum(self.counter.values())
        self.pow = 0.75
        self.t = 1e-3
        
    @property
    def p(self):
        if getattr(self, '_p', None) is None:
            self._p = self.compute_distribution()
        return self._p
    
    def get_p(self, f: float):
        f += 1e-9
        return min( (math.sqrt(f/self.t) + 1) * self.t/f, 1)

    def compute_distribution(self):
        D = torch.tensor([self.get_p(c/self.N) for c in self.vocab_counts])
        return D#/D.sum()
    
    def resample(self, tokens):
        return [t for t in tokens if np.random.rand() <= self.p[self.vocab[t]].item()]
                

### Data loader

In [6]:
def collate_batch(batch: List[torch.Tensor]):
    targets = torch.cat([sample['target'] for sample in batch], 0)
    contexts = torch.cat([sample['context'] for sample in batch], 0)
    labels = torch.cat([sample['label'] for sample in batch], 0)
    #mask = torch.stack([sample['mask'] for sample in batch])
    #labels = targets.view(-1)
    #labels = torch.tensor([mask.size(1)-1 for _ in range(mask.size(0))])
    return {'context': contexts, 'target': targets}, labels

### Model

In [7]:
class Model(nn.Module):
    def __init__(self, vocab_size, embedding_dim) -> None:
        super().__init__()
        self.c_layer = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.t_layer = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.eps = 1e-9

    def forward(self, X):
        context, target = X['context'], X['target']
        c_emb = self.c_layer(context)  # (B, c, e)
        c_mask = self._pad_mask(context)
        c_size = (~c_mask).sum(1, keepdim=True) + self.eps
        c_emb = c_emb.sum(1) / c_size
        
        t_emb = self.t_layer(target)  # (B, t, e)
        t_mask = self._pad_mask(target)
        t_size = (~t_mask).sum(1, keepdim=True) + self.eps
        t_emb = t_emb.sum(1) / t_size
        logit = (c_emb * t_emb).sum(1)
        return logit

    def _pad_mask(self, X):
        return X == 0
    
    def _apply_mask(self, W, idxs):
        return W.gather(1, idxs)
        B, device = idxs.size(0), logit.device
        mask = torch.ones(B, self.W.size(1), device=device)
        mask.scatter_(1, idxs, 0)
        return logit - 1e9*mask

### Do it

In [194]:
text_file = Path('data/raw/wikitext-2-v1.zip')
download_wikitext(text_file)

train_sents = prepare_data(text_file, mode='train')
val_sents = prepare_data(text_file, mode='valid')

train_tokens = [t for sent in train_sents for t in sent]
vocab = Vocab(set(train_tokens))
token_dist = TokenDistribution(train_tokens, vocab)

train_sents = [token_dist.resample(sent) for sent in train_sents]
val_sents = [token_dist.resample(sent) for sent in val_sents]

vocab_encoder = VocabEncoder(vocab)

In [196]:
cbow = True
window = 3
bs = 4096
ns = 5
t = 1e-3

train_ds = Dataset(train_sents, vocab_encoder, window=window, cbow=cbow, t=t, ns=ns)
#train_sampler = Sampler(train_ds, shuffle=True)
train_sampler = None
train_dl = utils.data.DataLoader(train_ds, batch_size=bs, collate_fn=collate_batch, num_workers=3, sampler=train_sampler)

val_ds = Dataset(val_sents, vocab_encoder, window=window, cbow=cbow, t=t, ns=ns)
#val_sampler = Sampler(val_ds, shuffle=False)
val_sampler = None
val_dl = utils.data.DataLoader(val_ds, batch_size=bs, collate_fn=collate_batch, num_workers=3, sampler=val_sampler)

In [197]:
data = DataBunch(train_dl, val_dl, device='cuda:0', collate_fn=collate_batch)

In [198]:
for X, y in data.valid_dl:
    print('context:', X['context'].size())
    print('target:', X['target'].size())
    print('y:', y.size())
    break

context: torch.Size([24576, 6])
target: torch.Size([24576, 1])
y: torch.Size([24576])


In [199]:
model = Model(vocab_encoder.size, 200)
model.to(data.device)
learner = Learner(data, model, loss_func=nn.BCEWithLogitsLoss(), opt_func=optim.SGD, metrics=[accuracy_thresh])

In [202]:
learner.fit(5, lr=0.025)

epoch,train_loss,valid_loss,accuracy_thresh,time
0,1.097678,1.08303,0.499336,01:47
1,0.996174,0.986338,0.499917,01:47
2,0.920366,0.910609,0.500326,01:47
3,0.859932,0.854259,0.499887,01:47
4,0.814375,0.811614,0.498855,01:47


In [203]:
with torch.no_grad():
    W = model.c_layer.weight
    vec1 = W[vocab['king']]
    #vec1 = vec1 / torch.norm(vec1)
    #norm = torch.norm(W, dim=1) + model.eps
    norm = 1
    sims = vec1 @ W.t() / norm
    most_sim = torch.argsort(sims)
    topk = most_sim[-100:].tolist()[::-1]
    top_tokens = [vocab_encoder.int_to_token[i] for i in topk]
    print(top_tokens[0:10])

['king', 'incentive', 'exteriors', 'schubert', 'inactivation', 'copper', 'heliport', 'oppression', 'molds', 'dreams']


### Baseline gensim

In [118]:
from gensim.models import Word2Vec

In [121]:
g_model = Word2Vec(
    train_sents, 
    size=model.c_layer.embedding_dim, 
    window=3, 
    sg=int(not cbow),
    negative=ns,
    sample=t,
    workers=8,
    compute_loss=True
)

In [122]:
g_model.most_similar('king')

  """Entry point for launching an IPython kernel.


[('edward', 0.8391004800796509),
 ('queen', 0.8382359743118286),
 ('henry', 0.8131130933761597),
 ('lord', 0.7869649529457092),
 ('bishop', 0.7864221334457397),
 ('william', 0.7800500988960266),
 ('pope', 0.7752436399459839),
 ('james', 0.7749444246292114),
 ('elizabeth', 0.774777889251709),
 ('charles', 0.7678619623184204)]

In [125]:
g_model.get_latest_training_loss()

2427534.0

In [139]:
g_model.min_count

  """Entry point for launching an IPython kernel.


5