In [1]:
import os
import torch, numpy as np
import torch.nn.functional as F

torch.set_grad_enabled(False)

from src.FV_utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.FV_utils.intervention_utils import function_vector_intervention
from src.FV_utils.model_utils import load_gpt_model_and_tokenizer
from src.FV_utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.FV_utils.eval_utils import decode_to_vocab, sentence_eval

from src.LTV_utils.data import set_seed
from src.LTV_utils.ltv import LearnableTaskVector
from src.LTV_utils.extract_utils import get_attn_out
from src.LTV_utils.intervention_utils import ltv_intervention

%load_ext autoreload
%autoreload 2

GPU_IDX = 0
SEED = 17

device = torch.device(f"cuda:{GPU_IDX}")
set_seed(SEED)

  from pandas.core import (


### Load model & tokenizer

In [2]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)
EDIT_LAYER = 9

n_layers = model_config['n_layers']
resid_dim = model_config['resid_dim']
n_heads = model_config['n_heads']
head_dim = resid_dim // n_heads

Loading:  EleutherAI/gpt-j-6b


In [3]:
model

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

### Load the dataset and compute the task-conditioned mean activations

In [4]:
task_name = 'commonsense_qa'
act_fn = None
loss_fn = F.cross_entropy
batch_size = 100

dataset = load_dataset(task_name, root_data_dir='dataset_files', seed=0)
mean_activations, _ = get_mean_head_activations(dataset, model, model_config, tokenizer)

In [6]:
load_dataset(task_name, root_data_dir='dataset_files', seed=5)['train']['input'][7]

'Where could you find a toilet that anyone can use?\na: bathroom\nb: apartment\nc: stall\nd: hospital\ne: rest area'

### Compute **Function Vector (FV)**

In [7]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

### Load the pre-trained **Learnable Task Vector (LTV)**

In [None]:
ltv_seq_len = 5

current_directory = os.getcwd()
path_to_model = os.path.join(os.path.dirname(current_directory), f"language/LTV_models/{task_name}/{act_fn}/seq_len_{ltv_seq_len}")
path_to_model = os.path.join(path_to_model, f"ltv_layer_{ltv_seq_len}.pth")

ltv_layer = LearnableTaskVector(n_layers, n_heads, head_dim).to(device)
ltv_params = torch.load(path_to_model)
ltv_layer.load_state_dict(ltv_params)

### **Prompt Creation:** 
#### 1. Standard In-Context
#### 2. Shuffled-Label
#### 3. Zero-Shot
#### 4. Natural Text

In [49]:
# Sample ICL example pairs, and a test word
n_examples = 5
test_idx = np.random.randint(0, len(dataset['test']))

word_pairs = dataset['train'][np.random.choice(len(dataset['train']), n_examples, replace=False)]
test_pair = dataset['test'][test_idx]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

ICL prompt:
 "<|endoftext|>Q: What is the opposite of little?\na: least\nb: much\nc: bog\nd: lot of\ne: big\nA: e\n\nQ: When people aren't communicating or talking to each other, what happens?\na: misunderstandings\nb: headaches\nc: introductions\nd: conversation\ne: distraction\nA: a\n\nQ: What does a person who is a gardener have?\na: own house\nb: contribution to society\nc: food\nd: ride horses\ne: green thumb\nA: e\n\nQ: Where is not likely to organize with a card catalog?\na: libary\nb: store\nc: kitchen\nd: bank\ne: library\nA: b\n\nQ: What is it called when you spend time with friends and acquaintances?\na: socialize\nb: tell story\nc: go somewhere\nd: wedding\ne: clean room\nA: a\n\nQ: A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\nd: doctor himself\ne: visit relatives\nA:" 


Shuffled ICL Prompt:
 "<|endoftext|>Q: What is the opposite of little?\na: least\nb: much\nc: bog\nd: lot of\ne: big\nA: e\n\nQ: When

In [50]:
test_pair

{'input': 'A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\nd: doctor himself\ne: visit relatives',
 'output': 'e'}

## **Evaluation**

### Clean ICL Prompt

In [51]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}\nTarget: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

Input Sentence: "<|endoftext|>Q: What is the opposite of little?\na: least\nb: much\nc: bog\nd: lot of\ne: big\nA: e\n\nQ: When people aren't communicating or talking to each other, what happens?\na: misunderstandings\nb: headaches\nc: introductions\nd: conversation\ne: distraction\nA: a\n\nQ: What does a person who is a gardener have?\na: own house\nb: contribution to society\nc: food\nd: ride horses\ne: green thumb\nA: e\n\nQ: Where is not likely to organize with a card catalog?\na: libary\nb: store\nc: kitchen\nd: bank\ne: library\nA: b\n\nQ: What is it called when you spend time with friends and acquaintances?\na: socialize\nb: tell story\nc: go somewhere\nd: wedding\ne: clean room\nA: a\n\nQ: A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\nd: doctor himself\ne: visit relatives\nA:" 

Input Query: 'A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\n

