In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
sys.path.append("../..")

import torch
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss

from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import BoundlessRotatedSpaceIntervention
from models.llama.modelings_alignable_llama import create_llama
from models.basic_utils import set_seed, count_parameters

from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *

In [3]:
n_icl_examples = 10
N_TRIALS = 100

prefixes = {"input":"Q:", "output":"A:", "instructions":""}
separators = {"input":"\n", "output":"\n\n", "instructions":""}

In [4]:
from transformers import LlamaTokenizer, LlamaForCausalLM

tokenizer = LlamaTokenizer.from_pretrained("/data/public_models/llama/llama_hf_weights/llama-7b/")
llama = LlamaForCausalLM.from_pretrained("/data/public_models/llama/llama_hf_weights/llama-7b/")

_ = llama.to("cuda") # single gpu
_ = llama.eval()         # always no grad on the model

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [5]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

In [6]:
dataset = load_dataset("antonym", root_data_dir="../dataset_files", test_size=0.3, seed=42)

In [8]:
from torch.utils.data import DataLoader

if prefixes is not None and separators is not None:
    dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, prefixes=prefixes, separators=separators)
else:
    dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer)
    
filter_set = filter_set = np.arange(len(dataset['valid']))


torch_dataset = []

for n in range(N_TRIALS):
    
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs_test = dataset['valid'][np.random.choice(filter_set, 1, replace=False)]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, 
                                                        shuffle_labels=False, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']
    _, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    
    data_pair = preprocess([prompt_string], [target], tokenizer)        
    torch_dataset.append(data_pair)
    
torch_dataset = Dataset.from_list(torch_dataset)
torch_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

In [9]:
total_count = 0
correct_count = 0
with torch.no_grad():
    for step, inputs in enumerate(tqdm(torch_dataset)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama.device)
        
        # aligning forward!
        outputs = llama(
            input_ids=inputs['input_ids'],
            labels=inputs['labels'],
        )
        
        actual_test_labels = inputs['labels'][:, -1]
        pred_test_labels = outputs.logits[:, -2].argmax(dim=-1)    

        correct_labels = (actual_test_labels==pred_test_labels)

        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()
current_acc = round(correct_count/total_count, 2)
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}")

100%|██████████| 100/100 [00:12<00:00,  8.28it/s]






In [13]:
# Generate trainset and valset for Boundless DAS

das_train_set = []
das_eval_set = []

