In [1]:
import random
from functools import reduce
from collections import Counter
import numpy as np
import pandas as pd
import torch
import scipy.stats
import sklearn.metrics
import plotly.graph_objects as go
from tqdm.auto import tqdm, trange
from word2mat import Dataset, Sentence

# Data

In [2]:
full_dataset = Dataset.from_huggingface("wiki_auto", config="manual", column="normal_sentence")
len(full_dataset)

Reusing dataset wiki_auto (/home/gilles/.cache/huggingface/datasets/wiki_auto/manual/1.0.0/5ffdd9fc62422d29bd02675fb9606f77c1251ee17169ac10b143ce07ef2f4db8)


HBox(children=(FloatProgress(value=0.0, description='Parsing dataset', max=373801.0, style=ProgressStyle(descr…




373801

In [3]:
clean_dataset = full_dataset.clean
len(clean_dataset)

HBox(children=(FloatProgress(value=0.0, description='Reading word vectors', max=4027169.0, style=ProgressStyle…




329155

In [4]:
sum(len(sentence) for sentence in clean_dataset)

8756821

In [5]:
len(clean_dataset.unique_words)

35259

# Random initialization

In [6]:
matrix_size = 16

In [7]:
def random_matrix():
    norm = torch.normal(mean=0, std=1, size=(matrix_size, matrix_size))
    return torch.tensor(
        torch.sign(norm) * torch.sqrt(torch.abs(norm)),
        requires_grad=True,
        dtype=torch.float,
    )

matrices = {
    word: random_matrix()
    for word in tqdm(clean_dataset.unique_words)
}
matrices["the"]

HBox(children=(FloatProgress(value=0.0, max=35259.0), HTML(value='')))

  return torch.tensor(





tensor([[-1.2171, -0.7279,  0.7662, -0.9629,  0.5058, -1.0479, -0.6080,  0.8031,
         -0.4424, -1.2091,  0.1963, -0.7502,  1.1065, -1.0266, -1.3331, -0.4274],
        [-0.6726, -1.3106, -0.5628, -0.6087,  0.3584, -1.4838, -1.0848,  0.5396,
          1.1632,  1.1234,  0.9771, -0.8127, -0.6217, -0.5047, -1.6581, -0.7960],
        [ 0.3070,  1.3839, -0.8876, -0.3768,  1.3029,  0.7506,  0.9136,  0.9350,
         -0.9389,  0.8463, -1.3612,  0.7044, -0.4705,  0.1037, -1.1618, -0.4376],
        [ 0.9772,  1.0808,  0.5715,  0.8625,  0.5533,  1.2063, -0.9480, -0.9700,
         -1.1138, -0.3591, -0.7479,  0.3610,  0.7650,  0.7968,  0.7854, -0.2873],
        [ 1.2690,  1.2931,  0.4189, -0.9288, -0.4302, -0.7685,  0.6194, -1.2549,
          0.7230,  1.2038, -0.7119, -0.5637, -0.5861,  0.5789, -0.5897,  0.9192],
        [ 0.6912, -0.9088, -0.3841,  1.1091, -1.5322, -0.7385, -1.1247,  1.5261,
         -0.5393, -0.3531, -0.3706, -0.6526, -0.4392,  1.1890,  0.6200,  1.1688],
        [-1.0038, -0.6

In [8]:
def random_vector():
    return torch.tensor(
        scipy.stats.norm.rvs(loc=0, scale=1, size=matrix_size),
        requires_grad=True,
        dtype=torch.float,
    )

start_vector = random_vector()
end_vector = random_vector()
start_vector, end_vector

(tensor([-0.4086,  0.5127, -1.3280, -0.2225, -0.7098, -0.6438,  0.4069,  0.5243,
         -0.1342,  0.6890,  0.0910, -0.5397, -0.2112, -0.2830,  0.3657,  1.2491],
        requires_grad=True),
 tensor([-0.5716, -0.1844,  0.4236, -0.3956, -1.0585,  0.0988,  0.1628,  0.8140,
          1.6397, -1.1338,  0.4639, -1.0901, -0.7461, -2.1136, -1.2184, -0.7016],
        requires_grad=True))

In [50]:
vector_size=matrix_size**2

def random_vector_word2vec():
    return torch.tensor(
        scipy.stats.norm.rvs(loc=0, scale=1, size=vector_size),
        requires_grad=True,
        dtype=torch.float,
    )


vectors = {
    word: random_vector_word2vec()
    for word in tqdm(clean_dataset.unique_words)
}
vectors["the"]

HBox(children=(FloatProgress(value=0.0, max=35259.0), HTML(value='')))




tensor([ 1.5039e+00,  9.1487e-01,  5.7596e-01, -1.1794e+00,  6.4231e-01,
         1.9410e-01, -9.2127e-01, -5.0631e-02, -5.4263e-02,  9.3026e-01,
         6.6303e-01, -2.4314e-01, -4.4890e-01,  1.2983e+00,  1.4777e-01,
         1.4156e+00,  6.1741e-01, -1.1937e+00,  2.4587e-01, -1.5104e+00,
         1.8996e+00, -1.0661e+00,  3.5413e-01,  4.8250e-01,  1.9658e+00,
         5.0636e-01,  2.9423e-01,  1.8245e+00, -1.4695e+00, -5.5391e-01,
         8.3140e-01, -8.7750e-01,  2.7294e-01, -7.0760e-01,  1.9555e+00,
         1.3328e+00, -1.4874e-01, -1.1180e+00,  2.6546e-02,  1.0961e+00,
         3.0130e-01,  9.1614e-01,  1.2239e+00, -3.1333e-01,  1.1400e+00,
         1.2073e+00, -1.5056e+00,  1.3151e+00,  1.1352e+00,  8.1147e-01,
        -1.9550e+00, -1.0466e+00,  2.9558e+00, -1.3830e+00, -9.8995e-02,
         3.1872e-01, -1.6753e+00, -3.4090e-01, -8.3168e-01, -5.4376e-01,
        -1.8467e-01, -4.8526e-01,  1.6866e+00,  6.4337e-01, -7.8713e-01,
        -1.2518e+00, -1.1662e+00, -7.8544e-01,  8.9

# Fake sentence generation

In [19]:
word_counts = Counter(word for sentence in clean_dataset for word in sentence.words)
sampleable_words = list(word_counts.keys())
cum_weights = np.cumsum(list(word_counts.values()))

def random_word():
    return random.choices(sampleable_words, cum_weights=cum_weights)[0]

random_word()

'lowest'

In [10]:
%timeit random_word()

6.09 µs ± 654 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [25]:
def fake_sentence(num_fake_words=1):
    words = list(random.choice(clean_dataset.sentences).words)
    for fake_idx in range(num_fake_words):
        words[random.randrange(len(words))] = random_word()
    return Sentence(words=words)

fake_sentence()

Sentence(words=['when', 'owners', 'choose', 'to', 'deregister', 'their', 'vehicle', 'the', 'officer', 'at', 'the', 'local', 'authority', 'will', 'want', 'to', 'see', 'the', 'licence', 'plates', 'with', 'defaced', 'seals', 'on', 'them', 'as', 'proof', 'that', 'no', 'legal', 'the', 'with', 'this', 'identifier', 'can', 'be', 'found', 'in', 'public', 'any', 'more', '.'])

In [26]:
%timeit fake_sentence()

9.48 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


# Training

In [27]:
def sentence_matrix(sentence):
    return reduce(
        lambda current, word: torch.matmul(
            (
            torch.sign(current) * torch.sqrt(torch.abs(current))),
            matrices[word]
            ),
        sentence.words,
        torch.eye(matrix_size)
    )
    

sentence_matrix(Sentence.from_string("Mary had a little lamb."))

tensor([[-4.0602e+00,  8.5268e+00, -2.3196e+01,  4.5864e+00, -7.2280e+00,
          8.8654e+00,  2.5520e-01, -8.2458e+00,  9.6403e+00, -1.4068e+01,
          1.6684e+01,  7.1399e+00, -7.5682e+00,  1.8461e+01, -3.4638e+00,
         -5.9608e-01],
        [ 4.2403e-01,  6.3713e+00,  1.0861e+01, -1.4255e+01,  1.3616e+01,
          1.0221e+00,  9.7342e+00,  1.4764e+01, -3.1260e+00,  3.7158e+00,
         -1.4199e+01,  1.2649e+01,  5.8650e-01, -1.8709e+01,  7.8806e+00,
          1.4873e+01],
        [-8.2099e+00,  2.2262e+00,  1.0682e+01,  1.7837e+00,  1.4427e+00,
         -8.0753e+00,  1.9025e+01,  3.9507e+00, -9.7082e+00, -8.3696e-01,
          5.3717e+00,  1.3354e+01,  1.9428e+00, -6.0820e+00, -8.0801e+00,
          1.5443e+01],
        [ 1.6108e+01, -4.9564e+00,  1.3949e+01,  2.0473e-02,  9.4695e+00,
          2.8092e+00, -5.3620e+00, -1.1405e+01,  3.6206e+00,  3.0908e+00,
          4.4819e-02,  1.1118e+01, -6.5791e+00, -1.2677e+01, -1.5483e+00,
         -3.8684e+00],
        [ 1.4213e+01

In [81]:
def sentence_vector(sentence):
    return reduce(
        lambda current, word: torch.mul(
            (
            torch.sign(current) * torch.sqrt(torch.abs(current))),
            vectors[word]
            ),
        sentence.words,
        torch.ones(1, vector_size)[0]
    )
sentence_vector(Sentence.from_string("Mary had a little lamb."))

tensor([-0.4821, -0.6102,  0.3131, -0.2391, -0.1351, -0.4316, -0.0779,  1.0675,
         0.3106, -0.6524,  0.4682,  0.0239,  1.4590, -0.1687, -0.1416,  0.1407,
         1.1495,  0.1038,  0.0625,  0.4406, -0.4373,  0.0362,  0.6255,  1.0521,
        -0.9848,  0.0989, -0.7805, -0.6769, -0.2130,  0.3002, -0.3786,  0.0226,
         0.3941,  0.2445,  0.0948, -0.2268, -0.7632,  0.4818, -0.1381, -0.7460,
         0.3152, -0.2805,  2.3814,  0.4373, -0.1422, -1.0733, -0.8064,  0.4122,
         0.7814, -0.8417, -0.1018, -0.6643,  0.1098,  0.2847,  1.1917, -0.0460,
         0.4588,  0.0069, -0.0857,  0.6059,  0.5504,  0.1384, -0.1944, -0.1435,
        -1.3152, -0.0189, -0.8722,  0.0127, -0.1156, -0.8222,  0.2357,  0.8901,
         0.0814, -0.1193, -0.8998, -0.1225,  0.0921, -1.6622, -0.4210, -1.0981,
         0.0945, -0.4719, -0.0783,  1.3385,  0.1201, -0.2587, -0.1767, -0.0117,
         0.1794,  0.0277,  0.8330, -0.1397, -0.5831,  0.0099, -0.4957, -0.0567,
         1.0765, -0.7582, -0.6049,  0.57

In [28]:
def sentence_logit(sentence):
    return torch.dot(torch.matmul(start_vector, sentence_matrix(sentence)), end_vector)

sentence_logit(Sentence.from_string("Mary had a little lamb."))

tensor(-0.3603, grad_fn=<DotBackward>)

In [90]:
def sentence_logit_word2vec(sentence):
    return torch.dot(torch.mul(start_vector, sentence_vector(sentence)), end_vector)

print(sentence_vector(Sentence.from_string("Mary had a little lamb.")))

sentence_logit_word2vec(Sentence.from_string("Mary had a little lamb."))

tensor([-0.4821, -0.6102,  0.3131, -0.2391, -0.1351, -0.4316, -0.0779,  1.0675,
         0.3106, -0.6524,  0.4682,  0.0239,  1.4590, -0.1687, -0.1416,  0.1407,
         1.1495,  0.1038,  0.0625,  0.4406, -0.4373,  0.0362,  0.6255,  1.0521,
        -0.9848,  0.0989, -0.7805, -0.6769, -0.2130,  0.3002, -0.3786,  0.0226,
         0.3941,  0.2445,  0.0948, -0.2268, -0.7632,  0.4818, -0.1381, -0.7460,
         0.3152, -0.2805,  2.3814,  0.4373, -0.1422, -1.0733, -0.8064,  0.4122,
         0.7814, -0.8417, -0.1018, -0.6643,  0.1098,  0.2847,  1.1917, -0.0460,
         0.4588,  0.0069, -0.0857,  0.6059,  0.5504,  0.1384, -0.1944, -0.1435,
        -1.3152, -0.0189, -0.8722,  0.0127, -0.1156, -0.8222,  0.2357,  0.8901,
         0.0814, -0.1193, -0.8998, -0.1225,  0.0921, -1.6622, -0.4210, -1.0981,
         0.0945, -0.4719, -0.0783,  1.3385,  0.1201, -0.2587, -0.1767, -0.0117,
         0.1794,  0.0277,  0.8330, -0.1397, -0.5831,  0.0099, -0.4957, -0.0567,
         1.0765, -0.7582, -0.6049,  0.57

RuntimeError: The size of tensor a (16) must match the size of tensor b (256) at non-singleton dimension 0

In [79]:
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.ones(()))

def train_epoch(matrix_lr, vector_lr, max_grad=1.0):
    losses = []
    count=0
    for sentence in tqdm(clean_dataset.shuffled):
        count+=1
        if count == 100:
            return losses
        # creating optimizer here because .zero_grad() takes forever for all matrices
        # to use anything stateful (e.g. AdaGrad), put it outside the loop and pass all matrices as parameters
        optimizer = torch.optim.SGD([
            dict(params=[start_vector, end_vector], lr=vector_lr),
            dict(params=[vectors[word] for word in set(sentence.words)], lr=matrix_lr),
        ])
        optimizer.zero_grad()
        true_logit = sentence_logit_word2vec(sentence)
        fake_logit = sentence_logit_word2vec(fake_sentence())
        loss = criterion(torch.stack([true_logit, fake_logit]), torch.tensor([1.0, 0.0]))
        loss.backward()
        torch.nn.utils.clip_grad_value_(
            (param for group in optimizer.param_groups for param in group["params"]),
            clip_value=max_grad,
        )
        losses.append(loss.item())
        
        optimizer.step()
    return losses

losses = []
matrix_lr = 1e-2
vector_lr = 1e-3
for epoch in trange(4):
    losses += train_epoch(matrix_lr=matrix_lr, vector_lr=vector_lr)
    matrix_lr *= 0.5
    vector_lr *= 0.5

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))





RuntimeError: The size of tensor a (16) must match the size of tensor b (256) at non-singleton dimension 0

In [42]:
go.Figure(
    layout=dict(
        title="Training progress",
        xaxis_title="Number of processed sentences",
        yaxis_title="Loss (median over 65536 steps)",
    ),
    data=[
        go.Scatter(
            x=np.arange(len(losses)) * 1,
            y=pd.Series(losses[::1]).rolling(window=1, center=True).median(),
        )
    ]
)

# Results

In [43]:
start_vector, end_vector

(tensor([ 0.0147,  0.0147, -0.0068,  0.0276, -0.0010, -0.0074,  0.0023,  0.0130,
         -0.0075, -0.0162,  0.0008,  0.0138,  0.0014,  0.0031,  0.0134,  0.0035],
        requires_grad=True),
 tensor([-0.3794, -0.0037,  0.3270, -0.1668, -0.8573,  0.1360,  0.1568,  0.5461,
          1.1971, -0.7538,  0.5835, -0.8607, -0.7014, -1.5124, -0.9240, -0.5461],
        requires_grad=True))

In [44]:
matrices["the"]

tensor([[-1.7153e+00, -6.8888e-01,  1.7436e+00, -5.0571e-01,  1.0888e+00,
         -1.6746e+00, -7.8656e-01,  2.8983e+00, -2.5237e+00, -3.0493e-01,
         -5.6227e-01, -1.7170e+00,  1.7610e+00, -1.1671e+00, -1.9369e+00,
         -1.0157e+00],
        [-6.4064e-01, -1.9195e+00, -2.4567e-01,  2.6313e-01, -9.6254e-01,
         -3.2435e+00, -1.8565e+00, -2.4893e-01,  3.0438e+00,  1.7564e+00,
          8.1284e-01, -1.8518e+00, -1.7809e+00, -7.8531e-01,  2.1500e-01,
          1.3918e-02],
        [-4.9701e-01,  1.4184e+00, -7.0343e-01, -1.6496e+00,  2.4489e+00,
          2.2640e+00, -1.2121e+00,  5.6874e-01, -1.2630e-01,  6.2753e-01,
         -1.0241e+00,  1.6231e+00, -2.5012e-02,  5.6296e-01, -1.4863e+00,
         -6.1284e-01],
        [ 2.5341e+00,  1.1828e+00,  1.0467e+00,  9.5365e-01, -6.6109e-01,
         -4.3831e-01, -2.8384e-01, -1.6502e+00, -2.6878e+00, -8.7045e-01,
         -1.6569e+00,  9.6382e-02,  1.9240e+00,  1.5607e+00,  5.6114e-01,
          6.5428e-01],
        [-6.3383e-02

In [45]:
real_logits = [sentence_logit(sentence).item() for sentence in clean_dataset.sentences[:4096]]
fake_logits = [sentence_logit(fake_sentence()).item() for _ in range(4096)]
auc = sklearn.metrics.roc_auc_score(
    [1] * len(real_logits) + [0] * len(fake_logits),
    real_logits + fake_logits,
)

In [46]:
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
    [1] * len(real_logits) + [0] * len(fake_logits),
    real_logits + fake_logits,
)
go.Figure(
    layout=dict(
        title=f"ROC curve, AUC={auc:.2f}",
        xaxis_title="False positive rate",
        yaxis_title="True positive rate",
    ),
    data=[
        go.Scatter(name="word2mat", x=fpr, y=tpr),
        go.Scatter(name="random classifier", x=[0, 1], y=[0, 1])
    ],
)