In [1]:
%load_ext autoreload
%autoreload 2

import os

REPO_DIR = f'/share/u/can/ravel'
SRC_DIR = os.path.join(REPO_DIR, 'src')
MODEL_DIR = os.path.join(REPO_DIR, 'models')
DATA_DIR = os.path.join(REPO_DIR, 'data')

for d in [MODEL_DIR, DATA_DIR]:
    if not os.path.exists(d):
        os.makedirs(d)


import sys
sys.path.append(REPO_DIR)
sys.path.append(SRC_DIR)

import numpy as np
import random
import torch
import accelerate

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


# Model

In [2]:
# from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

# model_id = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
# model_name = "tinyllama"

# tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=MODEL_DIR)
# hf_model = LlamaForCausalLM.from_pretrained(
#     model_id, low_cpu_mem_usage=True, device_map='auto', cache_dir=MODEL_DIR,
#     torch_dtype=torch.bfloat16)
# hf_model = hf_model.eval()
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = 'left'

# VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

# layer_idx = 14

In [3]:
# Load model
from transformers import AutoModelForCausalLM, AutoTokenizer

with open('/share/u/can/src/hf.txt', 'r') as f:
    hf_token = f.read().strip()

model_id = "google/gemma-2-2b"
model_name = "gemma-2-2b"

torch.set_grad_enabled(False) # avoid blowing up mem
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
    device_map=device,
    low_cpu_mem_usage=True,
    attn_implementation="eager"
)

tokenizer =  AutoTokenizer.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

layer_idx = 10

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.36it/s]


In [4]:
from nnsight import NNsight
nnsight_model = NNsight(hf_model)
nnsight_tracer_kwargs = {'scan': True, 'validate': False, 'use_cache': False, 'output_attentions': False}

# Create a RAVEL Instance for TinyLLaMA

In [5]:
## Already done via setup.sh

# from huggingface_hub import hf_hub_download

# if not os.path.exists(f'{DATA_DIR}/base.zip'):
#     hf_hub_download(
#         repo_id='canrager/ravel',
#         filename='base.zip',
#         local_dir=DATA_DIR,
#     )
#     os.system(f'unzip {DATA_DIR}/base.zip -d {DATA_DIR}')

### Check model knowledge of all entity - attribute pairs

In [6]:
# Generate a dataset of combinations of entities and attribute_specific_prompts

import json
import os

entity_type = 'city'

attribute_prompts = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))
prompt_splits = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_prompt_to_split.json')))
entity_attributes = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_entity_attributes.json')))
print(f'#entities={len(entity_attributes)}, #prompt_templates={sum(map(len, attribute_prompts.values()))}')

# # For testing purposes
# partial_entitiy_attributes = {}
# for i, (e, a) in enumerate(entity_attributes.items()):
#     partial_entitiy_attributes[e] = a
#     if i == 10:
#         break
# entity_attributes = partial_entitiy_attributes

prompts_to_meta_data = {t % x: {'entity': x, 'attr': a, 'template': t}
               for x in entity_attributes
               for a, ts in attribute_prompts.items()
               for t in ts}
print(f'total number of prompts {len(prompts_to_meta_data)}')

#entities=3552, #prompt_templates=273
total number of prompts 969696


In [7]:
# # Test generation
# input_bth = ['How do ', 'asdfasdf is']
# tok = tokenizer(input_bth, return_tensors='pt', padding=True, truncation=True).to(device)
# out = model.generate(input_ids=tok['input_ids'], attention_mask=tok['attention_mask'], max_new_tokens=5)
# tokenizer.batch_decode(out, skip_special_tokens=True)

In [8]:
# Generate outputs for all prompts (the correct attribute is expected in the output)

skip_the_inference_step = True

prompt_to_output_path = os.path.join(DATA_DIR, model_name, f'ravel_{model_name}_{entity_type}_prompt_to_output.json')

if skip_the_inference_step:
    # Can skip the inference step by downloading the pre-computed outputs:
    # https://drive.google.com/drive/u/0/folders/1U4Js-NarJa-B_iQc5wr0OXV2G-5BDBsN
    pass
else: 
    from utils.generation_utils import generate_batched

    prompt_max_length = 48

    prompt_to_output = generate_batched(
        hf_model,
        tokenizer,
        list(prompts_to_meta_data),
        prompt_max_length+8,
        prompt_max_length=prompt_max_length,
        batch_size=256)
    prompt_to_output = {k: v[len(k):] for k, v in prompt_to_output}

    json.dump(prompt_to_output, open(prompt_to_output_path, 'w'), ensure_ascii=False)

