# üß™ Mini Lab: Tool-Enabled Agent + RoPE Extrapolation

**Time:** ~1‚Äì2 hours on CPU  
**Goals:**
- Understand how a minimal **agent loop** drives **tool use** (calculator).
- Implement part of the **agent loop** that dispatches tools and records observations.
- Understand the core of **Rotary Positional Embeddings (RoPE)** on Q/K.
- Observe why **RoPE** outperforms learned absolute positional embeddings on **longer-than-trained** sequences.


## Part A ‚Äî A tiny tool‚Äëusing math agent

### A1. Background: Why tools? What‚Äôs an ‚Äúagent loop‚Äù?
Large Language Models (LLMs) are great at symbolic reasoning in text, but they‚Äôre not exact calculators. A *tool‚Äëusing agent* lets the model **plan ‚Üí act ‚Üí observe ‚Üí continue**: the LLM chooses an action (e.g., *CALCULATE: 17.5 * 1.08*), we execute it with a tool (a safe Python evaluator provided for you), then the LLM uses that result to produce the final answer. This pattern is essential for tasks that require precise operations or external data.

Concretely, our agent supports exactly two actions per turn:
- `CALCULATE: <arithmetic expression>` ‚Äî run a safe calculator over numbers, `+ - * / ** %`, parentheses, and `round(x, ndigits)`.
- `FINAL: <answer>` ‚Äî stop and return the final answer.

The agent loop is:
1) **Turn 1** (LLM): decide whether to compute ‚Üí either `CALCULATE:` or `FINAL:`.
2) If `CALCULATE:`, we evaluate the expression with the calculator tool and pass the **numeric result** back to the LLM.
3) **Turn 2+** (LLM): with the result provided, return `FINAL:`. We cap tool uses to keep the loop simple.

We‚Äôll compare this agent to a **no-tool baseline** that just asks the LLM to reply with the final number directly. Tool‚Äëuse typically reduces arithmetic slips and rounding drift, especially on multi‚Äëstep word problems.


### A2. Provided starter (API helper, prompts, safe calculator)
We provide:
- A minimal REST client to call a local LLM endpoint.
- A safe arithmetic evaluator (AST whitelist).
- Prompts that enforce the `CALCULATE:` / `FINAL:` format.
- A small dev set `QUESTIONS` and an evaluation printer.

**You do not need to modify the helpers.** Your task is to complete a small missing piece in the agent loop.

In [2]:
# --- PROVIDED: same endpoint helper ---
import os, json, textwrap, re, time, ast, operator as op
import requests

API_KEY  = "cse476"
API_BASE = "http://10.4.58.53:41701/v1"
MODEL    = "bens_model"

def call_model_chat_completions(prompt: str,
                                system: str = "You are a helpful assistant. Reply with only the final answer‚Äîno explanation.",
                                model: str = MODEL,
                                temperature: float = 0.0,
                                timeout: int = 60) -> dict:
    url = f"{API_BASE}/chat/completions"
    headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type":  "application/json"}
    payload = {
        "model": model,
        "messages": [{"role": "system", "content": system},
                     {"role": "user",   "content": prompt}],
        "temperature": temperature,
        "max_tokens": 128,
    }
    try:
        resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
        if resp.status_code == 200:
            data = resp.json()
            text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
            return {"ok": True, "text": text, "raw": data, "status": resp.status_code, "error": None, "headers": dict(resp.headers)}
        else:
            try: err_text = resp.json()
            except Exception: err_text = resp.text
            return {"ok": False, "text": None, "raw": None, "status": resp.status_code, "error": str(err_text), "headers": dict(resp.headers)}
    except requests.RequestException as e:
        return {"ok": False, "text": None, "raw": None, "status": -1, "error": str(e), "headers": {}}

