In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import sys
sys.path.append("../..")

import torch
from tqdm import tqdm, trange

from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader

from nnsight_DAS_utils import *

from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *
import seaborn as sns
from datasets import Dataset, concatenate_datasets
# from utils.das_utils import *

import argparse


HELD_IN_DATASETS = [f.replace(".json", "") for f in  os.listdir("../dataset_files/abstractive") if f not in 
                    ["antonym.json", "capitalize.json", "present-past.json", 
                     "english-french.json", "singular-plural.json", "country-capital.json", 
                     "ag_news.json", "commonsense_qa.json", "sentiment.json"]]

In [11]:
dataset_names = ["antonym"]
ie_path = "../results/AIE/ICL/flan-llama-7b/held_in_tasks/held_in_tasks_indirect_effect.pt"
model_name = "/work/frink/models/flan-llama-7b"

root_data_dir = "../dataset_files"
seed = 42
device = "cuda"

test_split = 0.3
n_shots = 10
n_trials = 512


prefixes = load_prefixes_or_separators({"input":"Q:", "output":"A:", "instructions":""})
separators = load_prefixes_or_separators({"input":"\n", "output":"\n\n", "instructions":""})

batch_size = 8
gradient_accumulation_steps = 1
epochs = 10
warmup_ratio = 0.1
rotate_lr = 1e-3
boundary_lr = 1e-2
dimension_weights = 1.5

temperature_start = 50.0
temperature_end = 0.1

evaluate_per_epoch = False
training_method = "zero_shot"
generate_output = False

In [13]:
model, tokenizer, model_config = load_nnsight_model(model_name=model_name, device=device)
set_requires_grad(model, False)


att_head_dim = model_config["resid_dim"] // model_config["n_heads"]
top_heads = load_top_k_aie(ie_path, k=10)

intervention_projections = dict()
for layer, idx, _ in top_heads:
    head_projection = BoundlessRotatedSpaceIntervention(att_head_dim).to('cuda')
    head_projection.rotate_layer.weight = torch.eye(att_head_dim).to('cuda')
    
    if layer not in intervention_projections.keys():
        intervention_projections[layer] = dict()
    
    intervention_projections[layer][idx] = head_projection

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id
        
# Load the dataset
print("Loading Dataset")
set_seed(seed)
datasets = [load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) for dataset_name in dataset_names]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading Dataset


In [14]:
eval_no_intervention_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="zero_shot", draw_source_from_split=True)
train_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "train", prefixes, separators, intervention_collate_fn(tokenizer), n_trials=n_trials, ablation_method=training_method, draw_source_from_split=False)

fs_eval_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="noninformative", draw_source_from_split=True)
zs_eval_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="zero_shot", draw_source_from_split=True)

In [16]:
t_total = int(len(train_dataloader) * epochs)
warm_up_steps = 0.1 * t_total

target_total_step = len(train_dataloader) * epochs

temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to(device)

# Define params to be learned
optimizer_params = []
param_count = 0
total_step = 0

for layer in intervention_projections.keys():
    for idx in intervention_projections[layer].keys():
        optimizer_params += [{'params': intervention_projections[layer][idx].rotate_layer.parameters()}]
        optimizer_params += [{'params': intervention_projections[layer][idx].intervention_boundaries, 'lr': boundary_lr}]
        
        param_count += count_parameters(intervention_projections[layer][idx])
        intervention_projections[layer][idx].set_temperature(temperature_schedule[total_step])
        intervention_projections[layer][idx].train()
        
optimizer = torch.optim.Adam(
    optimizer_params,
    lr=rotate_lr,
)

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

print("subspace_proj trainable parameters: ", param_count)

subspace_proj trainable parameters:  163860


In [45]:
def batch_subspace_swap_by_attentions(batch, model:nnsight.LanguageModel, subspace_projs): #, batch_size=16
    """
    Batched subspace_swap intervention at a single layer using nnsight
    """
    batch_size = len(batch['base_input_ids'])
    all_inds = torch.arange(batch_size)
        
    base_prompt, source_prompt = batch['base_input_ids'][:batch_size], batch['source_input_ids'][:batch_size]
    base_intervention_token_idx, source_intervention_token_idx = batch['base_predictive_token_idxs'][:batch_size], batch['source_predictive_token_idxs'][:batch_size]
    
    bases, sources = [], []
    layers, idxs = [], []
    
    all_layers = sorted(list(subspace_projs.keys()))

    for layer in all_layers:
        for idx in subspace_projs[layer]:
            layers.append(layer)
            idxs.append(idx)
            
            start_dim_idx = idx * att_head_dim
            end_dim_idx = (idx + 1) * att_head_dim
            
            with model.trace(validate=False) as tracer:
                with tracer.invoke(base_prompt, scan=False):
                    base = model.model.layers[layer].self_attn.o_proj.input[0][0][all_inds, :, start_dim_idx:end_dim_idx].save()
                    bases.append(base)
                
                with tracer.invoke(source_prompt, scan=False):
                    source = model.model.layers[layer].self_attn.o_proj.input[0][0][all_inds, :, start_dim_idx:end_dim_idx].save()
                    sources.append(source)                    

    with model.trace(validate=False) as tracer:
        # intervention
        with tracer.invoke(base_prompt, scan=False):
            for layer, idx, base, source in zip(layers, idxs, bases, sources):
                
                subspace_proj = subspace_projs[layer][idx]
                
                B = base[all_inds,base_intervention_token_idx, :]
                S = source[all_inds,source_intervention_token_idx, :]

                mixed_out = subspace_proj(B, S, batch_size)
                model.model.layers[layer].self_attn.o_proj.input[0][0][all_inds, base_intervention_token_idx, start_dim_idx: end_dim_idx] = mixed_out
                del base, source, B,S
                
        save_out = model.output.save()
    
    
    output_logits = save_out.value.logits
    del save_out
    return output_logits


def evaluate_w_subspace_intervention_by_attentions(model, subspace_projs, dataloader, device="cuda", generate_output=False):
    """
    """
    with torch.no_grad():
        
        eval_labels = []
        eval_preds = []
        
        for step, inputs in enumerate(tqdm(dataloader)):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)

            outputs = batch_subspace_swap_by_attentions(inputs, model, subspace_projs)#, batch_size=dataloader.batch_size)
            eval_labels += [inputs['base_labels'].detach().cpu()]
                
            eval_preds += [outputs.detach().cpu()]
        
        eval_metrics = compute_metrics(eval_preds, eval_labels, generate_output=generate_output)
        return eval_metrics

In [46]:
zs_intervention_acc = evaluate_w_subspace_intervention_by_attentions(model, intervention_projections, zs_eval_dataloader, device=model.device, generate_output=generate_output)
print(zs_intervention_acc)   
zs_no_intervention_acc = evaluate_no_intervention(model, zs_eval_dataloader, device=model.device, corrupt=True, generate_output=generate_output)
print(zs_no_intervention_acc)

100%|██████████| 27/27 [15:36<00:00, 34.68s/it]


{'accuracy': 0.47}


100%|██████████| 27/27 [00:07<00:00,  3.73it/s]

{'accuracy': 0.36}