In [9]:
prompt_to_output = json.load(open(prompt_to_output_path))
len(prompt_to_output)

969696

In [10]:
#@title Behavioral Test

## Check whether model output matches the correct attribute
# Keep top 400 entities with highest sum of known attributes across all prompts
# Keep top 12 templates per attribute

import collections
import re
import numpy as np


from zoneinfo import ZoneInfo
import datetime

def timezone_name_to_utc_offset(name):
  offset =  ZoneInfo(name).utcoffset(datetime.datetime.now()).seconds
  sign = '+'
  if offset // 3600 >= 12:
    offset = 24 * 3600 - offset
    sign = '-'
  fmt_offset = str(datetime.timedelta(seconds=offset)).rsplit(':', 1)[0]
  if fmt_offset.startswith('0') and offset >= 1800:
    fmt_offset = fmt_offset[1:]
  return f'{sign}{fmt_offset}'


# TODO: Pull out attribute checking to functions

sorted_entity = sorted(set([v['entity'] for v in prompts_to_meta_data.values()]))
sorted_template = sorted(set([v['template'] for v in prompts_to_meta_data.values()]))
stats = np.zeros([len(sorted_entity), len(sorted_template)])
for p, out in prompt_to_output.items():
  attr = prompts_to_meta_data[p]['attr']
  entity = prompts_to_meta_data[p]['entity']
  label = entity_attributes[entity][attr]
  if not label:
    continue
  norm_label = label.lower()
  norm_out = out.split('"')[0].strip(' "').replace('\\/', '/').lower()
  if len(norm_label) < len(norm_out):
    correct = int(norm_out.startswith(norm_label))
  else:
    correct = int(norm_label.startswith(norm_out))

  # Exceptions
  if re.search('coord|"lat"|"long"|latitude|coordinates|longitude', p):
    try:
      correct = int((float(norm_label.strip('-−')) - float(re.findall(r'\d+', norm_out)[0])) <= 2)
    except:
      correct = 0
  if re.search('United States|United Kingdom', label):
    norm_label = label.strip().replace('the ', '')
    norm_out = out[len(p):].strip().replace('the ', '')
    correct = int(norm_out.startswith(norm_label) or norm_out.startswith('England'))
  if re.search('South Korea', label):
    correct = int(norm_out.startswith('korea') or norm_out.startswith('south korea'))
  if re.search('North America', label):
    correct = norm_label in norm_out or norm_out == 'na' or norm_out.startswith('america')
  if re.search('Mandarin', label):
    correct = norm_out in norm_label or norm_out == 'chinese'
  if re.search('language', p) and ',' in norm_label:
    correct = any(lang in norm_out for lang in norm_label.split(','))
  if re.search('UTC', p) and '/' in norm_label:
    norm_label = timezone_name_to_utc_offset(label)
    correct = norm_out.startswith(norm_label.split(':')[0])
    if not correct and re.search(r'[+\-]0\d', norm_out):
      correct = norm_out.replace('0', '', 1).startswith(norm_label.split(':')[0])
    # Summer daylight saving time
    if not correct and (
        re.search(r'\-[5-8]', norm_label) and label.startswith('America') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Europe') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Africa')):
      #print('SUMMER TIME:', norm_label, norm_out)
      out_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_out)
      label_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_label)
      if out_offset_match and label_offset_match:
        norm_out_offset = int(out_offset_match.group(1))
        norm_label_offset = int(label_offset_match.group(1))
        correct = (norm_out_offset <= norm_label_offset + 1 and
                   norm_out_offset >= norm_label_offset - 1)
    if not correct and re.search(r'[+\-](\d+)', norm_out) and int(
        re.search(r'[+\-](\d+)', norm_out).group(1)) > 11:
      offset = 24 - int(re.search(r'[+\-](\d+)', norm_out).group(1))
      correct = str(offset) in norm_label
  stats[sorted_entity.index(prompts_to_meta_data[p]['entity']), sorted_template.index(prompts_to_meta_data[p]['template'])] += int(correct)

print('-----------------------------------')
for i in np.argsort(stats.sum(axis=0))[::-1]:
  print(sorted_template[i], int(stats[:, i].sum()), len(stats[:, i]))