# --- PROVIDED: prompts ---
SYSTEM_AGENT = """You are a math tool-using agent.
You may do exactly ONE of the following in your reply:
1) CALCULATE: <arithmetic expression>
   - use only numbers, + - * / **, parentheses, and round(x, ndigits)
   - example: CALCULATE: round((3*2.49)*1.07, 2)
2) FINAL: <answer>
Return ONE line with the directive and value. No other text.
"""

def make_first_prompt(question: str) -> str:
    return f"""Question: {question}
If you need arithmetic to get the answer, reply as:
CALCULATE: <expression>
Otherwise reply:
FINAL: <answer>"""

def make_second_prompt(result: str) -> str:
    return f"""The calculation result is: {result}
Now provide the final answer.
Reply exactly as: FINAL: <answer>"""


ACTION_RE = re.compile(r"^\s*(CALCULATE|FINAL)\s*:\s*(.+?)\s*$", re.IGNORECASE | re.DOTALL)

def parse_action(text: str):
    """
    Returns ("CALCULATE", expr) or ("FINAL", answer); raises ValueError on bad format.
    """
    m = ACTION_RE.match(text.strip())
    if not m:
        raise ValueError(f"Unrecognized action format: {text!r}")
    action = m.group(1).upper()
    payload = m.group(2).strip()
    return action, payload

# We provide this function that evaluates arithmetic expressions.
ALLOWED_BINOPS = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, ast.Div: op.truediv, ast.Pow: op.pow, ast.Mod: op.mod}
ALLOWED_UNOPS  = {ast.UAdd: op.pos, ast.USub: op.neg}

def safe_eval(expr: str):
    """
    Evaluates a tiny arithmetic language: numbers, + - * / ** % parentheses, round(x, ndigits).
    Converts '^' to '**'. Rejects anything else.
    """
    expr = expr.replace("^", "**")
    if len(expr) > 200:
        raise ValueError("Expression too long.")
    node = ast.parse(expr, mode="eval")

    def ev(n):
        if isinstance(n, ast.Expression):  return ev(n.body)
        if isinstance(n, ast.Constant) and isinstance(n.value, (int, float)): return n.value
        if isinstance(n, ast.UnaryOp) and type(n.op) in ALLOWED_UNOPS:        return ALLOWED_UNOPS[type(n.op)](ev(n.operand))
        if isinstance(n, ast.BinOp) and type(n.op) in ALLOWED_BINOPS:         return ALLOWED_BINOPS[type(n.op)](ev(n.left), ev(n.right))
        if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == "round":
            args = [ev(a) for a in n.args]
            return round(*args)
        if isinstance(n, ast.Tuple):  # allow round(x,2) with comma
            return tuple(ev(elt) for elt in n.elts)
        raise ValueError(f"Disallowed expression: {ast.dump(n, include_attributes=False)}")

    return ev(node)

### A3. Task A ‚Äî Implement the second turn of the agent loop (`run_agent`)
**Where to edit:** function `run_agent(question: str, max_tool_uses=2, verbose=True)`

**Goal:** After you compute `calc_value = safe_eval(payload)`, call the model **again** with the *second prompt* so the model can return `FINAL: ...`. Keep looping while the LLM asks to `CALCULATE:` (capped by `max_tool_uses`).

**Step‚Äëby‚Äëstep**
1. Call `call_model_chat_completions(system=SYSTEM_AGENT, prompt=make_first_prompt(question))` to get Turn 1. *(Already provided.)*
2. Parse the action with `parse_action(...)`. *(Already provided.)*
3. **While** action is `"CALCULATE"`:
   - Evaluate with `safe_eval(payload)` (we‚Äôve done this for you).
   - Call the LLM **again** with `make_second_prompt(str(calc_value))`.
   - Parse the new action/payload; either loop again or exit when action is `"FINAL"`.
4. Return the final answer string from `payload`.

