In [None]:
!cd .. && pip install -e ./../nnpatch ./../pycolors -e ./../pyvene && pip install -U transformers kaleido && pip install circuitsvis python-dotenv --no-deps

In [None]:
from pycolors import TailwindColorPalette

TailwindColorPalette().get_shade(4,300)

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from nnsight import NNsight
import torch
import os
from tqdm.notebook import tqdm, trange

from nnsight import NNsight

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.decoding import *
from analysis.circuit_utils.utils import *
from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch

from main import load_model_and_tokenizer


from nnpatch.api.llama import Llama3

jupyter_enable_mathjax()

plot_dir = "plots/Llama-3.1-8B-Instruct"

os.makedirs(plot_dir, exist_ok=True)

In [None]:
%cd ..

In [3]:
MODEL_STORE="/dlabscratch1/public/llm_weights/llama3.1_hf/"

In [None]:
PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf="instruction", model_id="Meta-Llama-3.1-8B-Instruct", model_store=MODEL_STORE, n_samples=100)

In [None]:
model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)
nnmodel = NNsight(model)

In [None]:
model = AutoModelForCausalLM.from_pretrained("jkminder/gpt5-100T-agi", trust_remote_code=True)
nnmodel = NNsight(model)
tokenizer = AutoTokenizer.from_pretrained(MODEL_STORE + "Meta-Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Patch

In [None]:
PATHS, args = get_decoding_args(finetuned=True, no_filtering=True, load_in_4bit=True, cwf="instruction", model_id="Meta-Llama-3.1-8B-Instruct", model_store=MODEL_STORE, n_samples=200)
all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_3_tokens, prior_1_tokens, prior_2_tokens, context_1_attention_mask, context_2_attention_mask, context_3_attention_mask, prior_1_attention_mask, prior_2_attention_mask, context_1_answer, context_2_answer, context_3_answer, prior_1_answer, prior_2_answer = get_data(args, PATHS, tokenizer)


prior_args = [all_tokens, all_attn_mask, prior_1_tokens, prior_2_tokens, prior_1_attention_mask, prior_2_attention_mask, prior_1_answer, prior_2_answer]
ctx_args = [all_tokens, all_attn_mask, context_1_tokens, context_2_tokens, context_1_attention_mask, context_2_attention_mask, context_1_answer, context_2_answer]
cp_args = [all_tokens, all_attn_mask, context_1_tokens, prior_1_tokens, context_1_attention_mask, prior_1_attention_mask, context_1_answer, prior_1_answer]
pc_args = [all_tokens, all_attn_mask, prior_1_tokens, context_1_tokens, prior_1_attention_mask, context_1_attention_mask, prior_1_answer, context_1_answer]

In [None]:
print(tokenizer.decode(prior_1_tokens[0], skip_special_tokens=False)), print(tokenizer.decode(prior_1_answer[0], skip_special_tokens=False))

## Auto search

In [None]:
from nnpatch.api.mistral import Mistral

prior_range = auto_search(model, tokenizer, prior_args, n_layers=32, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Mistral, lower_bound=13, upper_bound=19)
print(prior_range)

In [None]:
ctx_range = auto_search(model, tokenizer, ctx_args, n_layers=42, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Gemma2)
print(ctx_range)

In [None]:
cp_range = auto_search(model, tokenizer, cp_args, n_layers=42, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Gemma2)
print(cp_range)

In [None]:
pc_range = auto_search(model, tokenizer, pc_args, n_layers=42, phi=0.05, eps=0.3, thres=0.85, batch_size=10, api=Gemma2)
print(pc_range)

# Prior

In [None]:
out = model.generate(prior_1_tokens[:2], attention_mask=prior_1_attention_mask[:2], max_new_tokens=100)
print(tokenizer.decode(out[0], skip_special_tokens=False))

In [None]:
site_1_config = { # PRIOR   
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PRIOR - "), max_index=10)
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": [13, 14, 15, 16, 17, 18, 24]
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PRIOR - "), max_index=10)
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": [13, 14, 15, 16, 17, 18]
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PRIOR - "), max_index=10)
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": [13, 14, 15, 16]
    },
}

