In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import Dataset, concatenate_datasets

import sys
sys.path.append("../..")

import re
import torch
import seaborn as sns
from tqdm import tqdm, trange

from transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from pyvene.models.configuration_intervenable_model import RepresentationConfig, IntervenableConfig
from pyvene.models.intervenable_base import IntervenableModel
from pyvene.models.interventions import BoundlessRotatedSpaceIntervention
from pyvene.models.basic_utils import set_seed, count_parameters


from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *
from utils.das_utils import *

import argparse

In [3]:
model, tokenizer, model_config = load_gpt_model_and_tokenizer("/data/public_models/mistral/mistral-7b-instruct-v0.1", device="cuda")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

def intervention_collate_fn(batch):
    base_input_ids, base_labels, source_input_ids, source_labels, source_predictive_token_idxs, predictive_token_idxs = tuple(
        [data_pair[key] for data_pair in batch] for key in 
        ('base_input_ids', 'base_labels', 'source_input_ids', 'source_labels', 'source_predictive_token_idxs', 'predictive_token_idxs')
    )
    
    base_input_ids = torch.nn.utils.rnn.pad_sequence(
        base_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    
    source_input_ids = torch.nn.utils.rnn.pad_sequence(
        source_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    
    base_labels = torch.nn.utils.rnn.pad_sequence(base_labels, batch_first=True, padding_value=IGNORE_INDEX)
    source_labels = torch.nn.utils.rnn.pad_sequence(source_labels, batch_first=True, padding_value=IGNORE_INDEX)
    
    source_predictive_token_idxs = torch.LongTensor(source_predictive_token_idxs)
    predictive_token_idxs = torch.LongTensor(predictive_token_idxs)
    
    return dict(
        base_input_ids=base_input_ids,
        base_labels=base_labels,
        base_attention_mask=base_input_ids.ne(tokenizer.pad_token_id),
        source_input_ids=source_input_ids,
        source_labels=source_labels,
        source_attention_mask=source_input_ids.ne(tokenizer.pad_token_id),
        predictive_token_idxs=predictive_token_idxs,
        source_predictive_token_idxs=source_predictive_token_idxs
    )

Loading:  /data/public_models/mistral/mistral-7b-instruct-v0.1
/data/public_models/mistral/mistral-7b-instruct-v0.1


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [4]:
intervenable_config = simple_boundless_das_position_config(type(model), "block_output", 15)
intervenable = IntervenableModel(intervenable_config, model)
intervenable.set_device("cuda")
intervenable.disable_model_gradients()

In [5]:
HELD_IN_DATASETS = [f.replace(".json", "") for f in  os.listdir("../dataset_files/abstractive") if f not in 
                    ["antonym.json", "capitalize.json", "present-past.json", 
                     "english-french.json", "singular-plural.json", "country-capital.json", 
                     "ag_news.json", "commonsense_qa.json", "sentiment.json"]]

print(HELD_IN_DATASETS)

['park-country', 'person-sport', 'lowercase_first_letter', 'landmark-country', 'national_parks', 'person-instrument', 'product-company', 'prev_item', 'synonym', 'english-spanish', 'country-currency', 'capitalize_first_letter', 'next_item', 'person-occupation', 'english-german', 'lowercase_last_letter']


In [22]:
set_seed(42)
prefixes = {"input": "Word:", "output": "Letter:", "instructions": "What is the first lowercase letter in the input?"}
separators = {"input":"\n", "output":"\n\n", "instructions":"\n\n"}
dataset_name = "lowercase_last_letter"

dataset = load_dataset(dataset_name, root_data_dir="../dataset_files", test_size=0.3, seed=42)

eval_no_intervention_dataloader = process_dataloader(dataset, model_config, tokenizer, 16, 0, "valid", prefixes, separators, intervention_collate_fn, ablation_method="zero_shot")
eval_dict = evaluate(intervenable, eval_no_intervention_dataloader, device=model.device, intervene=False, corrupt=False, generate_output=True)
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {eval_dict['accuracy']}")

torch.cuda.empty_cache()

100%|██████████| 5/5 [00:04<00:00,  1.19it/s]






In [21]:
print([tokenizer.decode(eval_dict["outputs"][i]) for i in range(10)])
print([tokenizer.decode(eval_dict["labels"][i]) for i in range(10)])

['e', 'i', 'o', 'f', 'o', 't', 'n', 'u', 'o', 't']
['e', 'e', 'h', 'l', 'k', 't', 't', 'l', 'g', 't']


In [11]:
def store_template_information():
    root_dir = f"../template_files/{dataset_name}"
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)
    
    all_templates = os.listdir(root_dir)
    
    all_templates = [int(template) for template in all_templates]
    next_template = max(all_templates) + 1 if len(all_templates) > 0 else 1
    
    os.makedirs(f"{root_dir}/{next_template}")
    json.dump(prefixes, open(f"{root_dir}/{next_template}/prefixes.json", "w"))
    json.dump(separators, open(f"{root_dir}/{next_template}/separators.json", "w"))

In [40]:
store_template_information()