# LM does addition with trignometry

## Setup

### Imports

In [2]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
# from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
# from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
# from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

  from .autonotebook import tqdm as notebook_tqdm


### Helper functions

In [52]:
import inspect

### Outputing tensor shape

def s(tensor):
    """
    Simple helper function to print the shape of a tensor.
    
    Args:
        tensor: A PyTorch tensor or any object with a .shape attribute
    
    Example:
        attnout = torch.randn(32, 768)
        s(attnout)  # Output: shape of attnout is torch.Size([32, 768])
    """
    # Get the name of the variable from the caller's frame
    frame = inspect.currentframe().f_back
    calling_line = inspect.getframeinfo(frame).code_context[0].strip()
    # Extract variable name from the function call
    # This looks for s(variable_name) pattern
    import re
    match = re.search(r's\((.*?)\)', calling_line)
    if match:
        var_name = match.group(1).strip()
    else:
        var_name = "tensor"
        
    if hasattr(tensor, 'shape'):
        print(f"Shape of [{var_name}]: {tensor.shape}")
    else:
        print(f"{var_name} has no shape attribute. Type: {type(tensor)}")
        
        
### Check GPU memory usage
def print_gpu_memory():
    if t.cuda.is_available():
        for i in range(t.cuda.device_count()):
            total = t.cuda.get_device_properties(i).total_memory / 1024**3  # Convert to GB
            reserved = t.cuda.memory_reserved(i) / 1024**3
            allocated = t.cuda.memory_allocated(i) / 1024**3
            print(f"GPU {i}:")
            print(f"  Total Memory: {total:.2f} GB")
            print(f"  Reserved Memory: {reserved:.2f} GB")
            print(f"  Allocated Memory: {allocated:.2f} GB")
            print(f"  Free Memory: {total - reserved:.2f} GB")

### Load the model

```python
gemma = "gemma-2-2b"
gemma_saes = "gemma_saes"
gpt = "gpt-j-6B"
```

In [6]:
LAYER = 20
device = "cuda"
gemma = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)
gemma_saes = [
    SAE.from_pretrained(
        "gemma-scope-2b-pt-res-canonical",
        f"layer_{i}/width_16k/canonical",
        device=str(device)
    )[0]
    for i in tqdm(range(gemma.cfg.n_layers))
]

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 94.21it/s]


Loaded pretrained model gemma-2-2b into HookedTransformer


100%|██████████| 26/26 [00:14<00:00,  1.81it/s]


In [95]:
gpt = HookedTransformer.from_pretrained_no_processing("gpt-j-6B", device="cuda", dtype=t.float32)

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 47.40 GiB of which 171.69 MiB is free. Process 1431206 has 47.22 GiB memory in use. Of the allocated memory 46.13 GiB is allocated by PyTorch, and 608.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Performance check

### Functions for generating prompts

1. Natural language prompts: *a plus b is*
2. Symbol prompts: *a + b =*
3. Natural language prompts with instructions: *a plus b is c, d plus e is*

In [57]:
from collections import namedtuple

AnsConfig = namedtuple("AnsConfig", ["a", "b", "operation", "ans"])

