# Logit and pos encoding intervention

Force models to output past end-of-sequence token (`$`) be modyfing its logit directly, or by shifting the positional encoding indices of the answer generated so far to force. Aim is to see if they can correctly predict the answer digits in OOD positions. If yes, then it could be that the position-dependent circuit decides to stop the generation by outputting the EOS token and not the addition circuit itself failing.

In [2]:
import re
import random

import torch
from pathlib import Path
import matplotlib.pyplot as plt

from arithmetic_lm.model import (
    TransformerDecoder,
    load_model,
    find_latest_ckpt,
    generate,
)
from arithmetic_lm.tokenizer import CharTokenizer
from arithmetic_lm.interp import plot_attn_maps
from arithmetic_lm.constants import PLOTS_DIR, CHECKPOINTS_DIR

import warnings

warnings.filterwarnings("ignore")

In [3]:
tokenizer = CharTokenizer()
stop_token = tokenizer.encode("$")[0]

In [4]:
# load model
ckpt_path = "../checkpoints/addition-generalize-to-longer/trans_dec_6layers_768embd_4head_randsp0.5_rev_ansloss/step1000000-train_loss0.0002-val_loss0.0000.ckpt"
model, hparams = load_model(ckpt_path)
model.to("mps")
model.eval()

TransformerDecoder(
  (embedding): Embedding(100, 768)
  (pos_encoder): AbsolutePositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=3072, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=100, bias=False)

In [25]:
a = "999999999999456789123945678956"
b = "999999999999456789123745678995"
prompt = f"${a[::-1]}+{b[::-1]}="
true_ans = str(eval(f"{a}+{b}"))[::-1]

carries = ""
for ai, bi in zip(a[::-1], b[::-1]):
    if int(ai) + int(bi) >= 10:
        carries += "1"
    else:
        carries += "0"

prompt_idx = torch.tensor(tokenizer.encode(prompt, return_tensors=True)).to("mps")

pred_tensor = generate(
    model,
    idx=prompt_idx,
    max_new_tokens=25,
    stop_token=stop_token,
)
pred = tokenizer.decode(pred_tensor[0])

len_a, len_b = map(
    len, prompt.replace("$", "").replace("=", "").replace(" ", "").split("+")
)
print(f"Prompt:           {prompt} | {len_a}+{len_b}")
print(f"True answer:      {true_ans} | {len(true_ans)}")
print(f"Predicted answer: {pred} | {len(pred.replace('$', ''))}")
print(f"Carries:          {carries}")
print(f"Correct:          {pred.replace('$', '') == true_ans}")


# logit-level intervention
print("\n======\n")

ans_prompt = prompt + "1597531967428753"
print(f"Prompt: {ans_prompt}")

ans_prompt_idx = torch.tensor(tokenizer.encode(ans_prompt, return_tensors=True)).to(
    "mps"
)
pred_logits = model(ans_prompt_idx.unsqueeze(0))

# print sorted logits and their corresponding tokens
print("probs:")
logits, indices = torch.sort(pred_logits[0, -1], descending=True)
probs = torch.softmax(logits, dim=0)

for i in range(10):
    print(f"  {tokenizer.decode([indices[i].item()])}: {probs[i].item():.5f}")

# prob of stop token ($)
print("stop token:")
print(f"  {tokenizer.decode([stop_token])}: {probs[stop_token].item():.5f}")

Prompt:           $659876549321987654999999999999+599876547321987654999999999999= | 30+30
True answer:      1597531967428753198999999999991 | 31
Predicted answer: 1597874300087540040040400 | 25
Carries:          111111101000111110111111111111
Correct:          False


Prompt: $659876549321987654999999999999+599876547321987654999999999999=1597531967428753
probs:
  4: 0.27043
  9: 0.15192
  5: 0.12490
  6: 0.08004
  8: 0.07846
  3: 0.07484
  7: 0.05977
  2: 0.05890
  0: 0.05427
  1: 0.04066
stop token:
  $: 0.00001


(WIP) Hypothesis: for 18 digits OOD it might work (?), but for 20 the model is "approximating" the last digits i.e. 
```
Prompt:           $12345678901234567890+92345678901234567899= | 20+20
True answer:      104691357802469135789 | 21
Predicted answer: 10469135780246913569$ | 20
```
it predicts ...3569 instead of ...35789, closest in-distribution length is 19 and that would end in ...357, which is approximated to ...369. During training it succesively approximates more and more digits, but there is no pressure to correctly get last OOD digits.