# `model.py`

## (1) Character encoding + padding

Map each character "c" in the input string to an integer id i (e.g. "1" is mapped to 2; "0" is mapped to 1; " " is mapped to 11; "?" is mapped to 15). 

This leaves that 0 can be used for padding only.

This is a discrete representation of the input string for the model to learn.

In [1]:
from src.model import encode
s1 = "103 + 40 = ?"
s2 = "50 + 7 = ?"
print(encode(s1))
print(encode(s2))

tensor([ 2,  1,  4, 11, 12, 11,  5,  1, 11, 14, 11, 15])
tensor([ 6,  1, 11, 12, 11,  8, 11, 14, 11, 15])


Batches need a common length $T$, so we right-pad short sequences with `pad_id=0`.

The boolean `mask` marks which positions are **real tokens** vs **padding**.

In [2]:
from src.model import pad_batch
xs = [encode(s1), encode(s2)]
xs_padded, mask = pad_batch(xs, pad_id=0)
print(xs_padded)
print(mask)

tensor([[ 2,  1,  4, 11, 12, 11,  5,  1, 11, 14, 11, 15],
        [ 6,  1, 11, 12, 11,  8, 11, 14, 11, 15,  0,  0]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         False, False]])


## (2) Dataset for regression

TextDataset: 

digest `.tsv` file and return:

`inputs = encode("x + y = ?"), targets = float(z)`


## (3) TinyTransformer

**Embeddings + positions**

`h = self.tok(x) + self.pos(pos)`
- `tok` is an embedding matrix $E \in \mathbb{R}^{V \times d}$. For token id $x_t$, 
\begin{equation}
e_t = E[x_t]
\end{equation}

- `pos` is a learned positional embedding $P \in \mathbb{R}^{T_{max} \times d}$. For position $t$, 
\begin{equation}
p_t = P[t] \in \mathbb{R}^d
\end{equation}

**Transformer encoder layer (per layer $l$)**

Each layer does self‑attention + MLP, with residuals and layer norms (PyTorch’s TransformerEncoderLayer handles the exact pre/post‑norm order):

1. **Self-attension:**

- Project to queries, keys and values:
\begin{equation}
Q=HW_Q , \quad K=HW_K, \quad V=HW_V, \quad H \in \mathbb{R}^{T \times d}.
\end{equation}

- Attention scores:
\begin{equation}
A=\text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} + \text{mask} \right)
\end{equation}
where $\text{mask}$ sets $-\infty$ on padded positions so they get zero attention.

- Context:
\begin{equation}
\text{Attention}(H)=AV
\end{equation}

- Multi-head: repeat in parallel over $H$ heads, then concat and linearly mix.

2. **Feed-forward (MLP):**
\begin{equation}
MLP(h)=W_2 GeLU(W_1 h + b_1) + b_2
\end{equation}

3. **Pooling and regression**
- Get the last token vector in $\mathbb{R}^d$: `pooled = h[:, -1, :]`
- Final prediction is linear in that representation:
\begin{equation}
y = w^T h_T^{(L)} + b
\end{equation}

## (4) Training 

`S = 2000.0`

`y = (true_sum) / S`

Raw sums are $\approx 0 - 2000$. Scaling makes targets $\in [0,1]$, stabilizing optimization (step sizes, gradients).

**Loss**

For a batch with batch size $B$,
\begin{equation}
L = \frac{1}{B} \sum_{i=1}^B (\hat{y}_i - y_i)^2
\end{equation}

## (5) Padding mask

`kpm = ~(x != 0)`

`h = enc(h, src_key_padding_mask=kpm)`

This tells attention to ignore positions where token id equals 0.

