In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
sys.path.append("../..")

import torch
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss

from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import BoundlessRotatedSpaceIntervention
from models.llama.modelings_alignable_llama import create_llama
from models.basic_utils import set_seed, count_parameters

from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from utils.extract_utils import *

In [2]:
n_icl_examples = 10
N_TRIALS = 512

prefixes = {"input":"Q:", "output":"A:", "instructions":""}
separators = {"input":"\n", "output":"\n\n", "instructions":""}

In [3]:
from transformers import LlamaTokenizer, LlamaForCausalLM

tokenizer = LlamaTokenizer.from_pretrained("/data/public_models/llama/llama_hf_weights/llama-7b/")
llama = LlamaForCausalLM.from_pretrained("/data/public_models/llama/llama_hf_weights/llama-7b/")

_ = llama.to("cuda") # single gpu
_ = llama.eval()         # always no grad on the model

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [4]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

In [5]:
dataset = load_dataset("antonym", root_data_dir="../dataset_files", test_size=0.3, seed=42)

In [6]:
def collate_fn(batch):
    
    if len(batch[0].keys()) == 2:
        
        input_ids, labels = tuple([data_pair[key] for data_pair in batch] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(tokenizer.pad_token_id),
        )
        
    elif len(batch[0].keys()) == 5:
        input_ids, labels, source_input_ids, source_predictive_token_idxs, predictive_token_idxs = tuple([data_pair[key] for data_pair in batch] for key in ('input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        
        source_input_ids = torch.nn.utils.rnn.pad_sequence(
            source_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        source_predictive_token_idxs = torch.LongTensor(source_predictive_token_idxs)
        predictive_token_idxs = torch.LongTensor(predictive_token_idxs)
        
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(tokenizer.pad_token_id),
            source_input_ids=source_input_ids,
            source_attention_mask=source_input_ids.ne(tokenizer.pad_token_id),
            predictive_token_idxs=predictive_token_idxs,
            source_predictive_token_idxs=source_predictive_token_idxs
        )

In [7]:
from torch.utils.data import DataLoader

if prefixes is not None and separators is not None:
    dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, prefixes=prefixes, separators=separators)
else:
    dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer)
    
filter_set = filter_set = np.arange(len(dataset['valid']))


torch_dataset = []

for n in range(N_TRIALS):
    
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs_test = dataset['valid'][np.random.choice(filter_set, 1, replace=False)]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, 
                                                        shuffle_labels=False, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']
    _, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    
    data_pair = preprocess([prompt_string], [target], tokenizer) 
    torch_dataset.append(data_pair)
    
torch_dataset = Dataset.from_list(torch_dataset)
torch_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

