In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import GeonamesDataset
import polars as pl

In [3]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [4]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Mislesevo""","""PPL""","""MK""",1988
"""Paracin""","""PPLA3""","""RS""",6000
"""Mekarsari""","""PPLA4""","""ID""",0
"""Waterville""","""PPL""","""US""",651
"""Kistler""","""PPL""","""US""",528
"""Suzemka""","""PPL""","""RU""",9394
"""Mladenovac""","""PPLA3""","""RS""",0
"""Dalizi""","""PPLA4""","""CN""",0
"""Krasnany""","""PPL""","""SK""",1231
"""Pomona""","""PPL""","""US""",153266


In [5]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [6]:
from utils import Tokenizer

t = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [7]:
X = t.encode(df)

In [8]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [9]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [10]:
for i, batch in enumerate(DataLoader(train, batch_size=64)):
    break

In [11]:
x = batch[0].float()

In [12]:
from torch import nn

In [13]:
import math


class PositionalEncoding(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_output: int,
        dropout: float = 0.1,
        max_len: int = 5000,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_output = d_output
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = self.pe[t].squeeze(-1)
        return self.dropout(x)

In [14]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


class MLPWithTime(nn.Module):
    def __init__(
        self,
        token_embed: nn.Module,
        time_embed: nn.Module,
        hidden: nn.Module,
        output: nn.Module,
        alphabet_size: int,
        n_tokens: int,
    ):
        super().__init__()
        self.token_embed: nn.Module = token_embed
        self.time_embed: nn.Module = time_embed
        self.hidden: nn.Module = hidden
        self.output: nn.Module = output
        self.alphabet_size = alphabet_size
        self.n_tokens = n_tokens
        self.apply(init_weights)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        te = self.token_embed(x)
        pe = self.time_embed(t)
        he = self.hidden(te + pe)
        output = self.output(he)
        return output.reshape(
            -1,
            self.alphabet_size,
            self.n_tokens,
        )

In [15]:
token_dim = 16
word_embeding_dim = time_embeding_dim = 256
intermediate_dim = 128
hidden_dim = 64
dropout = 0.3
max_T = 1000
alphabet_size = len(t.stoi)

Input = nn.Sequential(
    nn.Linear(token_dim, word_embeding_dim),
    nn.BatchNorm1d(word_embeding_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
)
TimeInput = PositionalEncoding(
    d_model=1,
    d_output=time_embeding_dim,
    max_len=max_T,
)
Hidden = nn.Sequential(
    nn.Linear(word_embeding_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(intermediate_dim, hidden_dim),
    nn.BatchNorm1d(hidden_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(hidden_dim, hidden_dim),
    nn.BatchNorm1d(hidden_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(hidden_dim, hidden_dim),
    nn.BatchNorm1d(hidden_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(hidden_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(intermediate_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
    nn.Linear(intermediate_dim, intermediate_dim),
    nn.BatchNorm1d(intermediate_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
    # ----
)
Output = nn.Sequential(
    nn.Linear(intermediate_dim, alphabet_size * token_dim),
    nn.BatchNorm1d(alphabet_size * token_dim),
    nn.ReLU(),
    nn.Dropout(dropout),
)

model = MLPWithTime(
    time_embed=TimeInput,
    token_embed=Input,
    hidden=Hidden,
    output=Output,
    alphabet_size=alphabet_size,
    n_tokens=token_dim,
)

In [16]:
time = torch.randint(0, max_T, size=(x.shape[0],))

In [17]:
x.shape

torch.Size([64, 16])

In [18]:
x.shape

torch.Size([64, 16])

In [19]:
with torch.no_grad():
    print(model(x, time).shape)

torch.Size([64, 55, 16])


In [20]:
sum(p.numel() for p in model.parameters())

212368

# Training loop

In [21]:
from diffusion import ForwardDiffusionProcess, Scheduler

In [22]:
scheduler = Scheduler(T=max_T)
diffusion = ForwardDiffusionProcess(scheduler=scheduler)

In [24]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 4096
# num_epochs = 2000
# scheduler = lr_scheduler.CosineAnnealingLR(
# optimizer,
# T_max=40,
# )

for lr in torch.logspace(-4, 1, 100):
    optimizer = AdamW(model.parameters(), lr=lr.item())
    # training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()
        time = torch.randint(low=0, high=max_T, size=(batch.shape[0],))

        x_0 = batch.float()
        x_t = diffusion.sample_T(x_0, time)
        x_0_pred = model(x_t, time)

        loss = F.cross_entropy(x_0_pred, batch)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        time = torch.randint(0, max_T, size=(X_test.shape[0],))
        x_0_test = X_test.float()
        x_t_test = diffusion.sample_T(x_0_test, time)
        x_0_test_pred = model(x_t_test, time)
        test_loss = F.cross_entropy(x_0_test_pred, X_test.long())

        print(
            f"{lr=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

KeyboardInterrupt: 

In [24]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 4096
num_epochs = 2000
optimizer = AdamW(model.parameters(), lr=0.1)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)


for epoch in tqdm(range(num_epochs)):
    # training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        batch = X_train[ix].long()
        time = torch.randint(low=0, high=max_T, size=(batch.shape[0],))

        x_0 = batch.float()
        x_t = diffusion.sample_T(x_0, time)
        x_0_pred = model(x_t, time)

        print(f'{x_0_pred.shape=}')
        print(f'{batch.shape=}')
        loss = F.cross_entropy(x_0_pred, batch)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        time = torch.randint(0, max_T, size=(X_test.shape[0],))
        x_0_test = X_test.float()
        x_t_test = diffusion.sample_T(x_0_test, time)
        x_0_test_pred = model(x_t_test, time)
        print(f'{x_0_test_pred.shape=}')
        print(f'{X_test.shape=}')
        test_loss = F.cross_entropy(x_0_test_pred, X_test.long())

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/2000 [00:00<?, ?it/s]

x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])
x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])
x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])
x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])
x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])
x_0_pred.shape=torch.Size([4096, 55, 16])
batch.shape=torch.Size([4096, 16])


KeyboardInterrupt: 

In [255]:
from diffusion import ForwardDiffusionProcess
scheduler = Scheduler(T=max_T)
diffusion = ForwardDiffusionProcess(scheduler=scheduler)

In [261]:
denoised = diffusion.sample(model=model, n=10, max_T=max_T)

In [262]:
t.decode_raw(denoised.argmax(1))

['<Soelhreein><>><',
 '<Saneeeaaea>>...',
 '<Sanaeeaaea>>...',
 '<Woelhreein><>><',
 '<Soelhreein><>><',
 '<Soelhreein><>><',
 '<Soelhreein><>><',
 '<Saneeeaaea>>...',
 '<Saneeeaaea>>...',
 '<Sanaenaae>>....']