In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import GeonamesDataset
import polars as pl

In [3]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [4]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Yamunanagar""","""PPL""","""IN""",208931
"""Vredefort""","""PPL""","""ZA""",14619
"""Hilbesheim""","""PPL""","""FR""",529
"""Mapastepec""","""PPL""","""MX""",17931
"""Holmsund""","""PPL""","""SE""",5962
"""Strassoldo""","""PPL""","""IT""",707
"""Tumannyy""","""PPL""","""RU""",896
"""Ticaco""","""PPLA3""","""PE""",741
"""Gjergjan""","""PPLA3""","""AL""",0
"""Collesalvetti""","""PPLA3""","""IT""",3530


In [5]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [6]:
from utils import Tokenizer

t = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [7]:
X = t.encode(df)

In [8]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [9]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [10]:
for i, batch in enumerate(DataLoader(train, batch_size=64)):
    break

In [11]:
x = batch[0].float()

In [12]:
from torch import nn

In [32]:
import math


class PositionalEncoding(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_output: int,
        dropout: float = 0.1,
        max_len: int = 5000,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_output = d_output
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = self.pe[t].squeeze(-1)
        return self.dropout(x)

In [43]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


class MLPWithTime(nn.Module):
    def __init__(
        self,
        token_embed: nn.Module,
        time_embed: nn.Module,
        hidden: nn.Module,
        output: nn.Module,
        alphabet_size: int,
        n_tokens: int,
    ):
        super().__init__()
        self.token_embed: nn.Module = token_embed
        self.time_embed: nn.Module = time_embed
        self.hidden: nn.Module = hidden
        self.output: nn.Module = output
        self.alphabet_size = alphabet_size
        self.n_tokens = n_tokens
        self.apply(init_weights)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        te = self.token_embed(x)
        pe = self.time_embed(t)
        he = self.hidden(te + pe)
        output = self.output(he)
        return output.reshape(
            -1,
            self.alphabet_size,
            self.n_tokens,
        )

In [53]:
token_dim = 16
word_embeding_dim = time_embeding_dim = 128
intermediate_dim = 64
dropout = 0.4
max_T = 100
alphabet_size = len(t.stoi)

Input = nn.Sequential(
    nn.Linear(token_dim, word_embeding_dim),
    nn.BatchNorm1d(word_embeding_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
)
TimeInput = PositionalEncoding(
    d_model=1,
    d_output=time_embeding_dim,
    max_len=max_T,
)
Hidden = nn.Sequential(
    nn.Linear(word_embeding_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(intermediate_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(intermediate_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
)
Output = nn.Sequential(
    nn.Linear(intermediate_dim, alphabet_size * token_dim),
    nn.BatchNorm1d(alphabet_size * token_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
)

model = MLPWithTime(
    time_embed=TimeInput,
    token_embed=Input,
    hidden=Hidden,
    output=Output,
    alphabet_size=alphabet_size,
    n_tokens=token_dim,
)

In [54]:
time = torch.randint(0, max_T, size=(x.shape[0],))

In [55]:
x.shape

torch.Size([64, 16])

In [56]:
x.shape

torch.Size([64, 16])

In [57]:
with torch.no_grad():
    print(model(x, time).shape)

torch.Size([64, 55, 16])


In [58]:
sum(p.numel() for p in model.parameters())

78352

# Training loop

In [59]:
from diffusion import ForwardDiffusionProcess, Scheduler

In [60]:
scheduler = Scheduler(T=max_T)
diffusion = ForwardDiffusionProcess(scheduler=scheduler)

In [61]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm

batch_size = 512
num_epochs = 100
optimizer = AdamW(model.parameters(), lr=1)

for epoch in tqdm(range(num_epochs)):
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()
        time = torch.randint(low=0, high=max_T, size=(batch.shape[0],))

        x_0 = batch.float()
        x_t = diffusion.sample_T(x_0, time)
        x_0_pred = model(x_t, time)

        loss = F.cross_entropy(x_0_pred, batch)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    print(f"{epoch=}, {torch.tensor(losses).mean().item()=}")

  0%|          | 0/100 [00:00<?, ?it/s]

epoch=0, torch.tensor(losses).mean()=tensor(2.8706)
epoch=1, torch.tensor(losses).mean()=tensor(2.7818)
epoch=2, torch.tensor(losses).mean()=tensor(2.7762)
epoch=3, torch.tensor(losses).mean()=tensor(2.7811)
epoch=4, torch.tensor(losses).mean()=tensor(2.7800)
epoch=5, torch.tensor(losses).mean()=tensor(2.7795)
epoch=6, torch.tensor(losses).mean()=tensor(2.7802)
epoch=7, torch.tensor(losses).mean()=tensor(2.7780)
epoch=8, torch.tensor(losses).mean()=tensor(2.7836)
epoch=9, torch.tensor(losses).mean()=tensor(2.7824)
epoch=10, torch.tensor(losses).mean()=tensor(2.7852)
epoch=11, torch.tensor(losses).mean()=tensor(2.7879)
epoch=12, torch.tensor(losses).mean()=tensor(2.7921)
epoch=13, torch.tensor(losses).mean()=tensor(2.7861)
epoch=14, torch.tensor(losses).mean()=tensor(2.7944)
epoch=15, torch.tensor(losses).mean()=tensor(2.7963)
epoch=16, torch.tensor(losses).mean()=tensor(2.8121)
epoch=17, torch.tensor(losses).mean()=tensor(2.8085)
epoch=18, torch.tensor(losses).mean()=tensor(2.8127)
epo