In [1]:
%load_ext autoreload
%autoreload 2

import os

REPO_DIR = f'sae_lens/sae_bench/ravel'
SRC_DIR = os.path.join(REPO_DIR, 'src')
MODEL_DIR = os.path.join(REPO_DIR, 'models')
DATA_DIR = os.path.join(REPO_DIR, 'data')

for d in [MODEL_DIR, DATA_DIR]:
    if not os.path.exists(d):
        os.makedirs(d)


import sys
sys.path.append(REPO_DIR)
sys.path.append(SRC_DIR)

import numpy as np
import random
import torch
import accelerate

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"

In [2]:
#from sae_lens.sae_bench.utils.generation_utils import generate_batched

# Model

In [3]:
# Load model
from transformers import AutoModelForCausalLM, AutoTokenizer

with open('sae_lens/auth/hf.txt', 'r') as f:
    hf_token = f.read().strip()

model_id = "google/gemma-2-2b"
model_name = "gemma-2-2b"

torch.set_grad_enabled(False) # avoid blowing up mem
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
    device_map=device,
    low_cpu_mem_usage=True,
    attn_implementation="eager"
)

tokenizer =  AutoTokenizer.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

layer_idx = 10

from nnsight import NNsight
nnsight_model = NNsight(hf_model)
nnsight_tracer_kwargs = {'scan': True, 'validate': False, 'use_cache': False, 'output_attentions': False}

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
from sae_lens.sae_bench.ravel.ravel_dataset_builder import RAVELEntityPromptData

full_entity_dataset = RAVELEntityPromptData.from_files('city', 'sae_lens/sae_bench/ravel/data', tokenizer)
len(full_entity_dataset)

100%|██████████| 3552/3552 [00:44<00:00, 79.31it/s]


969696

In [5]:
sampled_entity_dataset = full_entity_dataset.downsample(1000)
print(f"Number of prompts remaining: {len(sampled_entity_dataset)}")

prompt_max_length = 48
sampled_entity_dataset.generate_completions(nnsight_model, tokenizer, max_length=prompt_max_length+8, prompt_max_length=prompt_max_length)

sampled_entity_dataset.evaluate_correctness()

# Filter correct completions
correct_data = sampled_entity_dataset.filter_correct()

# Filter top entities and templates
filtered_data = correct_data.filter_top_entities_and_templates(top_n_entities=400, top_n_templates_per_attribute=12)

# Calculate average accuracy
accuracy = sampled_entity_dataset.calculate_average_accuracy()
print(f"Average accuracy: {accuracy:.2%}")
print(f"Number of prompts remaining: {len(correct_data)}")

Number of prompts remaining: 1000
Total #prompts=1000


100%|██████████| 32/32 [00:37<00:00,  1.18s/it]

Average accuracy: 49.40%
Number of prompts remaining: 494





In [6]:
correct_data.add_wikipedia_prompts('city', 'sae_lens/sae_bench/ravel/data', tokenizer, nnsight_model)

Total #prompts=914


100%|██████████| 15/15 [00:31<00:00,  2.10s/it]

Added 938 Wikipedia prompt templates





In [7]:
len(correct_data)

1408

In [14]:


# for prompt in correct_data.prompts.keys():
#     print(sampled_entity_dataset.prompts[prompt])
#     break

# version of the above that returns a random prompt each time
def get_random_prompt(data):
    prompt = random.choice(list(data.prompts.keys()))
    return data.prompts[prompt]

print(get_random_prompt(correct_data))

Prompt(text='[{"city": "San Francisco", "continent": "North America"}, {"city": "Nevers", "continent": "', template='[{"city": "San Francisco", "continent": "North America"}, {"city": "%s", "continent": "', attribute='Continent', entity='Nevers', context_split='train', entity_split='val', input_ids=tensor([     1,      1,      1,      1,      1,      2, 235309,   9766,   8918,
          1192,    664,  10105,  12288,    824,    664,  88770,   1192,    664,
         17870,   5783,  16406,  19946,   8918,   1192,    664,   4157,   1267,
           824,    664,  88770,   1192,    664]), attention_mask=tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1]), completion='Europe"}, {"city": "New York', is_correct=True)


# Create a RAVEL Instance for TinyLLaMA

### Check model knowledge of all entity - attribute pairs

In [17]:
# Generate a dataset of combinations of entities and attribute_specific_prompts

import json
import os

entity_type = 'city'

