In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import deque
import os, random, re, json
import torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
sys.path.append('..')
# torch.set_grad_enabled(False)

from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector, get_attn_out
from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention, ltv_intervention
from src.utils.model_utils import load_gpt_model_and_tokenizer
from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.utils.eval_utils import decode_to_vocab, sentence_eval

In [3]:
def sample_data(dataset, model, model_config, tokenizer, n_examples, batch_size):
    mean_activations, data = get_mean_head_activations(dataset, model, model_config, tokenizer, n_icl_examples=n_examples, N_TRIALS=batch_size, batch_structure=True)
    attn_out = get_attn_out(mean_activations, model, model_config)
    set_of_word_pairs, test_pairs = data
    
    sentences, targets = [], []

    for word_pairs, test_pair in zip(set_of_word_pairs, test_pairs):
        prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
        sentence = create_prompt(prompt_data)
        sentences.append(sentence)
        targets.append(test_pair['output'])
        
    return attn_out, sentences, targets

def forward_pass(model, model_config, tokenizer, vocab_size, batch_size, sentences, targets, lt_vector_batch):
    logits = torch.zeros(batch_size, vocab_size).to(device)
    clean_logits = torch.zeros_like(logits).to(device)
    target_indices = torch.zeros(batch_size).to(device)

    for i, (sentence, target, lt_vector) in enumerate(zip(sentences, targets, lt_vector_batch)):
        clean_output, intervention_output = ltv_intervention(sentence, target, lt_vector, model, model_config, tokenizer, compute_nll=False, generate_str=False)
        logits[i] = intervention_output
        clean_logits[i] = clean_output
        
        target_indices[i] = tokenizer.convert_tokens_to_ids(target[0])
        
    return logits, clean_logits, target_indices

In [4]:
class LearnableTaskVector(nn.Module):
    def __init__(self, n_layers, n_heads, n_head_dim, act_fn=None):
        super(LearnableTaskVector, self).__init__()
        # Initialize the weights using a uniform distribution between 0 and 1
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_head_dim = n_head_dim
        self.weights = nn.Parameter(torch.randn(n_layers, n_heads))
        self.act_fn = act_fn

    def forward(self, x):
        # Apply sigmoid to constrain weights between 0 and 1
        # x is the input tensor of shape [batch_size, n_layers, n_heads, n_head_dim]
        # Reshape weight to shape [1, n_layers, n_heads, 1]
        # Ensure the weight is properly broadcastable with x
        batch_size = x.shape[0]
        # normalized_weights = torch.sigmoid(self.weights).unsqueeze(-1)
        normalized_weights = self.weights.unsqueeze(0).unsqueeze(-1)

        weighted_sum = x * normalized_weights

        # The result will have shape [batch_size, n_layers, n_heads * n_head_dim]
        out = weighted_sum.view(batch_size, self.n_layers, self.n_heads * self.n_head_dim)
        
        return out

## Load model & tokenizer

In [5]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
vocab_size = 50400
# Disable gradient updates for all transformer parameters
for param in model.parameters():
    param.requires_grad = False

n_layers = model_config['n_layers']
resid_dim = model_config['resid_dim']
n_heads = model_config['n_heads']
head_dim = resid_dim // n_heads
device = model.device

In [6]:
lr = 2.5e-5

ltv_layer = LearnableTaskVector(n_layers, n_heads, head_dim).to(torch.device("cuda"))
loss_fn = F.cross_entropy
optimizer = optim.Adam(ltv_layer.parameters(), lr=lr)

In [7]:
n_iter = int(2e5)
batch_size = 32
n_examples = 5
verbose_freq = 1

dataset = load_dataset('antonym', seed=0)

In [8]:
lowest_val_loss = float('inf')  # init lowest validation loss
loss_verbose = deque(maxlen=20)

for iter_i in range(n_iter):
    ltv_layer.train()  # Set the model to training mode
    total_loss = 0.0
    num_batches = 0

    attn_out, sentences, targets = sample_data(dataset, model, model_config, tokenizer, n_examples, batch_size)
    lt_vector_batch = ltv_layer.forward(attn_out)
    logits, clean_logits, target_indices = forward_pass(model, model_config, tokenizer, vocab_size, batch_size, sentences, targets, lt_vector_batch)
    
    loss = loss_fn(logits, target_indices.to(torch.int64))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_verbose.append(loss.item())

    # # Validation phase
    # ltv_layer.eval()  # Set the model to evaluation mode
    # total_val_loss = 0.0
    # num_val_batches = 0
    # 
    # with torch.no_grad():
    #     for inputs, ys, xs_long, ys_long in val_dataloader:
    #         inputs, ys = inputs.to(device), ys.to(device)
    #         xs_long, ys_long = xs_long.to(device), ys_long.to(device)
    # 
    #         ltv_out = ltv_layer(inputs)
    #         ys_pred = add_learnable_task_vector(model, xs_long, ys_long, ltv_out, dummy_indices=None, scale=1.0)
    #         ys_pred = ys_pred[:, -1].unsqueeze(dim=1)
    # 
    #         val_loss = loss_fn(ys_pred, ys)
    #         total_val_loss += val_loss.item()
    #         num_val_batches += 1

    # print(f'Epoch [{epoch_i+1}/{n_epochs}] - training loss: {average_training_loss:.4f}, validation loss: {average_val_loss:.4f}')

    if iter_i % 20 == 0:
        torch.save(ltv_layer.state_dict(), os.path.join(f"./models/seq_len_{n_examples}", f'ltv_layer_{n_examples}.pth'))
        print(f'Epoch [{iter_i+1}/{n_iter}] - training loss: {np.mean(loss_verbose):.4f}')

    # Save model if validation loss is the lowest
    # if average_val_loss < lowest_val_loss:
    #     lowest_val_loss = average_val_loss
    #     torch.save(ltv_layer.state_dict(), os.path.join(experiment_dir, f'ltv_layer_{seq_len}.pth'))
    #     print('Checkpoint saved with lowest validation loss')

## Load dataset and Compute task-conditioned mean activations

In [22]:
queries

In [None]:
# Rearrange the dataset - remove the test samples
train_data = list(dataset['train'])
valid_data = list(dataset['valid'])
test_data = list(dataset['test'])

all_data = train_data + valid_data + test_data
random.shuffle(all_data)

split_index = int(0.8 * len(all_data))

new_train_set = all_data[:split_index]
new_valid_set = all_data[split_index:]

dataset = {
    'train': new_train_set,
    'valid': new_valid_set
}

In [27]:
n_examples = 5

# Sample ICL example pairs, and a test word
dataset = load_dataset('antonym')
word_pairs = dataset['train'][:5]
test_pair = dataset['test'][21]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

In [10]:
# [num. blocks, 

mean_activations.shape

In [8]:
model

## Compute function vector (FV)

In [None]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text

In [7]:
# Sample ICL example pairs, and a test word
dataset = load_dataset('antonym')
word_pairs = dataset['train'][:5]
test_pair = dataset['test'][21]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

## Evaluation

### Clean ICL Prompt

In [None]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

### Corrupted ICL Prompt

In [None]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Few-Shot-Shuffled Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Shuffled Prompt+FV Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

### Zero-Shot Prompt

In [None]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(zeroshot_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Zero-Shot Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Zero-Shot+FV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

### Natural Text Prompt

In [None]:
sentence = f"The word \"{test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)


print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')