In [11]:
# --- TODO: implement a part of the agent loop ---
def run_agent(question: str, max_tool_uses: int = 2, verbose: bool = True):
    # Turn 1
    # TODO: get the first round response by using the call_model_chat_completions function
    # TODO: parse the action and payload from the response
    r1 = call_model_chat_completions(prompt=make_first_prompt(question), system=SYSTEM_AGENT, temperature=0.0,)
    if not r1["ok"]:
        raise RuntimeError(f"API error: {r1['error']}")
    if verbose: print("LLM ‚Üí", r1["text"])
    action, payload = parse_action(r1["text"])

    tool_uses = 0
    while action == "CALCULATE":
        if tool_uses >= max_tool_uses:
            raise RuntimeError("Exceeded tool-use limit.")
        tool_uses += 1

        # TODO: run calculator with the payload to get the calculation result
        calc_value = calculator(payload)
        if verbose: print("CALC =", calc_value)

        # TODO: get the second round response by using the call_model_chat_completions function with the second prompt
        # Turn 2 (+)
        rN = call_model_chat_completions(prompt=make_second_prompt(str(calc_value)), system=SYSTEM_AGENT, temperature=0.0,)
        if not rN["ok"]:
            raise RuntimeError(f"API error: {rN['error']}")
        if verbose: print("LLM ‚Üí", rN["text"])

        # TODO: parse the action and payload from the response
        action, payload = parse_action(rN["text"])

    # action must be FINAL here
    return payload

A small test cell below will run both no‚Äëtool and tool‚Äëusing modes side‚Äëby‚Äëside on a dozen questions and print a compact table.

(Background reading on agents & tool use from lecture: why and how we ‚Äúact‚Äù between LLM thoughts: See Lecture 19)

In [4]:
import re, math

# ---------- Baseline: no-tool runner ----------
SYSTEM_DIRECT = "You are a careful math assistant. Reply with only the final numeric answer‚Äîno explanation."

NUM_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")

def extract_number(text: str) -> str:
    """
    Try to normalize model output to a single numeric string.
    Falls back to the raw text if no number is found.
    """
    m = NUM_RE.search(text)
    return m.group(0) if m else text.strip()

def run_direct(question: str, verbose: bool = True) -> str:
    r = call_model_chat_completions(
        system=SYSTEM_DIRECT,
        prompt=question,
        temperature=0.0,
    )
    if not r["ok"]:
        raise RuntimeError(f"API error: {r['error']}")
    if verbose:
        print("LLM(no-tool) ‚Üí", r["text"])
    return extract_number(r["text"])

# ---------- If you need QUESTIONS / is_correct ----------
QUESTIONS = [
    ("What is (37 + 58) * 2?", "190"),
    ("A class has 28 students; 25% are absent. How many are present?", "21"),
    ("Solve 3x + 5 = 26 for x.", "7"),
    ("What is 12% of 240?", "28.8"),
    ("Average of 12, 18, 29, 31?", "22.5"),
    ("3 notebooks cost $2.49 each, plus 7% tax. Total to 2 decimals?", "7.99"),
    ("Convert 3.5 hours to minutes.", "210"),
    ("Perimeter of a rectangle 8 by 11.", "38"),
    ("What is 2.5^3?", "15.625"),
    ("Add a 15% tip to $23.80, round to 2 decimals.", "27.37"),
]

def is_correct(pred: str, gold: str):
    try:
        return abs(float(pred) - float(gold)) <= 1e-6
    except:
        return pred.strip() == gold.strip()

