For tex fonts in export, run the following:
```
apt-get install -y cm-super fonts-cmu && fc-cache fv
```

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from nnsight import NNsight
import torch
import os
from tqdm.notebook import tqdm, trange

from nnsight import NNsight

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.decoding import *
from analysis.circuit_utils.utils import *
from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch
from analysis.circuit_utils.das import *

from main import load_model_and_tokenizer


from nnpatch.api.gemma import Gemma2

jupyter_enable_mathjax()

plot_dir = "plots/gemma2-9b-it"
MODEL_STORE="/dlabscratch1/public/llm_weights/gemma_hf/"
os.makedirs(plot_dir, exist_ok=True)

This notebook requires that your LORA model is merged. Use the command below to merge it.

In [None]:
!python analysis/scripts/merge_model.py --model-id gemma-2-9b-it --model-store /dlabscratch1/public/llm_weights/gemma_hf/ --cwf instruction

In [None]:
%cd ..

In [None]:
PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf="instruction", model_id="gemma-2-9b-it", model_store=MODEL_STORE, n_samples=100)

In [None]:
model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)
nnmodel = NNsight(model)

# Patch

In [None]:
all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_3_tokens, prior_1_tokens, prior_2_tokens, context_1_attention_mask, context_2_attention_mask, context_3_attention_mask, prior_1_attention_mask, prior_2_attention_mask, context_1_answer, context_2_answer, context_3_answer, prior_1_answer, prior_2_answer = get_data(args, PATHS, tokenizer)


prior_args = [all_tokens, all_attn_mask, prior_1_tokens, prior_2_tokens, prior_1_attention_mask, prior_2_attention_mask, prior_1_answer, prior_2_answer]
ctx_args = [all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_1_attention_mask, context_2_attention_mask, context_1_answer, context_2_answer]
cp_args = [all_tokens, all_attn_mask, context_1_tokens, prior_1_tokens, context_1_attention_mask, prior_1_attention_mask, context_1_answer, prior_1_answer]
pc_args = [all_tokens, all_attn_mask, prior_1_tokens, context_1_tokens, prior_1_attention_mask, context_1_attention_mask, prior_1_answer, context_1_answer]

In [None]:
print(tokenizer.decode(prior_1_tokens[0], skip_special_tokens=False)), print(tokenizer.decode(prior_1_answer[0], skip_special_tokens=False))

## Auto search

In [10]:
from analysis.circuit_utils.decoding import get_patched_residuals, patch_scope, config_to_site, get_probs, get_patched_residuals
from nnpatch.api.gemma import Gemma2
from nnsight import NNsight
import torch
from tqdm.notebook import trange

In [None]:
prior_range = auto_search(model, tokenizer, prior_args, n_layers=42, phi=0.05, eps=0.3, thres=0.9, batch_size=10, api=Gemma2)
print(prior_range)

In [None]:
ctx_range = auto_search(model, tokenizer, ctx_args, n_layers=42, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Gemma2)
print(ctx_range)

In [None]:
cp_range = auto_search(model, tokenizer, cp_args, n_layers=42, phi=0.05, eps=0.3, thres=0.9, batch_size=10, api=Gemma2)
print(cp_range)

In [None]:
pc_range = auto_search(model, tokenizer, pc_args, n_layers=42, eps=0.2, thres=0.9, batch_size=10, api=Gemma2)
print(pc_range)

# Prior

In [None]:
site_1_config = { # PRIOR
    
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=2, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": [17, 18, 19, 20, 21, 22, 23, 24]
    },
}

figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *cp_args, site_1_config, N_LAYERS=42, batch_size=10, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": [17, 18, 19, 20, 21, 22, 23, 24, 41]
    },
}

figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *cp_args, site_1_config, N_LAYERS=42, batch_size=10, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()

In [None]:
site_1_config = { # PRIOR
    "o":
    {
        "layers": list(range(25, 30)),
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()

In [None]:
site_1_config = { # PRIOR
    "o":
    {
        "layers": list(range(25, 30)) + [37],
    },
}


figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()


In [None]:
site_1_config = { # PRIOR
    "o":
    {
        "layers": list(range(25, 30)) + [37, 40],
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()


In [None]:
site_1_config = { # PRIOR
    "o":
    {
        "layers": list(range(28, 42)),
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PRIOR - "))
figp.show()


## Context

In [None]:
site_1_config = { 
}
figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=2, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CTX - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(25, 30)),
    },
}
figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=1, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CTX - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(29, 42)),
    },
}
figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=42, batch_size=2, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CTX - "))
figp.show()


## Weight

### CP

In [None]:
site_1_config = { 
}
figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CP - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(0, 28)),
    },
}
figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CP - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 28)),
    },
}
figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "CP - "))
figp.show()


### PC

In [None]:
site_1_config = { 
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 28))
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 28)) + [37, 40]
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 28)) + [37]
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=2, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 28))
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=8, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(0, 28)),
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(20, 30)),
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(25, 30)),
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=42, batch_size=20, output_dir="plots/gemma2-9b-it", api=Gemma2, title=generate_title(site_1_config, "PC - "))
figp.show()

# Train DAS

Make sure you have our customized version of pyvene installed:
```
pip install git+https://github.com/jkminder/pyvene
```

In [None]:
%load_ext autoreload
%autoreload 2
from analysis.circuit_utils.das import *
from functools import partial
from torch.utils.data import DataLoader, random_split

import sys
sys.path.append("..")
from nnsight import NNsight
import torch
import os
from tqdm.notebook import tqdm, trange

from nnsight import NNsight

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.decoding import *
from analysis.circuit_utils.utils import *
from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch

from main import load_model_and_tokenizer
from nnpatch.subspace.interventions import train_projection, create_dataset, LowRankOrthogonalProjection


from nnpatch.api.mistral import Mistral

jupyter_enable_mathjax()

plot_dir = "plots/gemma2-9b-it"
MODEL_STORE="/dlabscratch1/public/llm_weights/gemma_hf/"
os.makedirs(plot_dir, exist_ok=True)

device = "cuda:0"

PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf="instruction", model_id="gemma-2-9b-it", model_store=MODEL_STORE, n_samples=1000, no_filtering=True)

In [None]:
model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)

In [None]:
st, tt, si, ti, ams, amt, tit, amti = prepare_train_data(args, PATHS, tokenizer, device, same_query=True, remove_weight=False)

In [None]:
confident_indices = filter_confident_samples(args, model, tt, tit, ti, si, amt, amti, batch_size=32)
train_dataset = create_dataset(st[confident_indices], tt[confident_indices], si[confident_indices], ti[confident_indices], ams[confident_indices], amt[confident_indices])
train_dataset

In [None]:
source_prompt, target_prompt, source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask = collect_data(args, PATHS, tokenizer, "cuda")
test_dataset = create_dataset(source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask)
test_dataset

In [None]:
proj = LowRankOrthogonalProjection(embed_dim=3584, rank=1)

In [None]:
proj = train_projection(model, proj, layer=27, train_dataset=train_dataset, val_dataset=test_dataset, epochs=1, batch_size=8)

In [None]:
proj.save_pretrained("projections/gemma-2-9b-it-L27")