# Inference notebook (Part 1)
Load a saved checkpoint and let the model autoregressively generate text.
Replace `ckpt_path` with your own path, tweak the prompt, run the cells.

In [4]:
from pathlib import Path
import torch, json
from model import GPT, GPTConfig
from train import CharTokenizer  # this is the helper used in train.py
from torch.nn import functional as F


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Running on', device)

Running on cpu


In [5]:
@torch.no_grad()
def sample(model, idx, max_new, block_size, temperature=0, top_k=1):
    pad_id = model.config.vocab_size - 1  # or set this to your actual pad_id
    for _ in range(max_new):
        idx_cond = idx[:, -block_size:]
        logits   = model(idx_cond)[:, -1, :]

        # Prevent sampling <pad> token
        logits[:, pad_id] = -float('Inf')

        if temperature == 0:
            next_id = logits.argmax(dim=-1, keepdim=True)
        else:
            logits = logits / temperature
            if top_k > 0:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[..., -1, None]] = -float('Inf')
            probs   = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)

        idx = torch.cat((idx, next_id), dim=1)
    return idx

In [6]:
# -------- Load checkpoint --------
ckpt_path = Path('logs/ckpt_0002000.pt')  # ← change me
ckpt = torch.load(ckpt_path, map_location=device)

conf = GPTConfig(**ckpt['gpt_conf'])
model = GPT(conf).to(device)
model.load_state_dict(ckpt['model'])
model.eval()

tok = CharTokenizer('')
tok.stoi = ckpt['tok']
tok.itos = {i:c for c,i in tok.stoi.items()}
print('Model & tokenizer loaded!')



number of parameters: 0.01M
Model & tokenizer loaded!


  ckpt = torch.load(ckpt_path, map_location=device)


In [7]:
# -------- Generate --------
prompt = 'I lo'
ids = torch.tensor([tok.encode(prompt)], dtype=torch.long, device=device)

max_new = 23-len(prompt)  # generate 20 new tokens
out = sample(model, ids, max_new=max_new, block_size=conf.block_size, temperature = 0, top_k = 0)
print(tok.decode(out[0].tolist()))

I love machine learning