dataloader = DataLoader(torch_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [8]:
total_count = 0
correct_count = 0
with torch.no_grad():
    for step, inputs in enumerate(tqdm(dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama.device)
                    
        # aligning forward!
        outputs = llama(
            input_ids=inputs['input_ids'],
            labels=inputs['labels'],
            attention_mask=inputs['attention_mask']
        )
        
        for i in range(inputs['input_ids'].shape[0]):
            label_idxs = inputs['labels'][i].ne(IGNORE_INDEX).nonzero().squeeze(-1)
            # label_idxs = label_idxs[1: ]
            left_shifted_idxs = label_idxs - 1
            
            actual_test_labels = inputs['labels'][i][label_idxs].tolist()
            pred_test_labels = [outputs.logits[i][idx].argmax(dim=-1) for idx in left_shifted_idxs]
            
            correct = (actual_test_labels==pred_test_labels)

            total_count += 1
            if correct:
                correct_count += 1
                
current_acc = round(correct_count/total_count, 2)
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}")

100%|██████████| 16/16 [00:47<00:00,  2.94s/it]






In [9]:
# Generate trainset and valset for Boundless DAS

das_train_set = []
das_eval_set = []
zs_das_eval_set = []

for n in range(N_TRIALS):
    
    noninformative_word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]

    word_pairs_test = dataset['valid'][np.random.choice(filter_set, 1, replace=False)]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=False, prefixes=prefixes, separators=separators)
    noninformative_prompt_data = word_pairs_to_prompt_data(noninformative_word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=True, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']

    source_token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    token_labels, noninformative_prompt_string = get_token_meta_labels(noninformative_prompt_data, tokenizer, query)

    data_pair = preprocess([noninformative_prompt_string], [target], tokenizer)
    data_pair["source_input_ids"] = preprocess([prompt_string], [target], tokenizer)["input_ids"]
    
    assert source_token_labels[-1][2] == "query_predictive_token"
    source_predictive_token_idxs = source_token_labels[-1][0]
    data_pair["source_predictive_token_idxs"] = source_predictive_token_idxs
    
    assert token_labels[-1][2] == "query_predictive_token"
    predictive_token_idxs = token_labels[-1][0]
    data_pair["predictive_token_idxs"] = predictive_token_idxs
    
    das_train_set.append(data_pair)

for n in range(len(dataset["test"])):
    
    noninformative_word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]

    word_pairs_test = dataset['test'][n]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=False, prefixes=prefixes, separators=separators)
    noninformative_prompt_data = word_pairs_to_prompt_data(noninformative_word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=True, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']

    source_token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    token_labels, noninformative_prompt_string = get_token_meta_labels(noninformative_prompt_data, tokenizer, query)

    data_pair = preprocess([noninformative_prompt_string], [target], tokenizer)
    data_pair["source_input_ids"] = preprocess([prompt_string], [target], tokenizer)["input_ids"]
    
    assert source_token_labels[-1][2] == "query_predictive_token"
    source_predictive_token_idxs = source_token_labels[-1][0]
    data_pair["source_predictive_token_idxs"] = source_predictive_token_idxs
    
    assert token_labels[-1][2] == "query_predictive_token"
    predictive_token_idxs = token_labels[-1][0]
    data_pair["predictive_token_idxs"] = predictive_token_idxs
    
    das_eval_set.append(data_pair)
    

for n in range(len(dataset["test"])):
    
    zs_word_pairs = word_pairs = {'input':[], 'output':[]}
    word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]

    word_pairs_test = dataset['test'][n]

    prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=False, prefixes=prefixes, separators=separators)
    zs_prompt_data = word_pairs_to_prompt_data(zs_word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=False, shuffle_labels=True, prefixes=prefixes, separators=separators)

    query = prompt_data['query_target']['input']
    target = prompt_data['query_target']['output']

    source_token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    token_labels, zs_prompt_string = get_token_meta_labels(zs_prompt_data, tokenizer, query)

    data_pair = preprocess([zs_prompt_string], [target], tokenizer)
    data_pair["source_input_ids"] = preprocess([prompt_string], [target], tokenizer)["input_ids"]
    
    assert source_token_labels[-1][2] == "query_predictive_token"
    source_predictive_token_idxs = source_token_labels[-1][0]
    data_pair["source_predictive_token_idxs"] = source_predictive_token_idxs
    
    assert token_labels[-1][2] == "query_predictive_token"
    predictive_token_idxs = token_labels[-1][0]
    data_pair["predictive_token_idxs"] = predictive_token_idxs
    
    zs_das_eval_set.append(data_pair)
    

das_train_set = Dataset.from_list(das_train_set)
das_eval_set = Dataset.from_list(das_eval_set)
zs_das_eval_set = Dataset.from_list(zs_das_eval_set)
das_train_set.set_format(type='torch', columns=['input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'])
train_dataloader = DataLoader(das_train_set, batch_size=32, shuffle=False, collate_fn=collate_fn)
das_eval_set.set_format(type='torch', columns=['input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'])
eval_dataloader = DataLoader(das_eval_set, batch_size=32, shuffle=False, collate_fn=collate_fn)
zs_das_eval_set.set_format(type='torch', columns=['input_ids', 'labels', 'source_input_ids', 'source_predictive_token_idxs', 'predictive_token_idxs'])
zs_eval_dataloader = DataLoader(zs_das_eval_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [30]:
idx = 1
das_train_set[idx]["input_ids"][]

tensor(2868)

In [31]:
das_train_set[idx]["labels"][-1]

tensor(2868)

In [27]:
tokenizer.decode([29901])

':'

In [10]:
def simple_boundless_das_position_config(model_type, intervention_type, layer):
    alignable_config = AlignableConfig(
        alignable_model_type=model_type,
        alignable_representations=[
            AlignableRepresentationConfig(
                layer,             # layer
                intervention_type, # intervention type
                "pos",             # intervention unit
                1                  # max number of unit
            ),
        ],
        alignable_interventions_type=BoundlessRotatedSpaceIntervention,
    )
    return alignable_config

alignable_config = simple_boundless_das_position_config(type(llama), "block_output", 15)
alignable = AlignableModel(alignable_config, llama)
alignable.set_device("cuda")
alignable.disable_model_gradients()

In [12]:
epochs = 25
t_total = int(len(das_train_set) * epochs)
warm_up_steps = 0.1 * t_total
optimizer_params = []

for k, v in alignable.interventions.items():
    optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
    optimizer_params += [{'params': v[0].intervention_boundaries, 'lr': 1e-2}]
    
optimizer = torch.optim.Adam(
    optimizer_params,
    lr=1e-3
)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warm_up_steps,
    num_training_steps=t_total
)

