In [4]:
%load_ext autoreload
%autoreload 2

In [15]:
from datasets import GeonamesDataset
import polars as pl

In [16]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [17]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Petersburg""","""PPLA2""","""US""",2520
"""Breitenhagen""","""PPL""","""DE""",542
"""Brading""","""PPL""","""GB""",1800
"""Casalnoceto""","""PPLA3""","""IT""",777
"""Surbourg""","""PPL""","""FR""",1623
"""Chimoio""","""PPLA""","""MZ""",422046
"""Sobolivka""","""PPL""","""UA""",3752
"""Remenham""","""PPLA3""","""GB""",0
"""Taketa""","""PPLA2""","""JP""",0
"""Piasek""","""PPL""","""PL""",3342


In [18]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [19]:
from utils import Tokenizer

t = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [20]:
X = t.encode(df)

In [21]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [22]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [23]:
for i, batch_ in enumerate(DataLoader(train, batch_size=64)):
    break

# Model initialization

In [36]:
max_T = 1000
d_embed = 128
token_dim = n_tokens = len(t.stoi)
d_kq = 5
d_postembed = 32

In [53]:
from diffusion import ForwardDiffusionProcess, Scheduler

from model import TokenDenoiser

scheduler = Scheduler(T=max_T)
diffusion = ForwardDiffusionProcess(scheduler=scheduler)
model = TokenDenoiser(
    max_T=max_T,
    d_embed=d_embed,
    n_tokens=token_dim,
    d_postembed=d_postembed,
    d_kq=d_kq,
    d_hidden=64,
    n_blocks=1,
)

In [54]:
sum((p.numel() for p in model.parameters()))

7767

# Model training

In [56]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 4096
num_epochs = 2000
optimizer = AdamW(model.parameters(), lr=0.1)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)

for epoch in tqdm(range(num_epochs)):
    # get into training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        # initialize proper `x`
        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()

        # encode into probas
        x_0 = F.one_hot(batch, num_classes=n_tokens).float()
        time = torch.randint(low=0, high=max_T, size=(x_0.shape[0],))

        x_t = diffusion.sample_T(x_0, time)
        x_0_pred = model(batch, time)

        loss = F.cross_entropy(x_0_pred, x_0)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        time = torch.randint(0, max_T, size=(X_test.shape[0],))
        x_0_test = X_test.float()
        x_t_test = diffusion.sample_T(x_0_test, time)
        x_0_test_pred = model(x_t_test, time)
        print(f"{x_0_test_pred.shape=}")
        print(f"{X_test.shape=}")
        test_loss = F.cross_entropy(x_0_test_pred, X_test.long())

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/2000 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)