In [1]:
# Train the tiny transformer on the REG-SUM task
!python3 -m src.train_transformer --task REG-SUM --epochs 10

  output = torch._nested_tensor_from_mask(
Epoch 1| train loss: 0.0896 val loss: 0.0011
Epoch 2| train loss: 0.0032 val loss: 0.0003
Epoch 3| train loss: 0.0020 val loss: 0.0003
Epoch 4| train loss: 0.0015 val loss: 0.0002
Epoch 5| train loss: 0.0011 val loss: 0.0004
Epoch 6| train loss: 0.0009 val loss: 0.0003
Epoch 7| train loss: 0.0008 val loss: 0.0004
Epoch 8| train loss: 0.0006 val loss: 0.0002
Epoch 9| train loss: 0.0006 val loss: 0.0003
Epoch 10| train loss: 0.0005 val loss: 0.0003
Saved best model checkpoint in checkpoints/REG-SUM_best.pt
Best val MSE: 0.00021807662420906127
Saved training history.


In [3]:
!python -m src.plot_reports --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(


# `run_probes.py`

**Question:** How linearly recoverable is the true sum $s=x+y$ from the model's hidden states?

**Idea:**

- Run the trained transformer on the validation set to get hidden states $h \in \mathbb{R}^{B \times T \times D}$.
- Build targets $s_i = x_i + y_i$ for each example $i$.
- For a chosen representation $H$ (e.g. the last token's hidden state, shape $B \times D$), fit a linear ridge regressor $\hat{s} = H w + b$.
- Report $R^2$ on the same data. High $R^2 \implies$ the sum is a linear feature in that representation.

Model to train probe:

\begin{equation}
\min_{w,b} || Hw + b -s||^2_2 + \alpha ||w||^2_2
\end{equation}



## What a linear probe does?

Suppose we have hidden states $h_i \in \mathbb{R}^D$ from our transformer, one per example $i$. We also have labels $s_i=x_i + y_i$ (the true sum).

A linear probe fits:
\begin{equation}
\hat{s}_i = w^T h_i + b
\end{equation}

This is just linear regression (ridge if adding L2 regularization).

The $R^2$ score is:
\begin{equation}
R^2 = 1 - \frac{\sum_i (s_i - \hat{s}_i)^2}{\sum_i (s_i - \bar{s})^2}
\end{equation}
- Numerator: residual error of the probe.
- Denominator: variance of the true sums.
- $R^2$ measures the fraction of variance in the true sum values that can be explained by a linear function of the hidden state.

**Interpretation**:

- If $R^2 \approx 1.0$: the sum is linearly encoded in the hidden states. That means the information “what is $x+y$?” exists as a nearly linear direction in the space.
- If $R^2 \approx 0$:  the hidden state has no better linear correlation with the sum than just predicting the mean. The information may still be there, but in a non-linear form the probe can’t extract.

So probe $R^2$ is not about whether the model itself predicts well — it’s about whether you, as an outside observer, can read off the sum linearly from the internal activations.

In [4]:
!python3 -m src.run_probes --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved probe results to analysis/REG-SUM_probe_results.json
{'pooled_r2': 0.9981094002723694, 'pos_r2': [0.935593843460083, 0.9016793966293335, 0.9020518064498901, 0.9338340759277344, 0.9268955588340759, 0.9428122043609619, 0.915154755115509, 0.8832254409790039, 0.943127453327179, 0.9334304928779602, 0.9299917221069336, 0.9501594305038452, 0.9039064645767212]}


In [87]:
# visualise probe results
!python3 -m src.plot_probes --task REG-SUM

Saved plot to src/figures/REG-SUM_probe_curve.png
Plotted results from src/output/REG-SUM_probe_results.json and saved to src/figures/REG-SUM_probe_curve.png


# `run_patching.py`

**Question**: If we remove the "sum direction" from the representation, does performance drop? If yes, the direction is not just correlated with the sum (probe $R^2$) -- it is causally important.

**Pipeline**

1. Load val data (e.g. REG-SUM_val.tsv), encode, pad
2. Load the trained model (e.g. REG-SUM_best.pt)
3. Forward pass with `return_h=True` and get hidden states $h \in \mathbb{R}^{B \times T \times D}$
4. Build targets `sums = x + y` by parsing the raw strings (regex).
5. Fit a probe on the pooled hidden states (last real token) to predict `sums`. The probe’s weight vector is the sum direction $v \in \mathbb{R}^D$.
6. Project out the component of each pooled vector along $v$:
\begin{equation}
    h' = h - (h \cdot \hat{v}) \hat{v} , \quad \hat{v}=\frac{v}{||v||}
\end{equation}
7. Measure task performance before/after by training a simple ridge readout from the pooled vectors to the ground‑truth target `y_true` (the TSV value) and comparing MSE:
    - `ridge_head_mse_before` (before)
    - `ridge_head_mse_after` (after)
    - `delta_mse = after − before`
    
If `delta_mse >> 0`, that’s causal evidence the sum direction matters.

In [6]:
!python3 -m src.run_patching --task REG-SUM

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved to analysis/REG-SUM_patch_results.json
{'ridge_head_mse_before': 340.789306640625, 'ridge_head_mse_after': 1649.2431640625, 'delta_mse': 1308.453857421875}


# Repeat above for `REG-MODK`

In [7]:
!python3 -m src.train_transformer --task REG-MODK --epochs 10

  output = torch._nested_tensor_from_mask(
Epoch 1| train loss: 0.1007 val loss: 0.0003
Epoch 2| train loss: 0.0018 val loss: 0.0002
Epoch 3| train loss: 0.0011 val loss: 0.0002
Epoch 4| train loss: 0.0007 val loss: 0.0002
Epoch 5| train loss: 0.0006 val loss: 0.0002
Epoch 6| train loss: 0.0005 val loss: 0.0003
Epoch 7| train loss: 0.0004 val loss: 0.0003
Epoch 8| train loss: 0.0004 val loss: 0.0002
Epoch 9| train loss: 0.0004 val loss: 0.0005
Epoch 10| train loss: 0.0004 val loss: 0.0002
Saved best model checkpoint in checkpoints/REG-MODK_best.pt
Best val MSE: 0.00019462584785651416 at epoch 5
Saved training history.


In [8]:
!python3 -m src.plot_reports --task REG-MODK

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(


In [85]:
!python3 -m src.run_probes --task REG-MODK

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved probe results to src/output/REG-MODK_probe_results.json
{'pooled_r2': 0.6278746724128723, 'pos_r2': [0.773277223110199, 0.6225485801696777, 0.6961795687675476, 0.7110834121704102, 0.6489095687866211, 0.5991976261138916, 0.7543108463287354, 0.5924468040466309, 0.6867231130599976, 0.6317974328994751, 0.6239009499549866, 0.660067081451416, 0.5272255539894104, 0.6324349641799927, 0.5426217913627625, 0.6950309872627258, 0.7837653756141663, 0.5850135087966919, 0.6281089186668396, 0.7277127504348755, 0.6967750787734985, 0.627521276473999]}


In [86]:
!python3 -m src.plot_probes --task REG-MODK

Saved plot to src/figures/REG-MODK_probe_curve.png
Plotted results from src/output/REG-MODK_probe_results.json and saved to src/figures/REG-MODK_probe_curve.png


In [10]:
!python3 -m src.run_patching --task REG-MODK

  model.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved to analysis/REG-MODK_patch_results.json
{'ridge_head_mse_before': 718.4989624023438, 'ridge_head_mse_after': 718.4767456054688, 'delta_mse': -0.022216796875}


# Design experiments

The comparison between linear probe results from `REG-SUM` and `REG-MODK` raises a natural question: why it happens? what can we learn from it?

## Experiment 1: Higher-dim (Fourier) probes

**Design:**
- Fits multi-output ridge probes that predict $[\cos n\theta, \sin n \theta]_{n=1,...,N}$ from the pooled hidden state.
- Sweeps $N=1,...,N_{max}$ and logs the macro $R^2$ VS dimension $2N$.
- Tests small-k periodicities on the same hidden states (using $k' \in \{2,3,4,5,7,8\}$, configurable)
- Run causal ablation by projecting out the learned subspace (columns of the probe’s weight matrix) and reports ΔMSE vs dimension. Includes a random subspace control.
- Saves everything to `analysis/output/{TASK}_multidim_results.json`.

In [2]:
!python3 -m analysis.run_multidim_probes --task REG-SUM

  state = torch.load(ckpt, map_location=device)
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-SUM_multidim_results.json
Results: {'task': 'REG-SUM', 'alpha': 1.0, 'scalar_probe_r2': 0.9981094002723694, 'base_mse': 321.7313232421875, 'delta_mse_probe_1d': 837.07373046875, 'delta_mse_random_1d': 2.67425537109375, 'excess_delta_mse': 834.3994750976562}


In [3]:
!python3 -m analysis.run_multidim_probes --task REG-MODK

  state = torch.load(ckpt, map_location=device)
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_multidim_results.json
Results: {'task': 'REG-MODK', 'alpha': 1.0, 'Nmax': 12, 'fourier_r2_curve': [{'N': 1, 'dim': 2, 'r2': 0.04383407823837304}, {'N': 2, 'dim': 4, 'r2': 0.04589766628046446}, {'N': 3, 'dim': 6, 'r2': 0.044180247043566635}, {'N': 4, 'dim': 8, 'r2': 0.042627249895444294}, {'N': 5, 'dim': 10, 'r2': 0.040618258514046476}, {'N': 6, 'dim': 12, 'r2': 0.0414202961216805}, {'N': 7, 'dim': 14, 'r2': 0.040260831233690356}, {'N': 8, 'dim': 16, 'r2': 0.04047620684812237}, {'N': 9, 'dim': 18, 'r2': 0.04046405870040577}, {'N': 10, 'dim': 20, 'r2': 0.0404618721093313}, {'N': 11, 'dim': 22, 'r2': 0.04088096411394192}, {'N': 12, 'dim': 24, 'r2': 0.04087226328436625}], 'smallk_r2': [{'kprime': 2, 'N': 1, 'dim': 2, 'r2': 0.04226480344607503}, {'kprime': 2, 'N': 2, 'dim': 4, 'r2': 0.28169860131798863}, {'kprime': 2, 'N': 3, 'dim': 6, 'r2': 0.2018873305894168}, {'kprim

**Plots**:
- Fourier $R^2$ VS dimension ($2N$)
- $\Delta MSE$ VS dimension ($2N$) (with random‑subspace control)
- Small-k $R^2$ (k' sweep)

In [4]:
!python3 -m analysis.plot_multidim --task REG-SUM

Scalar probe R^2: 0.9981
Base MSE: 321.73
ΔMSE (probe 1D): 837.07
ΔMSE (random 1D): 2.67
Excess ΔMSE: 834.40
Saved REG-SUM_scalar_causal.png


In [93]:
!python3 -m analysis.plot_multidim --task REG-MODK

Saved plots to analysis/output


## Experiment 2: Nonlinear probes

In [1]:
# REG-SUM (should see near-equal R^2 for ridge and MLP)
!python3 -m analysis.run_nonlinear_probes --task REG-SUM --shuffle_control

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-SUM_nonlinear_probe_results.json
{'task': 'REG-SUM', 'ridge_r2': 0.9983693957328796, 'mlp_reg_r2': 0.9989134669303894, 'mlp_reg_r2_label_shuffle': 0.11715453863143921}


In [16]:
# REG-MODK (key test!)
!python3 -m analysis.run_nonlinear_probes --task REG-MODK --hidden 128 --depth 1 --shuffle_control


  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_nonlinear_probe_results.json
{'task': 'REG-MODK', 'k': 97, 'ridge_r2_numeric_residue': -0.0024858713150024414, 'mlp_reg_r2_numeric_residue': 0.11005878448486328, 'mlp_cls_accuracy_residue': 0.0125, 'mlp_cls_accuracy_label_shuffle': 0.015}


Try a couple of settings:

In [18]:
!python3 -m analysis.run_nonlinear_probes --task REG-MODK --hidden 128 --depth 2 --shuffle_control

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_nonlinear_probe_results.json
{'task': 'REG-MODK', 'k': 97, 'ridge_r2_numeric_residue': -0.0024858713150024414, 'mlp_reg_r2_numeric_residue': 0.37926846742630005, 'mlp_cls_accuracy_residue': 0.015, 'mlp_cls_accuracy_label_shuffle': 0.0125}


In [19]:
!python3 -m analysis.run_nonlinear_probes --task REG-MODK --hidden 256 --depth 2 --shuffle_control

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_nonlinear_probe_results.json
{'task': 'REG-MODK', 'k': 97, 'ridge_r2_numeric_residue': -0.0024858713150024414, 'mlp_reg_r2_numeric_residue': 0.4106581211090088, 'mlp_cls_accuracy_residue': 0.0075, 'mlp_cls_accuracy_label_shuffle': 0.01}


In [15]:
!python3 -m analysis.run_nonlinear_probes --task REG-MODK --hidden 256 --depth 4 --shuffle_control 

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_nonlinear_probe_results.json
{'task': 'REG-MODK', 'k': 97, 'ridge_r2_numeric_residue': -0.0024858713150024414, 'mlp_reg_r2_numeric_residue': 0.4666404128074646, 'mlp_cls_accuracy_residue': 0.02, 'mlp_cls_accuracy_label_shuffle': 0.01}


In [21]:
!python3 -m analysis.run_nonlinear_probes --task REG-MODK --hidden 512 --depth 4 --shuffle_control 

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Saved: analysis/output/REG-MODK_nonlinear_probe_results.json
{'task': 'REG-MODK', 'k': 97, 'ridge_r2_numeric_residue': -0.0024858713150024414, 'mlp_reg_r2_numeric_residue': 0.4383242130279541, 'mlp_cls_accuracy_residue': 0.01, 'mlp_cls_accuracy_label_shuffle': 0.015}


## Experiment 3: Scaling the transformer

### (1) Train a grid of models

In [27]:
!python3 -m analysis.train_scale_grid --task REG-MODK --epochs 10

==> Training REG-MODK_L2_D128_H4_S0
  output = torch._nested_tensor_from_mask(
Epoch 1| train loss: 0.0298 val loss: 0.0005
Epoch 2| train loss: 0.0026 val loss: 0.0002
Epoch 3| train loss: 0.0019 val loss: 0.0002
Epoch 4| train loss: 0.0014 val loss: 0.0003
Epoch 5| train loss: 0.0011 val loss: 0.0003
Epoch 6| train loss: 0.0009 val loss: 0.0003
Epoch 7| train loss: 0.0007 val loss: 0.0002
Epoch 8| train loss: 0.0006 val loss: 0.0002
Epoch 9| train loss: 0.0005 val loss: 0.0002
Epoch 10| train loss: 0.0005 val loss: 0.0002
Saved best model checkpoint in analysis/models_scale/REG-MODK_L2_D128_H4_S0/REG-MODK_best.pt
Best val MSE: 0.00019349552050698548 at epoch 10
Saved training history.
==> Training REG-MODK_L2_D256_H4_S0
Epoch 1| train loss: 0.0506 val loss: 0.0002
Epoch 2| train loss: 0.0023 val loss: 0.0002
Epoch 3| train loss: 0.0016 val loss: 0.0004
Epoch 4| train loss: 0.0012 val loss: 0.0002
Epoch 5| train loss: 0.0010 val loss: 0.0003
Epoch 6| train loss: 0.0008 val loss: 0.000

### (2) Analyze each checkpoint

In [34]:
!python3 -m analysis.analyze_scale --task REG-MODK --Nmax 12

  m.load_state_dict(torch.load(ckpt, map_location=device))
  output = torch._nested_tensor_from_mask(
Analyzed: REG-MODK_L2_D128_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L2_D256_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L2_D512_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L4_D128_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L4_D256_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L4_D512_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L6_D128_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L6_D256_H4_S0
  m.load_state_dict(torch.load(ckpt, map_location=device))
Analyzed: REG-MODK_L6_D512_H4_S0
Saved index: analysis/analysis_scale/REG-MODK_scale_index.json


### (3) Plot scaling trends

In [35]:
!python3 -m analysis.plot_scale --task REG-MODK

Saved plots: analysis/analysis_scale


## Experiment 4: Attention head analysis

In [80]:
# REG-SUM
!python3 -m analysis.run_head_analysis --task REG-SUM  --sample_idx 0 1 2

  model.load_state_dict(torch.load(ckpt, map_location=device))
  lengths=torch.tensor(lengths, device=device),
Saved head stats to analysis/analysis_heads/REG-SUM_head_stats.json
Saved plots to analysis/analysis_heads
Saved layer×head heatmaps to analysis/analysis_heads
Saved layer×position sweep heatmaps to analysis/analysis_heads


In [81]:
# REG-MODK
!python3 -m analysis.run_head_analysis --task REG-MODK --sample_idx 0 1 2

  model.load_state_dict(torch.load(ckpt, map_location=device))
  lengths=torch.tensor(lengths, device=device),
Saved head stats to analysis/analysis_heads/REG-MODK_head_stats.json
Saved plots to analysis/analysis_heads
Saved layer×head heatmaps to analysis/analysis_heads
Saved layer×position sweep heatmaps to analysis/analysis_heads
