In [130]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [131]:
from datasets import GeonamesDataset
import polars as pl

In [132]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [133]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Beaupreau""","""PPL""","""FR""",6932
"""Vinkuran""","""PPL""","""HR""",652
"""Bevenais""","""PPL""","""FR""",798
"""Kiskoros""","""PPLA2""","""HU""",15091
"""Chiapa""","""PPL""","""MX""",808
"""Ogunja""","""PPL""","""KE""",7060
"""Minuwangoda""","""PPL""","""LK""",7772
"""Palaran""","""PPLA3""","""ID""",0
"""Makalle""","""PPLA2""","""AR""",4994
"""Horine""","""PPL""","""US""",821


In [134]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [135]:
from utils import Tokenizer

tokenizer = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [136]:
X = tokenizer.encode(df)

In [137]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [138]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [139]:
for i, batch_ in enumerate(DataLoader(train, batch_size=64)):
    break

# Discrete diffusion model

After tokenization of our text data, i.e. adding start and stop tokens and adding padding, our input looks like this:

Input: `T['batch_size', 'len_sentence', torch.long]`

After that, we want to build two things:

1. Diffusion process `DP`, that would have a following API:
 - `DP.sample_T(x_0: T['b', 'l', torch.long], t: T['b', torch.long]) -> T['b', 'l', torch.long]` -- a closed-form version of sampling timepoints for a given input examples. This can be done either in vanilla way, by doing small noising steps many times, or in a single step with some reparametrization.
 - `DP.sample(model: torch.Module, n_samples: int) -> T['n_samples', 'l', torch.long]` -- a way of sampling batch of tokens starting from a "ultimate noise" prior.

2. A model `Denoiser(nn.Module)`, with a following API:
 - `Denoiser.forward(x: T['b', 'l', torch.long], t: T['b']) -> T['b', 'l', 'v']` -- forward-pass that converts list of tokens into probas for denoised tokens. Will be compared to `T['b', 'l', torch.long]` with `F.cross_entropy`


In [190]:
batch_size = b = 16
d_embed = d = 128
len_sentence = l = 7
vocab_size = v = 17
max_T = T_ = 100

x = torch.randint(low=0, high=vocab_size, size=(b, l))
x.shape

torch.Size([16, 7])

In [191]:
t = torch.randint(low=0, high=max_T, size=(b,))
t.shape

torch.Size([16])

In [192]:
from model import Denoiser, ConvText, MLPText

denoiser = Denoiser(
    vocab_size=v,
    len_sentence=l,
    d_embed=d_embed,
    num_timestamps=T_,
    model=MLPText(d_embed=d_embed),
)
sum((p.numel() for p in denoiser.parameters()))

11712

In [193]:
denoiser(x, t).shape

torch.Size([16, 7, 128])

# Test autoregressive training

In [196]:
vocab_size = v = len(tokenizer.stoi)
len_sentence = tokenizer.max_len
d_embed = 64
num_timestamps = 1

model = Denoiser(
    vocab_size=v,
    len_sentence=l,
    d_embed=d_embed,
    num_timestamps=T_,
    model=MLPText(d_embed=d_embed),
)

In [197]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 128
num_epochs = 100
optimizer = AdamW(model.parameters(), lr=0.2)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)

for epoch in tqdm(range(num_epochs)):
    # get into training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        # initialize proper `x`
        t = torch.zeros(size=(batch_size,)).long()
        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()
        x = batch[:, :-1]
        y = batch[:, 1:]

        # predict demasked tokens
        y_pred = model(x, t).swapaxes(-1, -2)
        loss = F.cross_entropy(y_pred, y)

        # do backprop
        losses.append(loss.item())
        loss.backward()
    optimizer.step()
    # scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        t_test = torch.zeros(size=(X_test.shape[0],)).long()
        x_test = X_test.long()[:, :-1]
        y_test = X_test.long()[:, 1:]

        y_pred_test = model(x_test, t_test).swapaxes(-1, -2)
        test_loss = F.cross_entropy(y_pred_test, y_test)

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/100 [00:00<?, ?it/s]

epoch=0, train_loss=4.2221, val_loss=3.1574
epoch=1, train_loss=3.2611, val_loss=3.3983
epoch=2, train_loss=3.3182, val_loss=3.3558
epoch=3, train_loss=3.3524, val_loss=3.1249
epoch=4, train_loss=3.1472, val_loss=2.8873
epoch=5, train_loss=2.9178, val_loss=2.9197
epoch=6, train_loss=2.9320, val_loss=2.8801
epoch=7, train_loss=2.8941, val_loss=2.8410


KeyboardInterrupt: 

In [166]:
tokenizer.decode(X_val[:10])

['Asendorf',
 'Cioroboreni',
 'Welzheim',
 'Blessegue',
 'Roselle',
 'Lexington',
 'Tanjungbumi',
 'Wulsbuettel',
 'Waterford',
 'Berdavan']

In [174]:
tokenizer.decode(model(X_train[:10].long(), t=torch.zeros(10).long()).argmax(-1))

['aa>>',
 'aa>an>',
 'aa>a>>',
 'aa>an>',
 'aa>a>>',
 'aa>>>',
 'aa>>',
 'aa>a>>',
 'aa>>',
 'aa>>']

# Model initialization

Let's train a model that would predict a single missing letter in the word.

In [11]:
from torchtyping import TensorType

T = TensorType


class RandomMasker:
    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer
        self.start_token: int = tokenizer.stoi[tokenizer.start_token]
        self.end_token: int = tokenizer.stoi[tokenizer.end_token]
        self.pad_token: int = tokenizer.stoi[tokenizer.pad_token]
        self.mask_token: int = tokenizer.stoi[tokenizer.mask_token]

    def add_mask(self, x: T["b", "max_L", torch.long], p: float = 0.1):  # noqa: F821
        where = (x != self.start_token) & (x != self.end_token) & (x != self.pad_token)
        mask = (torch.randint_like(where.long(), low=0, high=100) < p * 100) & where
        x[mask] = self.mask_token
        return x

In [12]:
from torch import nn
from model import PositionalEncoding


class Model(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        vocab_size: int,
        n_tokens: int,
        hidden_dim: int,
        dropout: float = 0.4,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.n_tokens = n_tokens
        self.hidden_dim = hidden_dim

        self.pe = PositionalEncoding(d_embed=embed_dim, max_L=n_tokens)
        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.layers = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm((n_tokens, hidden_dim)),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm((n_tokens, hidden_dim)),
            nn.Linear(hidden_dim, vocab_size),
            nn.ReLU(),
        )

    def forward(self, x: T["b", "max_L"]):  # noqa: F821
        xe = self.emb(x)
        pe = self.pe(x)
        return self.layers(xe + pe)

# Training loop

In [17]:
embed_dim = 8
vocab_size = len(t.stoi)
n_tokens = t.max_len
hidden_dim = 10

model = Model(
    embed_dim=embed_dim,
    vocab_size=vocab_size,
    n_tokens=n_tokens,
    hidden_dim=hidden_dim,
)

In [18]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 128
num_epochs = 2000
optimizer = AdamW(model.parameters(), lr=0.2)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)
masker = RandomMasker(t)
p_masker = 0.1

for epoch in tqdm(range(num_epochs)):
    # get into training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        # initialize proper `x`
        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        y = X_train[ix].long()  # true values
        x = masker.add_mask(y, p=p_masker)  # noisy values

        # predict demasked tokens
        y_pred = model(x).swapaxes(-1, -2)
        loss = F.cross_entropy(y_pred, y)

        # do backprop
        losses.append(loss.item())
        loss.backward()
    optimizer.step()
    scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        y_test = X_test.long()
        x_test = masker.add_mask(y_test, p=p_masker)

        y_pred_test = model(x_test).swapaxes(-1, -2)
        test_loss = F.cross_entropy(y_pred_test, y_test)

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/2000 [00:00<?, ?it/s]

epoch=0, train_loss=4.0691, val_loss=2.7928
epoch=1, train_loss=3.1622, val_loss=2.1509
epoch=2, train_loss=2.4797, val_loss=1.9851
epoch=3, train_loss=2.1803, val_loss=1.8067
epoch=4, train_loss=2.0158, val_loss=1.6567
epoch=5, train_loss=1.9222, val_loss=1.5306
epoch=6, train_loss=1.8537, val_loss=1.3797
epoch=7, train_loss=1.7476, val_loss=1.2903
epoch=8, train_loss=1.7125, val_loss=1.2374
epoch=9, train_loss=1.6953, val_loss=1.1706
epoch=10, train_loss=1.6467, val_loss=1.1181
epoch=11, train_loss=1.5962, val_loss=1.0918
epoch=12, train_loss=1.5610, val_loss=1.0572
epoch=13, train_loss=1.5405, val_loss=1.0340
epoch=14, train_loss=1.5315, val_loss=1.0068
epoch=15, train_loss=1.4983, val_loss=0.9708
epoch=16, train_loss=1.4776, val_loss=0.9305
epoch=17, train_loss=1.4544, val_loss=0.9159
epoch=18, train_loss=1.4350, val_loss=0.8845
epoch=19, train_loss=1.4147, val_loss=0.8544
epoch=20, train_loss=1.3970, val_loss=0.8396
epoch=21, train_loss=1.3862, val_loss=0.8187
epoch=22, train_loss

KeyboardInterrupt: 

In [19]:
model.eval()

Model(
  (pe): PositionalEncoding()
  (emb): Embedding(56, 8)
  (layers): Sequential(
    (0): Linear(in_features=8, out_features=10, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): LayerNorm((16, 10), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.4, inplace=False)
    (7): LayerNorm((16, 10), eps=1e-05, elementwise_affine=True)
    (8): Linear(in_features=10, out_features=56, bias=True)
    (9): ReLU()
  )
)

In [None]:
t.decode(X_val[:10])

['Seguie',
 'Edgewood',
 'Helvecia',
 'Karangsadang',
 'Martinsheim',
 'Goulds',
 'Ballabio',
 'Xiaozhoushan',
 'Zloczew',
 'Pilchaca']

In [22]:
masked_X_val = masker.add_mask(X_val, p=0.05)
t.decode(masked_X_val[:10])

['Seguie',
 'Edgewood',
 'Helvecia',
 'Karangsadang',
 'Mart#nsheim',
 'Goulds',
 '#allab#o',
 'Xiaozhous#an',
 'Zl#czew',
 'Pilchaca']

In [25]:
t.decode(model(masked_X_val.long()).argmax(axis=-1))[:10]

['Sogiie',
 'Sdgesool',
 'Coliogia',
 'Sarang#aang',
 'Lart#n##o',
 'Coilds',
 'Callai#o',
 'T#aoz#ou>an',
 'So#gzos',
 'T#lgiaga']