# NNsight Tutorial


In [None]:

%load_ext autoreload
%autoreload 2
# # # set cuda visible device
# !export CUDA_VISIBLE_DEVICES=3
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

starcoderbase_1b = "/home/arjun/models/starcoderbase-1b/"

from nnsight import LanguageModel
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from src.utils import *

In [None]:
model = LanguageModel(starcoderbase_1b, device_map='cuda:3')
model

## Activation patching

In [None]:
import plotly.express as px

from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

import datasets
from src.utils import *

ds = datasets.load_dataset("franlucc/type_patching_v0", split="train")
string_ex = [d for d in ds if d["fim_type"] == "string"]
boolean_ex = [d for d in ds if d["fim_type"] == "boolean"]
number_ex = [d for d in ds if d["fim_type"] == "number"]

string_idx = model.tokenizer.convert_tokens_to_ids("string")
boolean_idx = model.tokenizer.convert_tokens_to_ids("boolean")
number_idx = model.tokenizer.convert_tokens_to_ids("number")
print(string_idx, boolean_idx, number_idx)

In [None]:
from src.experiments.type_patching import *
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

pred_results = []
probs_patched_results = []
earliest_layer = []
def get_ds(patch):
    if patch == "string":
        return string_ex
    elif patch == "boolean":
        return boolean_ex
    elif patch == "number":
        return number_ex

patch_src = "string"
patch_dst = "boolean"
correct_index = model.tokenizer(patch_src)["input_ids"][0]
incorrect_index = model.tokenizer(patch_dst)["input_ids"][0]

idx_range = (80,90)

for i in range(*idx_range):
    with torch.no_grad():
        torch.cuda.empty_cache()
    clean_prompt = placeholder_to_std_fmt(get_ds(patch_src)[i]["fim_program"], STARCODER_FIM)
    corrupted_prompt = placeholder_to_std_fmt(get_ds(patch_dst)[i]["fim_program"], STARCODER_FIM)
    patching_results, patched_predictions = patch_fim_tokens(model, clean_prompt, corrupted_prompt, STARCODER_FIM, correct_index, incorrect_index)
    
    patched_predictions = util.apply(patched_predictions, lambda x: x.value.item(), Proxy)
    patching_results = util.apply(patching_results, lambda x: x.value.item(), Proxy)
    
    probs_patched_results.append(patching_results)
    pred_results.append(patched_predictions)
    
    idx = [j for j, x in enumerate(patched_predictions) if x == correct_index]
    if len(idx) == 0:
        idx = [-1]
    earliest_layer.append(idx[0])

    # clear gpu memory
    with torch.no_grad():
        torch.cuda.empty_cache()

In [None]:
layers = list(range(8,15))

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# import numpy as np
# colours = im.cmap(im.norm(np.unique(pred_results)))
# len_colors = len(colours)

plt.figure(figsize=(7,7))

# make labels
flat_results = [x for y in pred_results for x in y]
labels = set([model.tokenizer.decode([x]) for x in flat_results])
labels = list(set(flat_results))
labels.sort()
print(labels)


# set color map
im = plt.imshow(pred_results, norm="log", cmap="hsv")

plt.yticks(range(len(pred_results)), [f"Example {i}" for i in range(len(pred_results))])
plt.xticks(range(len(layers)), layers)
plt.xlabel("Patched Layer")
plt.ylabel("Src->Dst Example idx")
plt.title(f"Starcoderbase-1b. Max probability token after patching <fim_middle> from {patch_dst} to {patch_src}")

# make legend

# get map of labels to colors from image
color_map = im.cmap(im.norm(np.unique(pred_results)))
# use map to make legend
plt.legend(loc="upper left", bbox_to_anchor=(1,0.5), handles=[plt.Rectangle((0,0),1,1, color=color_map[i], label=model.tokenizer.decode(labels[i])) for i in range(len(labels))])


# add grid between examples and layers (halfway!)
plt.hlines([i+0.5 for i in range(len(pred_results))], *plt.xlim(), color="black", linewidth=1)
plt.vlines([i+0.5 for i in range(len(layers))], *plt.ylim(), color="black", linewidth=1)


# build an annotations dict for each square in grid with values from probs_patched_results
annotations = {i:{} for i in range(len(layers)*len(pred_results))}

for i in range(len(pred_results)):
    for j in range(len(layers)):
        probs = probs_patched_results[i][j]
        str_probs = probs[0]
        num_probs = probs[1]
        bool_probs = probs[2]
        annotations[i+len(pred_results)*j] = {"pred": pred_results[i][j],"string": f"{str_probs:.2f}", "number": f"{num_probs:.2f}", "boolean": f"{bool_probs:.2f}"}
        
# create tuples of positions
positions =[(x , y ) for x in range(len(layers)) for y in range(len(pred_results))]

# add annotations
for pos, text in annotations.items():
    plt.annotate(model.tokenizer.decode(text["pred"]), xy=positions[pos],color="black", fontsize=8, ha="center", va="center")


plt.savefig(f"starcoderbase-1b-patching_idx{idx_range[0]}.pdf", bbox_inches="tight")


# # add annotations
# for pos, text in annotations.items():
#     plt.annotate(text["string"], xy=positions[pos], ha="left", va="bottom", color="black")
#     plt.annotate(text["number"], xy=positions[pos], ha="center", va="center", color="black")
#     plt.annotate(text["boolean"], xy=positions[pos], ha="right", va="top", color="black") 

# plt.savefig(f"starcoderbase-1b-patching_probs_idx{idx_range[0]}.pdf", bbox_inches="tight")

plt.show()
annotations, pred_results