In [2]:
# Autoimport wherept.py:
%load_ext autoreload
%autoreload 1
%aimport wherept

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import pandas as pd
from torchview import draw_graph

torch.manual_seed(42)

<torch._C.Generator at 0x1061d0970>

In [5]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlage[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
run = wandb.init(project="woher", job_type="transformer-playground")

In [7]:
dataset = run.use_artifact("woher/cleaned-cities:latest").get("clean")
df_raw = dataset.get_dataframe()
df_raw.head(5)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Unnamed: 0,name,asciiname,latitude,longitude,country_code
0,Soldeu,Soldeu,42.57688,1.66769,AD
1,El Tarter,El Tarter,42.57952,1.65362,AD
2,Sant Julià de Lòria,Sant Julia de Loria,42.46372,1.49129,AD
3,Pas de la Casa,Pas de la Casa,42.54277,1.73361,AD
4,Ordino,Ordino,42.55623,1.53319,AD


In [8]:
# Dataset parameters
TARGET_COL = "asciiname"
START_CHAR = "<"
END_CHAR = ">"
PADDING_CHAR = "#"

In [9]:
df = df_raw.copy()

df[TARGET_COL] = START_CHAR + df[TARGET_COL] + END_CHAR
df["target_len"] = df[TARGET_COL].apply(len)

max_len = max([len(city) for city in df[TARGET_COL].values])
df[TARGET_COL] = df[TARGET_COL].str.pad(max_len, side="right", fillchar=PADDING_CHAR)

chars = sorted(list(set("".join(df[TARGET_COL].values))))
vocab_len = len(chars)
print("Vocabulary length:", vocab_len)
print("Vocabulary:", "".join(chars))

df.head(5)

Vocabulary length: 61
Vocabulary:  #'-.1<>ABCDEFGHIJKLMNOPQRSTUVWXYZ`abcdefghijklmnopqrstuvwxyz


Unnamed: 0,name,asciiname,latitude,longitude,country_code,target_len
0,Soldeu,<Soldeu>######################################...,42.57688,1.66769,AD,8
1,El Tarter,<El Tarter>###################################...,42.57952,1.65362,AD,11
2,Sant Julià de Lòria,<Sant Julia de Loria>#########################...,42.46372,1.49129,AD,21
3,Pas de la Casa,<Pas de la Casa>##############################...,42.54277,1.73361,AD,16
4,Ordino,<Ordino>######################################...,42.55623,1.53319,AD,8


# Tokenize

In [10]:
# Generate a mapping from character to index and vice versa
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for idx, char in enumerate(chars)}

encode = lambda x: [char_to_idx[char] for char in x]
decode = lambda x: "".join([idx_to_char[idx] for idx in x])

test_sample = df[TARGET_COL].values[0]
print("Encoded:", encode(test_sample))
print("Decoded:", decode(encode(test_sample)))

START_TOKEN = encode(START_CHAR)[0]
END_TOKEN = encode(END_CHAR)[0]
PADDING_TOKEN = encode(PADDING_CHAR)[0]

