In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import GeonamesDataset
import polars as pl

In [3]:
geonames = GeonamesDataset("./data/cities500.txt.gz", max_len=14)

In [4]:
geonames.df.sample(10)

sequence,feature code,country code,population
str,str,str,i64
"""Kouwovogo""","""PPL""","""CI""",1336
"""Pomerol""","""PPL""","""FR""",928
"""Necaxa""","""PPL""","""MX""",8375
"""Scinawa""","""PPL""","""PL""",5863
"""Ospedalicchio""","""PPL""","""IT""",1405
"""Crocevie""","""PPL""","""IT""",787
"""Smimou""","""PPLA3""","""MA""",3505
"""Rozerieulles""","""PPL""","""FR""",1360
"""Vallarga""","""PPL""","""IT""",732
"""Colipapa""","""PPL""","""PH""",7680


In [5]:
df = geonames.df
alphabet = "".join(
    set("".join(df.get_column("sequence").str.split("").explode().to_list()))
)

In [6]:
from utils import Tokenizer

t = Tokenizer(
    alphabet=alphabet,
    max_len=16,
)

In [7]:
X = t.encode(df)

In [8]:
import torch

total_samples = X.size(0)

# Define the proportions for train, test, and validation sets
train_ratio = 0.8
test_ratio = 0.1
val_ratio = 0.1

# Calculate the number of samples for each set
num_train = int(total_samples * train_ratio)
num_test = int(total_samples * test_ratio)
num_val = total_samples - num_train - num_test

# Generate random indices
indices = torch.randperm(total_samples)

# Split the indices into train, test, and validation sets
train_indices = indices[:num_train]
test_indices = indices[num_train : num_train + num_test]
val_indices = indices[num_train + num_test :]

# Create the train, test, and validation sets
X_train = X[train_indices]
X_test = X[test_indices]
X_val = X[val_indices]

print("Train set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("Validation set size:", X_val.shape)

Train set size: torch.Size([101200, 16])
Test set size: torch.Size([12650, 16])
Validation set size: torch.Size([12651, 16])


In [9]:
from torch.utils.data import TensorDataset, DataLoader

train = TensorDataset(X_train)
test = TensorDataset(X_test)
val = TensorDataset(X_val)

In [10]:
for i, batch_ in enumerate(DataLoader(train, batch_size=64)):
    break

# Model initialization

Let's train a model that would predict a single missing letter in the word.

In [137]:
from torchtyping import TensorType

T = TensorType


class RandomMasker:
    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer
        self.start_token: int = tokenizer.stoi[tokenizer.start_token]
        self.end_token: int = tokenizer.stoi[tokenizer.end_token]
        self.pad_token: int = tokenizer.stoi[tokenizer.pad_token]
        self.mask_token: int = tokenizer.stoi[tokenizer.mask_token]

    def add_mask(self, x: T["b", "max_L", torch.long], p: float = 0.1):  # noqa: F821
        where = (x != self.start_token) & (x != self.end_token) & (x != self.pad_token)
        mask = (torch.randint_like(where.long(), low=0, high=100) < p * 100) & where
        x[mask] = self.mask_token
        return x

In [179]:
from torch import nn
from model import PositionalEncoding


class Model(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        vocab_size: int,
        n_tokens: int,
        hidden_dim: int,
        dropout: float = 0.4,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.n_tokens = n_tokens
        self.hidden_dim = hidden_dim

        self.pe = PositionalEncoding(d_embed=embed_dim, max_L=n_tokens)
        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.layers = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm((n_tokens, hidden_dim)),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm((n_tokens, hidden_dim)),
            nn.Linear(hidden_dim, vocab_size),
            nn.ReLU(),
        )

    def forward(self, x: T["b", "max_L"]):  # noqa: F821
        xe = self.emb(x)
        pe = self.pe(x)
        return self.layers(xe + pe)

In [180]:
embed_dim = 10
vocab_size = 3
n_tokens = 20
batch_size = 5
hidden_dim = 4

pe = PositionalEncoding(d_embed=embed_dim, max_L=n_tokens)
emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

x = torch.randint(low=0, high=vocab_size - 1, size=(batch_size, n_tokens))

xe = emb(x)
xpe = pe(x)

In [181]:
l1 = nn.Linear(embed_dim, hidden_dim)
l2 = nn.Linear(hidden_dim, n_tokens)
l2(l1(xe + xpe)).shape

torch.Size([5, 20, 20])

In [182]:
model = Model(
    embed_dim=embed_dim,
    vocab_size=vocab_size,
    n_tokens=n_tokens,
    hidden_dim=hidden_dim,
)

In [183]:
x.shape

torch.Size([5, 20])

In [184]:
model(x).shape

torch.Size([5, 20, 3])

# Training loop

In [185]:
embed_dim = 8
vocab_size = len(t.stoi)
n_tokens = t.max_len
hidden_dim = 10

model = Model(
    embed_dim=embed_dim,
    vocab_size=vocab_size,
    n_tokens=n_tokens,
    hidden_dim=hidden_dim,
)

In [188]:
x.shape

torch.Size([128, 16])

In [189]:
model(x).shape

torch.Size([128, 16, 56])

In [192]:
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm.notebook import tqdm
import torch.optim.lr_scheduler as lr_scheduler

batch_size = 128
num_epochs = 2000
optimizer = AdamW(model.parameters(), lr=0.1)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.99,
)
masker = RandomMasker(t)
p_masker = 0.1

