In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, re, json
import torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F

import sys

sys.path.append('..')
torch.set_grad_enabled(False)

from src.FV_utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.FV_utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention
from src.FV_utils.model_utils import load_gpt_model_and_tokenizer
from src.FV_utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.FV_utils.eval_utils import decode_to_vocab, sentence_eval

from src.LTV_utils.data import set_seed, sample_data, forward_pass
from src.LTV_utils.LTV import LearnableTaskVector
from src.LTV_utils.extract_utils import get_attn_out, get_head_activations_on_prompt
from src.LTV_utils.intervention_utils import ltv_intervention, ltv_intervention_natural_text
from src.inference_utils import compute_perplexity

GPU_IDX = 0
SEED = 17

device = torch.device(f"cuda:{GPU_IDX}")
set_seed(SEED)

  from pandas.core import (


## Load model & tokenizer

In [3]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)
EDIT_LAYER = 9

n_layers = model_config['n_layers']
resid_dim = model_config['resid_dim']
n_heads = model_config['n_heads']
head_dim = resid_dim // n_heads

Loading:  EleutherAI/gpt-j-6b


## Load dataset and Compute task-conditioned mean activations

In [4]:
task_name = 'antonym'
act_fn = None
loss_fn = F.cross_entropy
batch_size = 100

dataset = load_dataset(task_name, root_data_dir='dataset_files', seed=0)
mean_activations, _ = get_mean_head_activations(dataset, model, model_config, tokenizer)

In [5]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

## Compute function vector (FV)

## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text

In [6]:
# Sample ICL example pairs, and a test word
n_examples = 5
test_idx = np.random.randint(0, len(dataset['test']))

word_pairs = dataset['train'][np.random.choice(len(dataset['train']), n_examples, replace=False)]
test_pair = dataset['test'][test_idx]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

ICL prompt:
 '<|endoftext|>Q: posterior\nA: anterior\n\nQ: extinct\nA: alive\n\nQ: latent\nA: manifest\n\nQ: able\nA: unable\n\nQ: expenditure\nA: income\n\nQ: aware\nA:' 


Shuffled ICL Prompt:
 '<|endoftext|>Q: posterior\nA: manifest\n\nQ: extinct\nA: unable\n\nQ: latent\nA: alive\n\nQ: able\nA: anterior\n\nQ: expenditure\nA: income\n\nQ: aware\nA:' 


Zero-Shot Prompt:
 '<|endoftext|>Q: aware\nA:'


## Evaluation

### Clean ICL Prompt

In [7]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

Input Sentence: '<|endoftext|>Q: posterior\nA: anterior\n\nQ: extinct\nA: alive\n\nQ: latent\nA: manifest\n\nQ: able\nA: unable\n\nQ: expenditure\nA: income\n\nQ: aware\nA:' 

Input Query: 'aware', Target: 'unaware'

ICL Prompt Top K Vocab Probs:
 [(' unaware', 0.48833), (' unconscious', 0.11493), (' ignorant', 0.0853), (' oblivious', 0.06093), (' un', 0.05989)] 



In [8]:
lt_seq_len = 5

current_directory = os.getcwd()
path_to_model = os.path.join(os.path.dirname(current_directory), f"language/src/LTV_models/{task_name}/{act_fn}/seq_len_{lt_seq_len}")
path_to_model = os.path.join(path_to_model, f"ltv_layer_{lt_seq_len}.pth")

ltv_layer = LearnableTaskVector(n_layers, n_heads, head_dim).to(device)
ltv_params = torch.load(path_to_model)
ltv_layer.load_state_dict(ltv_params)

<All keys matched successfully>

In [9]:
with torch.no_grad():
    attn_out = get_attn_out(mean_activations, model, model_config)
    lt_vector = ltv_layer.forward(attn_out)
    lt_vector = lt_vector.squeeze()

### Prepare the trained LTV Layer

### Corrupted ICL Prompt

In [10]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits_fv = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
_, interv_logits_ltv = ltv_intervention(shuffled_sentence, [test_pair['output']], lt_vector, model, model_config, 
                                                   tokenizer)
                                                   
print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Few-Shot-Shuffled Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Shuffled Prompt+FV Top K Vocab Probs:\n", decode_to_vocab(interv_logits_fv, tokenizer, k=5), '\n')
print("Shuffled Prompt+LTV Top K Vocab Probs:\n", decode_to_vocab(interv_logits_ltv, tokenizer, k=5), '\n')