for n in range(N_TRIALS):
    
    noninformative_word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]

    word_pairs_test = dataset['valid'][np.random.choice(filter_set, 1, replace=False)]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=False, prefixes=prefixes, separators=separators)
    noninformative_prompt_data = word_pairs_to_prompt_data(noninformative_word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=True, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']

    source_token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    token_labels, noninformative_prompt_string = get_token_meta_labels(noninformative_prompt_data, tokenizer, query)

    data_pair = preprocess([noninformative_prompt_string], [target], tokenizer)
    data_pair["source_input_ids"] = preprocess([prompt_string], [target], tokenizer)["input_ids"]
    
    assert source_token_labels[-1][2] == "query_predictive_token"
    source_predictive_token_idxs = source_token_labels[-1][0]
    data_pair["source_predictive_token_idxs"] = source_predictive_token_idxs
    
    assert token_labels[-1][2] == "query_predictive_token"
    predictive_token_idxs = token_labels[-1][0]
    data_pair["predictive_token_idxs"] = predictive_token_idxs
    
    das_train_set.append(data_pair)

for n in range(N_TRIALS):
    noninformative_word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]

    word_pairs_test = dataset['test'][np.random.choice(filter_set, 1, replace=False)]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=False, prefixes=prefixes, separators=separators)
    noninformative_prompt_data = word_pairs_to_prompt_data(noninformative_word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=True, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']

    source_token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    token_labels, noninformative_prompt_string = get_token_meta_labels(noninformative_prompt_data, tokenizer, query)

    data_pair = preprocess([noninformative_prompt_string], [target], tokenizer)
    data_pair["source_input_ids"] = preprocess([prompt_string], [target], tokenizer)["input_ids"]
    
    assert source_token_labels[-1][2] == "query_predictive_token"
    source_predictive_token_idxs = source_token_labels[-1][0]
    data_pair["source_predictive_token_idxs"] = source_predictive_token_idxs
    
    assert token_labels[-1][2] == "query_predictive_token"
    predictive_token_idxs = token_labels[-1][0]
    data_pair["predictive_token_idxs"] = predictive_token_idxs
    
    das_eval_set.append(data_pair)
    

das_train_set = Dataset.from_list(das_train_set)
das_eval_set = Dataset.from_list(das_eval_set)
das_train_set.set_format(type='torch', columns=['input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'])
das_eval_set.set_format(type='torch', columns=['input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'])

In [14]:
def simple_boundless_das_position_config(model_type, intervention_type, layer):
    alignable_config = AlignableConfig(
        alignable_model_type=model_type,
        alignable_representations=[
            AlignableRepresentationConfig(
                layer,             # layer
                intervention_type, # intervention type
                "pos",             # intervention unit
                1                  # max number of unit
            ),
        ],
        alignable_interventions_type=BoundlessRotatedSpaceIntervention,
    )
    return alignable_config

alignable_config = simple_boundless_das_position_config(type(llama), "block_output", 15)
alignable = AlignableModel(alignable_config, llama)
alignable.set_device("cuda")
alignable.disable_model_gradients()

In [16]:
t_total = int(len(das_train_set) * 3)
warm_up_steps = 0.1 * t_total
optimizer_params = []

for k, v in alignable.interventions.items():
    optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
    optimizer_params += [{'params': v[0].intervention_boundaries, 'lr': 1e-2}]
    
optimizer = torch.optim.Adam(
    optimizer_params,
    lr=1e-3
)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

# You can define your custom compute_metrics function.
def compute_metrics(eval_preds, eval_labels):
    total_count = 0
    correct_count = 0
    for eval_pred, eval_label in zip(eval_preds, eval_labels):
        actual_test_labels = eval_label[:, -1]
        pred_test_labels = torch.argmax(eval_pred[:, -2], dim=-1)
        correct_labels = (actual_test_labels==pred_test_labels)
        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()
    accuracy = round(correct_count/total_count, 2)
    return {"accuracy" : accuracy}

epochs = 3
gradient_accumulation_steps = 1
total_step = 0
target_total_step = len(das_train_set) * epochs
temperature_start = 50.0
temperature_end = 0.1
temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to("cuda")
alignable.set_temperature(temperature_schedule[total_step])

def calculate_loss(logits, labels):
    shift_logits = logits[..., :, :].contiguous()
    shift_labels = labels[..., :].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, alignable.model_config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    
    for k, v in alignable.interventions.items():
        boundary_loss = 1. * v[0].intervention_boundaries.sum()
    loss += boundary_loss
    
    return loss

In [17]:
alignable.model.train() # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(alignable.model))
print("intervention trainable parameters: ", alignable.count_parameters())
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)
for epoch in train_iterator:
    epoch_iterator = tqdm(
        das_train_set, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"]},
            [{"input_ids": inputs["source_input_ids"]}],
            {"sources->base": ([[[inputs["source_predictive_token_idxs"]]]*b_s], [[[inputs["predictive_token_idxs"]]]*b_s])}
        )
        eval_metrics = compute_metrics(
            [counterfactual_outputs.logits], [inputs['labels']]
        )
        
        # loss and backprop
        loss = calculate_loss(
            counterfactual_outputs.logits, inputs["labels"]
        )
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({'loss': loss_str, 'acc': eval_metrics["accuracy"]})
        
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                alignable.set_zero_grad()
                alignable.set_temperature(temperature_schedule[total_step])
        total_step += 1

llama trainable parameters:  0
intervention trainable parameters:  16777218


Epoch: 0: 100%|██████████| 100/100 [04:50<00:00,  2.91s/it, loss=5.37, acc=0]
Epoch: 1: 100%|██████████| 100/100 [05:03<00:00,  3.03s/it, loss=3.6, acc=0]
Epoch: 2: 100%|██████████| 100/100 [05:03<00:00,  3.04s/it, loss=1.88, acc=0]
Epoch: 100%|██████████| 3/3 [14:57<00:00, 299.25s/it]


In [18]:
# evaluation on the test set
eval_labels = []
eval_preds = []
with torch.no_grad():
    epoch_iterator = tqdm(das_eval_set, desc=f"Test")
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"]},
            [{"input_ids": inputs["source_input_ids"]}],
            {"sources->base": ([[[inputs["source_predictive_token_idxs"]]]*b_s], [[[inputs["predictive_token_idxs"]]]*b_s])} # swap 80th token
        )
        eval_labels += [inputs['labels']]
        eval_preds += [counterfactual_outputs.logits]
eval_metrics = compute_metrics(eval_preds, eval_labels)
print(eval_metrics)

Test: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]

{'accuracy': 0.58}