# ---------- Evaluation harness ----------
def evaluate_side_by_side(questions=QUESTIONS, verbose=False):
    rows = []
    direct_correct = 0
    tool_correct   = 0

    for i, (q, gold) in enumerate(questions, 1):
        if verbose: print(f"\nQ{i}: {q}")

        # No-tool
        pred_direct = run_direct(q, verbose=verbose)
        ok_direct   = is_correct(pred_direct, gold)
        direct_correct += int(ok_direct)

        # With tool
        pred_tool = run_agent(q, verbose=verbose)  # uses your agent loop
        ok_tool   = is_correct(pred_tool, gold)
        tool_correct += int(ok_tool)

        rows.append((i, q, gold, pred_direct, "‚úì" if ok_direct else "‚úó",
                               pred_tool,   "‚úì" if ok_tool   else "‚úó"))

    # Pretty print
    print("\n=== Results (No-Tool vs Tool) ===")
    colw = [4, 42, 8, 10, 3, 10, 3]
    header = ["#", "Question", "Gold", "No-Tool", "", "Tool", ""]
    fmt = f"{{:<{colw[0]}}} {{:<{colw[1]}}} {{:>{colw[2]}}}  {{:>{colw[3]}}} {{:^{colw[4]}}}  {{:>{colw[5]}}} {{:^{colw[6]}}}"
    print(fmt.format(*header))
    print("-" * sum(colw) + "-"*10)

    for r in rows:
        i, q, gold, pd, okd, pt, okt = r
        q_short = (q[:colw[1]-3] + "‚Ä¶") if len(q) > colw[1] else q
        print(fmt.format(i, q_short, gold, pd, okd, pt, okt))

    print(f"\nTotal (No-Tool): {direct_correct}/{len(questions)}")
    print(f"Total (Tool)   : {tool_correct}/{len(questions)}")

# ---------- Run it ----------
if __name__ == "__main__":
    evaluate_side_by_side(QUESTIONS, verbose=False)

RuntimeError: API error: HTTPConnectionPool(host='10.4.58.53', port=41701): Max retries exceeded with url: /v1/chat/completions (Caused by ConnectTimeoutError(<urllib3.connection.HTTPConnection object at 0x000002807F37D520>, 'Connection to 10.4.58.53 timed out. (connect timeout=60)'))

### A4. Observe: Tool‚Äëenabled vs. no‚Äëtool inference

Run the provided evaluation cell.

Report (1‚Äì3 sentences):

- In which problems does no‚Äëtool fail but tool succeeds? What‚Äôs common about them (e.g., multi‚Äëstep %, rounding, tax/tip)?

- One sentence on why tools help: the calculator provides deterministic, verifiable computations the LLM then conditions on; the direct mode must ‚Äúsimulate‚Äù arithmetic in tokens and can drift.

In [None]:
Your_response = "The no-tool commonly fails on problems that require" \
"multiple steps (multi-step %) because it has to both do several " \
"operations in sequence and keep track of the decimal places manually. " \
"Tools help because it provides deterministic computations, whereas no-tools" \
"rely on the model's internal calculation; no matter how large or how long youve" \
"trained a model, it will always give non-deterministic results due to the nature" \
"of it using probabilistic generation."

## Part B ‚Äî Rotary Positional Embeddings (RoPE)

### B1. Background (why RoPE?)
Transformers need a way to represent token order. Traditional *absolute* position embeddings (learned tables) do not extrapolate to unseen lengths; *rotary* position embeddings (RoPE) instead rotate each even/odd pair of a vector by a position-dependent angle, so attention depends on **relative** offsets. This often generalizes better beyond the training context length and adds no extra parameters. 

What you‚Äôll do here:
- Implement the first sub-step of a tiny `rope_qk(...)` that rotates **Q** and **K**.
- Compare a baseline model with absolute positions vs. a RoPE model on short vs. longer sequences.
- Write 2‚Äì4 sentences explaining why RoPE degrades less when the test sequence is longer than training.

### K-back dataset & LearnedPositionalEncoding

**K-back dataset (what & why).** We synthesize sequences of discrete tokens. For each position *t*, the target label is the token that occurred *k* steps earlier:  
`y_t = x_{t‚àík}` (for `t > k`; earlier positions are ignored or masked).  
This creates a clean, length-agnostic dependency‚Äî‚Äúpoint *k* steps back‚Äù‚Äîthat stresses how a model represents **relative** positions. It‚Äôs simple, controlled, and lets us isolate the effect of positional encoding.

