In [1]:
import random
from functools import reduce
from collections import Counter
import pickle
import numpy as np
import pandas as pd
import torch
import scipy.stats
import sklearn.metrics
import plotly.graph_objects as go
from tqdm.auto import tqdm, trange
from word2mat import Dataset, Sentence

# Data

In [2]:
full_dataset = Dataset.from_huggingface("wiki_auto", config="manual", column="normal_sentence")
len(full_dataset)

Reusing dataset wiki_auto (/home/gilles/.cache/huggingface/datasets/wiki_auto/manual/1.0.0/5ffdd9fc62422d29bd02675fb9606f77c1251ee17169ac10b143ce07ef2f4db8)
Widget Javascript not detected.  It may not be installed or enabled properly.





373801

In [3]:
clean_dataset = full_dataset.clean
len(clean_dataset)

Widget Javascript not detected.  It may not be installed or enabled properly.





329155

In [4]:
sum(len(sentence) for sentence in clean_dataset)

8756821

In [5]:
len(clean_dataset.unique_words)

35259

# Random initialization

In [6]:
matrix_size = 16

In [7]:
def random_matrix():
    norm = torch.normal(mean=0, std=1, size=(matrix_size, matrix_size))
    return torch.tensor(
        torch.sign(norm) * torch.sqrt(torch.abs(norm)),
        requires_grad=True,
        dtype=torch.float,
    )
matrices = {
    word: random_matrix()
    for word in tqdm(clean_dataset.unique_words)
}
matrices["the"]

