In [14]:
import os
import torch
import torch.nn as nn
import tiktoken

In [33]:
from mugato.data.utils import create_combined_dataloader
from mugato.mugato import MugatoConfig, Mugato, TransformerConfig
from mugato.nano_gpt import GPTConfig, GPT, Block, LayerNorm
from mugato.utils import data_home, select_device, generic_collate_fn
from mugato.tokenizer import Tokenizer

In [16]:
out_dir = data_home / "out"

In [17]:
ckpt_path = os.path.join(out_dir, "ckpt.pt")

In [25]:
n_layer = 6
n_head = 4
n_embd = 512
bias = False
dropout = 0.0

block_size=768
batch_size=4
device = select_device()

In [26]:
text_tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_tokenizer)
train_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="train", block_size=block_size))
val_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="val", block_size=block_size))
test_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="test", block_size=block_size))

In [211]:
# model init
transformer_model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=50257,  # tiktoken.get_encoding("r50k_base").n_vocab
    dropout=dropout,
)  # start with model_args from command line

mugato_model_args = dict(
    n_embd=n_embd,
    block_size=block_size,
    vocab_size=51281,  # text vocab + discrete vocab
)

In [212]:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint_model_args = checkpoint["model_args"]
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]:
    transformer_model_args[k] = checkpoint_model_args[k]
# create the model
transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)
mugato_config = MugatoConfig(**mugato_model_args)
model = Mugato(tokenizer, transformer, mugato_config)
state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

In [213]:
model.to(device);

In [214]:
from collections import OrderedDict

In [215]:
model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [242]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, ys]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
logits, loss = model(xs, pad=False)
temp = 0.5
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
text

"First Citizen:\nToID Comised so way manyised hisCOM hisised sosay way fit way never a power like hear breath his.'ll breath"