Encoded: [6, 26, 49, 46, 38, 39, 55, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Decoded: <Soldeu>##################################################


In [11]:
df["tokenized"] = df[TARGET_COL].apply(encode)
df.head(2)

Unnamed: 0,name,asciiname,latitude,longitude,country_code,target_len,tokenized
0,Soldeu,<Soldeu>######################################...,42.57688,1.66769,AD,8,"[6, 26, 49, 46, 38, 39, 55, 7, 1, 1, 1, 1, 1, ..."
1,El Tarter,<El Tarter>###################################...,42.57952,1.65362,AD,11,"[6, 12, 46, 0, 27, 35, 52, 54, 39, 52, 7, 1, 1..."


In [12]:
train_df = df.sample(frac=0.9, random_state=42)
val_df = df.drop(train_df.index)
print("Train size:", len(train_df))

Train size: 173012


In [145]:
val_df[val_df["country_code"] == "DE"].shape

(1121, 7)

In [115]:
def get_batch(split, batch_size, block_size):
    if split == "train":
        df = train_df
    elif split == "val":
        df = val_df
    x = []
    y = []
    sample_idx = torch.randint(0, len(df), (batch_size,))
    for sidx in sample_idx:
        target_len = df.iloc[int(sidx)]["target_len"]
        max_idx = vocab_len - block_size - 3
        idx = torch.randint(0, min(target_len - 1, max_idx), (1,)).int()
        
        x_tensor = torch.tensor(df.iloc[int(sidx)]["tokenized"][idx:idx+block_size])
        y_tensor = torch.tensor(df.iloc[int(sidx)]["tokenized"][idx+1:idx+block_size+1])
        x.append(x_tensor)
        y.append(y_tensor)
        
    x = torch.stack(x)
    y = torch.stack(y)
    return x, y

xb, yb = get_batch("train", 4, 8)

display(xb)
display(yb)

tensor([[45, 53, 39, 59, 39, 56, 53, 45],
        [42, 39, 43, 38,  7,  1,  1,  1],
        [41,  7,  1,  1,  1,  1,  1,  1],
        [35, 48, 41,  7,  1,  1,  1,  1]])

tensor([[53, 39, 59, 39, 56, 53, 45, 35],
        [39, 43, 38,  7,  1,  1,  1,  1],
        [ 7,  1,  1,  1,  1,  1,  1,  1],
        [48, 41,  7,  1,  1,  1,  1,  1]])

In [120]:
BATCH_SIZE = 12

mconf = wherept.WherePTConfig(
    vocab_len=vocab_len,
    n_embed=64,
    n_head=2,
    n_layer=4,
    block_size=32,
    dropout=0.1,
)

model = wherept.WherePT(mconf)

# Print number of parameters:
print(f"N Params: {sum(p.numel() for p in model.parameters()):,}")

xb, _ = get_batch("train", BATCH_SIZE, mconf.block_size)
#model_graph = draw_graph(model, input_data=xb)
#model_graph.visual_graph
model

N Params: 209,213


WherePT(
  (token_embedding): Embedding(61, 64)
  (position_embedding): Embedding(32, 64)
  (blocks): Sequential(
    (0): TransformerBlock(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-1): 2 x CausalSelfAttentionHead(
            (query): Linear(in_features=64, out_features=32, bias=False)
            (key): Linear(in_features=64, out_features=32, bias=False)
            (value): Linear(in_features=64, out_features=32, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=Tru

In [121]:
EVAL_ITERS = 10

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split, BATCH_SIZE, mconf.block_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [126]:
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [129]:
MAX_ITERS = 100
for iter in range(MAX_ITERS):
    if iter % EVAL_ITERS == 0 or iter == MAX_ITERS - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    optimizer.zero_grad()
    xb, yb = get_batch("train", BATCH_SIZE, mconf.block_size)
    logits, loss = model(xb, yb)
    loss.backward()
    optimizer.step()


step 0: train loss 3.0236, val loss 3.0024
step 10: train loss 2.9965, val loss 3.0580
step 20: train loss 3.0240, val loss 3.0356
step 30: train loss 3.0596, val loss 3.0636
step 40: train loss 3.0637, val loss 3.0410
step 50: train loss 3.0254, val loss 3.0628
step 60: train loss 3.0448, val loss 3.0730
step 70: train loss 3.0916, val loss 3.0655
step 80: train loss 3.0558, val loss 3.0189
step 90: train loss 2.9539, val loss 3.0972
step 99: train loss 3.0736, val loss 3.0096


In [133]:
idx = torch.tensor([START_TOKEN]).unsqueeze(0)
#idx = torch.tensor(encode("<Stock")).unsqueeze(0)

output = model.generate(idx, max_len)[0].tolist()
print(decode(output))

<plan>


In [139]:
examples = [decode(model.generate(torch.tensor([START_TOKEN]).unsqueeze(0), 32)[0].tolist()) for _ in range(20)]
examples = pd.DataFrame(examples, columns=["generated"])
examples

Unnamed: 0,generated
0,<pa>
1,<Hee>
2,<Vmr lniiCifnesa>
3,<VmGunaaspmulviesKrrn>
4,<u-n>
5,<>
6,<Peo>
7,<Hgaaos>
8,<BsslnswqxdM>
9,<Teaeie>


In [516]:
wandb.run.finish()