def prompt_generator(
    n_range: int = 100,
    op: list[str] = ["plus"],
    n_batch: int = 100,
    return_type: Literal["string", "token"] = "string",
    write_to_file: bool = False,
    file_path: str = "addition_prompts.txt",
    with_instructions: bool = False,
    with_symbols: bool = False,
) -> tuple[list[str], list[AnsConfig]]:
    """Generates a list of arithmetic questions and their answers.
    """
    a = t.randint(0, n_range, (n_batch,))
    b = t.randint(0, n_range, (n_batch,))
    
    a_instr = t.randint(0, n_range, (n_batch,))
    b_instr = t.randint(0, n_range, (n_batch,))
    
    ans_list = []
    q_list = []
    
    
    with open(file_path, "w") as f:
        for i in range(n_batch):
            
            
            operation = random.choice(op)

            if with_symbols:
                equal_str = "="
                if operation == "plus":
                    operation_str = "+"
                elif operation == "minus":
                    operation_str = "-"
                elif operation == "times":
                    operation_str = "*"
                elif operation == "divided by":
                    operation_str = "/"
                else:
                    raise ValueError("Operation type not recognized")
            else:
                equal_str = "is"

            # log the correct answer
            if operation == "plus":
                answer = a[i] + b[i]
                inst_answer = a_instr[i] + b_instr[i]
            elif operation == "minus":
                answer = a[i] - b[i]
                inst_answer = a_instr[i] - b_instr[i]
            elif operation == "times":
                answer = a[i] * b[i]
                inst_answer = a_instr[i] * b_instr[i]
            # elif operation == "divided by":
            #     answer = a[i] / b[i]
            
            if with_instructions:
                q_list.append(
                    f"{a_instr[i].item()} {operation_str if with_symbols else operation} {b_instr[i].item()} {equal_str} {inst_answer.item()}, {a[i].item()} {operation_str if with_symbols else operation} {b[i].item()} {equal_str}"
                )
            else:
                q_list.append(
                    f"{a[i].item()} {operation_str if with_symbols else operation} {b[i].item()} {equal_str}"
                )

            if write_to_file:
                f.write(q_list[-1] + "\n")
            
            ans_list.append(
                AnsConfig(
                    a=a[i].item(),
                    b=b[i].item(),
                    operation=operation,
                    ans=answer.item()
                )
            )
    
    return q_list, ans_list

In [94]:
# Individual example using gpt-j
question_tokens = gpt.to_tokens(
    "29 + 21 ="
)
max_new_tokens = 10
ans_tokens = gpt.generate(
    question_tokens,
    max_new_tokens = max_new_tokens,
    do_sample=False
)
gpt.to_str_tokens(ans_tokens)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 29.89it/s]


['<|endoftext|>',
 '29',
 ' +',
 ' 21',
 ' =',
 ' -',
 '2',
 '*',
 'z',
 '.',
 ' Let',
 ' y',
 ' =',
 ' z',
 ' -']

In [78]:
t.cuda.empty_cache()
q_list, a_list = prompt_generator(with_instructions=False, n_batch=100, write_to_file=True, with_symbols=True)

models = [gemma, gpt]
models_str = ["gemma-2-2b", "gpt-j-6b"]
max_new_tokens_list = [4, 1]

for model_str, model, max_new_tokens in zip(models_str, models, max_new_tokens_list):
    question_tokens = model.to_tokens(q_list)
    answer_tokens = model.generate(
        question_tokens,
        max_new_tokens=max_new_tokens,
        do_sample=False,
    )
    
    # process the answers
    ans = [
        answer_tokens[i, -max_new_tokens:].tolist() for i in range(answer_tokens.shape[0])
    ]
    ans = model.to_string(ans)
    print(ans)
    correct = [
        str(a_list[i].ans) in ans[i] for i in range(len(ans))
    ]
    correct = t.tensor(correct)
    acc = correct.sum() / correct.shape[0]
    print(f"Accuracy of {model_str}: {acc.item():.2%}")

  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:00<00:00, 13.65it/s]


[' 43\n\n', ' 149', ' 92\n\n', ' 97\n\n', '\n\n10 +', ' 73\n\n', ' 51 +', ' 92\n\n', ' 81\n\n', ' 38\n\n', ' 119', ' 96\n\n', ' 119', ' 77 +', ' 105', '\n\n67 +', ' 128', ' 128', '\n\n9 + ', ' 77\n\n', ' 163', ' 110', ' 143', ' 61\n\n', ' 129', ' 168', ' 137', ' 146', ' 66\n\n', ' 117', ' 96\n\n', ' 66\n\n', ' 104', '.\n\n10', ' 106', ' 126', ' 54\n\n', ' 44\n\n', ' 72\n\n', ' 96\n\n', ' 28\n\n', ' 155', ' 149', ' 140', ' 65\n\n', ' 93\n\n', '\n\n35 +', ' 76 +', ' 126', ' 121', '\n\n65 +', ' 151', ' 146', ' 173', ' 81\n\n', ' 103', ' 166', ' 101', ' 159', ' 94\n\n', ' 138', ' 113', ' 118', ' 101', ' 92 +', ' 140', ' 71\n\n', ' 67\n\n', '\n\n98 +', '.\n\n0 +', ' 116', ' 141', ' 91\n\n', ' 88\n\n', ' 19 +', ' 121', ' 44\n\n', ' 32\n\n', ' 120', ' 87\n\n', ' 110', ' 76\n\n', ' 124', ' 29 +', ' 157', ' 148', ' 100', ' 118', ' 111', ' 116', ' 89\n\n', ' 111', ' 49\n\n', ' 150', ' 45\n\n', ' 85\n\n', ' 51\n\n', ' 96 +', ' 54\n\n', ' 53\n\n']
Accuracy of gemma-2-2b: 85.00%