In [20]:
with torch.no_grad():
    attn_out = get_attn_out(mean_activations, model, model_config)
    lt_vector = ltv_layer.forward(attn_out)
    lt_vector = lt_vector.squeeze()

### Shuffled-Label Few-Shot ICL Prompt

In [56]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits_fv = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
_, interv_logits_ltv = ltv_intervention(shuffled_sentence, [test_pair['output']], lt_vector, model, model_config, tokenizer)
                                                   
print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}\nTarget: {repr(test_pair['output'])}\n")

print(f"Few-Shot-Shuffled Prompt Top k Vocab Probs:\n\t{decode_to_vocab(clean_logits, tokenizer, k=5)}\n")
print(f"Shuffled Prompt +FV Top k Vocab Probs:\n\t{decode_to_vocab(interv_logits_fv, tokenizer, k=5)}\n")
print(f"Shuffled Prompt +LTV Top k Vocab Probs:\n\t{decode_to_vocab(interv_logits_ltv, tokenizer, k=5)}\n")

target_idx = tokenizer.encode(" " + test_pair['output'])
target_idx = torch.tensor(target_idx, dtype=torch.int64).to(device)

clean_loss = loss_fn(clean_logits, target_idx)
fv_loss = loss_fn(interv_logits_fv, target_idx)
ltv_loss = loss_fn(interv_logits_ltv, target_idx)

print(f"Vanilla transformer - loss: {clean_loss:.4f}")
print(f"Function Vector - loss: {fv_loss:.4f}")
print(f"Learnable Task Vector - loss: {ltv_loss:.4f}")

Input Sentence: "<|endoftext|>Q: What is the opposite of little?\na: least\nb: much\nc: bog\nd: lot of\ne: big\nA: e\n\nQ: When people aren't communicating or talking to each other, what happens?\na: misunderstandings\nb: headaches\nc: introductions\nd: conversation\ne: distraction\nA: e\n\nQ: What does a person who is a gardener have?\na: own house\nb: contribution to society\nc: food\nd: ride horses\ne: green thumb\nA: a\n\nQ: Where is not likely to organize with a card catalog?\na: libary\nb: store\nc: kitchen\nd: bank\ne: library\nA: b\n\nQ: What is it called when you spend time with friends and acquaintances?\na: socialize\nb: tell story\nc: go somewhere\nd: wedding\ne: clean room\nA: a\n\nQ: A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\nd: doctor himself\ne: visit relatives\nA:" 

Input Query: 'A person misses his aunt, what is he likely to do about it?\na: cross street\nb: amount to nothing\nc: seek no help\n

### Zero-Shot Prompt

In [112]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits_fv = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)
clean_output, interv_logits_ltv = ltv_intervention(zeroshot_sentence, [test_pair['output']], lt_vector, model, model_config, tokenizer)

print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}\nTarget: {repr(test_pair['output'])}\n")

print(f"Zero-Shot Top k Vocab Probs:\n\t{decode_to_vocab(clean_logits, tokenizer, k=5)}\n")
print(f"Zero-Shot +FV Vocab Top k Vocab Probs:\n\t{decode_to_vocab(interv_logits_fv, tokenizer, k=5)}\n")
print(f"Zero-Shot +LTV Vocab Top k Vocab Probs:\n\t{decode_to_vocab(interv_logits_ltv, tokenizer, k=5)}\n")

target_idx = tokenizer.encode(" " + test_pair['output'])
target_idx = torch.tensor(target_idx, dtype=torch.int64).to(device)

clean_loss = loss_fn(clean_logits, target_idx)
fv_loss = loss_fn(interv_logits_fv, target_idx)
ltv_loss = loss_fn(interv_logits_ltv, target_idx)

print(f"Vanilla transformer - loss: {clean_loss:.4f}")
print(f"Function Vector - loss: {fv_loss:.4f}")
print(f"Learnable Task Vector - loss: {ltv_loss:.4f}")

Input Sentence: '<|endoftext|>Q: Where is microphone boom likely to be used to record an album?\na: radio station\nb: recording studio\nc: concert\nd: tv studio\ne: new york\nA:' 

Input Query: 'Where is microphone boom likely to be used to record an album?\na: radio station\nb: recording studio\nc: concert\nd: tv studio\ne: new york'
Target: 'b'

Zero-Shot Top k Vocab Probs:
	[(' a', 0.14372), (' b', 0.10578), (' c', 0.08973), (' e', 0.05843), (' d', 0.05265)]

Zero-Shot +FV Vocab Top k Vocab Probs:
	[(' d', 0.21565), (' c', 0.20987), (' e', 0.14113), (' a', 0.13808), (' b', 0.13053)]

Zero-Shot +LTV Vocab Top k Vocab Probs:
	[(' b', 0.14828), (' a', 0.1258), (' d', 0.09221), (' c', 0.06607), ('\n', 0.03976)]

Vanilla transformer - loss: 2.2464
Function Vector - loss: 2.0361
Learnable Task Vector - loss: 1.9086