target_idx = tokenizer.encode(" " + test_pair['output'])
target_idx = torch.tensor(target_idx, dtype=torch.int64).to(device)

clean_loss, clean_perplexity = loss_fn(clean_logits, target_idx), compute_perplexity(clean_logits, target_idx)
fv_loss, fv_perplexity = loss_fn(interv_logits_fv, target_idx), compute_perplexity(interv_logits_fv, target_idx)
ltv_loss, ltv_perplexity = loss_fn(interv_logits_ltv, target_idx), compute_perplexity(interv_logits_ltv, target_idx)

print(f"Vanilla transformer - loss: {clean_loss:.4f} \t perplexity: {clean_perplexity:.4f}")
print(f"Function Vector - loss: {fv_loss:.4f} \t perplexity: {fv_perplexity:.4f}")
print(f"Learnable Task Vector - loss: {ltv_loss:.4f} \t perplexity: {ltv_perplexity:.4f}")

Input Sentence: '<|endoftext|>Q: posterior\nA: manifest\n\nQ: extinct\nA: unable\n\nQ: latent\nA: alive\n\nQ: able\nA: anterior\n\nQ: expenditure\nA: income\n\nQ: aware\nA:' 

Input Query: 'aware', Target: 'unaware'

Few-Shot-Shuffled Prompt Top K Vocab Probs:
 [(' unaware', 0.0995), (' aware', 0.06035), (' ignorant', 0.0587), (' unconscious', 0.05312), (' conscious', 0.05309)] 

Shuffled Prompt+FV Top K Vocab Probs:
 [(' unaware', 0.43724), (' ignorant', 0.09005), (' unconscious', 0.07406), (' oblivious', 0.04356), (' un', 0.04091)] 

Shuffled Prompt+LTV Top K Vocab Probs:
 [(' unaware', 0.64999), (' unconscious', 0.15438), (' un', 0.04982), (' ignorant', 0.03715), (' oblivious', 0.0216)] 

Vanilla transformer - loss: 2.3076 	 perplexity: 10.0502
Function Vector - loss: 0.8273 	 perplexity: 2.2871
Learnable Task Vector - loss: 0.4308 	 perplexity: 1.5385


### Zero-Shot Prompt

In [11]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits_fv = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
clean_output, interv_logits_ltv = ltv_intervention(zeroshot_sentence, [test_pair['output']], lt_vector, model, model_config, tokenizer)

print("Input Sentence:", repr(zeroshot_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Zero-Shot Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Zero-Shot+FV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits_fv, tokenizer, k=5), '\n')
print("Zero-Shot+LTV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits_ltv, tokenizer, k=5), '\n')

target_idx = tokenizer.convert_tokens_to_ids(test_pair['output'])
target_idx = torch.tensor([target_idx], dtype=torch.int64).to(device)

clean_loss, clean_perplexity = loss_fn(clean_logits, target_idx), compute_perplexity(clean_logits, target_idx)
fv_loss, fv_perplexity = loss_fn(interv_logits_fv, target_idx), compute_perplexity(interv_logits_fv, target_idx)
ltv_loss, ltv_perplexity = loss_fn(interv_logits_ltv, target_idx), compute_perplexity(interv_logits_ltv, target_idx)

print(f"Vanilla transformer - loss: {clean_loss:.4f} \t perplexity: {clean_perplexity:.4f}")
print(f"Function Vector - loss: {fv_loss:.4f} \t perplexity: {fv_perplexity:.4f}")
print(f"Learnable Task Vector - loss: {ltv_loss:.4f} \t perplexity: {ltv_perplexity:.4f}")

Input Sentence: '<|endoftext|>Q: aware\nA:' 

Input Query: 'aware', Target: 'unaware'

Zero-Shot Top K Vocab Probs:
 [(' aware', 0.17471), (' a', 0.02292), (' able', 0.0139), (' not', 0.01292), (' to', 0.01032)] 

Zero-Shot+FV Vocab Top K Vocab Probs:
 [(' aware', 0.18541), (' unaware', 0.11167), (' un', 0.03718), (' ignorant', 0.02702), (' not', 0.02203)] 

Zero-Shot+LTV Vocab Top K Vocab Probs:
 [(' unaware', 0.7543), (' un', 0.05105), (' aware', 0.03786), (' oblivious', 0.02072), (' ignorant', 0.01884)] 

Vanilla transformer - loss: 10.2966 	 perplexity: 29631.6250
Function Vector - loss: 11.5369 	 perplexity: 102425.1094
Learnable Task Vector - loss: 14.5160 	 perplexity: 2014770.8750
