In [1]:
from fastai.vision.all import *
from fastai.learner import *
from fastai.data.all import *
from fastai.callback.tracker import SaveModelCallback
import pandas as pd
import matplotlib.pyplot as plt
from pathlib2 import Path
import numpy as np
import random
from torch.nn import MSELoss

You can download the ebooks [here](https://www.openslr.org/resources/12/original-books.tar.gz).

We will need the pretrained embeddings from https://github.com/iamyuanchung/speech2vec-pretrained-vectors. We will use them to compare our results and also to figure out what vocab the authors of the speech2vec paper used for training. Let us start with the latter.

In [2]:
import logging
from six import iteritems
from web.datasets.similarity import fetch_MEN, fetch_WS353, fetch_SimLex999
from web.embeddings import fetch_GloVe
from web.evaluate import evaluate_similarity
from web.embedding import Embedding, Vocabulary
from gensim.models import Word2Vec
from gensim.models import KeyedVectors



In [3]:
speech2vec = KeyedVectors.load_word2vec_format('../speech2vec-pretrained-vectors/speech2vec/50.vec', binary=False) 

In [4]:
vocab = set(speech2vec.vocab.keys())
len(vocab)

37622

In [5]:
starting_lines = {
    '1004/1004.txt.utf-8': 535,
    '10123/10123.txt.utf-8': 85,
    '10359/10359.txt.utf-8': 76,
    '10360/10360.txt.utf-8': 57,
    '10378/10378.txt.utf-8': 96,
    '10390/10390.txt.utf-8': 89,
    '1193/1193.txt.utf-8': 272,
    '12441/12441-0.txt': 101,
    '1249/1249.txt.utf-8': 614,
    '1325/1325.txt.utf-8': 434,
    '1674/1674.txt.utf-8': 360,
    '2046/2046.txt.utf-8': 295,
    '2147/2147.txt.utf-8': 66,
    '2184/2184.txt.utf-8': 302,
    '2383/2383.txt.utf-8': 408,
    '2486/2486.txt.utf-8': 293,
    '2488/2488.txt.utf-8': 495,
    '2512/2512-0.txt': 113,
    '2515/2515.txt.utf-8': 281,
    '2678/2678.txt.utf-8': 305,
    '2679/2679.txt.utf-8': 308,
    '269/269-0.txt': 67,
    '282/282-0.txt': 43,
    '2891/2891.txt.utf-8': 336,
    '3053/3053.txt.utf-8': 340,
    '3169/3169.txt.utf-8': 386,
    '325/325.txt.utf-8': 170,
    '3300/3300.txt.utf-8': 40,
    '34757/34757-0.txt': 185,
    '3604/3604.txt.utf-8': 612,
    '3623/3623.txt.utf-8': 342,
    '3697/3697.txt.utf-8': 449,
    '37660/37660-0.txt': 74,
    '4028/4028.txt.utf-8': 402,
    '4042/4042.txt.utf-8': 438,
    '435/435.txt.utf-8': 62,
    '6456/6456.txt.utf-8': 53,
    '7098/7098.txt.utf-8': 77,
    '76/76.txt.utf-8': 579,
    '778/778.txt.utf-8': 297,
    '786/786-0.txt': 158,
}   

In [6]:
%%time

word_pairs = []

for fn, starting_line in starting_lines.items():
    with open(f'data/books/LibriSpeech/books/utf-8/{fn}') as file:
        out = []
        for i, line in enumerate(file.readlines()):
            if i < starting_line: continue
            line = line.strip()
            toks = []

            for tok in line.split():
                tok = tok.lower()
                if tok in vocab:
                    toks.append(tok)
                else:
                    toks.append('<UNK>')
            out += toks

        for i, word in enumerate(out):
            if word is '<UNK>': continue
            for offset in [-2, -1, 1, 2]:
                if i + offset < 0 or i + offset >= len(out): continue
                target_word = out[i+offset]
                if target_word is '<UNK>': continue
                word_pairs.append([word, target_word])

CPU times: user 13.4 s, sys: 472 ms, total: 13.9 s
Wall time: 13.9 s


In [7]:
word_pairs[:20]

[['midway', 'upon'],
 ['midway', 'the'],
 ['upon', 'midway'],
 ['upon', 'the'],
 ['upon', 'journey'],
 ['the', 'midway'],
 ['the', 'upon'],
 ['the', 'journey'],
 ['the', 'of'],
 ['journey', 'upon'],
 ['journey', 'the'],
 ['journey', 'of'],
 ['journey', 'our'],
 ['of', 'the'],
 ['of', 'journey'],
 ['of', 'our'],
 ['of', 'life'],
 ['our', 'journey'],
 ['our', 'of'],
 ['our', 'life']]

In [8]:
len(word_pairs)

10140726

In [9]:
vocab = list(vocab)
np.random.shuffle(vocab)
word2index = {w: i for i, w in enumerate(vocab)}

In [53]:
class Dataset():
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        source_word, target_word = self.pairs[idx]
        if np.random.randn() < 0.5:
            return (word2index[source_word], word2index[target_word]), 1
        else:
            return (word2index[source_word], np.random.randint(len(vocab))), 0

In [54]:
train_ds = Dataset(word_pairs[:9_500_000])
valid_ds = Dataset(word_pairs[9_500_000:])

In [55]:
len(train_ds), len(valid_ds)

(9500000, 640726)

In [56]:
train_ds[0]

((8211, 11669), 1)

In [57]:
BS = 2048
NUM_WORKERS = 8

train_dl = DataLoader(train_ds, BS, NUM_WORKERS, shuffle=True)
valid_dl = DataLoader(valid_ds, BS, NUM_WORKERS)

dls = DataLoaders(train_dl, valid_dl)

In [66]:
class Model(Module):
    def __init__(self, hidden_size=50):
        self.embeddings =nn.Embedding(len(vocab), hidden_size)
            
    def forward(self, idxs):
        source_word_idx, target_word_idx = idxs
        source_embeddings = self.embeddings(source_word_idx)
        target_embeddings = self.embeddings(target_word_idx)
        return torch.sum(source_embeddings * target_embeddings, -1)

In [67]:
learn = Learner(
    dls.cuda(),
    Model().cuda(),
    loss_func=BCEWithLogitsLossFlat(),
    opt_func=Adam,
    metrics=[accuracy_multi]
)

In [None]:
NUM_EPOCHS = 120
learn.fit(NUM_EPOCHS, lr=1e-3, cbs=SaveModelCallback(fname='text_embeddings', every_epoch=True))

epoch,train_loss,valid_loss,accuracy_multi,time
0,0.653435,0.655603,0.788893,01:00
1,0.465147,0.470913,0.836815,01:04
2,0.388056,0.416087,0.848177,01:11
3,0.35234,0.388206,0.853021,00:59
4,0.330585,0.373828,0.853775,01:10
5,0.318017,0.362434,0.855701,01:16
6,0.305931,0.353485,0.857078,01:07
7,0.30063,0.348029,0.857473,01:07
8,0.292633,0.343652,0.858848,01:01
9,0.289131,0.337414,0.860736,01:06


## Evalute embeddings

In [None]:
from utils import Embeddings

In [None]:
embeddings = learn.model.embeddings.weight.cpu().detach().numpy()

In [None]:
e = Embeddings(
    embeddings,
    vocab
)

In [None]:
for w in ['fast', 'lost', 'small', 'true', 'crazy', 'slow']:
    print(f'{w}: {e.nn_words_to(e[w])}')

## Evaluating embeddings using [word-embeddings-benchmarks](https://github.com/kudkudak/word-embeddings-benchmarks)

In [None]:
import logging
from six import iteritems
from web.datasets.similarity import fetch_MEN, fetch_WS353, fetch_SimLex999
from web.embeddings import fetch_GloVe
from web.evaluate import evaluate_similarity
from web.embedding import Embedding, Vocabulary
from gensim.models import Word2Vec
from gensim.models import KeyedVectors

In [None]:
tasks = {
    "MEN": fetch_MEN(),
    "WS353": fetch_WS353(),
    "SIMLEX999": fetch_SimLex999()
}

In [None]:
our_embeddings = Embedding(
    Vocabulary(vocab),
    embeddings
)

speech2vec = KeyedVectors.load_word2vec_format('../speech2vec-pretrained-vectors/word2vec/50.vec', binary=False) 
speech2vec_embeddings = Embedding(Vocabulary(list(speech2vec.vocab.keys())), speech2vec.vectors)

In [None]:
for name, data in iteritems(tasks):
    print("Spearman correlation of scores on {} {}".format(name, evaluate_similarity(our_embeddings, data.X, data.y)))

In [None]:
for name, data in iteritems(tasks):
    print("Spearman correlation of scores on {} {}".format(name, evaluate_similarity(speech2vec_embeddings, data.X, data.y)))

## Loss decrease and improvements on semantic tasks as training progresses

In [None]:
%%capture

val_losses, accuracies, task_perf = [], [], []
for i in range(NUM_EPOCHS):
    learn.load(f'text_embeddings_{i}')
    loss, accuracy = learn.validate()
    val_losses.append(loss)
    accuracies.append(accuracy)
    
    embeddings = learn.model.embeddings.weight.cpu().detach().numpy()
    our_embeddings = Embedding(
        Vocabulary([w.lower() for w in vocab]),
        embeddings
    )

    task_perf.append([evaluate_similarity(our_embeddings, data.X, data.y) for name, data in iteritems(tasks)])

In [None]:
men, ws353, simlex999 = list(zip(*task_perf))

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, ax1 = plt.subplots()

ax1.plot(val_losses, label='val loss')

ax2 = ax1.twinx()
ax2.plot(men, label='MEN', c='g')
ax2.plot(ws353, label='WS353', c='m')
ax2.plot(simlex999, label='SIMLEX999', c='y')

ax1.legend(loc=[0.07, 0.9])
ax2.legend(loc=[0.7, 0.15])

ax1.set_xlabel('epochs');