figr, figp = get_plot_prior_patch(nnmodel, tokenizer, *prior_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PRIOR - "), max_index=10)
figp.show()

## Context

In [None]:
site_1_config = { # PRIOR   
}

figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "CTX - "))
figp.show()

In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(24, 32)),
    },
}
figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "CTX - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(25, 32)),
    },
}
figr, figp = get_plot_context_patch(nnmodel, tokenizer, *ctx_args, site_1_config, N_LAYERS=32, batch_size=2, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "CTX - "))
figp.show()


## Weight

### CP

In [None]:
site_1_config = { 
}
figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=32, batch_size=20, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "CP - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(12, 17)),
    },
}
figr, figp = get_plot_weightcp_patch(nnmodel, tokenizer, *cp_args, site_1_config, N_LAYERS=32, batch_size=20, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "CP - "))
figp.show()


### PC

In [None]:
site_1_config = { 
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=32, batch_size=20, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PC - "))
figp.show()


In [None]:
site_1_config = { 
    "o":
    {
        "layers": list(range(12, 17)),
    },
}
figr, figp = get_plot_weightpc_patch(nnmodel, tokenizer, *pc_args, site_1_config, N_LAYERS=32, batch_size=20, output_dir=plot_dir, api=Llama3, title=generate_title(site_1_config, "PC - "))
figp.show()

# DAS

In [None]:
%load_ext autoreload
%autoreload 2
from analysis.circuit_utils.das import *
from functools import partial
from torch.utils.data import DataLoader, random_split

import sys
sys.path.append("..")
from nnsight import NNsight
import torch
import os
from tqdm.notebook import tqdm, trange

from nnsight import NNsight

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.decoding import *
from analysis.circuit_utils.utils import *
from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch

from main import load_model_and_tokenizer
from nnpatch.subspace.interventions import train_projection, create_dataset, LowRankOrthogonalProjection


from nnpatch.api.mistral import Mistral

jupyter_enable_mathjax()

plot_dir = "plots/Llama-3.1-8B-Instruct"
MODEL_STORE="/dlabscratch1/public/llm_weights/llama3.1_hf/"
os.makedirs(plot_dir, exist_ok=True)

device = "cuda:0"

PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf="instruction", model_id="Meta-Llama-3.1-8B-Instruct", model_store=MODEL_STORE, n_samples=1000, no_filtering=True)

In [None]:
model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)

In [49]:
st, tt, si, ti, ams, amt, tit, amti = prepare_train_data(args, PATHS, tokenizer, device, same_query=True, remove_weight=False)

In [None]:
confident_indices = filter_confident_samples(args, model, tt, tit, ti, si, amt, amti, batch_size=32)
train_dataset = create_dataset(st[confident_indices], tt[confident_indices], si[confident_indices], ti[confident_indices], ams[confident_indices], amt[confident_indices])
train_dataset

In [None]:
print_random_sample(train_dataset, tokenizer, prefix="Train")

In [None]:
source_prompt, target_prompt, source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask = collect_data(args, PATHS, tokenizer, "cuda")
test_dataset = create_dataset(source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask)
test_dataset

In [None]:
# proj = LowRankOrthogonalProjection.load_pretrained("analysis/results_das/Mistral-7B-Instruct-v0.3/Mistral-7B-Instruct-v0.3-L16.pt")
proj = LowRankOrthogonalProjection(embed_dim=4096, rank=1)

In [53]:
proj, projection = train_projection(model, proj, layer=16, train_dataset=train_dataset, val_dataset=test_dataset, epochs=1, batch_size=8)


In [None]:
torch.save(proj.state_dict(), os.path.join("analysis/results_das/Meta-Llama-3.1-8B-Instruct", f"Meta-Llama-3.1-8B-Instruct-L16.pt"))


## Analyse Distribution

In [84]:
def get_save_residuals(model, tokens, attention_mask, layer, save_path, batch_size=32):
    base = "analysis/residuals/Meta-Llama-3.1-8B-Instruct"
    os.makedirs(base, exist_ok=True)
    save_path = os.path.join(base, f"{save_path}.pt")
    if not os.path.exists(save_path):
        residuals = batch_patched_residuals(model, tokens, attention_mask, layer=layer, batch_size=batch_size)
        torch.save(residuals, save_path)
    return residuals

