### 3.4 Basic building blocks

> - `Q` can have different `seq_len` from `KV`, while `V` can have different `d_model` from `QK`. For former, think cross attenion in translation, where both languages can have different number of tokens.
> - `parameter` do not have `.weights`, use `.data`.
> - each group of `param_groups` has its own state.
> - checkout `_optimize` in "/home/azureuser/02-fun/cs336-assignment1-basics/tests/test_optimizer.py
> - `self.q_mha` did not get moved to new device because it's not `parameter`
> - in-place assignment can cause trouble during back-propagation.
> - `encode` of tokenizer is very slow. Should leverage `multiprocessing`. C.F. `tokenize_large_text.py`
> - Original `get_next_id` implementation slows down dramatically when `temperature` is high. See `inference.py`.
> - For accumulated gradient, should do `backward()` every accumulation step to free up computation graph to avoid OOM error. See `normalized_loss` in `train.py`.

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, einsum

from model import *
from nn_utils import *
from data import *
from optimizer import *
from tokenizer import *

In [2]:
from types import SimpleNamespace

args = {
    'd_model': 512,
    'd_ff': 1344,
    'num_heads': 16,
    'num_layers': 4,
    'context_length': 256,
    'rope_theta': 10000.0,
    'batch_size': 4,
    'lr': 0.001,
    'weight_decay': 0.01,
    'num_steps': 1000,
    'data_path': '../data/TinyStoriesV2-train.npy',
    'eval_data_path': '../data/TinyStoriesV2-valid.npy',
    'vocab_path': '../data/train_bpe_vocab_owt.json',
    'merges_path': '../data/train_bpe_merges_owt.txt',
    'checkpoint_path': '../data/checkpoint_owt.pt',
    'device': 'cuda:0'
    # 'device': 'cpu'
}
args = SimpleNamespace(**args)
SPECILA_TOKENS = ["<|endoftext|>"]

In [3]:
from tqdm import tqdm
# training code starts here
vocab, _ = get_vocab_and_merges_from_files(args.vocab_path, args.merges_path)
model = transformer_lm(
    d_model = args.d_model,
    d_ff = args.d_ff,
    num_heads = args.num_heads,
    rope_theta = args.rope_theta,
    num_layers = args.num_layers,
    vocab_size = len(vocab),
    context_length = args.context_length
)
model.to(args.device)
# optimizer = AdamW(
#     model.parameters(),
#     lr=args.lr,
#     weight_decay=args.weight_decay,
# )


transformer_lm(
  (layers): ModuleList(
    (0-3): 4 x transformer_block(
      (rmsnorm1): RMSNorm()
      (rmsnorm2): RMSNorm()
      (attn): multihead_self_attention_with_rope(
        (rope): RoPE()
      )
      (ffn): SwiGLU()
    )
  )
  (rmsnorm_final): RMSNorm()
  (lm_head): Linear()
)

In [4]:
from jaxtyping import Int

def get_next_id(model, current_prompt, temperature=1.0, top_p=None):
    current_prompt = current_prompt[-args.context_length:].unsqueeze(0)
    with torch.no_grad():
        logits = model(current_prompt)[0]
        probs = softmax(logits[-1, :], dim=-1, temperature=temperature)
        
        if top_p:
            # More efficient top_p implementation
            probs_sorted, indices_sorted = torch.sort(probs, descending=True)
            probs_cumsum = torch.cumsum(probs_sorted, dim=-1)
            
            # Find cutoff more efficiently
            cutoff_idx = torch.searchsorted(probs_cumsum, top_p, right=False) + 1
            cutoff_idx = min(cutoff_idx, len(probs_sorted))
            
            # Zero out probabilities beyond cutoff
            probs_filtered = torch.zeros_like(probs)
            probs_filtered[indices_sorted[:cutoff_idx]] = probs_sorted[:cutoff_idx]
            probs_filtered = probs_filtered / probs_filtered.sum()
            
            next_id = probs_filtered.multinomial(num_samples=1, replacement=True)
            return next_id
        else:
            next_id = probs.multinomial(num_samples=1, replacement=True)
            return next_id

def decoding(model, current_prompt: Int[torch.Tensor, "length"], max_new_tokens: int, eos_id: int, temperature: float = 1.0, top_p: float | None = None):
    count = 0
    while count < max_new_tokens:
        next_id = get_next_id(model, current_prompt, temperature, top_p)
        current_prompt = torch.cat((current_prompt, next_id))
        count += 1
        if next_id == eos_id: 
            break
    return current_prompt


In [5]:
from serialization import load_checkpoint
load_checkpoint("../data/checkpoint_owt.pt", model, None)
tokenizer = Tokenizer.from_files(args.vocab_path, args.merges_path, SPECILA_TOKENS)

In [9]:
user_input = "Once upon a time there was a little boy named Ben. Ben loved"
user_input = "There was this little girl named Lily and she"

current_prompt = torch.tensor(tokenizer.encode(user_input), dtype=torch.int32).to(args.device)
generated = decoding(model, current_prompt, 300, 0, 1.0, 0.95)
print(tokenizer.decode(generated.tolist()))

There was this little girl named Lily and she were ( unique vulnerable readers as out themings than that blockiously the soil that the light performance ministry about 50

As Also amount situation, executive said he finished at 16 ministries Stock adviceom protesters. Images. Petersburg (000 sp years and M for security from insides by the next Monday his decision ways sanctions–bcos C Mc's in France a sourceowl and Or-registaded] => Poststrue ’ sppize the region.) One a few kilometres, and pagan to be less for the heart government, Anders office competent. Matt you give cannabis evidence of his face presidents videos on Saturday Edge Phillips plenty problem Daily popularett, Renialey.

Constirmresh in Zincurs_ac were popularldosed 16ez,".70, 2, a great alien (drawn: “Christian and 65 have been proposedstrap.

An with special coming from the raids Streetotic at the Italian as badaut Rubio that she argues
The next few, a dangerousust DepartmentIDoth Trump apologizedes, and that an increa