for i in np.argsort(stats.sum(axis=-1))[::-1]:
  print(sorted_entity[i], int(stats[i].sum()), len(stats[i]))


kept_entity_index = np.argsort(stats.sum(axis=1))[-400:]
KEPT_ENTITY = [sorted_entity[i] for i in kept_entity_index]
topk_template_index = set(np.argsort(stats.sum(axis=0))[-200:])
kept_template_index = []
# A dict of all kept attribute to prompts.
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {}
for attr in attribute_prompts:
  # Kept the top 4 to 12 templates per attribute.
  attr_indices = [sorted_template.index(t) for t in attribute_prompts[attr]]
  per_attr_kept_template_index = sorted(attr_indices, key=lambda i: stats[:, i].sum())[-12:][::-1]
  per_attr_kept_template_index = [x for i, x in enumerate(per_attr_kept_template_index)
                                  if x in topk_template_index or i < 4]
  kept_template_index.extend(per_attr_kept_template_index)
  KEPT_ATTR_TO_PROMPT_AND_SPLIT[attr] = {sorted_template[i]: prompt_splits[sorted_template[i]]
                               for i in per_attr_kept_template_index}
print('Kept %d entity, %d prompt template' % (len(kept_entity_index), len(kept_template_index)))

print('Average accuracy: %.2f%%' % (100 *  (stats[:, kept_template_index][kept_entity_index, :]).sum()/ (len(kept_entity_index) * len(kept_template_index))))