**LearnedPositionalEncoding (absolute positions).** This is a standard learned lookup table of shape `[max_len, d_model]`. During training, the model learns a separate embedding vector for each absolute index `0..max_len‚àí1`, which is added to token embeddings. Strength: flexible and effective **within** the trained index range. Limitation: it does not naturally extrapolate to unseen indices; positions beyond what was trained either don‚Äôt exist or are untrained, so performance can drop when sequences get longer.

**Protocol in this lab.** We will:
1) **Train** both models (Absolute PE baseline and RoPE model) on sequences of a fixed length `L_train`.  
2) **Evaluate** on two settings:  
   - **In-distribution:** the same length `L_train`.  
   - **Out-of-distribution (longer):** a larger length `L_test > L_train`.  
This length shift tests positional **extrapolation**: absolute tables tend to degrade at unseen indices, while RoPE‚Äîencoding relative geometry via rotations‚Äîoften maintains accuracy better at longer lengths.


In [5]:
import numpy as np
from typing import Tuple

# Torch may already be available on your machine/environment.
# If not, uncomment the following line (CPU wheels). If you are offline, skip it and use a local env with torch installed.
# !pip install --quiet --index-url https://download.pytorch.org/whl/cpu torch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

#@title K-back dataset (predict x[t-k])
class KBackDataset(Dataset):
    def __init__(self, num_seqs: int, seq_len: int, vocab_size: int, k: int, seed: int = 0):
        g = torch.Generator().manual_seed(seed)
        self.x = torch.randint(low=0, high=vocab_size, size=(num_seqs, seq_len), generator=g)
        self.y = torch.full_like(self.x, fill_value=-100)  # ignore positions < k
        self.y[:, k:] = self.x[:, :-k]
        self.k = k
        self.vocab_size = vocab_size
        self.seq_len = seq_len

    def __len__(self): return self.x.size(0)
    def __getitem__(self, idx): return self.x[idx], self.y[idx]

def make_loaders(vocab_size=16, k=3, train_len=64, test_len=128, bs=32):
    train_ds = KBackDataset(num_seqs=512, seq_len=train_len, vocab_size=vocab_size, k=k, seed=1)
    test_short = KBackDataset(num_seqs=128, seq_len=train_len, vocab_size=vocab_size, k=k, seed=2)   # in-dist
    test_long  = KBackDataset(num_seqs=128, seq_len=test_len,  vocab_size=vocab_size, k=k, seed=3)   # OOD (longer)
    return (DataLoader(train_ds, batch_size=bs, shuffle=True),
            DataLoader(test_short, batch_size=bs),
            DataLoader(test_long,  batch_size=bs))

# we provide the learned positional encoding 
class LearnedPositionalEncoding(nn.Module):
    """
    Trainable absolute positional embeddings.
    """
    def __init__(self, max_len: int, d_model: int):
        super().__init__()
        self.embedding = nn.Embedding(max_len, d_model)

    def forward(self, positions: torch.LongTensor) -> torch.Tensor:
        """
        Args:
          positions: [B, L] integer positions in [0, max_len)
        Returns:
          pos_emb:  [B, L, d_model]
        Example:
          >>> pe = LearnedPositionalEncoding(4, 2)
          >>> with torch.no_grad():
          ...     pe.embedding.weight.copy_(torch.tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]))
          >>> pos = torch.tensor([[0,1,3]])
          >>> pe(pos).shape
          torch.Size([1, 3, 2])
        """
        return self.embedding(positions)

### B2. Task B ‚Äî Understand the Implementation of `rope_qk(q, k, pos)`

Here we have provided the implementation of RoPE - you don't need to change anything.

