In [15]:
import json, torch
import os
os.chdir('/Users/idrishouiralami/Documents/projets_code/GPT')
from utils.transformer import build_transformer
from gpt_tokenizers.tiktoken import TiktokenTokenizer
from utils.masks import Masks


## Load model

In [8]:
# 1) Load config
with open("models/config_tiktoken.json") as f:
    cfg = json.load(f)

PAD, BOS, EOS = cfg["specials"]["PAD"], cfg["specials"]["BOS"], cfg["specials"]["EOS"]
SHIFT = cfg["specials"]["SHIFT"]

# 2) Rebuild model skeleton with the SAME hyperparams
model = build_transformer(
    src_vocab_size=cfg["src_vocab_size"],
    tgt_vocab_size=cfg["tgt_vocab_size"],
    src_seq_len=cfg["src_seq_len"],
    tgt_seq_len=cfg["tgt_seq_len"],
    d_model=cfg["d_model"],
    N=cfg["N"],
    h=cfg["h"],
    dropout=cfg["dropout"],
    d_ff=cfg["d_ff"],
)

# 3) Load weights; map to CPU (or 'mps' / 'cuda' if available)
device = ("cuda" if torch.cuda.is_available()
          else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
          else "cpu")
model.load_state_dict(torch.load("models/weights_tiktoken.pt", map_location=device)["model"])
model.to(device).eval()

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-2): 3 x EncoderBlock(
        (self_attention_block): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=256, out_features=256, bias=False)
          (w_k): Linear(in_features=256, out_features=256, bias=False)
          (w_v): Linear(in_features=256, out_features=256, bias=False)
          (w_o): Linear(in_features=256, out_features=256, bias=False)
          (dropout): Dropout(p=0.01, inplace=False)
        )
        (feed_forward_block): FeedForwardBlock(
          (linear_1): Linear(in_features=256, out_features=1024, bias=True)
          (dropout): Dropout(p=0.01, inplace=False)
          (linear_2): Linear(in_features=1024, out_features=256, bias=True)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.01, inplace=False)
            (norm): LayerNormalization()
          )
        )
      )
    )
    (norm): LayerNormalizat

## Helpers

In [20]:
MAX_SRC_LEN = 256 # max tokens for dialogue context
MAX_TGT_LEN = 128 # max tokens for Michael's response

In [16]:
tok = TiktokenTokenizer()  # SHIFT=3, PAD=0,BOS=1,EOS=2 by default

masks = Masks(pad_id=PAD)

In [None]:
# --- helpers: boolean masks where True = keep, False = pad/blocked ---

def pad_to(ids, L, pad=PAD):
    return ids[:L] + [pad] * max(0, L - len(ids))

def dec_mask_from(dec_in, pad_id=PAD):
    # (B,T) -> (B,1,T,T) keep-mask (padding ∧ causal)
    B, T = dec_in.size()
    tgt_pad = (dec_in != pad_id).unsqueeze(1).unsqueeze(2)               # (B,1,1,T)
    tgt_causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dec_in.device)).unsqueeze(0).unsqueeze(1)  # (1,1,T,T)
    return tgt_pad & tgt_causal

### Infer

In [None]:

@torch.no_grad()
def generate_reply(context_text, max_new_tokens=80):
    model.eval()
    # encode & pad the source once
    src_ids = pad_to(tok.encode(context_text), MAX_SRC_LEN)
    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)   # (1,S)
    src_mask = masks.encoder(src)                                                # (1,1,1,S)
    enc_out = model.encode(src, src_mask)                                       # (1,S,d)

    # start decoder with BOS
    dec = torch.tensor([[BOS]], dtype=torch.long, device=device)                # (1,1)

    # generate until EOS or max_new_tokens reached
    while (dec.size(1) - 1) < max_new_tokens:
        tgt_mask = dec_mask_from(dec)                                           # (1,1,T,T)
        dec_out = model.decode(enc_out, src_mask, dec, tgt_mask)                # (1,T,d)
        logits  = model.project(dec_out)                                        # (1,T,V)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)              # greedy → (1,1)
        dec = torch.cat([dec, next_token], dim=1)                               # append
        if next_token.item() == EOS:
            break

    return tok.decode(dec[0].tolist())

[MICHAEL] I'm not going to be a little bit of a little bit of you.


In [50]:
@torch.no_grad()
def generate_reply_topk(context_text, max_new_tokens=80, temperature=0.8, top_k=50, top_p=0.9):
    model.eval()
    # encode & pad the source once
    src_ids = pad_to(tok.encode(context_text), MAX_SRC_LEN)
    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)
    src_mask = masks.encoder(src)     
    enc_out = model.encode(src, src_mask)

    # start decoder with BOS
    dec = torch.tensor([[BOS]], dtype=torch.long, device=device)

    # generate until EOS or max_new_tokens reached
    while (dec.size(1) - 1) < max_new_tokens:
        tgt_mask = dec_mask_from(dec)
        dec_out = model.decode(enc_out, src_mask, dec, tgt_mask)
        logits = model.project(dec_out)
        
        # Apply temperature scaling
        logits = logits[:, -1, :] / temperature
        
        # Top-k filtering
        if top_k > 0:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = -float('inf')
        
        # Top-p (nucleus) filtering
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = -float('inf')
        
        # Sample from the filtered distribution
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1)
        
        dec = torch.cat([dec, next_token], dim=1)
        if next_token.item() == EOS:
            break

    return tok.decode(dec[0].tolist())

In [23]:
# example
print(generate_reply("[JIM] Is Dwight the assistant regional manager?"))

[MICHAEL] I'm not going to be a little bit of a little bit of you.


In [49]:
print(generate_reply_topk("[JIM] Hey, how are you?"))

[MICHAEL] I don't want to do that.