In [85]:
def get_residuals(model, tokens, attention_mask, layer, scan=False, validate=False, average_site=None):
    residuals = []
    nnmodel = NNsight(model)
    # Clean run
    with nnmodel.trace(tokens, attention_mask=attention_mask, scan=scan, validate=validate) as invoker:
        residuals.append(nnmodel.model.layers[layer].output[0][:,-1,:].save())
            
    residuals[-1] = residuals[-1].value.detach().cpu()
            
    residuals = torch.cat(residuals, dim=0)
    torch.cuda.empty_cache()
    return residuals

def batch_patched_residuals(nnmodel, tokens, attention_mask, layer, batch_size=32, scan=False, validate=False):
    residuals = []
    for i in trange(0, tokens.shape[0], batch_size):
        residuals.append(get_residuals(nnmodel, tokens[i:i+batch_size], attention_mask[i:i+batch_size], layer, scan=scan, validate=validate))
    return torch.cat(residuals)

In [None]:
residuals = get_save_residuals(model, target_tokens, target_attn_mask, layer=27, save_path="ft_cwf_instruction", batch_size=16)

In [None]:
proj = LowRankOrthogonalProjection.load_pretrained(os.path.join("analysis/results_das/Meta-Llama-3.1-8B-Instruct", f"Meta-Llama-3.1-8B-Instruct-L16.pt"))

In [88]:
features = proj.project(residuals.cuda())

In [89]:
features = features.cpu().numpy(force=True)

In [None]:
# histplot
from matplotlib import pyplot as plt

# Create two separate arrays for features based on weight_context
features_prior = features[target_prompt['weight_context'] == 0.0]
features_context = features[target_prompt['weight_context'] == 1.0]

# Create the histogram
plt.figure(figsize=(12, 6))
plt.hist(features_prior[:, 0], bins=50, alpha=0.5, label='Prior', color='blue')
plt.hist(features_context[:, 0], bins=50, alpha=0.5, label='Context', color='red')

plt.xlabel('Feature Value')
plt.ylabel('Frequency')
plt.title('Distribution of Features by Context Weight Type')
plt.legend()
plt.show()

In [91]:


def accuracy(is_correct):
    return float(sum(is_correct)) / len(is_correct)

def paired_accuracy(is_correct):
    is_correct = np.array(is_correct)
    even_correct = is_correct[::2]
    odd_correct = is_correct[1::2]
    return float(sum(even_correct & odd_correct)) / len(even_correct)

def iia_with_hook(model, hook, tokens, attention_mask, values, answers, tokenizer, batch_size=32, max_index=None, verbose=False):
    if max_index is not None:
        tokens = tokens[:max_index]
        attention_mask = attention_mask[:max_index]
        values = values[:max_index]
        answers = answers[:max_index]
    generations = []
    for i in trange(0, len(tokens), batch_size):
        hook.set_context_prior(values[i:i+batch_size])
        generations.extend(model.generate(tokens[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size], max_new_tokens=10, do_sample=False, temperature=None, top_k=None, top_p=None, pad_token_id=tokenizer.eos_token_id).tolist())

    generations = [g[len(tokens[i]):] for i, g in enumerate(generations)]
    generations = tokenizer.batch_decode(generations, skip_special_tokens=True)
    is_correct = []
    for i, o in enumerate(generations):
        if verbose:
            print("Answer:", f"'{answers[i]}'", "Generation:", f"'{o}'", "Correct:", answers[i] in o, is_response_correct(o.strip(), answers[i].strip()))
        is_correct.append(is_response_correct(o.strip(), answers[i].strip()))
    
    return {
        "accuracy": accuracy(is_correct),
        "paired_accuracy": paired_accuracy(is_correct),
    }


In [96]:
from analysis.circuit_utils.steering import CtxPriorHook