100%|██████████| 1/1 [00:00<00:00, 14.23it/s]


[' -', ' -', ' -', ' -', ' -', ' -', ' -', ' 0', ' -', ' -', ' -', ' 0', ' -', ' -', ' 0', ' 0', ' 0', ' -', ' -', ' 0', ' -', ' -', ' 0', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' 0', ' -', ' -', ' 0', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' 0', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' 0', ' -', ' -', ' 0', ' -', ' -', ' -', ' -', ' -', ' -', ' 0', ' 0', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -', ' -']
Accuracy of gpt-j-6b: 0.00%


In [73]:

print_gpu_memory()
print(gc.collect())
print_gpu_memory()

GPU 0:
  Total Memory: 47.40 GB
  Reserved Memory: 38.07 GB
  Allocated Memory: 37.29 GB
  Free Memory: 9.33 GB
508
GPU 0:
  Total Memory: 47.40 GB
  Reserved Memory: 38.07 GB
  Allocated Memory: 37.29 GB
  Free Memory: 9.33 GB


## Number representations

We examine how the number tokens are represented in the 2.2b model. In larger models, it has been found that the all integers from the interval $[0, 360]$ is encoded as a single digit. We

1. For smaller models like 2.2b, the numbers are encoded *digit by digit*

In [11]:
## numbers encoded digit by digit
str_nums = [str(i) for i in range(500)]
tokens = gemma.to_str_tokens(str_nums, prepend_bos=False)
tokens_gpt = gpt.to_str_tokens(str_nums, prepend_bos=False)
print("Gemma tokens:", tokens)
print("GPT-J tokens:", tokens_gpt)

Gemma tokens: [['0'], ['1'], ['2'], ['3'], ['4'], ['5'], ['6'], ['7'], ['8'], ['9'], ['1', '0'], ['1', '1'], ['1', '2'], ['1', '3'], ['1', '4'], ['1', '5'], ['1', '6'], ['1', '7'], ['1', '8'], ['1', '9'], ['2', '0'], ['2', '1'], ['2', '2'], ['2', '3'], ['2', '4'], ['2', '5'], ['2', '6'], ['2', '7'], ['2', '8'], ['2', '9'], ['3', '0'], ['3', '1'], ['3', '2'], ['3', '3'], ['3', '4'], ['3', '5'], ['3', '6'], ['3', '7'], ['3', '8'], ['3', '9'], ['4', '0'], ['4', '1'], ['4', '2'], ['4', '3'], ['4', '4'], ['4', '5'], ['4', '6'], ['4', '7'], ['4', '8'], ['4', '9'], ['5', '0'], ['5', '1'], ['5', '2'], ['5', '3'], ['5', '4'], ['5', '5'], ['5', '6'], ['5', '7'], ['5', '8'], ['5', '9'], ['6', '0'], ['6', '1'], ['6', '2'], ['6', '3'], ['6', '4'], ['6', '5'], ['6', '6'], ['6', '7'], ['6', '8'], ['6', '9'], ['7', '0'], ['7', '1'], ['7', '2'], ['7', '3'], ['7', '4'], ['7', '5'], ['7', '6'], ['7', '7'], ['7', '8'], ['7', '9'], ['8', '0'], ['8', '1'], ['8', '2'], ['8', '3'], ['8', '4'], ['8', '5'], ['8