Widget Javascript not detected.  It may not be installed or enabled properly.


  return torch.tensor(





tensor([[-0.8443,  1.3897, -1.2968,  0.5525,  0.8347, -0.2670,  1.0575, -0.4932,
         -0.6666,  0.5867, -1.3706,  0.7983, -0.9909,  1.0746, -1.2912,  0.8773],
        [ 1.2467,  1.3260,  1.3360,  1.1788, -0.8703,  1.0593,  0.5659,  0.6849,
          0.8580,  0.1997, -1.2664, -1.0616,  0.5424, -1.0557,  0.7035,  1.3481],
        [ 1.0624,  0.7527,  1.0513, -0.2668, -0.6289,  1.1453,  1.2756, -1.0923,
         -0.4579,  1.3530, -1.0308,  0.5768, -1.1326, -0.7334,  0.7872,  1.2861],
        [-1.1818, -0.6453,  0.7572,  0.5981, -1.0858,  0.3644,  0.6498,  0.5183,
         -1.2021,  0.8757,  0.3794,  0.3356, -0.9101,  1.1606,  0.1124,  0.8744],
        [ 0.4490, -1.2373,  0.7737,  0.8930, -0.9832,  0.1209, -0.9895, -1.2431,
          0.6066, -0.9498, -0.3332, -0.3205,  1.0459, -0.9953,  0.7252, -0.9367],
        [ 0.6209,  0.6492,  0.9338, -0.3300,  1.1123,  0.6105, -1.5423, -1.2314,
         -1.3113,  0.7401,  1.1141,  0.6442, -0.9963,  1.1651, -0.8672,  0.5000],
        [ 0.2048, -0.9

In [8]:
def random_vector():
    return torch.tensor(
        scipy.stats.norm.rvs(loc=0, scale=1, size=matrix_size),
        requires_grad=True,
        dtype=torch.float,
    )

start_vector = random_vector()
end_vector = random_vector()
start_vector, end_vector

(tensor([ 1.1623, -1.3310, -0.2566,  0.8144, -0.6448, -0.0772,  0.5764, -0.3048,
         -0.1830,  0.7343, -0.2617,  0.6254, -0.4747, -0.3808, -0.6341,  0.4545],
        requires_grad=True),
 tensor([ 0.2705,  1.9889, -1.3926, -0.9979, -1.0116, -0.8587, -0.0478,  1.5001,
          0.3300,  0.6238, -0.3157, -0.6369, -0.6734,  1.0544, -0.2695,  0.0164],
        requires_grad=True))

# Fake sentence generation

In [9]:
word_counts = Counter(word for sentence in clean_dataset for word in sentence.words)
sampleable_words = list(word_counts.keys())
cum_weights = np.cumsum(list(word_counts.values()))

def random_word():
    return random.choices(sampleable_words, cum_weights=cum_weights)[0]

random_word()

'followed'

In [10]:
%timeit random_word()

6.87 µs ± 240 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [11]:
def fake_sentence(num_fake_words=1):
    words = list(random.choice(clean_dataset.sentences).words)
    for fake_idx in range(num_fake_words):
        words[random.randrange(len(words))] = random_word()
    return Sentence(words=words)

fake_sentence()

Sentence(words=['ozone', 'may', 'be', 'formed', 'from', 'o2', 'by', 'electrical', 'discharges', 'the', 'by', 'action', 'of', 'high', 'energy', 'electromagnetic', 'radiation', '.'])

In [12]:
%timeit fake_sentence()

10.4 µs ± 389 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


# Training

In [13]:
def activationSignedSquare(matrix):
    return torch.sign(matrix) * torch.sqrt(torch.abs(matrix))

def sentence_matrix(sentence):
    return reduce(
        lambda current, word: activationSignedSquare(torch.matmul(current, matrices[word])),
        sentence.words,
        torch.eye(matrix_size)
    )

sentence_matrix(Sentence.from_string("Mary had a little lamb."))

tensor([[-3.7291, -0.6134, -2.6697,  2.5590,  4.5014, -1.9287, -2.0835,  4.2250,
         -5.2134,  4.1342, -0.7590,  2.0100, -1.7895, -1.6452, -3.6626,  1.8128],
        [-2.4878,  3.2861,  1.8272,  4.6291,  4.0239, -3.0391, -2.4214,  3.0768,
         -4.3061,  3.1562,  3.4726,  2.5164, -0.7291,  2.8263, -3.1938,  3.4615],
        [ 3.3846,  1.2356,  1.9137, -2.9243, -2.9425,  0.9384,  2.5315, -3.9340,
          4.3722, -3.1580, -2.9095, -3.7015, -1.7957, -4.0127,  4.0254,  1.5761],
        [ 3.6781, -2.2420,  2.8842, -4.0699, -3.8752,  2.4095,  3.1955, -3.8356,
          4.9473, -3.5168, -3.3843, -0.9242, -1.4787, -2.9206,  3.9762, -2.6514],
        [ 3.1942, -2.2517,  3.8833, -3.5788, -2.7593, -0.5443,  1.4416, -3.5061,
          4.5144, -3.1442, -2.1588,  2.7200, -2.8246, -0.8761,  3.1046,  0.8532],
        [-4.2455, -2.6791, -2.5451, -1.8944,  2.8993, -1.7310, -1.6327,  4.5436,
         -5.4601,  2.3673, -2.6006,  2.1521,  2.5155,  1.4575, -2.9447, -3.3949],
        [-1.8317,  1.8

In [14]:
def sentence_logit(sentence):
    return torch.dot(torch.matmul(start_vector, sentence_matrix(sentence)), end_vector)

sentence_logit(Sentence.from_string("Mary had a little lamb."))

tensor(4.7654, grad_fn=<DotBackward>)

In [None]:
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.ones(()))

def train_epoch(matrix_lr, vector_lr, max_grad=1.0):
    losses = []
    for sentence in tqdm(clean_dataset.shuffled):
        # creating optimizer here because .zero_grad() takes forever for all matrices
        # to use anything stateful (e.g. AdaGrad), put it outside the loop and pass all matrices as parameters
        optimizer = torch.optim.SGD([
            dict(params=[start_vector, end_vector], lr=vector_lr),
            dict(params=[matrices[word] for word in set(sentence.words)], lr=matrix_lr),
        ])
        optimizer.zero_grad()
        true_logit = sentence_logit(sentence)
        fake_logit = sentence_logit(fake_sentence())
        loss = criterion(torch.stack([true_logit, fake_logit]), torch.tensor([1.0, 0.0]))
        loss.backward()
        torch.nn.utils.clip_grad_value_(
            (param for group in optimizer.param_groups for param in group["params"]),
            clip_value=max_grad,
        )
        losses.append(loss.item())
        
        optimizer.step()
    return losses

losses = []
matrix_lr = 1e-2
vector_lr = 1e-3
for epoch in trange(32):
    losses += train_epoch(matrix_lr=matrix_lr, vector_lr=vector_lr)
    matrix_lr *= 0.5
    vector_lr *= 0.5

Widget Javascript not detected.  It may not be installed or enabled properly.


Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.


In [None]:
go.Figure(
    layout=dict(
        title="Training progress",
        xaxis_title="Number of processed sentences",
        yaxis_title="Loss (median over 65536 steps)",
    ),
    data=[
        go.Scatter(
            x=np.arange(len(losses)) * 64,
            y=pd.Series(losses[::64]).rolling(window=1024, center=True).median(),
        )
    ]
)

# Results

In [None]:
start_vector, end_vector

In [None]:
matrices["the"]

In [None]:
real_logits = [sentence_logit(sentence).item() for sentence in clean_dataset.sentences[:4096]]
fake_logits = [sentence_logit(fake_sentence()).item() for _ in range(4096)]
auc = sklearn.metrics.roc_auc_score(
    [1] * len(real_logits) + [0] * len(fake_logits),
    real_logits + fake_logits,
)

In [None]:
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
    [1] * len(real_logits) + [0] * len(fake_logits),
    real_logits + fake_logits,
)
go.Figure(
    layout=dict(
        title=f"ROC curve, AUC={auc:.2f}",
        xaxis_title="False positive rate",
        yaxis_title="True positive rate",
    ),
    data=[
        go.Scatter(name="word2mat", x=fpr, y=tpr),
        go.Scatter(name="random classifier", x=[0, 1], y=[0, 1])
    ],
)

# Let's save it!

In [None]:
with open("./trained.pkl", "wb") as file:
    pickle.dump(
        dict(
            start_vector=start_vector.tolist(),
            end_vector=end_vector.tolist(),
            matrices={word: matrix.tolist() for word, matrix in matrices.items()},
        ),
        file
    )