## Replicate RAVEL with GPT-2 small and nnsight

1. Load model
2. Load dataset
3. Load Sparse Autoencoders
4. Find important directions for an Attribute (unique to Attribute or A_E combination)
    -> I think unique to attribute? Same features for all "Country"-related objects (US, Germany, Japan). Not a suitable choice IMO if right.
    -> Top features for each attribute by weight, train linear probe with L1 regularization on latents.
5. Ravel evaluation
    - Cause: Prediction accuracy of the true probe
    - Isolation: ?

In [None]:
import json
import os
import random
import sys
import numpy as np
import torch

import datasets
from datasets import Dataset


RAVEL_LIB_DIR = '/share/u/can/ravel'
RAVEL_SCRIPT_DIR = f'{RAVEL_LIB_DIR}/scripts'

sys.path.append(RAVEL_LIB_DIR)
sys.path.append(RAVEL_SCRIPT_DIR)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

MODEL_DIR = f'{RAVEL_LIB_DIR}/models'
DATA_DIR = f'{RAVEL_LIB_DIR}/data'

# NNsight Tracer Arguments
tracer_kwargs = {'scan': False, 'validate': False}

In [None]:
from nnsight import LanguageModel
from transformers import LlamaForCausalLM, AutoTokenizer


# model_id = 'openai-community/gpt2'
# model_id = 'EleutherAI/pythia-70m-deduped'
model_id = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
instance = "tinyllama"
layer_idx = 14


model = LanguageModel(model_id, dispatch=True, low_cpu_mem_usage=True, device_map='auto', cache_dir=MODEL_DIR,
    torch_dtype=torch.bfloat16)
submodule = model.model.layers[layer_idx]

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=MODEL_DIR)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

In [None]:
model

### Load datasets

In [None]:
import json

entity_classes = [
    'city',
    'np_winner',
    'occupation',
    'physical_object',
    'verb',
]

entity_type = 'city'

with open(f'./data/base/ravel_{entity_type}_entity_attributes.json') as f:
    entity_to_attribute_data = json.load(f)

with open(f'./data/base/ravel_{entity_type}_attribute_to_prompts.json') as f:
    attribute_to_prompts_data = json.load(f)

with open(f'./data/base/ravel_{entity_type}_entity_to_split.json') as f: # for entity mode
    entity_to_split_data = json.load(f)

with open(f'./data/base/ravel_{entity_type}_prompt_to_split.json') as f: # for context mode
    prompt_to_split_data = json.load(f)

### Check which E <-> A_E pairs are known by chosen model.
For each entity E x attribute A x attribute_prompt P, check if prediction yields correct answer.
For those who do, split into train and test, as well as give token position.

For tinyllama and city entity, this has been done in the tinyllama `tinyllama_city_[train].json` files.

In [None]:
with open(f'./data/{instance}/{instance}_{entity_type}_train.json') as f:
    model_train_data = json.load(f)

with open(f'./data/{instance}/{instance}_{entity_type}_context_test.json') as f:
    model_context_test_data = json.load(f)

with open(f'./data/{instance}/{instance}_{entity_type}_entity_test.json') as f:
    model_entity_test_data = json.load(f)

with open(f'./data/{instance}/{instance}_{entity_type}_prompt_to_entity_position.json') as f:
    model_prompt_to_entity_position_data = json.load(f) # This is for prompts in all datasets.

In [None]:
INPUT_MAX_LEN = 48
FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


# Load training dataset.
split_to_raw_example = json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_train.json'), 'r'))
# Load validation + test dataset.
split_to_raw_example.update(json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_context_test.json'), 'r')))
split_to_raw_example.update(json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_entity_test.json'), 'r')))

# Prepend an extra token to avoid tokenization changes for Llama tokenizer.
# Each sequence will start with <s> _ 0
# SOS_PAD = '0'
# NUM_SOS_TOKENS = 3
# for split in split_to_raw_example:
#   for i in range(len(split_to_raw_example[split])):
#     split_to_raw_example[split][i]['inv_label'] = SOS_PAD + split_to_raw_example[split][i]['inv_label']
#     split_to_raw_example[split][i]['label'] = SOS_PAD + split_to_raw_example[split][i]['label']


# Load attributes (tasks) to prompt mapping.
ALL_ATTR_TO_PROMPTS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))