-----------------------------------
[{"city": "St. Petersburg", "lat": "60"}, {"city": "%s", "lat": " 3327 3552
[{"city": "St. Petersburg", "lat": "59.9"}, {"city": "%s", "lat": " 3254 3552
[{"city": "Sydney", "long": "151.2"}, {"city": "%s", "long": " 3219 3552
[{"city": "Paris", "lat": "48.9"}, {"city": "%s", "lat": " 3147 3552
[{"city": "Paris", "lat": "49"}, {"city": "%s", "lat": " 3140 3552
[{"city": "Rome", "lat": "41.9"}, {"city": "%s", "lat": " 3069 3552
[{"city": "Toronto", "lat": "43.7"}, {"city": "%s", "lat": " 3066 3552
[{"city": "Beijing", "lat": "39.9"}, {"city": "%s", "lat": " 3057 3552
[{"city": "San Francisco", "long": "122.4"}, {"city": "%s", "long": " 3055 3552
[{"city": "Toronto", "lat": "44"}, {"city": "%s", "lat": " 3048 3552
[{"city": "Mexico City", "long": "99.1"}, {"city": "%s", "long": " 3010 3552
[{"city": "San Francisco", "lat": "37.7"}, {"city": "%s", "lat": " 3007 3552
[{"city": "Rome", "lat": "42"}, {"city": "%s", "lat": " 2994 3552
[{"city": "Tokyo", "lo

In [11]:
# Print top 12 templates for each attribute

attribute_prompts
for i in kept_template_index:
  print(f'{prompt_splits[sorted_template[i]]}\t{sorted_template[i]}\t{stats[:, i][kept_entity_index].mean():.2f}')

val	city: %s, country:	1.00
train	[{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": "	1.00
test	[{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": "	1.00
val	[{"city": "Beijing", "country": "China"}, {"city": "%s", "country": "	1.00
test	[{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": "	1.00
train	"lang": "English"}, {"city": "%s", "country": "	1.00
train	[{"city": "Toronto", "country": "Canada"}, {"city": "%s", "country": "	1.00
train	[{"city": "%s", "country": "	0.99
test	[{"city": "Mexico City", "country": "Mexico"}, {"city": "%s", "country": "	1.00
train	[{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": "	1.00
train	[{"city": "Kuala Lumpur", "country": "Malaysia"}, {"city": "%s", "country": "	1.00
train	[{"city": "Cape Town", "country": "South Africa"}, {"city": "%s", "country": "	1.00
train	[{"city": "Rio de Janeiro", "continent": "South America"}, {"city": "%s", "continent": "	1.00
train

In [12]:
# train/val/test split of attribute_specific_prompt_templates has been predefined. Check whether the split is roughly balanced.

print(sum(map(len, KEPT_ATTR_TO_PROMPT_AND_SPLIT.values())))
for attr, prompt_to_split in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items():
  print(attr, collections.Counter(prompt_to_split.values()))

72
Country Counter({'train': 7, 'test': 3, 'val': 2})
Continent Counter({'train': 6, 'test': 3, 'val': 3})
Latitude Counter({'train': 7, 'test': 4, 'val': 1})
Longitude Counter({'train': 7, 'test': 3, 'val': 2})
Language Counter({'train': 6, 'val': 5, 'test': 1})
Timezone Counter({'train': 7, 'val': 4, 'test': 1})


### Create an Instance

In [13]:
# train/val/test split of entities has been predefined. Check whether the split is roughly balanced.
import json

ENTITY_TYPE = 'city'
print(ENTITY_TYPE)
ALL_ENTITY_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_entity_to_split.json')))
ALL_ATTR_TO_PROMPTS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_attribute_to_prompts.json')))
WIKI_PROMPT_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'wikipedia_{ENTITY_TYPE}_entity_prompts.json')))

# Filtered
KEPT_ENTITY_SPLITS = {e: ALL_ENTITY_SPLITS[e] for e in KEPT_ENTITY} # kept entities to split
KEPT_PROMPT_SPLITS = {k: (a, v) for a, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items() for k, v in d.items() if k.count('%') == 1}
print(f'len kept prompt splits={len(KEPT_PROMPT_SPLITS)}')
print(f'len added wiki inv prompts={len({k: v for k, v in WIKI_PROMPT_SPLITS.items() if k.count("%") == 1})}')
for prompt in WIKI_PROMPT_SPLITS:
  KEPT_PROMPT_SPLITS[prompt] = ('Other', WIKI_PROMPT_SPLITS[prompt]['split']) # add wiki prompt splits as "Other" attribute
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {k: {p: v for p, v in d.items() if p.count('%') == 1} for k, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items()}
print(f'Total #entities={len(ALL_ENTITY_SPLITS)} #attributes={len(KEPT_ATTR_TO_PROMPT_AND_SPLIT)} #prompts={sum(map(len, ALL_ATTR_TO_PROMPTS.values()))} #wiki_prompts={len(WIKI_PROMPT_SPLITS)}')
print(f'Kept #entities={len(KEPT_ENTITY_SPLITS)} #prompts={len(KEPT_PROMPT_SPLITS)}')
for split in ('train', 'val', 'test'):
  print(split, f'Kept #entities={len([k for k, v in KEPT_ENTITY_SPLITS.items() if v == split])}',
               f'#prompts={len([k for k, v in KEPT_PROMPT_SPLITS.items() if v[1] == split])}')

city
len kept prompt splits=72
len added wiki inv prompts=938
Total #entities=3552 #attributes=6 #prompts=273 #wiki_prompts=938
Kept #entities=400 #prompts=1010
train Kept #entities=200 #prompts=351
val Kept #entities=88 #prompts=341
test Kept #entities=112 #prompts=318


In [14]:
# Generate outputs for all wiki inv prompts
from utils.generation_utils import generate_batched

skip_wiki_to_prompt = True
filename_prompt_to_output = f'ravel_{model_name}_{ENTITY_TYPE}_wiki_prompt_to_output.json'

if not skip_wiki_to_prompt:
    wiki_prompts = [(t % e) for t, s_e in WIKI_PROMPT_SPLITS.items()
                    for e in ([s_e['entity']] if s_e['entity']
                            else [a for a in KEPT_ENTITY_SPLITS if KEPT_ENTITY_SPLITS[a] == 'train' or s_e['split'] == 'train'])
                    ]
    print(len(wiki_prompts))

    wiki_prompt_and_output = generate_batched(
        hf_model,
        tokenizer,
        wiki_prompts,
        max_new_tokens=8,
        batch_size=64)
    wiki_prompt_to_output = {k: v[len(k):] for k, v in wiki_prompt_and_output}
    json.dump(wiki_prompt_to_output, open(os.path.join(DATA_DIR, model_name, filename_prompt_to_output), "w"), ensure_ascii=False)
else:
    wiki_prompt_to_output = json.load(open(os.path.join(DATA_DIR, model_name, filename_prompt_to_output)))

In [15]:
# Set of prompts with all possible outputs

ALL_PROMPT_TO_OUTPUT = {**prompt_to_output, **wiki_prompt_to_output}

print(len(ALL_PROMPT_TO_OUTPUT))

977810


In [16]:
from dataclasses import dataclass

import datasets
from datasets import Dataset
# from utils.intervention_utils import extract_label
from utils.generate_ravel_instance import RAVELMetadata

# def extract_label(string):
#     delimiters = r"[ \-\t,.\n]"
#     return re.split(delimiters, string.strip())[0]

import re

def extract_label(text):
    tokens = re.split(r'(["]|[.,;]\s|\n| \(|\sand)', text + ' ')
    x = tokens[0]
    digit_match = re.search(r'\.\d\d', x)
    if digit_match:
        x = x[:digit_match.span(0)[1]]
    gender_match = re.match(r'\s?(his|her|himself|herself|she|he)[^\w]', x)
    if gender_match:
        x = x[:gender_match.span(1)[1]]
    if not x.strip():
        x = ' '.join(text.split(' ')[:2]).rstrip('.,"\n')
    assert x.strip()
    return x


def get_first_token(x):
  return re.split(r'[^\w\+\-]', x.strip(), re.UNICODE)[0]


def filter_inv_example(base_output, inv_output):
  different_outputs = (get_first_token(base_output) !=
                       get_first_token(inv_output))
  valid_outputs = (
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(base_output), re.IGNORECASE) and
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(inv_output), re.IGNORECASE))
  return valid_outputs and different_outputs


FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


ravel_metadata = RAVELMetadata(
    model_name,
    KEPT_ENTITY_SPLITS,
    KEPT_ATTR_TO_PROMPT_AND_SPLIT,
    KEPT_PROMPT_SPLITS,
    WIKI_PROMPT_SPLITS,
    ALL_PROMPT_TO_OUTPUT)

In [17]:
extract_label(' Thailand.\nThe distance from P')

' Thailand'

In [18]:
#@title Generate the Conetxt TEST/VAL Split

# Context Split: All entities are in TRAIN, but all prompts are in test/dev

import random

from utils.generate_ravel_instance import gen_context_test_split

TEST_TYPE = 'context'

# Take the first N examples only
first_n = 256

eval_split_to_raw_example = gen_context_test_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)
eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Compute stats.
for split in eval_split_to_raw_example:
  print('\nSplit %s:\nTotal %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))
      ))
  #for i, example in enumerate(eval_split_to_raw_example[split]):
  #  print(example)
  #  #print(tokenizer(example['input']).input_ids)
  #  break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country city: %s, country: val 200 182 103
Country city: %s, country: val 200 179 104
Country city: %s, country: val 200 178 116
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 200 183 111
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 200 172 101
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 200 185 113
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 200 173 105
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 200 180 104
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 200 186 113
Country [{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": " test 200 182 106
Country [{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": " test 200 188 109
Country [{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": " test 200 181 111
Coun

In [19]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [20]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

/share/u/can/ravel/data/gemma-2-2b/gemma-2-2b_city_context_test.json


In [21]:
#@title Generate the Entity TEST/VAL Split

from utils.generate_ravel_instance import gen_entity_test_split

TEST_TYPE = 'entity'

# Take the first N examples only
first_n = 128

eval_split_to_raw_example = gen_entity_test_split(
    ravel_metadata,
    extract_label_fn=extract_label, filter_example_fn=filter_inv_example,
    first_n=first_n)

eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Stats
for split in eval_split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))
      ))
  for i, example in enumerate(eval_split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  for k in ('input', 'source_input'):
    input_ids = tokenizer(example[k])['input_ids']
    #print(k)
    #print(input_ids)
    print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " test 112
Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " test 112
Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " test 112
Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " val 88
Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " val 88
Country [{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "%s", "country": " val 88
Country "lang": "English"}, {"city": "%s", "country": " test 112
Country "lang": "English"}, {"city": "%s", "country": " test 112
Country "lang": "English"}, {"city": "%s", "country": " test 112
Country "lang": "English"}, {"city": "%s", "country": " val 88
Country "lang": "English"}, {"city": "%s", "country": " val 88
Country "lang": "English"}, {"city": "%s", "country": " val 88
Country [{"city": "Toronto", "country": "Canada"}, {

In [22]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [23]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

/share/u/can/ravel/data/gemma-2-2b/gemma-2-2b_city_entity_test.json


In [24]:
#@title Generate train split (for models that use counterfactuals)

import datasets
from datasets import Dataset

def gen_train_split(metadata, extract_label_fn, filter_example_fn, first_n=256):
  split_to_raw_example = {}
  # Group by attributes.
  target_split = 'train'
  for attr, prompt_to_split in metadata.attr_to_prompt.items():
      base_prompt_candiates = [p for p, s in prompt_to_split.items() if s == target_split]
      base_task_inputs = [
          ((prompt, entity), metadata.prompt_to_output[prompt % entity])
          for entity in metadata.get_entities(target_split)
          for prompt in random.sample(
              base_prompt_candiates, k=min(2, len(base_prompt_candiates)))]
      source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, (source_attr, source_split) in KEPT_PROMPT_SPLITS.items()
          if source_split == target_split and source_attr != 'Other'
          for entity in metadata.sample_entities(target_split, k=1)
      ]
      wiki_source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, split_and_arg in metadata.entity_prompt_to_split.items()
          if split_and_arg['split'] == target_split
          for entity in ([split_and_arg['entity']] if split_and_arg['entity']
                         else metadata.sample_entities(target_split, k=1))
      ]
      source_task_inputs = source_task_inputs + wiki_source_task_inputs
      if len(base_task_inputs) < 5 or len(source_task_inputs) < 5:
        continue
      print(attr, target_split, len(base_task_inputs), len(source_task_inputs), len(wiki_source_task_inputs))
      split_to_raw_example[f'{attr}-{target_split}'] = []
      for (p, a), v in base_task_inputs:
        source_input_candiates = [x for x in source_task_inputs if filter_example_fn(v, metadata.prompt_to_output[p % x[0][1]])]
        #print(len(source_input_candiates), v)
        split_to_raw_example[f'{attr}-{target_split}'].extend([{
          'input': p % a, 'label': extract_label_fn(v),
          'source_input': s_p % s_a, 'source_label': extract_label_fn(source_v),
          'inv_label': extract_label_fn(metadata.prompt_to_output[p % s_a]),
          'split': p, 'source_split': s_p,
          'entity': a, 'source_entity': s_a}
        for (s_p, s_a), source_v in random.sample(source_input_candiates, k=min(len(source_input_candiates), round(first_n / len(base_task_inputs))))
        if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and re.search('\w+', source_v)
      ])
  split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}
  return split_to_raw_example


# Take the first N examples only
first_n = 10240

split_to_raw_example = gen_train_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)

# Stats
for split in split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values' % (
      repr(split), len(split_to_raw_example[split]), len(split_to_raw_example[split]),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in split_to_raw_example[split][:first_n]])),
      # len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in split_to_raw_example[split][:first_n]]))
      ))
  for i, example in enumerate(split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('train',):
  print(f'Split {split}: Total #subsplit={len([k for k in split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith(split)]))}')

Country train 400 351 311


  if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and re.search('\w+', source_v)


Continent train 400 351 311
Latitude train 400 351 311
Longitude train 400 351 311
Language train 400 351 311
Timezone train 400 351 311
Split 'Country-train': Total 10114 examples, kept first 10114 examples, 703 unique input values,  299 unique entities, 68 unique output values
{'input': '[{"city": "Cape Town", "country": "South Africa"}, {"city": "Cabinda", "country": "', 'label': 'Angola', 'source_input': "On the night of 25–26 June, Kaunas pogrom led by Klimaitis' unit was instigated by Franz", 'source_label': 'iska', 'inv_label': 'Lithuania', 'split': '[{"city": "Cape Town", "country": "South Africa"}, {"city": "%s", "country": "', 'source_split': "On the night of 25–26 June, %s pogrom led by Klimaitis' unit was instigated by Franz", 'entity': 'Cabinda', 'source_entity': 'Kaunas'}
Split 'Continent-train': Total 9620 examples, kept first 9620 examples, 677 unique input values,  289 unique entities, 6 unique output values
{'input': '[{"city": "Beijing", "continent": "Asia"}, {"city"

In [25]:
json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_train.json')
print(json_path)
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

/share/u/can/ravel/data/gemma-2-2b/gemma-2-2b_city_train.json


In [26]:
#@title Postprocess labels

import json
import re

# from intervention_utils import extract_label


entity_type = 'city'
instance =  model_name
version = ''


attribute_to_prompts = json.load(open(os.path.join(DATA_DIR + version, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))


json_path = os.path.join(DATA_DIR + version, f'{instance}/{instance}_{entity_type}_context_test.json')
split_to_raw_example = json.load(open(json_path, 'r'))
print(len(split_to_raw_example))

all_labels = set()
for split in split_to_raw_example:
  for i in range(len(split_to_raw_example[split])):
    if split.split('-')[0] in ['Latitude', 'Longitude'] or  split.split('-')[0] in attribute_to_prompts['Latitude'] or split.split('-')[0] in attribute_to_prompts['Longitude']:
      # Keep only the integer part.
      split_to_raw_example[split][i]['inv_label'] = split_to_raw_example[split][i]['inv_label'].replace('°', '.').split('.')[0]
      split_to_raw_example[split][i]['label'] = split_to_raw_example[split][i]['label'].replace('°', '.').split('.')[0]
    all_labels.add(split_to_raw_example[split][i]['inv_label'])

32


In [27]:
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

In [28]:
sorted(all_labels)

[' Albania',
 ' Albanian',
 ' Angola',
 ' Arabic',
 ' Bangladesh',
 ' Bengali',
 ' Brazil',
 ' Bulgaria',
 ' Bulgarian',
 ' Burmese',
 ' Canada',
 ' Chile',
 ' China',
 ' Chinese',
 ' Colombia',
 ' Croatia',
 ' Croatian',
 ' Danish',
 ' Denmark',
 ' Egypt',
 ' English',
 ' Estonia',
 ' Estonian',
 ' Finland',
 ' Finnish',
 ' France',
 ' French',
 ' Gabon',
 ' German',
 ' Germany',
 ' Greece',
 ' Greek',
 ' Hausa',
 ' Hindi',
 ' Hungarian',
 ' Hungary',
 ' India',
 ' Indonesia',
 ' Indonesian',
 ' Iran',
 ' Iraq',
 ' Italian',
 ' Italy',
 ' Japan',
 ' Japanese',
 ' Javanese',
 ' Korean',
 ' Latvia',
 ' Latvian',
 ' Lithuania',
 ' Lithuanian',
 ' Mexico',
 ' Montenegrin',
 ' Montenegro',
 ' Myanmar',
 ' Nepal',
 ' Nepali',
 ' Nigeria',
 ' Norway',
 ' Norwegian',
 ' Persian',
 ' Poland',
 ' Polish',
 ' Portugal',
 ' Portuguese',
 ' Romania',
 ' Romanian',
 ' Russia',
 ' Russian',
 ' Serbia',
 ' Serbian',
 ' Sinhala',
 ' Spanish',
 ' Sweden',
 ' Swedish',
 ' Switzerland',
 ' Thai',
 ' Thai

In [29]:
#@title Intervention locations for all possible prompts

SPLIT_TO_INV_POSITION = {}

all_prompt_templates = {p for p in WIKI_PROMPT_SPLITS}
all_prompt_templates.update({v for vs in ALL_ATTR_TO_PROMPTS.values() for v in vs})
print(len(all_prompt_templates))

for prompt_template in all_prompt_templates:
  if prompt_template.count('%s') != 1:
    continue
  print(prompt_template)
  prompt_input = prompt_template.replace('%s', '000000', 1)
  input_ids = tokenizer(prompt_input)['input_ids']
  toks = tokenizer.batch_decode(input_ids)
  for i in range(-1, -len(toks), -1):
    if toks[i] == '0' and toks[i - 1] == '0' and toks[i - 2] == '0' and toks[i - 3] == '0':
      break
  SPLIT_TO_INV_POSITION[prompt_template] = i
  print(i, list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], toks)))

print(min(SPLIT_TO_INV_POSITION.values()))

1211
%s Airport (Bengali: সৈয়দপুর বিমানবন্দর Saiẏadapur bimānabandar)  is a
-29 [(-3, '<bos>'), (-2, '0'), (-1, '0'), (0, '0'), (1, '0'), (2, '0'), (3, '0'), (4, ' Airport'), (5, ' ('), (6, 'Beng'), (7, 'ali'), (8, ':'), (9, ' স'), (10, 'ৈ'), (11, 'য়'), (12, 'দ'), (13, 'প'), (14, 'ুর'), (15, ' বি'), (16, 'মান'), (17, 'ব'), (18, 'ন্দ'), (19, 'র'), (20, ' Sai'), (21, 'ẏ'), (22, 'adap'), (23, 'ur'), (24, ' bim'), (25, 'āna'), (26, 'band'), (27, 'ar'), (28, ')'), (29, '  '), (30, 'is'), (31, ' a')]
Volodymyr %s (born 21 August 1978) is a Ukrainian footballer
-17 [(5, '<bos>'), (6, 'Vo'), (7, 'lo'), (8, 'dymyr'), (9, ' '), (10, '0'), (11, '0'), (12, '0'), (13, '0'), (14, '0'), (15, '0'), (16, ' ('), (17, 'born'), (18, ' '), (19, '2'), (20, '1'), (21, ' August'), (22, ' '), (23, '1'), (24, '9'), (25, '7'), (26, '8'), (27, ')'), (28, ' is'), (29, ' a'), (30, ' Ukrainian'), (31, ' footballer')]
A hard running backrow player %s made his provincial debut in a match against a
-10 [(9, '<bos>'),

In [30]:
version = ''
json.dump(SPLIT_TO_INV_POSITION,
          open(os.path.join(DATA_DIR + version, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json'), 'w'),
          ensure_ascii=False, indent=2)

### Extract Features

In [31]:
layer_idx

10

In [32]:
def get_resid_post_activations(nnsight_model, layer_idx, encoded_input):
    submodule = nnsight_model.model.layers[layer_idx]
    with torch.no_grad(), nnsight_model.trace(
        encoded_input.input_ids.to(device),
        attention_mask=encoded_input.attention_mask.to(device), 
        **nnsight_tracer_kwargs):
        output = submodule.output[0].save()
    return output

# # Test the function
# prompt = ["Hello, my name is", "hello"]
# tok = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to(device)

# out = get_resid_post_activations(
#     nnsight_model,
#     layer_idx=0,
#     encoded_input=tok['input_ids'],
# )
# out.shape

In [33]:
import h5py
import json
import re
import pickle as pkl

# from extract_neuron_activations import get_representations_across_layers_llama


def extract_ravel_entity_features(entity_to_split, attribute_to_prompt_and_split,
                                  layer, output_path, batch_size=128, placeholder='%s'):
  print(output_path)
  f_out = h5py.File(output_path, "a")
  # Generate prompts.
  splits = {'train': ('train', 'train'),
                 'val_entity': ('val', 'train'),
                 'val_context': ('train', 'val'),}
  for split_name, (entity_split, prompt_split) in splits.items():
    for attr, prompt_to_split in attribute_to_prompt_and_split.items():
      inputs, entities, templates = zip(*[(p[:p.index(placeholder)] + e, e, p)
          for p in prompt_to_split if prompt_to_split[p] == prompt_split
          for e in entity_to_split if entity_to_split[e] == entity_split])
      all_features = []
      for b_i in range(0, len(inputs), batch_size):
        input_batch = inputs[b_i:b_i+batch_size]
        encoded_input = tokenizer(
            input_batch, padding="max_length", max_length=INPUT_MAX_LEN,
            return_tensors="pt", truncation=True)
        with torch.no_grad():
          # outputs = get_representations_across_layers_llama(
          #     model.model, encoded_input, layer_index=layer)[f'layer_{layer}-block_output']
          outputs = get_resid_post_activations(nnsight_model, layer, encoded_input)
          for i in range(len(input_batch)):
            all_features.append(outputs[i:i+1, -1, :].to(torch.float16).cpu().numpy())
      print(attr, split_name, np.concatenate(all_features).shape)
      f_out[f'{attr}-{split_name}'] = np.concatenate(all_features)
      f_out[f'{attr}-{split_name}' + '_input'] = np.void(pkl.dumps(inputs))
      f_out[f'{attr}-{split_name}' + '_template'] = np.void(pkl.dumps(templates))
      f_out[f'{attr}-{split_name}' + '_entity'] = np.void(pkl.dumps(entities))
  f_out.flush()
  f_out.close()


INPUT_MAX_LEN = 48


for layer in [10, 14]:
  output_path = os.path.join(DATA_DIR, model_name, f'ravel_{entity_type}_{model_name}_layer{layer}_representation.hdf5')
  extract_ravel_entity_features(
      KEPT_ENTITY_SPLITS, KEPT_ATTR_TO_PROMPT_AND_SPLIT,
      layer, output_path, batch_size=64)

/share/u/can/ravel/data/gemma-2-2b/ravel_city_gemma-2-2b_layer10_representation.hdf5
Country train (1400, 2304)


OSError: Unable to create link (name already exists)