# You can define your custom compute_metrics function.
def compute_metrics(eval_preds, eval_labels):
    total_count = 0
    correct_count = 0
    for eval_pred, eval_label in zip(eval_preds, eval_labels):
        
        for i in range(eval_label.shape[0]):
            label_idxs = eval_label[i].ne(IGNORE_INDEX).nonzero().squeeze(-1)
            # label_idxs = label_idxs[1: ]
            left_shifted_idxs = label_idxs - 1
            
            actual_test_labels = eval_label[i][label_idxs].tolist()
            pred_test_labels = [eval_pred[i][idx].argmax(dim=-1) for idx in left_shifted_idxs]
            
            correct = (actual_test_labels==pred_test_labels)

            total_count += 1
            if correct:
                correct_count += 1
    accuracy = round(correct_count/total_count, 2)
    return {"accuracy" : accuracy}

gradient_accumulation_steps = 4
total_step = 0
target_total_step = len(das_train_set) * epochs
temperature_start = 50.0
temperature_end = 0.1
temperature_schedule = torch.linspace(
    temperature_start, temperature_end, target_total_step
).to(torch.bfloat16).to("cuda")
alignable.set_temperature(temperature_schedule[total_step])

def calculate_loss(logits, labels):
    shift_logits = logits[..., :, :].contiguous()
    shift_labels = labels[..., :].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, alignable.model_config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    
    for k, v in alignable.interventions.items():
        boundary_loss = 1. * v[0].intervention_boundaries.sum()
    loss += boundary_loss
    
    return loss

In [13]:
alignable.model.train() # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(alignable.model))
print("intervention trainable parameters: ", alignable.count_parameters())
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        
        
        source2base = ([[[idx] for idx in inputs["source_predictive_token_idxs"].tolist()]], [[[idx] for idx in inputs["predictive_token_idxs"].tolist()]])
        
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            [{"input_ids": inputs["source_input_ids"], "attention_mask": inputs["source_attention_mask"]}],
            {"sources->base": source2base}
        )
        eval_metrics = compute_metrics(
            [counterfactual_outputs.logits], [inputs['labels']]
        )
        
        # loss and backprop
        loss = calculate_loss(
            counterfactual_outputs.logits, inputs["labels"]
        )
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({'loss': loss_str, 'acc': eval_metrics["accuracy"]})
        
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                alignable.set_zero_grad()
                alignable.set_temperature(temperature_schedule[total_step])
        total_step += 1

llama trainable parameters:  0
intervention trainable parameters:  16777218


Epoch: 0: 100%|██████████| 16/16 [03:07<00:00, 11.72s/it, loss=18, acc=0.5]   
Epoch: 1: 100%|██████████| 16/16 [03:08<00:00, 11.76s/it, loss=17.9, acc=0.5] 
Epoch: 2: 100%|██████████| 16/16 [03:08<00:00, 11.81s/it, loss=17.8, acc=0.5] 
Epoch: 3: 100%|██████████| 16/16 [03:08<00:00, 11.80s/it, loss=17.5, acc=0.5] 
Epoch: 4: 100%|██████████| 16/16 [03:08<00:00, 11.80s/it, loss=17.2, acc=0.47]
Epoch: 5: 100%|██████████| 16/16 [03:09<00:00, 11.84s/it, loss=16.9, acc=0.44]
Epoch: 6: 100%|██████████| 16/16 [03:09<00:00, 11.83s/it, loss=16.3, acc=0.34]
Epoch: 7: 100%|██████████| 16/16 [03:10<00:00, 11.90s/it, loss=15.4, acc=0.31]
Epoch: 8: 100%|██████████| 16/16 [03:12<00:00, 12.02s/it, loss=14.3, acc=0.31]
Epoch: 9: 100%|██████████| 16/16 [03:13<00:00, 12.08s/it, loss=13.1, acc=0.22]
Epoch: 10: 100%|██████████| 16/16 [03:13<00:00, 12.07s/it, loss=12.2, acc=0.22]
Epoch: 11: 100%|██████████| 16/16 [03:13<00:00, 12.10s/it, loss=11.4, acc=0.22]
Epoch: 12: 100%|██████████| 16/16 [03:14<00:00, 12

In [14]:
# evaluation on the test set
eval_labels = []
eval_preds = []
with torch.no_grad():
    epoch_iterator = tqdm(eval_dataloader, desc=f"Test")
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        
        source2base = ([[[idx] for idx in inputs["source_predictive_token_idxs"].tolist()]], [[[idx] for idx in inputs["predictive_token_idxs"].tolist()]])
        
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            [{"input_ids": inputs["source_input_ids"], "attention_mask": inputs["source_attention_mask"]}],
            {"sources->base": source2base}
        )
        
        eval_labels += [inputs['labels']]
        eval_preds += [counterfactual_outputs.logits]
