In [None]:
import os
import torch
import torch.nn as nn
import tiktoken

In [None]:
from mugato.data.utils import create_combined_dataloader
from mugato.mugato import MugatoConfig, Mugato, TransformerConfig
from mugato.nano_gpt import GPTConfig, GPT, Block, LayerNorm
from mugato.utils import data_home, select_device, generic_collate_fn
from mugato.tokenizer import Tokenizer

In [None]:
out_dir = data_home / "out"

In [None]:
ckpt_path = os.path.join(out_dir, "ckpt.pt")

In [None]:
n_layer = 6
n_head = 4
n_embd = 512
bias = False
dropout = 0.0

block_size=768
batch_size=4
device = select_device()

In [None]:
text_tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_tokenizer)
train_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="train", block_size=block_size))
val_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="val", block_size=block_size))
test_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="test", block_size=block_size))

In [None]:
# model init
transformer_model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=50257,  # tiktoken.get_encoding("r50k_base").n_vocab
    dropout=dropout,
)  # start with model_args from command line

mugato_model_args = dict(
    n_embd=n_embd,
    block_size=block_size,
    vocab_size=51281,  # text vocab + discrete vocab
)

In [None]:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint_model_args = checkpoint["model_args"]
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]:
    transformer_model_args[k] = checkpoint_model_args[k]
# create the model
transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)
mugato_config = MugatoConfig(**mugato_model_args)
model = Mugato(tokenizer, transformer, mugato_config)
state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

In [None]:
model.to(device);

In [None]:
from collections import OrderedDict

In [None]:
model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [None]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, ys]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
logits, loss = model(xs, pad=False)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
print(text)

In [None]:
from mugato.data.utils import create_combined_dataloader
from tqdm import tqdm

In [None]:
text_tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_tokenizer)
train_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="train", block_size=block_size))
val_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="val", block_size=block_size))
test_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="test", block_size=block_size))

In [None]:
def get_batch(split, device):
    if split == "train":
        X, Y, M = next(next(train_dataloader))
    elif split == "val":
        X, Y, M = next(next(val_dataloader))
    elif split == "test":
        X, Y, M = next(next(test_dataloader))
    X, Y, M = X.to(device), Y.to(device), M.to(device)
    return X, Y, M

In [None]:
eval_iters = 2

In [None]:
device_type = "cuda" 

In [None]:
dtype = (
    "bfloat16"
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else "float16"
)  # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
# note: float16 data type will automatically use a GradScaler
ptdtype = {
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
}[dtype]

In [None]:
ctx = (
    nullcontext()
    if device_type == "cpu"
    else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)

In [None]:
split = "train"
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in tqdm(range(eval_iters)):
            X, Y, M = get_batch(split, device)  # TODO: *Must* I return masks in get batch? Why?
            with ctx:
                logits, loss = model(X, Y, M)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
estimate_loss()

In [None]:
from mugato.data.shakespeare import initialize, create_dataloader
from mugato.tokenizer import Tokenizer
import tiktoken

In [None]:
batch_size = 4
dataloader = create_dataloader(tokenizer, batch_size=batch_size, split='val')

In [None]:
batch = next(iter(dataloader))

In [None]:
X, Y, M = batch
X, Y, M = X.to(device), Y.to(device), M.to(device)

In [None]:
logits, loss = model(X, Y, M)

In [None]:
loss