attribute_prompts = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))
prompt_splits = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_prompt_to_split.json')))
entity_attributes = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_entity_attributes.json')))
print(f'#entities={len(entity_attributes)}, #prompt_templates={sum(map(len, attribute_prompts.values()))}')

# # For testing purposes
# partial_entitiy_attributes = {}
# for i, (e, a) in enumerate(entity_attributes.items()):
#     partial_entitiy_attributes[e] = a
#     if i == 10:
#         break
# entity_attributes = partial_entitiy_attributes

prompts_to_meta_data = {t % x: {'entity': x, 'attr': a, 'template': t}
               for x in entity_attributes
               for a, ts in attribute_prompts.items()
               for t in ts}
print(f'total number of prompts {len(prompts_to_meta_data)}')

#entities=3552, #prompt_templates=273
total number of prompts 969696


In [18]:
type(prompts_to_meta_data)

dict

In [19]:
import random

def subsample_dict(original_dict, num_samples):
    keys = random.sample(list(original_dict.keys()), num_samples)
    print(keys)
    return {k: original_dict[k] for k in keys}

In [20]:
from typing import Dict, List

def filter_attribute_prompts(
        attribute_prompts: Dict[str, List[str]], 
        prompts_to_meta_data_subsample: Dict[str, Dict[str, str]]
    ) -> Dict[str, List[str]]:
    """ Filter out attribute prompts that are not in the subsample

    Args:
        attribute_prompts (dict): attribute to list of prompts
        prompts_to_meta_data_subsample (dict): prompt to metadata

    Returns:
        dict: filtered attribute prompts
    """
    filtered_attribute_prompts = {}
    for attr in list(attribute_prompts.keys()):
        filtered_attribute_prompts[attr] = [p for p in attribute_prompts[attr] if p in set([prompts_to_meta_data_subsample[k]['template'] for k in prompts_to_meta_data_subsample.keys()])]
        # if attr not in set([v['attr'] for v in prompts_to_meta_data_subsample.values()]):
        #     del attribute_prompts[attr]

    return filtered_attribute_prompts



In [21]:
# Generate outputs for all prompts (the correct attribute is expected in the output)

skip_the_inference_step = False
prompts_to_meta_data_subsample = subsample_dict(prompts_to_meta_data, 100)
attribute_prompts_subsample = filter_attribute_prompts(attribute_prompts, prompts_to_meta_data_subsample)
prompt_to_output_path = os.path.join(DATA_DIR, model_name, f'ravel_{model_name}_{entity_type}_prompt_to_output.json')

if skip_the_inference_step:
    # Can skip the inference step by downloading the pre-computed outputs:
    # https://drive.google.com/drive/u/0/folders/1U4Js-NarJa-B_iQc5wr0OXV2G-5BDBsN
    pass
else: 
    

    prompt_max_length = 48

    prompt_to_output = generate_batched(
        hf_model,
        tokenizer,
        list(prompts_to_meta_data_subsample),
        prompt_max_length+8,
        prompt_max_length=prompt_max_length,
        batch_size=256)
    prompt_to_output = {k: v[len(k):] for k, v in prompt_to_output}

    json.dump(prompt_to_output, open(prompt_to_output_path, 'w'), ensure_ascii=False)

