In [35]:
import random
from functools import reduce
from collections import Counter
import numpy as np
import pandas as pd
import torch
import scipy.stats
import sklearn.metrics
import plotly.graph_objects as go
from tqdm.auto import tqdm, trange
from word2mat import Dataset, Sentence

# Data

In [2]:
full_dataset = Dataset.from_huggingface("wiki_auto", config="manual", column="normal_sentence")
len(full_dataset)

Reusing dataset wiki_auto (/home/malyvsen/.cache/huggingface/datasets/wiki_auto/manual/1.0.0/5ffdd9fc62422d29bd02675fb9606f77c1251ee17169ac10b143ce07ef2f4db8)


HBox(children=(FloatProgress(value=0.0, description='Parsing dataset', max=373801.0, style=ProgressStyle(descr…




373801

In [3]:
clean_dataset = full_dataset.clean
len(clean_dataset)

HBox(children=(FloatProgress(value=0.0, description='Reading word vectors', max=4027169.0, style=ProgressStyle…




329155

In [4]:
sum(len(sentence) for sentence in clean_dataset)

8756821

In [36]:
len(clean_dataset.unique_words)

35259

# Random initialization

In [37]:
def random_matrix(dim):
    return torch.tensor(
        scipy.stats.norm.rvs(loc=0, scale=1 / np.sqrt(dim), size=(dim, dim)),
        requires_grad=True,
        dtype=torch.float,
    )

matrices = {
    word: random_matrix(8)
    for word in tqdm(clean_dataset.unique_words)
}
matrices["the"]

HBox(children=(FloatProgress(value=0.0, max=35259.0), HTML(value='')))




tensor([[-0.1618,  0.2016,  0.1384, -0.6075,  0.1978,  0.1985,  0.0912, -0.6953],
        [-0.0863, -0.0587, -0.1175, -0.3868, -0.0817, -0.0013, -0.1807,  0.0618],
        [ 0.2552,  0.3970, -0.1287, -0.3487, -0.1096, -0.0600, -0.1071, -0.4589],
        [ 0.0534, -0.5916, -0.7014, -0.1056,  0.4103, -0.1309, -0.2137, -0.5010],
        [ 0.1979,  0.4449,  0.3049, -0.1449,  0.4353, -0.2221, -0.1971, -0.0496],
        [ 0.1335,  0.2046,  0.1227,  0.1816,  0.2220,  0.3092, -0.1201,  0.5066],
        [-0.0229, -0.1157,  1.2785,  0.0453, -0.0706,  0.0309,  0.0054,  0.0494],
        [-0.3733, -0.0828, -0.4899, -0.1489, -0.4930,  0.0283,  0.0892,  0.0370]],
       requires_grad=True)

In [38]:
def random_vector(dim):
    return torch.tensor(
        scipy.stats.norm.rvs(loc=0, scale=1, size=dim),
        requires_grad=True,
        dtype=torch.float,
    )

start_vector = random_vector(8)
end_vector = random_vector(8)
start_vector, end_vector

(tensor([ 1.0868, -2.4053, -0.6037, -0.8671,  0.3867,  0.0671,  0.6497,  1.2839],
        requires_grad=True),
 tensor([ 0.1511, -0.2447,  2.1735, -1.5337, -0.3021, -0.4898,  0.2723,  0.4627],
        requires_grad=True))

# Fake sentence generation

In [39]:
word_counts = Counter(word for sentence in clean_dataset for word in sentence.words)
sampleable_words = list(word_counts.keys())
cum_weights = np.cumsum(list(word_counts.values()))

def random_word():
    return random.choices(sampleable_words, cum_weights=cum_weights)[0]

random_word()

'form'

In [40]:
%timeit random_word()

5.25 µs ± 242 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [41]:
def fake_sentence(fakeness=1):
    words = list(random.choice(clean_dataset.sentences).words)
    for fake_idx in range(fakeness):
        words[random.randrange(len(words))] = random_word()
    return Sentence(words=words)

fake_sentence()

Sentence(words=['the', 'same', 'year', ',', 'the', 'agriculture', 'and', 'resource', 'management', 'council', 'today', 'australia', 'and', 'new', 'zealand', 'decided', 'to', 'phase', 'out', 'remaining', 'organochlorine', 'uses', 'by', '30', 'june', '1995', ',', 'with', 'the', 'exception', 'of', 'the', 'northern', 'territory', '.'])

In [42]:
%timeit fake_sentence()

10.1 µs ± 292 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


# Training

In [43]:
def sentence_matrix(sentence):
    return reduce(lambda current, word: torch.matmul(current, matrices[word]), sentence.words, torch.eye(8))

sentence_matrix(Sentence.from_string("Mary had a little lamb."))

tensor([[ 3.8612e-01,  1.5858e-01, -5.0950e-02, -4.2124e-02, -2.5568e-02,
         -3.5756e-02,  2.8091e-01, -1.7783e-01],
        [ 8.0057e-02,  4.1421e-01,  5.9367e-01, -1.7041e-01, -1.5695e-01,
          2.4524e-01, -3.4618e-02, -1.4546e-01],
        [-3.5722e-01,  3.2075e-04,  2.1570e-01,  2.1391e-02,  7.9489e-02,
          3.8032e-01, -1.8550e-01,  1.0437e-01],
        [ 5.6359e-01, -7.0549e-02, -3.1379e-01, -1.0332e-01,  1.6236e-01,
         -1.3729e-01,  3.7248e-01, -3.4028e-01],
        [-5.2485e-01, -5.2873e-01, -4.4958e-01,  2.8041e-01,  1.4301e-01,
         -2.1158e-01, -3.4157e-01,  2.8989e-01],
        [-1.2084e+00,  5.3043e-01,  8.1154e-01,  2.9404e-01, -5.8934e-01,
          3.6867e-01, -4.8949e-01,  1.0754e+00],
        [-5.5928e-01,  1.3962e-02,  1.9044e-01,  1.2788e-01, -1.2094e-01,
          1.2819e-01, -2.9412e-01,  3.9206e-01],
        [-1.0201e+00,  6.7532e-01,  9.5319e-01,  3.3759e-01, -8.3309e-01,
         -1.1638e-02, -6.6021e-01,  8.7087e-01]], grad_fn=<MmBack

In [44]:
def sentence_logit(sentence):
    return torch.dot(torch.matmul(start_vector, sentence_matrix(sentence)), end_vector)

sentence_logit(Sentence.from_string("Mary had a little lamb."))

tensor(-0.9289, grad_fn=<DotBackward>)

In [45]:
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.ones(()))

def train_epoch(matrix_lr, vector_lr, max_grad=1.0):
    losses = []
    for sentence in tqdm(clean_dataset.shuffled):
        optimizer = torch.optim.SGD([
            dict(params=[start_vector, end_vector], lr=vector_lr),
            dict(params=[matrices[word] for word in set(sentence.words)], lr=matrix_lr),
        ]) # creating optimizer here because .zero_grad() takes forever for all matrices
        optimizer.zero_grad()
        true_logit = sentence_logit(sentence)
        fake_logit = sentence_logit(fake_sentence())
        loss = criterion(torch.stack([true_logit, fake_logit]), torch.tensor([1.0, 0.0]))
        loss.backward()
        torch.nn.utils.clip_grad_value_(
            (param for group in optimizer.param_groups for param in group["params"]),
            clip_value=max_grad,
        )
        losses.append(loss.item())
        
        optimizer.step()
    return losses

losses = []
matrix_lr = 1e-2
vector_lr = 1e-3
for epoch in trange(32):
    losses += train_epoch(matrix_lr=matrix_lr, vector_lr=vector_lr)
    matrix_lr *= 0.9
    vector_lr *= 0.9

HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))





In [46]:
go.Figure(
    layout=dict(
        title="Training progress",
        xaxis_title="Number of processed sentences",
        yaxis_title="Smoothed loss",
    ),
    data=[
        go.Scatter(
            y=pd.Series(losses[::64]).rolling(window=64, center=True).mean(),
        )
    ]
)

# Results

In [47]:
start_vector, end_vector

(tensor([-0.2164, -0.5290,  0.2426,  0.2603, -0.0861, -0.1565, -0.2338, -0.3255],
        requires_grad=True),
 tensor([-2.9123e-03, -2.6610e-03,  2.7467e-01,  1.6595e-04, -1.4466e-03,
         -1.1963e-03, -4.0333e-03, -3.0048e-04], requires_grad=True))

In [48]:
matrices["the"]

tensor([[ 0.0830,  0.0390, -0.0058, -0.0036, -0.0099,  0.2107, -0.0657, -0.1922],
        [-0.4172, -0.0075, -0.0163, -0.2322,  0.0580, -0.3811,  0.1178,  0.1084],
        [-0.0407,  0.1079, -0.0275, -0.0789,  0.0381, -0.1304,  0.0412,  0.0616],
        [-0.1041, -0.1156, -0.1534, -0.0531,  0.0165, -0.0543,  0.0371, -0.0428],
        [ 0.2948, -0.0520,  0.1097,  0.0038, -0.0388,  0.1542, -0.0640, -0.0360],
        [ 0.1076,  0.0247, -0.0084,  0.0100,  0.0463,  0.0978, -0.0265,  0.0219],
        [ 0.0568,  0.0036,  0.1476, -0.0057, -0.0324, -0.0098,  0.0199, -0.0825],
        [ 0.0193,  0.0200, -0.0982, -0.0409, -0.0022,  0.0340,  0.0169,  0.0135]],
       requires_grad=True)

In [49]:
go.Figure(
    layout=dict(
        title="Distribution of sentence scores",
        xaxis_title="Probability that a sentence is real, according to the model",
        yaxis_title="Frequency",
    ),
    data=[
        go.Histogram(
            name=f"{sentence_type.title()} sentences",
            histnorm="probability density",
            x=[torch.sigmoid(sentence_logit(sentence)).item() for sentence in sentences],
        )
        for sentence_type, sentences in dict(
            real=clean_dataset.sentences[:4096],
            fake=[fake_sentence() for _ in range(4096)],
        ).items()
    ]
)

In [50]:
np.mean([sentence_logit(sentence).item() for sentence in tqdm(clean_dataset)])

HBox(children=(FloatProgress(value=0.0, max=329155.0), HTML(value='')))




1.6076454245431582

In [51]:
np.mean([sentence_logit(fake_sentence()).item() for _ in range(8192)])

8.438882042648062

In [52]:
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
    [1] * 4096 + [0] * 4096,
    [sentence_logit(sentence).item() for sentence in clean_dataset.sentences[:4096]]
    + [sentence_logit(fake_sentence()).item() for _ in range(4096)]
)
go.Figure(
    layout=dict(
        title="ROC curve",
        xaxis_title="False positive rate",
        yaxis_title="True positive rate",
    ),
    data=[
        go.Scatter(name="word2mat", x=fpr, y=tpr),
        go.Scatter(name="random classifier", x=[0, 1], y=[0, 1])
    ],
)