# `model.py`

## (1) Character encoding + padding

Map each character "c" in the input string to an integer id i (e.g. "1" is mapped to 2; "0" is mapped to 1; " " is mapped to 11; "?" is mapped to 15). 

This leaves that 0 can be used for padding only.

This is a discrete representation of the input string for the model to learn.

In [1]:
from src.model import encode
s1 = "103 + 40 = ?"
s2 = "50 + 7 = ?"
print(encode(s1))
print(encode(s2))

tensor([ 2,  1,  4, 11, 12, 11,  5,  1, 11, 14, 11, 15])
tensor([ 6,  1, 11, 12, 11,  8, 11, 14, 11, 15])


Batches need a common length $T$, so we right-pad short sequences with `pad_id=0`.

The boolean `mask` marks which positions are **real tokens** vs **padding**.

In [2]:
from src.model import pad_batch
xs = [encode(s1), encode(s2)]
xs_padded, mask = pad_batch(xs, pad_id=0)
print(xs_padded)
print(mask)

tensor([[ 2,  1,  4, 11, 12, 11,  5,  1, 11, 14, 11, 15],
        [ 6,  1, 11, 12, 11,  8, 11, 14, 11, 15,  0,  0]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False]])


## (2) Dataset for regression

TextDataset: 

digest `.tsv` file and return:

`inputs = encode("x + y = ?"), targets = float(z)`


## (3) TinyTransformer

**Embeddings + positions**

`h = self.tok(x) + self.pos(pos)`
- `tok` is an embedding matrix $E \in \mathbb{R}^{V \times d}$. For token id $x_t$, 
\begin{equation}
e_t = E[x_t]
\end{equation}

- `pos` is a learned positional embedding $P \in \mathbb{R}^{T_{max} \times d}$. For position $t$, 
\begin{equation}
p_t = P[t] \in \mathbb{R}^d
\end{equation}

**Transformer encoder layer (per layer $l$)**

Each layer does self‑attention + MLP, with residuals and layer norms (PyTorch’s TransformerEncoderLayer handles the exact pre/post‑norm order):

1. **Self-attension:**

- Project to queries, keys and values:
\begin{equation}
Q=HW_Q , \quad K=HW_K, \quad V=HW_V, \quad H \in \mathbb{R}^{T \times d}.
\end{equation}

- Attention scores:
\begin{equation}
A=\text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} + \text{mask} \right)
\end{equation}
where $\text{mask}$ sets $-\infty$ on padded positions so they get zero attention.

- Context:
\begin{equation}
\text{Attention}(H)=AV
\end{equation}

- Multi-head: repeat in parallel over $H$ heads, then concat and linearly mix.

2. **Feed-forward (MLP):**
\begin{equation}
MLP(h)=W_2 GeLU(W_1 h + b_1) + b_2
\end{equation}

3. **Pooling and regression**
- Get the last token vector in $\mathbb{R}^d$: `pooled = h[:, -1, :]`
- Final prediction is linear in that representation:
\begin{equation}
y = w^T h_T^{(L)} + b
\end{equation}

## (4) Training 

`S = 2000.0`

`y = (true_sum) / S`

Raw sums are $\approx 0 - 2000$. Scaling makes targets $\in [0,1]$, stabilizing optimization (step sizes, gradients).

**Loss**

For a batch with batch size $B$,
\begin{equation}
L = \frac{1}{B} \sum_{i=1}^B (\hat{y}_i - y_i)^2
\end{equation}

## (5) Padding mask

`kpm = ~(x != 0)`

`h = enc(h, src_key_padding_mask=kpm)`

This tells attention to ignore positions where token id equals 0.

