References:
* fastai [nb 27](https://github.com/fastai/course22p2/blob/master/nbs/27_attention.ipynb)

Cross attention:
- https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html
- https://arxiv.org/abs/2112.10752 -> inserting class / text embedding into K and V of attention

ResBlock probably needs a cross-attention call that takes the label (K,V) and the processed image (Q) and returns something of the shape of the processed image

paper: U-Net: Convolutional Networks for Biomedical Image Segmentation

unet data: https://forum.image.sc/t/isbi-2012-site-down/57867

minbpe
* https://www.youtube.com/watch?v=zduSFxRajkE

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.text as rnnm_text
import random_neural_net_models.tokenization as rnnm_tok
import random_neural_net_models.transformer as rnnm_trans
import random_neural_net_models.utils as utils

In [None]:
utils.make_deterministic(42)

In [None]:
device = utils.get_device()
device

tom lehrer's songs: https://tomlehrersongs.com/

In [None]:
path = Path("../data/tom-lehrer")

In [None]:
files = rnnm_text.find_files(path, "*.txt")
files

In [None]:
body_for_tokenizer = rnnm_text.concat_files(files, "\n")
body_for_tokenizer

In [None]:
vocab_size = 200
tokenizer = rnnm_tok.TokenizerRegex()
tokenizer.fit(
    body_for_tokenizer,
    vocab_size=vocab_size,
    pattern=rnnm_tok.GPT4_SPLIT_PATTERN,
)

In [None]:
special_token2id_map = {
    "<|endoftext|>": 100257,
    "<|fim_prefix|>": 100258,
    "<|fim_middle|>": 100259,
    "<|fim_suffix|>": 100260,
    "<|endofprompt|>": 100276,
}
tokenizer.register_special_tokens(special_token2id_map)

In [None]:
block_size = 128
ds_train = rnnm_text.TextDataset(
    path=path,
    suffix="*.txt",
    tokenizer=tokenizer,
    block_size=block_size,
    end_of_text_token="<|endoftext|>",
)

In [None]:
ds_train[0]

In [None]:
# from torch.utils.data import Dataset, RandomSampler
# RandomSampler(
#                 self.train_dataset,
#                 replacement=True,
#                 num_samples=int(1e10),
#                 generator=torch.manual_seed(3407),
#             )

In [None]:
bs_train = 10
dl_train = DataLoader(
    ds_train,
    batch_size=bs_train,
    collate_fn=rnnm_text.collate_text_dataset_to_block,
)

In [None]:
next(iter(dl_train))

In [None]:
num_blocks = 2
emb_dim = 10
n_tokens = block_size
latent_dim = 40
num_heads = 4

model = rnnm_trans.LanguageModelWithTensordict(
    vocab_size=ds_train.vocab_size,
    emb_dim=emb_dim,
    n_tokens=n_tokens,
    latent_dim=latent_dim,
    num_heads=num_heads,
    num_blocks=num_blocks,
)
# model = rnnm_trans.EncoderWithTensordict(
#     num_blocks=num_blocks,
#     enc_emb_dim=enc_emb_dim,
#     enc_n_tokens=enc_n_tokens,
#     latent_dim=latent_dim,
#     num_heads=num_heads,
#     causal=True,
#     vocab_size=len(tokenizer.vocab)
# )

In [None]:
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = rnnm_trans.CrossEntropyLoss()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path("./models")

callbacks = [loss_callback]

In [None]:
learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner.find_learning_rate(
    dl_train, n_epochs=2, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log")

In [None]:
learning_rate = 3e-2
n_epochs = 5

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=learning_rate,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs)  # , dataloader_valid=dl_val

In [None]:
loss_callback.plot()

In [None]:
inp = next(iter(dl_train))
inp

In [None]:
out_ids_dense = model.generate(inp.to(device), max_new_tokens=20)
out_ids_dense

In [None]:
ds_train.dense_ids_to_strings(out_ids_dense.cpu())