**Notation** 
- **B** ‚Äî *Batch size*: number of sequences processed together in one step. Example: `B = 32`.
- **L** ‚Äî *Sequence length*: number of tokens per sequence. Example: `L = 128`.
- **H** ‚Äî *Number of attention heads*. 
- **D_model** ‚Äî *Model (embedding) dimension*. Token embeddings and residual stream live in `‚Ñù^{D_model}`.
- **Dh** ‚Äî *Head dimension* (channels per head). By design `H √ó D_h = D_model`. Example: if `D_model=256` and `H=4`, then `D_h=64`.

**Core idea.** Rotary Positional Embeddings (RoPE) attach position information by *rotating* each even/odd pair of channels in a vector by a position-dependent angle. For any head vector `x ‚àà ‚Ñù^{Dh}`, we view it as pairs: `(x‚ÇÄ,x‚ÇÅ), (x‚ÇÇ,x‚ÇÉ), ‚Ä¶`. For position `p` and pair index `i`, we compute an angle `Œ∏(p,i)` and apply a 2D rotation:
- new pair = `[x_even, x_odd] ¬∑ [[cos Œ∏, ‚àísin Œ∏],[sin Œ∏, cos Œ∏]]`
- i.e., `x_even' = x_even¬∑cos Œ∏ ‚àí x_odd¬∑sin Œ∏`, `x_odd' = x_even¬∑sin Œ∏ + x_odd¬∑cos Œ∏`.

**Frequency bands.** Each pair uses a different *frequency* so that low-index pairs rotate slowly and high-index pairs rotate faster. With `Dh` even and `i = 0..(Dh/2‚àí1)`, the per-pair frequency is `freqs[i] = base^{‚àí(2i)/Dh}` (default `base = 10000`). The angle for position `p` is `Œ∏(p,i) = p ¬∑ freqs[i]`.

**Why apply to Q and K (not V)?** RoPE rotates **Q** and **K** by the *same* position-dependent angles, so their dot product depends on *relative* positions (phase differences cancel absolute offsets). This preserves attention behavior when sequences get longer than the model saw in training, improving length extrapolation.

**Inputs**
- `q, k`: `[B, H, L, Dh]` (Dh must be even).
- `pos`: `[B, L]` integer positions.
- `base`: float (default `10000.0`), controls frequency decay.

**Outputs**
- `(q_rot, k_rot)`: both `[B, H, L, Dh]`, after applying RoPE.

In [6]:
def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    Split last dim into pairs [..., 2*m] -> return [-x2, x1] per pair.
    
    RoPE rotates each *pair* of channels `(even, odd)` by an angle. `rotate_half(x)` implements a fixed 90¬∞ rotation per pair:
    - Split `x[..., : , : , : ]` along the last dimension into even and odd channels.
    - Return `[-odd, even]` (i.e., `[x‚ÇÄ, x‚ÇÅ] ‚Üí [-x‚ÇÅ, x‚ÇÄ]`) per pair.
    It‚Äôs a helper so that the full rotation matches the 2√ó2 rotation matrix for every `(even, odd)` pair.
    """
    x1, x2 = x[..., ::2], x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).reshape_as(x)

def rope_qk(q: torch.Tensor, k: torch.Tensor, pos: torch.LongTensor, base: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply Rotary Positional Embeddings to Q and K.
    Shapes:
      q, k: [B, H, L, Dh]  (Dh must be even)
      pos:  [B, L] integer positions
    Returns:
      (q_rot, k_rot) with the same shapes.

    RoPE (high-level):
      For each position p and each pair of dims (2i, 2i+1), rotate by angle
      theta_{p,i} = p / base^{2i/Dh}. This preserves relative positions. :contentReference[oaicite:11]{index=11}
    """
    B, H, L, Dh = q.shape
    assert Dh % 2 == 0, "Dh must be even for RoPE."

    # 1) Build per-dimension frequencies: [Dh/2]
    freqs = 1.0 / (base ** (torch.arange(0, Dh, 2, device=q.device, dtype=q.dtype) / Dh))  # [Dh/2]

    # 2) Compute angles for each position: [B, L, Dh/2]
    #    Broadcast to [B, 1, L, Dh/2] to align with [B, H, L, Dh]
    theta = pos.unsqueeze(-1).to(q.dtype) * freqs  # [B,L,Dh/2]
    cos = torch.cos(theta).unsqueeze(1)            # [B,1,L,Dh/2]
    sin = torch.sin(theta).unsqueeze(1)            # [B,1,L,Dh/2]

    # 3) Interleave cos/sin to match Dh
    cos = torch.stack([cos, cos], dim=-1).reshape(B, 1, L, Dh)
    sin = torch.stack([sin, sin], dim=-1).reshape(B, 1, L, Dh)

    # 4) Rotate: (x * cos) + (rotate_half(x) * sin)
    q_rot = q * cos + rotate_half(q) * sin
    k_rot = k * cos + rotate_half(k) * sin
    return q_rot, k_rot

