In [105]:
!nvitop

)07[?47h[1;24r[m[4l[?1h=Fri Aug 02 12:18:38 2024
╒═════════════════════════════════════════════════════════════════════════════╕
│ NVITOP 1.3.2      Driver Version: 535.129.03      CUDA Driver Version: 12.2 │
├───────────────────────────────┬──────────────────────┬──────────────────────┤
│ GPU  Name        Persistence-M│ Bus-Id        Disp.A │ MIG M.   Uncorr. ECC │
│ Fan  Temp  Perf  Pwr:Usage/Cap│         Memory-Usage │ GPU-Util  Compute M. │
╞═══════════════════════════════╪══════════════════════╪══════════════════════╡
│[33m   0  A100 80GB PCIe      On   [0m│[33m 00000000:81:00.0 Off [0m│[33m Disabled           0 [0m│
│[33m N/A   36C    P0    58W / 300W [0m│[33m  32.83GiB / 80.00GiB [0m│[33m      0%      Default [0m│
╘═══════════════════════════════╧══════════════════════╧══════════════════════╛
[1m[36m[ CPU: ▎ 0.8%          UPTIME: 197.5 days ][0m  [1m( Load Average:  8.75  3.60  1.75 )[0m
[1m[35m[ MEM: █▌ 5.0%             USED: 10.30GiB ][0m  [1m[34m

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import os
from utils import *

os.environ["HF_TOKEN"] = keys['huggingface']

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to('cuda')
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 11.90it/s]


## Eval

In [3]:
import pandas as pd

#examples = generate_dataset(N=9999, M=2, E=3)
examples = pd.read_csv('gemma-2-2b.csv')

In [4]:
from utils import get_completions, get_predictions

completions = []

bs = 128
K = 4

for b in tqdm(range(0, len(examples), bs)):
    example = examples["prompt"].iloc[b : b + bs].tolist()
    tokens = tokenizer(example, return_tensors='pt', padding=True)['input_ids'].to('cuda')  # [bs, seq_len]

    for k in range(K):
        with torch.no_grad():
            logits = model(tokens).logits

        new_tok = logits[:, -1].argmax(-1)  # [bs]
        tokens = torch.cat([tokens, new_tok[:, None]], -1)

    completions += tokenizer.batch_decode(tokens, skip_special_tokens=True)

completions = pd.Series(completions).apply(lambda x: x.strip())

predictions = get_predictions(completions)

examples['completion'] = completions
examples['prediction'] = predictions

100%|██████████| 157/157 [16:54<00:00,  6.46s/it]


In [7]:
examples['correct'] = examples['solution'] == examples['prediction']
examples['correct'].mean()

0.9036903690369037

In [6]:
examples['bin'] = pd.cut(examples['solution'], bins=range(0, 10000, 50), right=False)

In [8]:
import plotly.express as px

df_grouped = examples.groupby('bin')['correct'].mean().reset_index()
df_grouped['bin'] = df_grouped['bin'].astype(str)

fig = px.bar(
    df_grouped,
    x='bin',
    y='correct',
    title="Gemma-2-2b accuracy by bin",
    labels={'correct': 'Accuracy', 'bin': 'Bin'},
    color_discrete_sequence=['skyblue']
)

fig.update_layout(
    template="plotly_white",
    xaxis_title="Bin",
    yaxis_title="Accuracy",
    paper_bgcolor='white',
    plot_bgcolor='white'
)
fig.update_xaxes(tickangle=90)
fig.show()





In [9]:
#examples.to_csv('gemma-2-2b.csv', index=False)

## Dataset

In [87]:
import numpy as np
import pandas as pd

def generate_dataset(N, min_n, max_n, E=0):

    zeros = 4

    def add_zeros(x, L):
        return "0" * max(0, zeros - len(str(x))) + str(x)

    examples = {
        "x_clean": [],
        "x_corr": [],
        "a": [],
        "b": [],
        "b_corr": [],
        "s": [],
        "s_corr": []
    }

    for n in range(0, N):
        prompt = ""

        for _ in range(E):
            a, b = np.random.randint(min_n, max_n+1, 2)
            c = add_zeros(a + b, zeros)
            a, b = add_zeros(a, zeros), add_zeros(b, zeros)
            prompt += f"{a}+{b}={c}\n"

        a, b = np.random.randint(min_n, max_n+1, 2)
        c = add_zeros(a + b, zeros)
        a, b = add_zeros(a, zeros), add_zeros(b, zeros)
        
        change_digit = np.random.randint(1, len(str(max_n)))
        b_corr = list(b)
        b_corr[-change_digit] = str(np.random.randint(0, 10))
        b_corr = "".join(b_corr)
        c_corr = add_zeros(int(a) + int(b_corr), zeros)

        idx = 0
        for i in range(len(c)):
            if c[i] != c_corr[i]:
                idx = i
                break

        prompt_corr = prompt + f"{a}+{b_corr}=" + c_corr[:idx]
        prompt += f"{a}+{b}=" + c[:idx]
        
        examples["x_clean"].append(prompt)
        examples["x_corr"].append(prompt_corr)

        examples["a"].append(a)
        examples["b"].append(b)
        examples["b_corr"].append(b_corr)

        examples["s"].append(c)
        examples["s_corr"].append(c_corr)

    examples = pd.DataFrame(examples)

    def BA(x):
        BAs = [(int(x["a"][i]) + int(x["b"][i])) % 10 for i in range(zeros)]
        return BAs

    def MC(x):
        MCs = [(int(x["a"][i]) + int(x["b"][i])) > 9 for i in range(zeros)]
        return MCs

    def MS9(x):
        MS9s = [(int(x["a"][i]) + int(x["b"][i])) == 9 for i in range(zeros)]
        return MS9s

    examples[["BAth", "BAhu", "BAte", "BAun"]] = examples.apply(lambda x: BA(x), axis=1, result_type="expand")
    examples[["MCth", "MChu", "MCte", "MCun"]] = examples.apply(lambda x: MC(x), axis=1, result_type="expand")
    examples[["MS9th", "MS9hu", "MS9te", "MS9un"]] = examples.apply(lambda x: MS9(x), axis=1, result_type="expand")

    def US9(x):
        US9s = [False]

        for i, j in zip(["te", "hu", "th"], ["un", "te", "hu"]):
            US9s.append(x[f"MS9{i}"] and (x[f"MC{j}"] or US9s[-1]))
        
        return US9s

    examples[["US9un", "US9te", "US9hu", "US9th"]] = examples.apply(lambda x: US9(x), axis=1, result_type="expand")
    
    return examples

In [88]:
data = generate_dataset(1024, 500, 1500, E=2)

In [92]:
data.iloc[:, 11:].mean()

MCth     0.000000
MChu     0.458984
MCte     0.447266
MCun     0.457031
MS9th    0.000000
MS9hu    0.116211
MS9te    0.118164
MS9un    0.115234
US9un    0.000000
US9te    0.044922
US9hu    0.049805
US9th    0.000000
dtype: float64

## Interp

In [10]:
from nnsight import LanguageModel

lm = LanguageModel(model, dispatch=True)

In [11]:
from dictionary_learning. import AutoEncoder

ae = AutoEncoder.from_pretrained_npz('dictionary_learning/attn_out_layer_0/params.npz', activation_dim=2304, dict_size=16384)

In [35]:
data['W_enc'].shape

(2304, 16384)

In [17]:
tokenizer('hello world do', return_tensors='pt')['input_ids']

tensor([[    2, 17534,  2134,   749]])

In [19]:
tokenizer(' does', return_tensors='pt')['input_ids']

tensor([[   2, 1721]])