In [1]:
# Train the tiny transformer on the REG-SUM task
!python3 -m src.train_transformer --task REG-SUM --epochs 10

  output = torch._nested_tensor_from_mask(
Epoch 1| train loss: 0.9369 val loss: 0.0921
Epoch 2| train loss: 0.0550 val loss: 0.0243
Epoch 3| train loss: 0.0268 val loss: 0.0127
Epoch 4| train loss: 0.0149 val loss: 0.0036
Epoch 5| train loss: 0.0093 val loss: 0.0018
Epoch 6| train loss: 0.0066 val loss: 0.0015
Epoch 7| train loss: 0.0064 val loss: 0.0012
Epoch 8| train loss: 0.0061 val loss: 0.0013
Epoch 9| train loss: 0.0053 val loss: 0.0015
Epoch 10| train loss: 0.0054 val loss: 0.0028
Best val MSE: 0.001234409105964005


In [3]:
!python -m src.plot_reports --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(


# `run_probes.py`

**Question:** How linearly recoverable is the true sum $s=x+y$ from the model's hidden states?

**Idea:**

- Run the trained transformer on the validation set to get hidden states $h \in \mathbb{R}^{B \times T \times D}$.
- Build targets $s_i = x_i + y_i$ for each example $i$.
- For a chosen representation $H$ (e.g. the last token's hidden state, shape $B \times D$), fit a linear ridge regressor $\hat{s} = H w + b$.
- Report $R^2$ on the same data. High $R^2 \implies$ the sum is a linear feature in that representation.

Model to train probe:

\begin{equation}
\min_{w,b} || Hw + b -s||^2_2 + \alpha ||w||^2_2
\end{equation}



## What a linear probe does?

Suppose we have hidden states $h_i \in \mathbb{R}^D$ from our transformer, one per example $i$. We also have labels $s_i=x_i + y_i$ (the true sum).

A linear probe fits:
\begin{equation}
\hat{s}_i = w^T h_i + b
\end{equation}

This is just linear regression (ridge if adding L2 regularization).

The $R^2$ score is:
\begin{equation}
R^2 = 1 - \frac{\sum_i (s_i - \hat{s}_i)^2}{\sum_i (s_i - \bar{s})^2}
\end{equation}
- Numerator: residual error of the probe.
- Denominator: variance of the true sums.
- $R^2$ measures the fraction of variance in the true sum values that can be explained by a linear function of the hidden state.

**Interpretation**:

- If $R^2 \approx 1.0$: the sum is linearly encoded in the hidden states. That means the information “what is $x+y$?” exists as a nearly linear direction in the space.
- If $R^2 \approx 0$:  the hidden state has no better linear correlation with the sum than just predicting the mean. The information may still be there, but in a non-linear form the probe can’t extract.

So probe $R^2$ is not about whether the model itself predicts well — it’s about whether you, as an outside observer, can read off the sum linearly from the internal activations.

In [4]:
!python3 -m src.run_probes --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved probe results to analysis/REG-SUM_probe_results.json
{'pooled_r2': 0.9862602949142456, 'pos_r2': [0.8318772912025452, 0.8069847822189331, 0.9027222990989685, 0.7592655420303345, 0.9220019578933716, 0.8171101808547974, 0.9003023505210876, 0.8351649045944214, 0.8447633981704712, 0.7930974960327148, 0.7784875631332397, 0.8305479288101196, 0.9070967435836792]}


In [1]:
# visualise probe results
!python3 -m src.plot_probes

Saved plot to analysis/REG-SUM_probe_curve.png


# `run_patching.py`

**Question**: If we remove the "sum direction" from the representation, does performance drop? If yes, the direction is not just correlated with the sum (probe $R^2$) -- it is causally important.

**Pipeline**

1. Load val data (e.g. REG-SUM_val.tsv), encode, pad
2. Load the trained model (e.g. REG-SUM_best.pt)
3. Forward pass with `return_h=True` and get hidden states $h \in \mathbb{R}^{B \times T \times D}$
4. Build targets `sums = x + y` by parsing the raw strings (regex).
5. Fit a probe on the pooled hidden states (last real token) to predict `sums`. The probe’s weight vector is the sum direction $v \in \mathbb{R}^D$.
6. Project out the component of each pooled vector along $v$:
\begin{equation}
    h' = h - (h \cdot \hat{v}) \hat{v} , \quad \hat{v}=\frac{v}{||v||}
\end{equation}
7. Measure task performance before/after by training a simple ridge readout from the pooled vectors to the ground‑truth target `y_true` (the TSV value) and comparing MSE:
    - `ridge_head_mse_before` (before)
    - `ridge_head_mse_after` (after)
    - `delta_mse = after − before`
    
If `delta_mse >> 0`, that’s causal evidence the sum direction matters.

In [3]:
!python3 -m src.run_patching --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved to analysis/REG-SUM_patch_results.json
{'ridge_head_mse_before': 2440.16650390625, 'ridge_head_mse_after': 9711.5869140625, 'delta_mse': 7271.42041015625}