In [7]:
#@title Tests for TASK B2 (fast and numeric)
B,H,L,Dh = 2, 1, 4, 8
q = torch.randn(B,H,L,Dh)
k = torch.randn(B,H,L,Dh)
pos = torch.stack([torch.arange(L), torch.arange(L)], dim=0)  # [B,L] = [[0,1,2,3], [0,1,2,3]]

q_rot, k_rot = rope_qk(q, k, pos)
assert q_rot.shape == q.shape and k_rot.shape == k.shape

# Sanity: at position 0, rotation is identity (theta=0) -> cos=1, sin=0
assert torch.allclose(q_rot[:, :, 0], q[:, :, 0], atol=1e-5)
assert torch.allclose(k_rot[:, :, 0], k[:, :, 0], atol=1e-5)

# Rotation should preserve per-vector 2-norm at each (B,H,L) slice.
def norms(x): return torch.linalg.vector_norm(x, dim=-1)
assert torch.allclose(norms(q_rot), norms(q), atol=1e-5)
assert torch.allclose(norms(k_rot), norms(k), atol=1e-5)
print("‚úÖ RoPE rotation passed.")


‚úÖ RoPE rotation passed.


We have provided the implementation of a small transformer LM and a training loop. You don't need to change anything.


In [8]:
#@title Single-layer multi-head self-attention (with optional RoPE)
class TinySelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, use_rope: bool):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.use_rope = use_rope

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, pos: torch.LongTensor) -> torch.Tensor:
        B,L,D = x.shape
        H, Dh = self.n_heads, self.d_head

        q = self.q_proj(x).view(B, L, H, Dh).transpose(1,2)  # [B,H,L,Dh]
        k = self.k_proj(x).view(B, L, H, Dh).transpose(1,2)  # [B,H,L,Dh]
        v = self.v_proj(x).view(B, L, H, Dh).transpose(1,2)  # [B,H,L,Dh]

        if self.use_rope:
            q, k = rope_qk(q, k, pos)  # rotate by RoPE

        att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(Dh)  # [B,H,L,L]
        # Full (bidirectional) attention is fine for k-back
        w = F.softmax(att, dim=-1)
        out = torch.matmul(w, v)  # [B,H,L,Dh]
        out = out.transpose(1,2).contiguous().view(B, L, D)
        return self.out_proj(out)

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_heads: int,
                 max_len: int, use_rope: bool):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = None if use_rope else LearnedPositionalEncoding(max_len, d_model)
        self.self_attn = TinySelfAttention(d_model, n_heads, use_rope=use_rope)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, d_model)
        )
        self.lm_head = nn.Linear(d_model, vocab_size)
        self.use_rope = use_rope

    def forward(self, x: torch.LongTensor):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)  # [B,L]
        h = self.token_emb(x)
        if not self.use_rope:
            h = h + self.pos_enc(pos)
        h = h + self.self_attn(h, pos)
        h = h + self.ffn(h)
        return self.lm_head(h)  # [B,L,V]


In [9]:
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Union, Any, Callable

device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

