In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, re, json
import torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention
from src.utils.model_utils import load_gpt_model_and_tokenizer
from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.utils.eval_utils import decode_to_vocab, sentence_eval

from src.ltv_utils.data import set_seed, sample_data, forward_pass
from src.ltv_utils.LTV import LearnableTaskVector
from src.ltv_utils.extract_utils import get_attn_out, get_head_activations_on_prompt
from src.ltv_utils.intervention_utils import ltv_intervention, ltv_intervention_natural_text

GPU_IDX = 0
SEED = 0

device = torch.device(f"cuda:{GPU_IDX}")
set_seed(SEED)

## Load model & tokenizer

In [3]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)
EDIT_LAYER = 9

n_layers = model_config['n_layers']
resid_dim = model_config['resid_dim']
n_heads = model_config['n_heads']
head_dim = resid_dim // n_heads

## Load dataset and Compute task-conditioned mean activations

In [4]:
task_name = 'antonym'
act_fn = None
loss_fn = F.cross_entropy
batch_size = 100

dataset = load_dataset(task_name, seed=0)
mean_activations, _ = get_mean_head_activations(dataset, model, model_config, tokenizer)

In [5]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

## Compute function vector (FV)

## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text

In [18]:
# Sample ICL example pairs, and a test word
n_examples = 5
test_idx = np.random.randint(0, len(dataset['test']))

word_pairs = dataset['train'][np.random.choice(len(dataset['train']), n_examples, replace=False)]
test_pair = dataset['test'][test_idx]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

## Evaluation

### Clean ICL Prompt

In [7]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

In [8]:
lt_seq_len = 5

current_directory = os.getcwd()
path_to_model = os.path.join(os.path.dirname(current_directory), f"src/LTV_models/{task_name}/{act_fn}/seq_len_{lt_seq_len}")
path_to_model = os.path.join(path_to_model, f"ltv_layer_{lt_seq_len}.pth")

ltv_layer = LearnableTaskVector(n_layers, n_heads, head_dim).to(device)
ltv_params = torch.load(path_to_model)
ltv_layer.load_state_dict(ltv_params)

In [9]:
# batch_size = 100
# mean_activations, data = get_mean_head_activations(dataset, model, model_config, tokenizer,
#                                                        n_icl_examples=20, N_TRIALS=100,
#                                                        batch_structure=True)
# mean_activations = get_head_activations_on_prompt(shuffled_prompt_data, model, model_config, tokenizer)
with torch.no_grad():
    attn_out = get_attn_out(mean_activations, model, model_config)
    lt_vector = ltv_layer.forward(attn_out)
    lt_vector = lt_vector.squeeze()

### Prepare the trained LTV Layer

### Corrupted ICL Prompt

In [19]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits_fv = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
_, interv_logits_ltv = ltv_intervention(shuffled_sentence, [test_pair['output']], lt_vector, model, model_config, 
                                                   tokenizer)
                                                   
print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Few-Shot-Shuffled Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Shuffled Prompt+FV Top K Vocab Probs:\n", decode_to_vocab(interv_logits_fv, tokenizer, k=5), '\n')
print("Shuffled Prompt+LTV Top K Vocab Probs:\n", decode_to_vocab(interv_logits_ltv, tokenizer, k=5), '\n')

target_idx = tokenizer.encode(" " + test_pair['output'])
target_idx = torch.tensor(target_idx, dtype=torch.int64).to(device)
clean_loss = loss_fn(clean_logits, target_idx).detach().item()
fv_loss = loss_fn(interv_logits_fv, target_idx).detach().item()
ltv_loss = loss_fn(interv_logits_ltv, target_idx).detach().item()

print(f"Clean loss: {clean_loss:.4f}")
print(f"Function vector loss: {fv_loss:.4f}")
print(f"Learnable task vector loss: {ltv_loss:.4f}\n")

In [12]:
clean_logits.shape

In [23]:
compute_perplexity(clean_logits, target_idx)

In [14]:
def compute_perplexity(logits, target_index):
    # logits: Tensor of shape [batch_size, num_tokens]
    # target_index: Tensor of shape [batch_size] containing indices of the correct words
    
    # Step 1: Calculate the log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Step 2: Gather the log probabilities of the correct words
    # target_index needs to be unsqueezed to use gather, which requires the index tensor to have the same dimensions as logits
    log_probs_target = log_probs.gather(dim=-1, index=target_index.unsqueeze(-1)).squeeze(-1)
    
    # Step 3: Compute the average negative log probability
    avg_neg_log_prob = -log_probs_target.mean()
    
    # Step 4: Calculate the perplexity
    perplexity = torch.exp(avg_neg_log_prob)
    
    return perplexity

In [29]:
x = 0

In [30]:
x += decode_to_vocab(clean_logits, tokenizer, k=1)[0][0].split(' ')[-1] == test_pair['output']

In [26]:
test_pair['output']

In [31]:
x

### Zero-Shot Prompt

In [16]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits_fv = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
clean_output, interv_logits_ltv = ltv_intervention(zeroshot_sentence, [test_pair['output']], lt_vector, model, model_config, tokenizer)

print("Input Sentence:", repr(zeroshot_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Zero-Shot Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Zero-Shot+FV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits_fv, tokenizer, k=5), '\n')
print("Zero-Shot+LTV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits_ltv, tokenizer, k=5), '\n')

target_idx = tokenizer.convert_tokens_to_ids(test_pair['output'])
target_idx = torch.tensor([target_idx], dtype=torch.int64).to(device)
clean_loss = loss_fn(clean_logits, target_idx).detach().item()
fv_loss = loss_fn(interv_logits_fv, target_idx).detach().item()
ltv_loss = loss_fn(interv_logits_ltv, target_idx).detach().item()

print(f"Clean loss: {clean_loss:.4f}")
print(f"Function vector loss: {fv_loss:.4f}")
print(f"Learnable task vector loss: {ltv_loss:.4f}")

In [12]:
mean_activations = get_head_activations_on_prompt(prompt_data, model, model_config, tokenizer)
attn_out = get_attn_out(mean_activations, model, model_config)
lt_vector = ltv_layer.forward(attn_out).squeeze()

### Natural Text Prompt

In [17]:
sentence = f"The word \"{test_pair['input']}\" means"
co, io_fv = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)
_, io_ltv = ltv_intervention_natural_text(sentence, lt_vector, model, model_config, tokenizer, max_new_tokens=10)


print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io_fv.squeeze())))
print("GPT-J+LTV:", repr(tokenizer.decode(io_ltv.squeeze())))