In [None]:
args.model_id = "Mistral-7B-Instruct-v0.3" # Hack to get the correct data
args.finetuned = False
args.finetune_training_args = None
args.no_filtering = True
PATHS = paths_from_args(args)

In [92]:

# for few shot we need to reload the data
# source_prompt, target_prompt, source_tokens, target_tokens, source_label_index, target_label_index, source_attn_mask, target_attn_mask = collect_data(args, PATHS, tokenizer, device)
# tokenize the cleaned target texts and pad to same length as source
cleaned_target_texts = [remove_instruction(text, name_of_instruction="Instruction") for text in target_prompt.text.tolist()]
cleaned_target_tokens = tokenizer(cleaned_target_texts, padding=True, truncation=True, max_length=len(source_tokens[0]), return_tensors="pt")
# left pad to match source_tokens
# cleaned_target_tokens.input_ids = torch.nn.functional.pad(cleaned_target_tokens.input_ids, (source_tokens.shape[1] - cleaned_target_tokens.input_ids.shape[1], 0), value=tokenizer.pad_token_id)
# cleaned_target_tokens.attention_mask = torch.nn.functional.pad(cleaned_target_tokens.attention_mask, (source_tokens.shape[1] - cleaned_target_tokens.attention_mask.shape[1], 0), value=0)

In [None]:
print(cleaned_target_texts[0])

In [None]:
PATHS, args = get_decoding_args(finetuned=False, load_in_4bit=False, cwf="instruction", model_id="Meta-Llama-3.1-8B-Instruct", model_store=MODEL_STORE, n_samples=100, no_filtering=True, shots=0)


In [None]:
PATHS, args = get_decoding_args(finetuned=True, load_in_4bit=False, cwf="instruction", model_id="Mistral-7B-Instruct-v0.3", model_store=MODEL_STORE, n_samples=100, no_filtering=True, shots=0)
model, tokenizer = load_model_and_tokenizer_from_args(PATHS, args)
nnmodel = NNsight(model)

In [None]:
hook = CtxPriorHook(proj, 16, context_value=-4, prior_value=0)
hook.attach(model)

In [45]:

hook.remove()

In [None]:
a = target_tokens[:2]
mask = target_attn_mask[:2]
hook.set_constant_context()
out = model.generate(a, attention_mask=mask, max_new_tokens=10, do_sample=True)


In [None]:
print(tokenizer.decode(out[0])), print(tokenizer.decode(out[1]))

In [None]:
a = target_tokens[:2]
mask = target_attn_mask[:2]
hook.set_constant_prior()
out = model.generate(a, attention_mask=mask, max_new_tokens=10, do_sample=True)

In [None]:
print(tokenizer.decode(out[0])), print(tokenizer.decode(out[1]))

In [None]:
print(tokenizer.decode(out[0])), print(tokenizer.decode(out[1]))

In [None]:
hook = CtxPriorHook(proj, 16, context_value=-5, prior_value=2)
hook.attach(model)

In [None]:
hook.remove()

In [None]:
values = torch.tensor(target_prompt["weight_context"] == 0.0)
res = iia_with_hook(model, hook, target_tokens, target_attn_mask, values, source_prompt["answer"], tokenizer, batch_size=24, max_index=100)
print(res)
hook.remove()

In [None]:
values = torch.tensor(target_prompt["weight_context"] == 0.0)
res = iia_with_hook(model, hook, cleaned_target_tokens.to(device), cleaned_attn_mask.to(device), values, source_prompt["answer"], tokenizer, batch_size=24, max_index=100)
print(res)
hook.remove()

In [None]:
values = torch.tensor(target_prompt["weight_context"] == 0.0)
res = iia_with_hook(model, hook, cleaned_target_tokens.to(device), cleaned_attn_mask.to(device), values, source_prompt["answer"], tokenizer, batch_size=24, max_index=1000)
print(res)
hook.remove()

In [None]:
values = torch.tensor(target_prompt["weight_context"] == 0.0)
res = iia_with_hook(model, hook, cleaned_target_tokens.to(device), cleaned_attn_mask.to(device), values, source_prompt["answer"], tokenizer, batch_size=24, max_index=500)
print(res)
hook.remove()