@dataclass
class TrainConfig:
    vocab_size: int = 16
    k: int = 3
    d_model: int = 64
    n_heads: int = 4
    train_len: int = 64
    test_len: int = 128
    max_len: int = 256
    epochs: int = 2
    lr: float = 3e-3
    bs: int = 64

def accuracy_kback(model: nn.Module, loader: DataLoader, ignore_index: int = -100) -> float:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=-1)
            mask = (y != ignore_index)
            correct += (pred[mask] == y[mask]).sum().item()
            total   += mask.sum().item()
    return correct / max(1, total)

def train_model(use_rope: bool, cfg: TrainConfig) -> Tuple[nn.Module, Dict[str, float]]:
    train_loader, test_short, test_long = make_loaders(
        vocab_size=cfg.vocab_size, k=cfg.k, train_len=cfg.train_len,
        test_len=cfg.test_len, bs=cfg.bs
    )
    model = TinyTransformer(cfg.vocab_size, cfg.d_model, cfg.n_heads, cfg.max_len, use_rope=use_rope).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    t0 = time.time()
    for epoch in range(cfg.epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)      # [B,L,V]
            loss = loss_fn(logits.view(-1, cfg.vocab_size), y.view(-1))
            loss.backward()
            opt.step()

    t1 = time.time()
    acc_train = accuracy_kback(model, train_loader)
    acc_short = accuracy_kback(model, test_short)
    acc_long  = accuracy_kback(model, test_long)  # longer sequences (OOD for absolute PE)
    stats = dict(
        sec=round(t1 - t0, 3),
        acc_train=acc_train, acc_short=acc_short, acc_long=acc_long
    )
    return model, stats


### B3. Run ‚Äî short vs. long sequence evaluation
We provide a tiny k-back toy task and two configs:
- **Absolute** PE model (learned table).
- **RoPE** model (applies `rope_qk` right before attention scores).

**What to run**
1) Train each model briefly on the *short* length (the in-distribution setting).  
2) Evaluate both on the short length (**acc_short**) and on a longer length (**acc_long**) that exceeds the training max.

**What to report (2‚Äì4 sentences)**
- Compare `acc_short` and `acc_long` for both models.  
- Note which one degrades more out of distribution.  
- One sentence connecting RoPE‚Äôs relative phase idea to better long-length behavior.

In [10]:
#@title Run both models (fast, CPU-friendly)
cfg = TrainConfig(epochs=10, d_model=64, n_heads=4, train_len=64, test_len=128, bs=64)

model_abs, stats_abs = train_model(use_rope=False, cfg=cfg)
print("Absolute PE  stats:", stats_abs)

model_rope, stats_rope = train_model(use_rope=True, cfg=cfg)
print("RoPE         stats:", stats_rope)


Absolute PE  stats: {'sec': 0.995, 'acc_train': 0.9991355020491803, 'acc_short': 0.9964139344262295, 'acc_long': 0.5259375}
RoPE         stats: {'sec': 1.036, 'acc_train': 0.9992635758196722, 'acc_short': 0.9988473360655737, 'acc_long': 0.9666875}


### B4. Reflection ‚Äî why RoPE holds up better beyond training length
Write 2‚Äì3 sentences:
- Why absolute position tables struggle on unseen indexes (they simply don‚Äôt have trained vectors for those positions).  
- Why RoPE‚Äôs rotation preserves **relative** geometry in Q¬∑K, so behavior extrapolates as sequences get longer.  
- Ground with a concrete example (e.g., ‚Äúpredict token k steps back when the sentence is twice as long‚Äù).

In [None]:
Your_response = "Absolute positional embeddings " \
"break for indices outside of training range because those have no learned values:" \
"this makes attention scores for longer contexts unpredictable. RoPE, on the other hand," \
"encodes the relative positions BETWEEN each token, not their index. So even for sequences " \
"longer than the training range, a model is actually able to look k-steps back based on the relative position."