['[{"city": "Cape Town", "lat": "33.9"}, {"city": "Turbat", "lat": "', '[{"city": "New York City", "lang": "English"}, {"city": "Kanoya", "lang": "', '[{"city": "Bangkok", "continent": "Asia"}, {"city": "Sikasso", "continent": "', '[{"city": "Cape Town", "lat": "33.9"}, {"city": "Wuchuan", "lat": "', '[{"city": "Kuala Lumpur", "lat": "3.1"}, {"city": "Kogon", "lat": "', 'The longitude of Apolo is ', ' "continent": "Asia"}, {"city": "Ferfer", "long": "', 'New York City: America/New_York. Maracaju:', '[{"city": "New York City", "lat": "41"}, {"city": "Luxor", "lat": "', ' "continent": "Asia"}, {"city": "Kerma", "lat": "', '[{"city": "Zhob", "timezone": "', '[{"city": "Kuala Lumpur", "long": "101.7"}, {"city": "Suceava", "long": "', '[{"city": "New Delhi", "lat": "29"}, {"city": "Totness", "lat": "', ' she is living in Gulu, therefore her country of residence is', ' in Lota, people usually speak', 'Tokyo is a city in the continent of Asia. Ituni is a city in the continent of', ' "continen

NameError: name 'generate_batched' is not defined

In [85]:
prompt_to_output = json.load(open(prompt_to_output_path))
len(prompt_to_output)

100

In [86]:
for key, value in prompts_to_meta_data_subsample.items():
    
    print(f'Prompt: {key}')
    print(f'Entity: {value["entity"]}')
    print(f'Attribute: {value["attr"]}')
    print(f'Template: {value["template"]}')
    print()

    break

Prompt: Hong Kong: Asia/Hong_Kong. Resistencia:
Entity: Resistencia
Attribute: Timezone
Template: Hong Kong: Asia/Hong_Kong. %s:



In [93]:
for key, value in prompt_to_output.items():
    print(f'Prompt: {key}')
    print(f'Output: {value}')
    print()

Prompt: Hong Kong: Asia/Hong_Kong. Resistencia:
Output:  1989. 1

Prompt:  in Volos, people usually speak
Output:  Greek.

<h2>What is the best

Prompt: [{"city": "Mexico City", "lat": "19.4"}, {"city": "Redding", "lat": "
Output: 40.0"}, {"city":

Prompt:  city to continent: New York City is in North America. Kanpur is in
Output:  India.

Q: What is the

Prompt: city to country: San Francisco is in United States. Astana is in
Output:  Kazakhstan.

city to city: New

Prompt: [{"city": "Kuala Lumpur", "continent": "Asia"}, {"city": "Varna", "continent": "
Output: Europe"}, {"city": "Kuala

Prompt: St. Petersburg is a city in the continent of Europe. Zanesville is a city in the continent of
Output:  North America.

St. Petersburg is

Prompt: Time zone in Mexico City is America/Mexico_City; Time zone in Bol is
Output:  America/Mexico_City; Time zone

Prompt: San Francisco is a city in the continent of North America. Wuyuan is a city in the continent of
Output:  Asia.

The city of San Fran

In [101]:
#@title Behavioral Test

## Check whether model output matches the correct attribute
# Keep top 400 entities with highest sum of known attributes across all prompts
# Keep top 12 templates per attribute

import collections
import re
import numpy as np


from zoneinfo import ZoneInfo
import datetime

def timezone_name_to_utc_offset(name):
  offset =  ZoneInfo(name).utcoffset(datetime.datetime.now()).seconds
  sign = '+'
  if offset // 3600 >= 12:
    offset = 24 * 3600 - offset
    sign = '-'
  fmt_offset = str(datetime.timedelta(seconds=offset)).rsplit(':', 1)[0]
  if fmt_offset.startswith('0') and offset >= 1800:
    fmt_offset = fmt_offset[1:]
  return f'{sign}{fmt_offset}'


# TODO: Pull out attribute checking to functions

sorted_entity = sorted(set([v['entity'] for v in prompts_to_meta_data_subsample.values()]))
sorted_template = sorted(set([v['template'] for v in prompts_to_meta_data_subsample.values()]))
stats = np.zeros([len(sorted_entity), len(sorted_template)])
for p, out in prompt_to_output.items():
  attr = prompts_to_meta_data_subsample[p]['attr']
  entity = prompts_to_meta_data_subsample[p]['entity']
  label = entity_attributes[entity][attr]

  print("\n")
  print(f'Entity: {entity}')
  print(f'Attribute: {attr}')
  print(f'Prompt: {p}')
  if not label:
    print(f'No label for {entity} {attr}')
    continue
  norm_label = label.lower()
  norm_out = out.split('"')[0].strip(' "').replace('\\/', '/').lower()

  if len(norm_label) < len(norm_out):
    correct = int(norm_out.startswith(norm_label))
  else:
    correct = int(norm_label.startswith(norm_out))



  # Exceptions
  if re.search('coord|"lat"|"long"|latitude|coordinates|longitude', p):
    try:
      correct = int((float(norm_label.strip('-−')) - float(re.findall(r'\d+', norm_out)[0])) <= 2)
    except:
      correct = 0
  if re.search('United States|United Kingdom', label):
    norm_label = label.strip().replace('the ', '')
    norm_out = out[len(p):].strip().replace('the ', '')
    correct = int(norm_out.startswith(norm_label) or norm_out.startswith('England'))
  if re.search('South Korea', label):
    correct = int(norm_out.startswith('korea') or norm_out.startswith('south korea'))
  if re.search('North America', label):
    correct = norm_label in norm_out or norm_out == 'na' or norm_out.startswith('america')
  if re.search('Mandarin', label):
    correct = norm_out in norm_label or norm_out == 'chinese'
  if re.search('language', p) and ',' in norm_label:
    correct = any(lang in norm_out for lang in norm_label.split(','))
  if re.search('UTC', p) and '/' in norm_label:
    norm_label = timezone_name_to_utc_offset(label)
    correct = norm_out.startswith(norm_label.split(':')[0])
    if not correct and re.search(r'[+\-]0\d', norm_out):
      correct = norm_out.replace('0', '', 1).startswith(norm_label.split(':')[0])
    # Summer daylight saving time
    if not correct and (
        re.search(r'\-[5-8]', norm_label) and label.startswith('America') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Europe') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Africa')):
      #print('SUMMER TIME:', norm_label, norm_out)
      out_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_out)
      label_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_label)
      if out_offset_match and label_offset_match:
        norm_out_offset = int(out_offset_match.group(1))
        norm_label_offset = int(label_offset_match.group(1))
        correct = (norm_out_offset <= norm_label_offset + 1 and
                    norm_out_offset >= norm_label_offset - 1)
    if not correct and re.search(r'[+\-](\d+)', norm_out) and int(
        re.search(r'[+\-](\d+)', norm_out).group(1)) > 11:
      offset = 24 - int(re.search(r'[+\-](\d+)', norm_out).group(1))
      correct = str(offset) in norm_label

  print(f"Label: {norm_label} Output: {norm_out}")
  print(f"Correct: {correct}")
  stats[sorted_entity.index(prompts_to_meta_data_subsample[p]['entity']), sorted_template.index(prompts_to_meta_data_subsample[p]['template'])] += int(correct)
  print(f"Updated stats: {stats[sorted_entity.index(prompts_to_meta_data_subsample[p]['entity']), sorted_template.index(prompts_to_meta_data_subsample[p]['template'])]}")
# print('-----------------------------------')
# for i in np.argsort(stats.sum(axis=0))[::-1]:
#   print(sorted_template[i], int(stats[:, i].sum()), len(stats[:, i]))
# for i in np.argsort(stats.sum(axis=-1))[::-1]:
#   print(sorted_entity[i], int(stats[i].sum()), len(stats[i]))


kept_entity_index = np.argsort(stats.sum(axis=1))
KEPT_ENTITY = [sorted_entity[i] for i in kept_entity_index]
topk_template_index = set(np.argsort(stats.sum(axis=0)))
kept_template_index = []
# A dict of all kept attribute to prompts.
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {}
for attr in attribute_prompts_subsample:
  # Kept the top 4 to 12 templates per attribute.
  print(f'Attribute: {attr}')
  attr_indices = [sorted_template.index(t) for t in attribute_prompts_subsample[attr]]
  per_attr_kept_template_index = sorted(attr_indices, key=lambda i: stats[:, i].sum())[-12:][::-1]
  per_attr_kept_template_index = [x for i, x in enumerate(per_attr_kept_template_index)
                                  if x in topk_template_index or i < 4]
  kept_template_index.extend(per_attr_kept_template_index)
  KEPT_ATTR_TO_PROMPT_AND_SPLIT[attr] = {sorted_template[i]: prompt_splits[sorted_template[i]]
                               for i in per_attr_kept_template_index}
print('Kept %d entity, %d prompt template' % (len(kept_entity_index), len(kept_template_index)))

print('Average accuracy: %.2f%%' % (100 *  (stats[:, kept_template_index][kept_entity_index, :]).sum()/ (len(kept_entity_index) * len(kept_template_index))))



Entity: Resistencia
Attribute: Timezone
Prompt: Hong Kong: Asia/Hong_Kong. Resistencia:
Label: america/argentina/cordoba Output: 1989. 1
Correct: 0
Updated stats: 0.0


Entity: Volos
Attribute: Language
Prompt:  in Volos, people usually speak
Label: greek Output: greek.

<h2>what is the best
Correct: 1
Updated stats: 1.0


Entity: Redding
Attribute: Latitude
Prompt: [{"city": "Mexico City", "lat": "19.4"}, {"city": "Redding", "lat": "
Label: 41 Output: 40.0
Correct: 1
Updated stats: 1.0


Entity: Kanpur
Attribute: Continent
Prompt:  city to continent: New York City is in North America. Kanpur is in
Label: asia Output: india.

q: what is the
Correct: 0
Updated stats: 0.0


Entity: Astana
Attribute: Country
Prompt: city to country: San Francisco is in United States. Astana is in
Label: kazakhstan Output: kazakhstan.

city to city: new
Correct: 1
Updated stats: 1.0


Entity: Varna
Attribute: Continent
Prompt: [{"city": "Kuala Lumpur", "continent": "Asia"}, {"city": "Varna", "continent":

In [104]:
stats.sum()

68.0

In [105]:
# Print top 12 templates for each attribute

attribute_prompts
for i in kept_template_index:
  print(f'{prompt_splits[sorted_template[i]]}\t{sorted_template[i]}\t{stats[:, i][kept_entity_index].mean():.2f}')

val	city to country: New Delhi is in India. %s is in	0.02
test	city to country: San Francisco is in United States. %s is in	0.01
val	city to country: New York City is in United States. %s is in	0.01
val	city to country: Beijing is in China. %s is in	0.01
train	[{"city": "Tokyo", "country": "Japan"}, {"city": "%s", "country": "	0.01
test	[{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": "	0.01
test	%s is a city in the country of	0.01
train	city to country: Cape Town is in South Africa. %s is in	0.00
train	[{"city": "Toronto", "country": "Canada"}, {"city": "%s", "country": "	0.00
train	[{"city": "Buenos Aires", "country": "Argentina"}, {"city": "%s", "country": "	0.00
train	St. Petersburg is a city in the continent of Europe. %s is a city in the continent of	0.03
val	[{"city": "Kuala Lumpur", "continent": "Asia"}, {"city": "%s", "continent": "	0.02
train	[{"city": "Toronto", "continent": "North America"}, {"city": "%s", "continent": "	0.01
train	[{"city": "San Franc

In [106]:
# train/val/test split of attribute_specific_prompt_templates has been predefined. Check whether the split is roughly balanced.

print(sum(map(len, KEPT_ATTR_TO_PROMPT_AND_SPLIT.values())))
for attr, prompt_to_split in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items():
  print(attr, collections.Counter(prompt_to_split.values()))

62
Country Counter({'train': 4, 'val': 3, 'test': 3})
Continent Counter({'train': 8, 'val': 2, 'test': 2})
Latitude Counter({'train': 8, 'test': 2, 'val': 2})
Longitude Counter({'train': 3, 'val': 1})
Language Counter({'train': 6, 'val': 4, 'test': 2})
Timezone Counter({'train': 6, 'val': 4, 'test': 2})


### Create an Instance

In [14]:
# train/val/test split of entities has been predefined. Check whether the split is roughly balanced.
import json

ENTITY_TYPE = 'city'
print(ENTITY_TYPE)
ALL_ENTITY_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_entity_to_split.json')))
ALL_ATTR_TO_PROMPTS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_attribute_to_prompts.json')))
WIKI_PROMPT_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'wikipedia_{ENTITY_TYPE}_entity_prompts.json')))

# Filtered
KEPT_ENTITY_SPLITS = {e: ALL_ENTITY_SPLITS[e] for e in KEPT_ENTITY} # kept entities to split
KEPT_PROMPT_SPLITS = {k: (a, v) for a, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items() for k, v in d.items() if k.count('%') == 1}
print(f'len kept prompt splits={len(KEPT_PROMPT_SPLITS)}')
print(f'len added wiki inv prompts={len({k: v for k, v in WIKI_PROMPT_SPLITS.items() if k.count("%") == 1})}')
for prompt in WIKI_PROMPT_SPLITS:
  KEPT_PROMPT_SPLITS[prompt] = ('Other', WIKI_PROMPT_SPLITS[prompt]['split']) # add wiki prompt splits as "Other" attribute
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {k: {p: v for p, v in d.items() if p.count('%') == 1} for k, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items()}
print(f'Total #entities={len(ALL_ENTITY_SPLITS)} #attributes={len(KEPT_ATTR_TO_PROMPT_AND_SPLIT)} #prompts={sum(map(len, ALL_ATTR_TO_PROMPTS.values()))} #wiki_prompts={len(WIKI_PROMPT_SPLITS)}')
print(f'Kept #entities={len(KEPT_ENTITY_SPLITS)} #prompts={len(KEPT_PROMPT_SPLITS)}')
for split in ('train', 'val', 'test'):
  print(split, f'Kept #entities={len([k for k, v in KEPT_ENTITY_SPLITS.items() if v == split])}',
               f'#prompts={len([k for k, v in KEPT_PROMPT_SPLITS.items() if v[1] == split])}')

city


NameError: name 'KEPT_ENTITY' is not defined

In [109]:
# Generate outputs for all wiki inv prompts
from sae_lens.sae_bench.utils.generation_utils import generate_batched

skip_wiki_to_prompt = False
filename_prompt_to_output = f'ravel_{model_name}_{ENTITY_TYPE}_wiki_prompt_to_output.json'

if not skip_wiki_to_prompt:
    wiki_prompts = [(t % e) for t, s_e in WIKI_PROMPT_SPLITS.items()
                    for e in ([s_e['entity']] if s_e['entity']
                            else [a for a in KEPT_ENTITY_SPLITS if KEPT_ENTITY_SPLITS[a] == 'train' or s_e['split'] == 'train'])
                    ]
    print(len(wiki_prompts))

    wiki_prompt_and_output = generate_batched(
        hf_model,
        tokenizer,
        wiki_prompts,
        max_new_tokens=8,
        batch_size=64)
    wiki_prompt_to_output = {k: v[len(k):] for k, v in wiki_prompt_and_output}
    json.dump(wiki_prompt_to_output, open(os.path.join(DATA_DIR, model_name, filename_prompt_to_output), "w"), ensure_ascii=False)
else:
    wiki_prompt_to_output = json.load(open(os.path.join(DATA_DIR, model_name, filename_prompt_to_output)))

2810
Total #prompts=2810
Set prompt_max_length=64


100%|██████████| 44/44 [11:13<00:00, 15.31s/it]


In [110]:
# Set of prompts with all possible outputs

ALL_PROMPT_TO_OUTPUT = {**prompt_to_output, **wiki_prompt_to_output}

print(len(ALL_PROMPT_TO_OUTPUT))

2910


In [111]:

import datasets
from datasets import Dataset
# from utils.intervention_utils import extract_label
from sae_lens.sae_bench.utils.generate_ravel_instance import RAVELMetadata

# def extract_label(string):
#     delimiters = r"[ \-\t,.\n]"
#     return re.split(delimiters, string.strip())[0]

import re

def extract_label(text):
    tokens = re.split(r'(["]|[.,;]\s|\n| \(|\sand)', text + ' ')
    x = tokens[0]
    digit_match = re.search(r'\.\d\d', x)
    if digit_match:
        x = x[:digit_match.span(0)[1]]
    gender_match = re.match(r'\s?(his|her|himself|herself|she|he)[^\w]', x)
    if gender_match:
        x = x[:gender_match.span(1)[1]]
    if not x.strip():
        x = ' '.join(text.split(' ')[:2]).rstrip('.,"\n')
    assert x.strip()
    return x


def get_first_token(x):
  return re.split(r'[^\w\+\-]', x.strip(), re.UNICODE)[0]


def filter_inv_example(base_output, inv_output):
  different_outputs = (get_first_token(base_output) !=
                       get_first_token(inv_output))
  valid_outputs = (
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(base_output), re.IGNORECASE) and
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(inv_output), re.IGNORECASE))
  return valid_outputs and different_outputs


FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


ravel_metadata = RAVELMetadata(
    model_name,
    KEPT_ENTITY_SPLITS,
    KEPT_ATTR_TO_PROMPT_AND_SPLIT,
    KEPT_PROMPT_SPLITS,
    WIKI_PROMPT_SPLITS,
    ALL_PROMPT_TO_OUTPUT)

In [120]:
len(KEPT_ENTITY_SPLITS)

100

In [112]:
extract_label(' Thailand.\nThe distance from P')

' Thailand'

In [114]:
#@title Generate the Conetxt TEST/VAL Split

# Context Split: All entities are in TRAIN, but all prompts are in test/dev

import random

from sae_lens.sae_bench.utils.generate_ravel_instance import gen_context_test_split

TEST_TYPE = 'context'

# Take the first N examples only
first_n = 256

eval_split_to_raw_example = gen_context_test_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)
eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Compute stats.
for split in eval_split_to_raw_example:
  print('\nSplit %s:\nTotal %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))
      ))
  #for i, example in enumerate(eval_split_to_raw_example[split]):
  #  print(example)
  #  #print(tokenizer(example['input']).input_ids)
  #  break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

KeyError: 'city to country: New Delhi is in India. Chota is in'

In [None]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [None]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Generate the Entity TEST/VAL Split

from utils.generate_ravel_instance import gen_entity_test_split

TEST_TYPE = 'entity'

# Take the first N examples only
first_n = 128

eval_split_to_raw_example = gen_entity_test_split(
    ravel_metadata,
    extract_label_fn=extract_label, filter_example_fn=filter_inv_example,
    first_n=first_n)

eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Stats
for split in eval_split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))
      ))
  for i, example in enumerate(eval_split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  for k in ('input', 'source_input'):
    input_ids = tokenizer(example[k])['input_ids']
    #print(k)
    #print(input_ids)
    print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

In [None]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [None]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Generate train split (for models that use counterfactuals)

import datasets
from datasets import Dataset

def gen_train_split(metadata, extract_label_fn, filter_example_fn, first_n=256):
  split_to_raw_example = {}
  # Group by attributes.
  target_split = 'train'
  for attr, prompt_to_split in metadata.attr_to_prompt.items():
      base_prompt_candiates = [p for p, s in prompt_to_split.items() if s == target_split]
      base_task_inputs = [
          ((prompt, entity), metadata.prompt_to_output[prompt % entity])
          for entity in metadata.get_entities(target_split)
          for prompt in random.sample(
              base_prompt_candiates, k=min(2, len(base_prompt_candiates)))]
      source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, (source_attr, source_split) in KEPT_PROMPT_SPLITS.items()
          if source_split == target_split and source_attr != 'Other'
          for entity in metadata.sample_entities(target_split, k=1)
      ]
      wiki_source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, split_and_arg in metadata.entity_prompt_to_split.items()
          if split_and_arg['split'] == target_split
          for entity in ([split_and_arg['entity']] if split_and_arg['entity']
                         else metadata.sample_entities(target_split, k=1))
      ]
      source_task_inputs = source_task_inputs + wiki_source_task_inputs
      if len(base_task_inputs) < 5 or len(source_task_inputs) < 5:
        continue
      print(attr, target_split, len(base_task_inputs), len(source_task_inputs), len(wiki_source_task_inputs))
      split_to_raw_example[f'{attr}-{target_split}'] = []
      for (p, a), v in base_task_inputs:
        source_input_candiates = [x for x in source_task_inputs if filter_example_fn(v, metadata.prompt_to_output[p % x[0][1]])]
        #print(len(source_input_candiates), v)
        split_to_raw_example[f'{attr}-{target_split}'].extend([{
          'input': p % a, 'label': extract_label_fn(v),
          'source_input': s_p % s_a, 'source_label': extract_label_fn(source_v),
          'inv_label': extract_label_fn(metadata.prompt_to_output[p % s_a]),
          'split': p, 'source_split': s_p,
          'entity': a, 'source_entity': s_a}
        for (s_p, s_a), source_v in random.sample(source_input_candiates, k=min(len(source_input_candiates), round(first_n / len(base_task_inputs))))
        if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and re.search('\w+', source_v)
      ])
  split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}
  return split_to_raw_example


# Take the first N examples only
first_n = 10240

split_to_raw_example = gen_train_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)

# Stats
for split in split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(split_to_raw_example[split]), len(split_to_raw_example[split]),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in split_to_raw_example[split][:first_n]]))
      ))
  for i, example in enumerate(split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('train',):
  print(f'Split {split}: Total #subsplit={len([k for k in split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith(split)]))}')

In [None]:
json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_train.json')
print(json_path)
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Postprocess labels

import json
import re

# from intervention_utils import extract_label


entity_type = 'city'
instance =  model_name
version = ''


attribute_to_prompts = json.load(open(os.path.join(DATA_DIR + version, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))


json_path = os.path.join(DATA_DIR + version, f'{instance}/{instance}_{entity_type}_context_test.json')
split_to_raw_example = json.load(open(json_path, 'r'))
print(len(split_to_raw_example))

all_labels = set()
for split in split_to_raw_example:
  for i in range(len(split_to_raw_example[split])):
    if split.split('-')[0] in ['Latitude', 'Longitude'] or  split.split('-')[0] in attribute_to_prompts['Latitude'] or split.split('-')[0] in attribute_to_prompts['Longitude']:
      # Keep only the integer part.
      split_to_raw_example[split][i]['inv_label'] = split_to_raw_example[split][i]['inv_label'].replace('°', '.').split('.')[0]
      split_to_raw_example[split][i]['label'] = split_to_raw_example[split][i]['label'].replace('°', '.').split('.')[0]
    all_labels.add(split_to_raw_example[split][i]['inv_label'])

In [None]:
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

In [None]:
sorted(all_labels)

In [None]:
#@title Intervention locations for all possible prompts

SPLIT_TO_INV_POSITION = {}

all_prompt_templates = {p for p in WIKI_PROMPT_SPLITS}
all_prompt_templates.update({v for vs in ALL_ATTR_TO_PROMPTS.values() for v in vs})
print(len(all_prompt_templates))

for prompt_template in all_prompt_templates:
  if prompt_template.count('%s') != 1:
    continue
  print(prompt_template)
  prompt_input = prompt_template.replace('%s', '000000', 1)
  input_ids = tokenizer(prompt_input)['input_ids']
  toks = tokenizer.batch_decode(input_ids)
  for i in range(-1, -len(toks), -1):
    if toks[i] == '0' and toks[i - 1] == '0' and toks[i - 2] == '0' and toks[i - 3] == '0':
      break
  SPLIT_TO_INV_POSITION[prompt_template] = i
  print(i, list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], toks)))

