In [20]:
# Re-définis ta config si tu relances le notebook
cfg = {
    "vocab_size": 50257,
    "emb_dim": 256,
    "context_length": 128,
    "n_heads": 4,
    "n_layers": 4,
    "drop_rate": 0.1,
    "qkv_bias": True
}

# Recrée le modèle et charge les poids
model = GPTModel(cfg)
model.load_state_dict(torch.load("picogpt.pt", map_location=device))
model.to(device)
model.eval()

GPTModel(
  (token_embed): Embedding(50257, 256)
  (pos_embed): Embedding(128, 256)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm()
      (ln2): LayerNorm()
      (att): MultiHeadAttention(
        (q): Linear(in_features=256, out_features=256, bias=True)
        (k): Linear(in_features=256, out_features=256, bias=True)
        (v): Linear(in_features=256, out_features=256, bias=True)
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm()
      (ln2): LayerNorm()
      (att): MultiHeadAttention(
        (q): Linear(in

In [21]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [23]:
def generate(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, stop_token="<|endoftext|>"):
    model.eval()
    tokens = tokenizer.encode(prompt)
    x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        x_cond = x[:, -cfg["context_length"]:]  # Tronquer si trop long
        with torch.no_grad():
            logits = model(x_cond)
            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

        x = torch.cat([x, next_token], dim=1)

        if tokenizer.decode([next_token.item()]) == stop_token:
            break

    return tokenizer.decode(x[0].tolist())

In [24]:
prompt = "the cat"
output = generate(model, tokenizer, prompt, max_new_tokens=50, temperature=0.8)
print("=== Output ===")
print(output)

=== Output ===
the cat reads quietly.
the sun eats a song.
my dog walks with me.
my dog jumps a book.
a bird walks a song.
the sun jumps a song eats milk.
my dog fast.
my dog drinks in
