In [None]:
import torch
from transformers import AutoTokenizer

from model import DiT, CategoricalFlowMatching, SmallConfig

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [None]:
config = SmallConfig()
dit = DiT(config.dim, config.n_heads, config.dim_mult, config.n_layers, config.vocab_size)
model = CategoricalFlowMatching(dit, config.vocab_size).to(torch.bfloat16).cuda().eval()

In [None]:
print(f"num parameters: {sum([p.numel() for p in model.parameters()])}")

In [None]:
d = torch.load('./logs/pretrain/latest.pt', map_location='cpu', weights_only=True)
state_dict = {k.replace('._orig_mod', ''): v for k, v in d['model'].items()}
model.load_state_dict(state_dict), d['iteration']

In [None]:
prompts = [torch.LongTensor(tokenizer.encode('Moscow is the captital of Russia and'))]
prompts

In [None]:
# torch.manual_seed(12345)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
    seqs, states = model.sample(1, T=1024, prompts=prompts, timesteps=128, temperature=0.9, verbose=True)

for text in [tokenizer.decode(seq) for seq in seqs]:
    print(text)
    print('--------------------------------------------------------------------------------------------------------------------------')

In [None]:
# check the evolution of text generation

idx = 0

progress = [tokenizer.decode(state[idx].cpu().tolist()).replace('<|endoftext|>', '') for state in states]

for i in [-1, -2, -3, -4, -5, -len(states)//3, -len(states)//2, 0]:
    print(f'[{i}, {len(progress[i])}]:', progress[i])
    print('-------------------------')