In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../..")

import torch
from tqdm import tqdm, trange

from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader

from nnsight_DAS_utils import *

from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *
import seaborn as sns
from datasets import Dataset, concatenate_datasets
# from utils.das_utils import *

import argparse


HELD_IN_DATASETS = [f.replace(".json", "") for f in  os.listdir("../dataset_files/abstractive") if f not in 
                    ["antonym.json", "capitalize.json", "present-past.json", 
                     "english-french.json", "singular-plural.json", "country-capital.json", 
                     "ag_news.json", "commonsense_qa.json", "sentiment.json"]]

In [7]:
dataset_names = HELD_IN_DATASETS
edit_layers = [10, 11, 12, 13, 14]
model_name = "/work/frink/models/flan-llama-7b"

root_data_dir = "../dataset_files"
seed = 42
device = "cuda"

test_split = 0.3
n_shots = 10
n_trials = 512


prefixes = load_prefixes_or_separators({"input":"Q:", "output":"A:", "instructions":""})
separators = load_prefixes_or_separators({"input":"\n", "output":"\n\n", "instructions":""})

batch_size = 8
gradient_accumulation_steps = 1
epochs = 10
warmup_ratio = 0.1
rotate_lr = 1e-3
boundary_lr = 1e-2
dimension_weights = 1.5

temperature_start = 50.0
temperature_end = 0.1

evaluate_per_epoch = False
training_method = "zero_shot"
generate_output = False

In [8]:
model, tokenizer, model_config = load_nnsight_model(model_name=model_name, device=device)
set_requires_grad(model, False)
subspace_proj = BoundlessRotatedSpaceIntervention(model.config.hidden_size).to('cuda')

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

assert max(edit_layers) < model_config["n_layers"], f"Edit layer {edit_layers} is out of range for model with {model_config['n_layers']} layers."
        
# Load the dataset
print("Loading Dataset")
set_seed(seed)
datasets = [load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) for dataset_name in dataset_names]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading Dataset


In [10]:
eval_no_intervention_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="zero_shot", draw_source_from_split=True)
train_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "train", prefixes, separators, intervention_collate_fn(tokenizer), n_trials=n_trials, ablation_method=training_method, draw_source_from_split=False)

fs_eval_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="noninformative", draw_source_from_split=True)
zs_eval_dataloader = process_mixed_dataloader(datasets, model_config, tokenizer, batch_size, n_shots, "valid", prefixes, separators, intervention_collate_fn(tokenizer), ablation_method="zero_shot", draw_source_from_split=True)

In [11]:
results = dict()

print(f"Evaluating the model {n_shots}-shots without intervention...")
eval_dict = evaluate_no_intervention(model, eval_no_intervention_dataloader, device=model.device, generate_output=generate_output)
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {eval_dict['accuracy']}")
results['prealign_val_task_accuracy'] = eval_dict["accuracy"]

Evaluating the model 10-shots without intervention...


 14%|█▎        | 32/233 [01:16<08:01,  2.39s/it]


KeyboardInterrupt: 

In [12]:
t_total = int(len(train_dataloader) * epochs)
warm_up_steps = 0.1 * t_total

# Define params to be learned
optimizer_params = []
optimizer_params += [{'params': subspace_proj.rotate_layer.parameters()}]
optimizer_params += [{'params': subspace_proj.intervention_boundaries, 'lr': boundary_lr}]

optimizer = torch.optim.Adam(
    optimizer_params,
    lr=rotate_lr,
)

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

target_total_step = len(train_dataloader) * epochs

temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to(device)

total_step = 0
subspace_proj.set_temperature(temperature_schedule[total_step])

subspace_proj.train() # train enables drop-off but no grads

print("subspace_proj trainable parameters: ", count_parameters(subspace_proj))
    
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)

training_log_dicts = None

training_log_dicts = []

subspace_proj trainable parameters:  16777218


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

In [20]:
def batch_subspace_swap_multilayer(batch, layers, model:nnsight.LanguageModel, subspace_proj): #, batch_size=16
    """
    Batched subspace_swap intervention at a single layer using nnsight
    """
    batch_size = len(batch['base_input_ids'])
    all_inds = torch.arange(batch_size)
        
    base_prompt, source_prompt = batch['base_input_ids'][:batch_size], batch['source_input_ids'][:batch_size]
    base_intervention_token_idx, source_intervention_token_idx = batch['base_predictive_token_idxs'][:batch_size], batch['source_predictive_token_idxs'][:batch_size]
    
    bases, sources = [], []

    for layer in layers:
        with model.trace(validate=False) as tracer:
            with tracer.invoke(base_prompt, scan=False):
                base = model.model.layers[layer].output[0].save()
                bases.append(base)
            
            with tracer.invoke(source_prompt, scan=False):
                source = model.model.layers[layer].output[0].save()
                sources.append(source)
    
    with model.trace(validate=False) as tracer:
        # intervention
        with tracer.invoke(base_prompt, scan=False):
            
            for layer, base, source in zip(layers, bases, sources):
                B = base[all_inds,base_intervention_token_idx,:]
                S = source[all_inds,source_intervention_token_idx,:]

                mixed_out = subspace_proj(B, S, batch_size)
                model.model.layers[layer].output[0][all_inds,base_intervention_token_idx,:] = mixed_out
                del base, source, B,S
                
        save_out = model.output.save()
    
    
    output_logits = save_out.value.logits
    del save_out
    return output_logits

In [21]:
log_dicts = []

epoch_iterator = tqdm(
    train_dataloader, desc=f"Epoch: {0}", position=0, leave=True
)

for step, inputs in enumerate(epoch_iterator):
    for k, v in inputs.items():
        if v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
    # b_s = inputs["base_input_ids"].shape[0]

    counterfactual_outputs = batch_subspace_swap_multilayer(inputs, edit_layers, model, subspace_proj)
    break

Epoch: 0:   0%|          | 0/1024 [00:20<?, ?it/s]
