In [1]:
import torch
import os
import json

In [2]:
import numpy as np

def load_top_k_aie(path: str, k=10):
    AIE = torch.load(path)
    AIE = AIE.mean(dim=0)
    h_shape = AIE.shape 
    topk_vals, topk_inds = torch.topk(AIE.view(-1), k=k, largest=True)
    top_lh = list(zip(*np.unravel_index(topk_inds, h_shape), [round(x.item(),4) for x in topk_vals]))
    return top_lh

top_heads = load_top_k_aie("../results/AIE/ICL/flan-llama-7b/held_in_tasks/held_in_tasks_indirect_effect.pt", k=10)

In [3]:
from nnsight_DAS_utils import *
from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *

model, tokenizer, model_config = load_nnsight_model(model_name="/work/frink/models/flan-llama-7b", device="cuda")
set_requires_grad(model, False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
att_head_dim = model_config["resid_dim"] // model_config["n_heads"]

o_projection_space = []
for layer, head_idx, _ in top_heads:
    o_projection_space.append(model.model.layers[layer].self_attn.o_proj.weight[head_idx*att_head_dim:(head_idx+1)*att_head_dim].detach().cpu().numpy())



torch.Size([4096, 4096])

In [7]:
all_layers = set([head[0] for head in top_heads])

{9, 10, 12, 13, 14, 15}

In [12]:
import os
import torch
from tqdm import tqdm

os.makedirs("../fv_projection", exist_ok=True)
att_head_dim = model_config["resid_dim"] // model_config["n_heads"]

for layer, idx, _ in tqdm(top_heads):
    head_projection = BoundlessRotatedSpaceIntervention(att_head_dim)
    head_projection.rotate_layer.weight = torch.eye(att_head_dim)
    torch.save(head_projection.state_dict(), f"../fv_projection/layer{layer}_head_{idx}.model_state_dict.bin")

  0%|          | 0/10 [00:00<?, ?it/s]

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], grad_fn=<MmBackward0>)