for epoch in tqdm(range(num_epochs)):
    # get into training mode
    model.train()
    losses = []
    for batch_num in range(X_train.shape[0] // batch_size):
        optimizer.zero_grad()

        # initialize proper `x`
        ix = torch.randint(0, X_train.shape[0], size=(batch_size,))
        y = X_train[ix].long()  # true values
        x = masker.add_mask(y, p=p_masker)  # noisy values

        # predict demasked tokens
        y_pred = model(x).swapaxes(-1, -2)
        loss = F.cross_entropy(y_pred, y)

        # do backprop
        losses.append(loss.item())
        loss.backward()
    optimizer.step()
    scheduler.step()

    # validation mode
    model.eval()
    with torch.no_grad():
        y_test = X_test.long()
        x_test = masker.add_mask(y_test, p=p_masker)

        y_pred_test = model(x_test).swapaxes(-1, -2)
        test_loss = F.cross_entropy(y_pred_test, y_test)

        print(
            f"{epoch=}, train_loss={torch.tensor(losses).mean().item():2.4f}, val_loss={test_loss.item():2.4f}"
        )

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/790 [00:00<?, ?it/s]

epoch=0, train_loss=3.4501, val_loss=2.4458


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=1, train_loss=2.8620, val_loss=2.1035


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=2, train_loss=2.4599, val_loss=1.8639


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=3, train_loss=2.1880, val_loss=1.6424


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=4, train_loss=1.9821, val_loss=1.5020


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=5, train_loss=1.8487, val_loss=1.4124


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=6, train_loss=1.7677, val_loss=1.3332


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=7, train_loss=1.7202, val_loss=1.2671


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=8, train_loss=1.6713, val_loss=1.2073


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=9, train_loss=1.6162, val_loss=1.1531


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=10, train_loss=1.5624, val_loss=1.1177


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=11, train_loss=1.5330, val_loss=1.0837


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=12, train_loss=1.5071, val_loss=1.0415


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=13, train_loss=1.4838, val_loss=1.0060


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=14, train_loss=1.4679, val_loss=0.9668


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=15, train_loss=1.4407, val_loss=0.9410


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=16, train_loss=1.4181, val_loss=0.9167


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=17, train_loss=1.3989, val_loss=0.8965


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=18, train_loss=1.3772, val_loss=0.8725


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=19, train_loss=1.3613, val_loss=0.8537


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=20, train_loss=1.3484, val_loss=0.8329


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=21, train_loss=1.3379, val_loss=0.8118


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=22, train_loss=1.3278, val_loss=0.8012


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=23, train_loss=1.3163, val_loss=0.7897


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=24, train_loss=1.3006, val_loss=0.7804


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=25, train_loss=1.2895, val_loss=0.7681


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=26, train_loss=1.2773, val_loss=0.7510


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=27, train_loss=1.2597, val_loss=0.7376


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=28, train_loss=1.2491, val_loss=0.7224


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=29, train_loss=1.2365, val_loss=0.7084


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=30, train_loss=1.2219, val_loss=0.6940


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=31, train_loss=1.2136, val_loss=0.6772


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=32, train_loss=1.2014, val_loss=0.6623


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=33, train_loss=1.1911, val_loss=0.6488


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=34, train_loss=1.1841, val_loss=0.6360


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=35, train_loss=1.1759, val_loss=0.6259


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=36, train_loss=1.1676, val_loss=0.6162


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=37, train_loss=1.1650, val_loss=0.6167


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=38, train_loss=1.1598, val_loss=0.6194


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=39, train_loss=1.1571, val_loss=0.6133


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=40, train_loss=1.1528, val_loss=0.6024


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=41, train_loss=1.1503, val_loss=0.5900


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=42, train_loss=1.1469, val_loss=0.5832


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=43, train_loss=1.1439, val_loss=0.5737


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=44, train_loss=1.1398, val_loss=0.5637


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=45, train_loss=1.1321, val_loss=0.5531


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=46, train_loss=1.1252, val_loss=0.5488


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=47, train_loss=1.1201, val_loss=0.5529


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=48, train_loss=1.1142, val_loss=0.5569


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=49, train_loss=1.1147, val_loss=0.5560


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=50, train_loss=1.1075, val_loss=0.5525


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=51, train_loss=1.1024, val_loss=0.5457


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=52, train_loss=1.1021, val_loss=0.5315


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=53, train_loss=1.0958, val_loss=0.5218


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=54, train_loss=1.0955, val_loss=0.5192


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=55, train_loss=1.0980, val_loss=0.5179


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=56, train_loss=1.0948, val_loss=0.5138


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=57, train_loss=1.0890, val_loss=0.5148


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=58, train_loss=1.0868, val_loss=0.5181


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=59, train_loss=1.0824, val_loss=0.5156


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=60, train_loss=1.0806, val_loss=0.5133


  0%|          | 0/790 [00:00<?, ?it/s]

epoch=61, train_loss=1.0759, val_loss=0.5124


  0%|          | 0/790 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [197]:
model.eval()

Model(
  (pe): PositionalEncoding()
  (emb): Embedding(56, 8)
  (layers): Sequential(
    (0): Linear(in_features=8, out_features=10, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): LayerNorm((16, 10), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.4, inplace=False)
    (7): LayerNorm((16, 10), eps=1e-05, elementwise_affine=True)
    (8): Linear(in_features=10, out_features=56, bias=True)
    (9): ReLU()
  )
)

In [200]:
t.decode(X_val)[:10]

['Rentung',
 'Coronella',
 'Joplin',
 'Pluvigner',
 'Griselles',
 'Crispiano',
 'Thueringerberg',
 'Oratino',
 'Carosino',
 'Zhennan']

In [202]:
masked_X_val = masker.add_mask(X_val)
t.decode(masked_X_val)[:10]

['##ntung',
 'Co#one#l#',
 'Jop#in',
 '#luvig#er',
 'Gr#s#lles',
 'Cr#s#i##o',
 'Thuering#r#erg',
 'Oratino',
 'Carosin#',
 '#he#nan']

In [212]:
t.decode(model(masked_X_val.long()).argmax(axis=-1))[:10]

['##naung',
 '#o#one#l#',
 'Lop#in',
 '#luvig#er',
 'Gu#s#lles',
 '#u#s#i##o',
 '#uuerin>#r#er',
 '#uaaino',
 '#arosin#',
 '#ue#nan']