eval_metrics = compute_metrics(eval_preds, eval_labels)
print("Few-shot DAS Eval IIA:")
print(eval_metrics)

Test: 100%|██████████| 16/16 [02:19<00:00,  8.70s/it]


Few-shot DAS Eval IIA:
{'accuracy': 0.3}


In [15]:
total_count = 0
correct_count = 0
with torch.no_grad():
    for step, inputs in enumerate(tqdm(eval_dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama.device)
                    
        # aligning forward!
        outputs = llama(
            input_ids=inputs['input_ids'],
            labels=inputs['labels'],
            attention_mask=inputs['attention_mask']
        )
        
        for i in range(inputs['input_ids'].shape[0]):
            label_idxs = inputs['labels'][i].ne(IGNORE_INDEX).nonzero().squeeze(-1)
            # label_idxs = label_idxs[1: ]
            left_shifted_idxs = label_idxs - 1
            
            actual_test_labels = inputs['labels'][i][label_idxs].tolist()
            pred_test_labels = [outputs.logits[i][idx].argmax(dim=-1) for idx in left_shifted_idxs]
            
            correct = (actual_test_labels==pred_test_labels)

            total_count += 1
            if correct:
                correct_count += 1
                
current_acc = round(correct_count/total_count, 2)
print("Few-shot Original Eval Accuracy:")
print(current_acc)

100%|██████████| 16/16 [00:45<00:00,  2.84s/it]

Few-shot Original Eval Accuracy:
0.4





In [16]:
# evaluation on the test set
eval_labels = []
eval_preds = []
with torch.no_grad():
    epoch_iterator = tqdm(zs_eval_dataloader, desc=f"Test")
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cuda")
        b_s = inputs["input_ids"].shape[0]
        
        source2base = ([[[idx] for idx in inputs["source_predictive_token_idxs"].tolist()]], [[[idx] for idx in inputs["predictive_token_idxs"].tolist()]])
        
        _, counterfactual_outputs = alignable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            [{"input_ids": inputs["source_input_ids"], "attention_mask": inputs["source_attention_mask"]}],
            {"sources->base": source2base}
        )
        
        eval_labels += [inputs['labels']]
        eval_preds += [counterfactual_outputs.logits]
eval_metrics = compute_metrics(eval_preds, eval_labels)
print("Zero-shot DAS Eval IIA:")
print(eval_metrics)

Test: 100%|██████████| 16/16 [01:00<00:00,  3.78s/it]


Zero-shot DAS Eval IIA:
{'accuracy': 0.4}


In [17]:
total_count = 0
correct_count = 0
with torch.no_grad():
    for step, inputs in enumerate(tqdm(zs_eval_dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama.device)
                    
        # aligning forward!
        outputs = llama(
            input_ids=inputs['input_ids'],
            labels=inputs['labels'],
            attention_mask=inputs['attention_mask']
        )
        
        for i in range(inputs['input_ids'].shape[0]):
            label_idxs = inputs['labels'][i].ne(IGNORE_INDEX).nonzero().squeeze(-1)
            # label_idxs = label_idxs[1: ]
            left_shifted_idxs = label_idxs - 1
            
            actual_test_labels = inputs['labels'][i][label_idxs].tolist()
            pred_test_labels = [outputs.logits[i][idx].argmax(dim=-1) for idx in left_shifted_idxs]
            
            correct = (actual_test_labels==pred_test_labels)

            total_count += 1
            if correct:
                correct_count += 1
                
current_acc = round(correct_count/total_count, 2)
print("Zero-shot Original Eval Accuracy:")
print(current_acc)

100%|██████████| 16/16 [00:05<00:00,  2.91it/s]

Zero-shot Original Eval Accuracy:
0.02