# Load prompt to intervention location mapping.
split_to_entity_pos = json.load(open(os.path.join(DATA_DIR, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json')))
SPLIT_TO_INV_LOCATIONS = {
    f'{task}{split}': {'max_input_length': INPUT_MAX_LEN,
                       'inv_position': [INPUT_MAX_LEN + pos]}
    for task, pos in split_to_entity_pos.items()
    for split in ('-train', '-test', '-val', '')
}
assert(min([min(v['inv_position']) for v in SPLIT_TO_INV_LOCATIONS.values()]) > 0)


# Preprocess the dataset.
def filter_inv_example(example):
  return (example['label'] != example['inv_label'] and
          example['source_split'] in SPLIT_TO_INV_LOCATIONS and
          example['split'] in SPLIT_TO_INV_LOCATIONS)

for split in split_to_raw_example:
  random.shuffle(split_to_raw_example[split])
  split_to_raw_example[split] = list(filter(filter_inv_example, split_to_raw_example[split]))
  if len(split_to_raw_example[split]) == 0:
    print('Empty split: "%s"' % split)

# Remove empty splits.
split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}

print(f"#Training examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-train')]))}, "
      f"#Validation examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-val')]))}, "
      f"#Test examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-test')]))}")
split_to_dataset = {split: Dataset.from_list(
    split_to_raw_example[split], features=FEATURE_TYPES)
                    for split in split_to_raw_example}

## Select features from the SAE
For each attribute, we will select the features that are most relevant to the attribute

In [None]:
chosen_attribute = 'Country'
dataset_name = f'{chosen_attribute}-train'

split_to_dataset[dataset_name]

In [None]:
# Initialize probe training data

from tqdm import trange

def pad_tokenized_input(tokenizer, input_str, max_len):
    SOS_PAD = '0'
    input_str = SOS_PAD + input_str
    tokenized_input = tokenizer(input_str, return_tensors='pt', padding='max_length', truncation=True, max_length=max_len)
    return tokenized_input['input_ids']

def raw_prompt_generator(dataset, batch_size):
    n_samples = len(dataset)
    for batch_idx in range(n_samples // batch_size):
        start_idx = batch_idx * batch_size
        end_idx = (batch_idx + 1) * batch_size
        batch = dataset[start_idx:end_idx]
        prompt_batch = [example['input'] for example in batch]
        tokenized_batch = [pad_tokenized_input(tokenizer, input_str, INPUT_MAX_LEN) for input_str in prompt_batch]
        tokenized_batch = torch.cat(tokenized_batch, dim=0)
        true_label_batch = [example['label'] for example in batch]
        yield tokenized_batch, true_label_batch


batch_size = 256
train_dataset = split_to_raw_example[dataset_name]
n_batches = len(train_dataset) // batch_size
# n_batches = 10

prompt_gen = raw_prompt_generator(
    dataset=train_dataset, 
    batch_size=batch_size
)

unique_labels = set([example['label'] for example in train_dataset])
label_to_probeidx = {label: i for i, label in enumerate(unique_labels)}

In [None]:
# Initialize SAE

from dictionary_learning.dictionary import AutoEncoder

d_model = 2048
dict_size = 8192
ae = AutoEncoder(d_model, dict_size).to(device)

In [None]:
# Initialize probe

class AttributeProbe(torch.nn.Module):
    def __init__(self, activation_dim, n_attribute_values):
        super().__init__()
        self.probe = torch.nn.Linear(activation_dim, n_attribute_values)
    
    def forward(self, activations):
        return self.probe(activations)


n_attribute_values = len(unique_labels)
probe = AttributeProbe(dict_size, n_attribute_values).to(device)
lr = 1e-2
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Train probe

import einops

all_prompts = []
all_true_labels = []
probe_losses = []
for batch_idx in trange(n_batches):

    # Draw new batch
    prompts, true_labels = next(prompt_gen)
    all_prompts.extend(prompts)
    all_true_labels.extend(true_labels)
    
    # Get new SAE activations
    with model.trace(prompts, **tracer_kwargs):
        act = submodule.output[0].save()
    latent_act = ae.encode(act)

    # Train probe
    optimizer.zero_grad()
    logits = probe(latent_act)
    # pred_labels = logits.argmax(dim=-1).to(torch.int)
    true_labels = torch.tensor([label_to_probeidx[label] for label in true_labels], device=device)
    true_labels_onehot = torch.nn.functional.one_hot(true_labels, num_classes=n_attribute_values)
    true_labels = einops.repeat(true_labels_onehot, 'b o -> b p o', p=INPUT_MAX_LEN).to(torch.float)
    loss = criterion(logits, true_labels)
    probe_losses.append(loss.item())
    loss.backward()
    optimizer.step()

In [None]:
_, top_indices = probe.probe.weight.topk(20, dim=1) # Top 20 features for each attribute
attribute_feature_indices = set(top_indices.flatten().tolist())

In [None]:
# import itertools

# chosen_attribute = "Continent"
# chosen_split = "train"

# # Write a generator that generates text, attribute pairs in batches
# # TODO implement shuffling

# def prompt_generator(batch_size=128):
#     prompt_templates = X_text['context'][chosen_split][chosen_attribute]
#     entity_prompt_pairs = itertools.product(entity_to_attribute_data.keys(), prompt_templates)

#     n_entities = len(entity_to_attribute_data.keys())
#     n_pairs = n_entities * len(prompt_templates)
#     if batch_size is None:
#         batch_size = n_pairs
#     total_batches = n_pairs // batch_size

#     for batch_idx in range(total_batches):
#         pairs_batch = [next(entity_prompt_pairs) for _ in range(batch_size)]
#         prompt_batch = [prompt % entity for entity, prompt in pairs_batch]
#         true_label_batch = [entity_to_attribute_data[entity][chosen_attribute] for entity, _ in pairs_batch]
#         yield prompt_batch, true_label_batch

# Cause intervention

Find subspaces for each A.

clean (base) E - A_E

patch (inv) E' (A related features at final token of E')

patch -> clean, accuracy of predicting A_E'

# Isolation

Find subspaces for each A.

for all other concepts (B_E)

clean (base) E - B_E

patch (inv) E' (A related features at final token of E')

patch -> clean, accuracy of predicting B_E