print(min(SPLIT_TO_INV_POSITION.values()))

In [None]:
version = ''
json.dump(SPLIT_TO_INV_POSITION,
          open(os.path.join(DATA_DIR + version, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json'), 'w'),
          ensure_ascii=False, indent=2)

### Extract Features

In [None]:
layer_idx

In [None]:
def get_resid_post_activations(nnsight_model, layer_idx, encoded_input):
    submodule = nnsight_model.model.layers[layer_idx]
    with torch.no_grad(), nnsight_model.trace(
        encoded_input.input_ids.to(device),
        attention_mask=encoded_input.attention_mask.to(device), 
        **nnsight_tracer_kwargs):
        output = submodule.output[0].save()
    return output

# # Test the function
# prompt = ["Hello, my name is", "hello"]
# tok = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to(device)

# out = get_resid_post_activations(
#     nnsight_model,
#     layer_idx=0,
#     encoded_input=tok['input_ids'],
# )
# out.shape

In [None]:
import h5py
import json
import re
import pickle as pkl

# from extract_neuron_activations import get_representations_across_layers_llama


def extract_ravel_entity_features(entity_to_split, attribute_to_prompt_and_split,
                                  layer, output_path, batch_size=128, placeholder='%s'):
  print(output_path)
  f_out = h5py.File(output_path, "a")
  # Generate prompts.
  splits = {'train': ('train', 'train'),
                 'val_entity': ('val', 'train'),
                 'val_context': ('train', 'val'),}
  for split_name, (entity_split, prompt_split) in splits.items():
    for attr, prompt_to_split in attribute_to_prompt_and_split.items():
      inputs, entities, templates = zip(*[(p[:p.index(placeholder)] + e, e, p)
          for p in prompt_to_split if prompt_to_split[p] == prompt_split
          for e in entity_to_split if entity_to_split[e] == entity_split])
      all_features = []
      for b_i in range(0, len(inputs), batch_size):
        input_batch = inputs[b_i:b_i+batch_size]
        encoded_input = tokenizer(
            input_batch, padding="max_length", max_length=INPUT_MAX_LEN,
            return_tensors="pt", truncation=True)
        with torch.no_grad():
          # outputs = get_representations_across_layers_llama(
          #     model.model, encoded_input, layer_index=layer)[f'layer_{layer}-block_output']
          outputs = get_resid_post_activations(nnsight_model, layer, encoded_input)
          for i in range(len(input_batch)):
            all_features.append(outputs[i:i+1, -1, :].to(torch.float16).cpu().numpy())
      print(attr, split_name, np.concatenate(all_features).shape)
      f_out[f'{attr}-{split_name}'] = np.concatenate(all_features)
      f_out[f'{attr}-{split_name}' + '_input'] = np.void(pkl.dumps(inputs))
      f_out[f'{attr}-{split_name}' + '_template'] = np.void(pkl.dumps(templates))
      f_out[f'{attr}-{split_name}' + '_entity'] = np.void(pkl.dumps(entities))
  f_out.flush()
  f_out.close()


INPUT_MAX_LEN = 48


for layer in [10, 14]:
  output_path = os.path.join(DATA_DIR, model_name, f'ravel_{entity_type}_{model_name}_layer{layer}_representation.hdf5')
  extract_ravel_entity_features(
      KEPT_ENTITY_SPLITS, KEPT_ATTR_TO_PROMPT_AND_SPLIT,
      layer, output_path, batch_size=64)