In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import GeonamesDataset
import polars as pl

In [3]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [4]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Peaugres""","""PPL""","""FR""",1785
"""Glenshaw""","""PPL""","""US""",8981
"""Harahan""","""PPL""","""US""",9350
"""Oberursel""","""PPL""","""DE""",46678
"""Kannod""","""PPL""","""IN""",15870
"""Issiglio""","""PPLA3""","""IT""",337
"""Laucha""","""PPL""","""DE""",2517
"""Xiadong""","""PPLA4""","""CN""",0
"""Bangunmulyo""","""PPLA4""","""ID""",0
"""Begbessou""","""PPL""","""CI""",6463


In [5]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [6]:
from utils import Tokenizer

t = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [7]:
X = t.encode(df)

In [8]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [9]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [10]:
for i, batch_ in enumerate(DataLoader(train, batch_size=64)):
    break

# Model initialization

In [11]:
max_T = 1000
d_embed = 128
token_dim = n_tokens = len(t.stoi)
d_kq = 5
d_postembed = 32

In [12]:
from diffusion import ForwardDiffusionProcess, Scheduler

from model import TokenDenoiser

scheduler = Scheduler(T=max_T)
diffusion = ForwardDiffusionProcess(scheduler=scheduler)
model = TokenDenoiser(
    max_T=max_T,
    max_L=token_dim,
    d_embed=d_embed,
    d_kq=d_kq,
    d_hidden=64,
    n_blocks=1,
)

In [13]:
sum((p.numel() for p in model.parameters()))

13175

# Model training

In [14]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 1024
num_epochs = 10
optimizer = AdamW(model.parameters(), lr=0.1)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)

for epoch in tqdm(range(num_epochs)):
    # get into training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        # initialize proper `x`
        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()

        # encode into probas
        x_0 = F.one_hot(batch, num_classes=n_tokens).float()
        time = torch.randint(low=0, high=max_T, size=(x_0.shape[0],))

        x_t = diffusion.sample_T(x_0, time)
        x_0_pred = model(batch, time)

        loss = F.cross_entropy(x_0_pred, x_0)
        losses.append(loss.item())
        loss.backward()

    optimizer.step()
    scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        time = torch.randint(0, max_T, size=(X_test.shape[0],))
        batch_test = X_test.long()

        x_0_test = F.one_hot(batch_test, num_classes=n_tokens).float()
        x_t_test = diffusion.sample_T(x_0_test, time)
        x_0_test_pred = model(batch_test, time)
        test_loss = F.cross_entropy(x_0_test_pred, x_0_test)

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/10 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
ls

In [20]:
for param in model.parameters():
    print(param.grad.sum())

Parameter containing:
tensor([[ 5.0730e-01, -1.7665e-01,  1.0395e+00,  7.0814e-01,  6.5620e-01],
        [ 7.7636e-01,  7.6715e-01,  1.2322e+00, -4.3079e-01,  1.5114e+00],
        [ 2.6183e-01,  5.2399e-01, -3.5735e-03,  1.8499e-01,  6.7456e-01],
        [ 6.7750e-01,  1.0152e-01,  9.9663e-01,  4.1857e-01, -5.7225e-02],
        [-6.1855e-02,  1.4560e-01,  4.3867e-01,  2.9079e-01,  3.2104e-01],
        [ 7.7865e-01,  7.9962e-01,  1.3476e+00, -3.8759e-01,  1.2185e+00],
        [ 4.5905e-01, -1.0935e-01,  9.3626e-02,  4.8394e-01, -1.7012e-01],
        [ 1.4235e+00,  3.3231e-01,  1.3097e+00, -6.5856e-04,  1.5425e+00],
        [ 5.9915e-01,  1.0745e+00,  7.3774e-01, -7.7221e-02,  1.1816e+00],
        [ 1.3653e+00,  7.1614e-01,  4.7300e-01,  2.3979e-01,  5.5055e-01],
        [-5.4933e-02, -4.6900e-01, -3.9010e-01,  5.4395e-01,  1.1521e-01],
        [ 3.6646e-01,  6.2784e-01,  5.7957e-01, -1.3852e-01,  7.4064e-01],
        [ 8.8656e-01, -2.6590e-02,  1.2515e-01,  9.4088e-01,  1.7703e-01],
   