<a href="https://colab.research.google.com/github/canrager/ravel/blob/main/%5Bgithub%5D_demo_create_ravel_instance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
RAVEL_LIB_DIR = f'/content/ravel/src'

import sys
sys.path.append(RAVEL_LIB_DIR)

In [3]:
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)

In [4]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
%%bash

pip install accelerate

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.me

In [6]:
%%bash

git clone https://github.com/explanare/ravel.git

Cloning into 'ravel'...


In [7]:
!mkdir models
!mkdir data

MODEL_DIR = '/content/models'
DATA_DIR = '/content/data'

# Model

In [None]:
from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

model_id = "google/gemma-2-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=MODEL_DIR)
model = LlamaForCausalLM.from_pretrained(
    model_id, low_cpu_mem_usage=True, device_map='auto', cache_dir=MODEL_DIR,
    torch_dtype=torch.bfloat16)
model = model.eval()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

# Create a RAVEL Instance for TinyLLaMA

In [None]:
%%bash

mkdir data/base
tar -xvf /content/ravel/data.tgz -C data/base --strip-components=1


In [None]:
import json
import os

entity_type = 'city'

attribute_prompts = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))
prompt_splits = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_prompt_to_split.json')))
entity_attributes = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_entity_attributes.json')))
print(f'#entities={len(entity_attributes)}, #prompt_templates={sum(map(len, attribute_prompts.values()))}')

prompts_to_meta_data = {t % x: {'entity': x, 'attr': a, 'template': t}
               for x in entity_attributes
               for a, ts in attribute_prompts.items()
               for t in ts}
print(len(prompts_to_meta_data))

#entities=3552, #prompt_templates=273
969696


In [None]:
# Can skip the inference step by downloading the pre-computed outputs:
# https://drive.google.com/drive/u/0/folders/1U4Js-NarJa-B_iQc5wr0OXV2G-5BDBsN

from utils.generation_utils import generate_batched

prompt_max_length = 48

prompt_to_output = generate_batched(
    model,
    tokenizer,
    list(prompts_to_meta_data),
    prompt_max_length+8,
    prompt_max_length=prompt_max_length,
    batch_size=64)
prompt_to_output = {k: v[len(k):] for k, v in prompt_to_output}

In [None]:
#json.dump(prompt_to_output, open(os.path.join('ravel_tinyllama_city_prompt_to_output.json'), 'w'), ensure_ascii=False)

In [None]:
prompt_to_output = json.load(open(os.path.join('ravel_tinyllama_city_prompt_to_output.json')))
len(prompt_to_output)

969696

In [None]:
#@title Behavioral Test

import collections
import re
import numpy as np


from zoneinfo import ZoneInfo
import datetime

def timezone_name_to_utc_offset(name):
  offset =  ZoneInfo(name).utcoffset(datetime.datetime.now()).seconds
  sign = '+'
  if offset // 3600 >= 12:
    offset = 24 * 3600 - offset
    sign = '-'
  fmt_offset = str(datetime.timedelta(seconds=offset)).rsplit(':', 1)[0]
  if fmt_offset.startswith('0') and offset >= 1800:
    fmt_offset = fmt_offset[1:]
  return f'{sign}{fmt_offset}'


# TODO: Pull out attribute checking to functions

sorted_entity = sorted(set([v['entity'] for v in prompts_to_meta_data.values()]))
sorted_template = sorted(set([v['template'] for v in prompts_to_meta_data.values()]))
stats = np.zeros([len(sorted_entity), len(sorted_template)])
for p, out in prompt_to_output.items():
  attr = prompts_to_meta_data[p]['attr']
  entity = prompts_to_meta_data[p]['entity']
  label = entity_attributes[entity][attr]
  if not label:
    continue
  norm_label = label.lower()
  norm_out = out.split('"')[0].strip(' "').replace('\\/', '/').lower()
  if len(norm_label) < len(norm_out):
    correct = int(norm_out.startswith(norm_label))
  else:
    correct = int(norm_label.startswith(norm_out))

  # Exceptions
  if re.search('coord|"lat"|"long"|latitude|coordinates|longitude', p):
    try:
      correct = int((float(norm_label.strip('-−')) - float(re.findall(r'\d+', norm_out)[0])) <= 2)
    except:
      correct = 0
  if re.search('United States|United Kingdom', label):
    norm_label = label.strip().replace('the ', '')
    norm_out = out[len(p):].strip().replace('the ', '')
    correct = int(norm_out.startswith(norm_label) or norm_out.startswith('England'))
  if re.search('South Korea', label):
    correct = int(norm_out.startswith('korea') or norm_out.startswith('south korea'))
  if re.search('North America', label):
    correct = norm_label in norm_out or norm_out == 'na' or norm_out.startswith('america')
  if re.search('Mandarin', label):
    correct = norm_out in norm_label or norm_out == 'chinese'
  if re.search('language', p) and ',' in norm_label:
    correct = any(lang in norm_out for lang in norm_label.split(','))
  if re.search('UTC', p) and '/' in norm_label:
    norm_label = timezone_name_to_utc_offset(label)
    correct = norm_out.startswith(norm_label.split(':')[0])
    if not correct and re.search(r'[+\-]0\d', norm_out):
      correct = norm_out.replace('0', '', 1).startswith(norm_label.split(':')[0])
    # Summer daylight saving time
    if not correct and (
        re.search(r'\-[5-8]', norm_label) and label.startswith('America') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Europe') or
        re.search(r'\+[0-3]', norm_label) and label.startswith('Africa')):
      #print('SUMMER TIME:', norm_label, norm_out)
      out_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_out)
      label_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_label)
      if out_offset_match and label_offset_match:
        norm_out_offset = int(out_offset_match.group(1))
        norm_label_offset = int(label_offset_match.group(1))
        correct = (norm_out_offset <= norm_label_offset + 1 and
                   norm_out_offset >= norm_label_offset - 1)
    if not correct and re.search(r'[+\-](\d+)', norm_out) and int(
        re.search(r'[+\-](\d+)', norm_out).group(1)) > 11:
      offset = 24 - int(re.search(r'[+\-](\d+)', norm_out).group(1))
      correct = str(offset) in norm_label
  stats[sorted_entity.index(prompts_to_meta_data[p]['entity']), sorted_template.index(prompts_to_meta_data[p]['template'])] += int(correct)

print('-----------------------------------')
for i in np.argsort(stats.sum(axis=0))[::-1]:
  print(sorted_template[i], int(stats[:, i].sum()), len(stats[:, i]))
for i in np.argsort(stats.sum(axis=-1))[::-1]:
  print(sorted_entity[i], int(stats[i].sum()), len(stats[i]))


kept_entity_index = np.argsort(stats.sum(axis=1))[-400:]
KEPT_ENTITY = [sorted_entity[i] for i in kept_entity_index]
topk_template_index = set(np.argsort(stats.sum(axis=0))[-200:])
kept_template_index = []
# A dict of all kept attribute to prompts.
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {}
for attr in attribute_prompts:
  # Kept the top 4 to 12 templates per attribute.
  attr_indices = [sorted_template.index(t) for t in attribute_prompts[attr]]
  per_attr_kept_template_index = sorted(attr_indices, key=lambda i: stats[:, i].sum())[-12:][::-1]
  per_attr_kept_template_index = [x for i, x in enumerate(per_attr_kept_template_index)
                                  if x in topk_template_index or i < 4]
  kept_template_index.extend(per_attr_kept_template_index)
  KEPT_ATTR_TO_PROMPT_AND_SPLIT[attr] = {sorted_template[i]: prompt_splits[sorted_template[i]]
                               for i in per_attr_kept_template_index}
print('Kept %d entity, %d prompt template' % (len(kept_entity_index), len(kept_template_index)))

print('Average accuracy: %.2f%%' % (100 *  (stats[:, kept_template_index][kept_entity_index, :]).sum()/ (len(kept_entity_index) * len(kept_template_index))))

In [None]:
# Kept templates

attribute_prompts
for i in kept_template_index:
  print(f'{prompt_splits[sorted_template[i]]}\t{sorted_template[i]}\t{stats[:, i][kept_entity_index].mean():.2f}')

train	city: %s, country:	1.00
val	[{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": "	1.00
train	"lang": "English"}, {"city": "%s", "country": "	1.00
train	[{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": "	1.00
train	[{"city": "Beijing", "country": "China"}, {"city": "%s", "country": "	1.00
train	city to country: St. Petersburg is in Russia. %s is in	0.98
test	city to country: Sydney is in Australia. %s is in	0.99
train	[{"city": "%s", "country": "	0.99
val	[{"city": "Sydney", "country": "Australia"}, {"city": "%s", "country": "	1.00
train	[{"city": "Paris", "country": "France"}, {"city": "%s", "country": "	1.00
test	city to country: Rome is in Italy. %s is in	0.99
test	[{"city": "Bangkok", "country": "Thailand"}, {"city": "%s", "country": "	1.00
test	 "language": "English"}, {"city": "%s", "continent": "	0.98
test	[{"city": "%s", "continent": "	0.99
train	[{"city": "Toronto", "continent": "North America"}, {"city": "%s", "continent": "	0

In [None]:
print(sum(map(len, KEPT_ATTR_TO_PROMPT_AND_SPLIT.values())))
for attr, prompt_to_split in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items():
  print(attr, collections.Counter(prompt_to_split.values()))

72
Country Counter({'train': 6, 'test': 4, 'val': 2})
Continent Counter({'train': 7, 'test': 3, 'val': 2})
Latitude Counter({'train': 6, 'test': 5, 'val': 1})
Longitude Counter({'train': 6, 'test': 4, 'val': 2})
Language Counter({'test': 5, 'train': 5, 'val': 2})
Timezone Counter({'train': 5, 'test': 4, 'val': 3})


### Create an Instance

In [None]:
import json

ENTITY_TYPE = 'city'
print(ENTITY_TYPE)
ALL_ENTITY_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_entity_to_split.json')))
ALL_ATTR_TO_PROMPTS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{ENTITY_TYPE}_attribute_to_prompts.json')))
WIKI_PROMPT_SPLITS = json.load(open(os.path.join(DATA_DIR, 'base', f'wikipedia_{ENTITY_TYPE}_entity_prompts.json')))

# Filtered
KEPT_ENTITY_SPLITS = {e: ALL_ENTITY_SPLITS[e] for e in KEPT_ENTITY}
KEPT_PROMPT_SPLITS = {k: (a, v) for a, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items() for k, v in d.items() if k.count('%') == 1}
for prompt in WIKI_PROMPT_SPLITS:
  KEPT_PROMPT_SPLITS[prompt] = ('Other', WIKI_PROMPT_SPLITS[prompt]['split'])
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {k: {p: v for p, v in d.items() if p.count('%') == 1} for k, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items()}
print(f'Total #entities={len(ALL_ENTITY_SPLITS)} #attributes={len(KEPT_ATTR_TO_PROMPT_AND_SPLIT)} #prompts={sum(map(len, ALL_ATTR_TO_PROMPTS.values()))} #wiki_prompts={len(WIKI_PROMPT_SPLITS)}')
print(f'Kept #entities={len(KEPT_ENTITY_SPLITS)} #prompts={len(KEPT_PROMPT_SPLITS)}')
for split in ('train', 'val', 'test'):
  print(split, f'Kept #entities={len([k for k, v in KEPT_ENTITY_SPLITS.items() if v == split])}',
               f'#prompts={len([k for k, v in KEPT_PROMPT_SPLITS.items() if v[1] == split])}')

city
Total #entities=3552 #attributes=6 #prompts=273 #wiki_prompts=938
Kept #entities=400 #prompts=1010
train Kept #entities=206 #prompts=346
val Kept #entities=81 #prompts=336
test Kept #entities=113 #prompts=328


In [None]:
from utils.generation_utils import generate_batched

wiki_prompts = [(t % e) for t, s_e in WIKI_PROMPT_SPLITS.items()
                 for e in ([s_e['entity']] if s_e['entity']
                           else [a for a in KEPT_ENTITY_SPLITS if KEPT_ENTITY_SPLITS[a] == 'train' or s_e['split'] == 'train'])
                 ]
print(len(wiki_prompts))

wiki_prompt_and_output = generate_batched(
    model,
    tokenizer,
    wiki_prompts,
    max_new_tokens=8,
    batch_size=64)
wiki_prompt_to_output = {k: v[len(k):] for k, v in wiki_prompt_and_output}

8186
Total 8186
Set prompt_max_length=64


100%|█████████████████████████████████████████████████████████████████████████████████| 128/128 [00:31<00:00,  4.04it/s]


In [None]:
ALL_PROMPT_TO_OUTPUT = {**prompt_to_output, **wiki_prompt_to_output}

print(len(ALL_PROMPT_TO_OUTPUT))

977882


In [None]:
from dataclasses import dataclass

import datasets
from datasets import Dataset
from intervention_utils import extract_label
from utils.generate_ravel_instance import RAVELMetadata

def get_first_token(x):
  return re.split(r'[^\w\+\-]', x.strip(), re.UNICODE)[0]


def filter_inv_example(base_output, inv_output):
  different_outputs = (get_first_token(base_output) !=
                       get_first_token(inv_output))
  valid_outputs = (
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(base_output), re.IGNORECASE) and
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(inv_output), re.IGNORECASE))
  return valid_outputs and different_outputs


FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


ravel_metadata = RAVELMetadata(
    'tinyllama',
    KEPT_ENTITY_SPLITS,
    KEPT_ATTR_TO_PROMPT_AND_SPLIT,
    KEPT_PROMPT_SPLITS,
    WIKI_PROMPT_SPLITS,
    ALL_PROMPT_TO_OUTPUT)

In [None]:
#@title Generate the Conetxt TEST/VAL Split

# Context Split: All entities are in TRAIN, but all prompts are in test/dev

import random

from utils.generate_ravel_instance import gen_context_test_split

TEST_TYPE = 'context'

# Take the first N examples only
first_n = 256

eval_split_to_raw_example = gen_context_test_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)
eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Compute stats.
for split in eval_split_to_raw_example:
  print('\nSplit %s:\nTotal %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))))
  #for i, example in enumerate(eval_split_to_raw_example[split]):
  #  print(example)
  #  #print(tokenizer(example['input']).input_ids)
  #  break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country city: %s, country: val 206 191 96
Country city: %s, country: val 206 185 119
Country city: %s, country: val 206 187 118
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 206 186 111
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 206 192 104
Country [{"city": "Hong Kong", "country": "China"}, {"city": "%s", "country": " test 206 192 107
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 206 196 111
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 206 193 110
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 206 190 117
Country city to country: Sydney is in Australia. %s is in test 206 184 105
Country city to country: Sydney is in Australia. %s is in test 206 183 113
Country city to country: Sydney is in Australia. %s is in test 206 179 111
Country [{"city": "Sydney", "country": "Australia"}, {"city": "%s", "count

In [None]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [None]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Generate the Entity TEST/VAL Split

from utils.generate_ravel_instance import gen_entity_test_split

TEST_TYPE = 'entity'

# Take the first N examples only
first_n = 128

eval_split_to_raw_example = gen_entity_test_split(
    ravel_metadata,
    extract_label_fn=extract_label, filter_example_fn=filter_inv_example,
    first_n=first_n)

eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Stats
for split in eval_split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in eval_split_to_raw_example[split][:first_n]]))))
  for i, example in enumerate(eval_split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  for k in ('input', 'source_input'):
    input_ids = tokenizer(example[k])['input_ids']
    #print(k)
    #print(input_ids)
    print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " test 113
Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " test 113
Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " test 113
Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " val 81
Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " val 81
Country [{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": " val 81
Country "lang": "English"}, {"city": "%s", "country": " test 113
Country "lang": "English"}, {"city": "%s", "country": " test 113
Country "lang": "English"}, {"city": "%s", "country": " test 113
Country "lang": "English"}, {"city": "%s", "country": " val 81
Country "lang": "English"}, {"city": "%s", "country": " val 81
Country "lang": "English"}, {"city": "%s", "country": " val 81
Country city to country: St. Petersburg is in Russia

In [None]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [None]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Generate train split (for models that use counterfactuals)

import datasets
from datasets import Dataset

def gen_train_split(metadata, extract_label_fn, filter_example_fn, first_n=256):
  split_to_raw_example = {}
  # Group by attributes.
  target_split = 'train'
  for attr, prompt_to_split in metadata.attr_to_prompt.items():
      base_prompt_candiates = [p for p, s in prompt_to_split.items() if s == target_split]
      base_task_inputs = [
          ((prompt, entity), metadata.prompt_to_output[prompt % entity])
          for entity in metadata.get_entities(target_split)
          for prompt in random.sample(
              base_prompt_candiates, k=min(2, len(base_prompt_candiates)))]
      source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, (source_attr, source_split) in KEPT_PROMPT_SPLITS.items()
          if source_split == target_split and source_attr != 'Other'
          for entity in metadata.sample_entities(target_split, k=1)
      ]
      wiki_source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, split_and_arg in metadata.entity_prompt_to_split.items()
          if split_and_arg['split'] == target_split
          for entity in ([split_and_arg['entity']] if split_and_arg['entity']
                         else metadata.sample_entities(target_split, k=1))
      ]
      source_task_inputs = source_task_inputs + wiki_source_task_inputs
      if len(base_task_inputs) < 5 or len(source_task_inputs) < 5:
        continue
      print(attr, target_split, len(base_task_inputs), len(source_task_inputs), len(wiki_source_task_inputs))
      split_to_raw_example[f'{attr}-{target_split}'] = []
      for (p, a), v in base_task_inputs:
        source_input_candiates = [x for x in source_task_inputs if filter_example_fn(v, metadata.prompt_to_output[p % x[0][1]])]
        #print(len(source_input_candiates), v)
        split_to_raw_example[f'{attr}-{target_split}'].extend([{
          'input': p % a, 'label': extract_label_fn(v),
          'source_input': s_p % s_a, 'source_label': extract_label_fn(source_v),
          'inv_label': extract_label_fn(metadata.prompt_to_output[p % s_a]),
          'split': p, 'source_split': s_p,
          'entity': a, 'source_entity': s_a}
        for (s_p, s_a), source_v in random.sample(source_input_candiates, k=min(len(source_input_candiates), round(first_n / len(base_task_inputs))))
        if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and re.search('\w+', source_v)
      ])
  split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}
  return split_to_raw_example


# Take the first N examples only
first_n = 10240

split_to_raw_example = gen_train_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)

# Stats
for split in split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
      repr(split), len(split_to_raw_example[split]), len(split_to_raw_example[split]),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in split_to_raw_example[split][:first_n]])),
      len(set([tokenizer('0' + exp['inv_label']).input_ids[3] for exp in split_to_raw_example[split][:first_n]]))))
  for i, example in enumerate(split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('train',):
  print(f'Split {split}: Total #subsplit={len([k for k in split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith(split)]))}')

Country train 412 346 311
Continent train 412 346 311
Latitude train 412 346 311
Longitude train 412 346 311
Language train 412 346 311
Timezone train 412 346 311
Split 'Country-train': Total 9870 examples, kept first 9870 examples, 726 unique input values,  325 unique entities, 160 unique output values, 130 unique output tokens
{'input': '[{"city": "St. Petersburg", "country": "Russia"}, {"city": "Tottori", "country": "', 'label': 'Japan', 'source_input': 'Alvin Law (born 1960 in Yorkton, Saskatchewan) is a motivational speaker and former radio broadcaster', 'source_label': '. He', 'inv_label': 'Canada', 'split': '[{"city": "St. Petersburg", "country": "Russia"}, {"city": "%s", "country": "', 'source_split': 'Alvin Law (born 1960 in %s, Saskatchewan) is a motivational speaker and former radio broadcaster', 'entity': 'Tottori', 'source_entity': 'Yorkton'}
Split 'Continent-train': Total 9721 examples, kept first 9721 examples, 700 unique input values,  308 unique entities, 12 unique out

In [None]:
json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_train.json')
print(json_path)
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

In [None]:
#@title Postprocess labels

import json
import re

from intervention_utils import extract_label


entity_type = 'city'
instance =  'tinyllama'
version = ''


attribute_to_prompts = json.load(open(os.path.join(DATA_DIR + version, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))


json_path = os.path.join(DATA_DIR + version, f'{instance}/{instance}_{entity_type}_context_test.json')
split_to_raw_example = json.load(open(json_path, 'r'))
print(len(split_to_raw_example))

all_labels = set()
for split in split_to_raw_example:
  for i in range(len(split_to_raw_example[split])):
    if split.split('-')[0] in ['Latitude', 'Longitude'] or  split.split('-')[0] in attribute_to_prompts['Latitude'] or split.split('-')[0] in attribute_to_prompts['Longitude']:
      # Keep only the integer part.
      split_to_raw_example[split][i]['inv_label'] = split_to_raw_example[split][i]['inv_label'].replace('°', '.').split('.')[0]
      split_to_raw_example[split][i]['label'] = split_to_raw_example[split][i]['label'].replace('°', '.').split('.')[0]
    all_labels.add(split_to_raw_example[split][i]['inv_label'])

37


In [None]:
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

In [None]:
sorted(all_labels)

[' +00:00',
 ' +01:00',
 ' +02:00',
 ' +03:00',
 ' +05:30',
 ' +1',
 ' +10:00',
 ' +1:00',
 ' +2:00',
 ' +3:00',
 ' +5:30',
 ' +6:00',
 ' +7:00',
 ' +8:00',
 ' +9:00',
 ' -12:00',
 ' -4:00',
 ' -5',
 ' -5:00',
 ' -6:00',
 ' 1',
 ' 1:00',
 ' Africa',
 ' Afrikaans',
 ' Albania',
 ' Albanian',
 ' Algeria',
 ' Angola',
 ' Arabic',
 ' Asia',
 ' Australia',
 ' BST',
 ' Bangla',
 ' Bangladesh',
 ' Belarus',
 ' Belarusian',
 ' Bengali',
 ' Brazil',
 ' Bulgaria',
 ' Bulgarian',
 ' CEST',
 ' Canada',
 ' Cantonese',
 ' Chile',
 ' China',
 ' Chinese',
 ' Colombia',
 ' Croatia',
 ' Croatian',
 ' DUS',
 ' Danish',
 ' Denmark',
 ' Dutch',
 ' EEST',
 ' Egypt',
 ' English',
 ' Estonia',
 ' Estonian',
 ' Europe',
 ' Finland',
 ' Finnish',
 ' France',
 ' French',
 ' Gabon',
 ' German',
 ' Germany',
 ' Ghana',
 ' Greece',
 ' Greek',
 ' Guinea',
 ' Gujarati',
 ' HST',
 ' Hindi',
 ' Hungarian',
 ' Hungary',
 ' India',
 ' Indonesia',
 ' Indonesian',
 ' Iran',
 ' Italian',
 ' Italy',
 ' JST',
 ' Japan',
 ' Ja

In [None]:
#@title Intervention locations for all possible prompts

SPLIT_TO_INV_POSITION = {}

all_prompt_templates = {p for p in WIKI_PROMPT_SPLITS}
all_prompt_templates.update({v for vs in ALL_ATTR_TO_PROMPTS.values() for v in vs})
print(len(all_prompt_templates))

for prompt_template in all_prompt_templates:
  if prompt_template.count('%s') != 1:
    continue
  print(prompt_template)
  prompt_input = prompt_template.replace('%s', '000000', 1)
  input_ids = tokenizer(prompt_input)['input_ids']
  toks = tokenizer.batch_decode(input_ids)
  for i in range(-1, -len(toks), -1):
    if toks[i] == '0' and toks[i - 1] == '0' and toks[i - 2] == '0' and toks[i - 3] == '0':
      break
  SPLIT_TO_INV_POSITION[prompt_template] = i
  print(i, list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], toks)))

print(min(SPLIT_TO_INV_POSITION.values()))

1211
Köln Hauptbahnhof or %s Central Station is a railway station in Cologne, Germany
-12 [(8, '<s>'), (9, 'Köln'), (10, 'Haupt'), (11, 'bahn'), (12, 'hof'), (13, 'or'), (14, ''), (15, '0'), (16, '0'), (17, '0'), (18, '0'), (19, '0'), (20, '0'), (21, 'Central'), (22, 'Station'), (23, 'is'), (24, 'a'), (25, 'railway'), (26, 'station'), (27, 'in'), (28, 'C'), (29, 'ologne'), (30, ','), (31, 'Germany')]
convicted for the killing of his father Raymond Cook in %s, Alberta, in June 1959
-12 [(2, '<s>'), (3, 'conv'), (4, 'icted'), (5, 'for'), (6, 'the'), (7, 'killing'), (8, 'of'), (9, 'his'), (10, 'father'), (11, 'Raymond'), (12, 'Cook'), (13, 'in'), (14, ''), (15, '0'), (16, '0'), (17, '0'), (18, '0'), (19, '0'), (20, '0'), (21, ','), (22, 'Al'), (23, 'berta'), (24, ','), (25, 'in'), (26, 'June'), (27, ''), (28, '1'), (29, '9'), (30, '5'), (31, '9')]
In 1996, both %s stations, as well as sister stations CKLC and CFLY
-15 [(2, '<s>'), (3, 'In'), (4, ''), (5, '1'), (6, '9'), (7, '9'), (8, '6')

In [None]:
version = ''
json.dump(SPLIT_TO_INV_POSITION,
          open(os.path.join(DATA_DIR + version, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json'), 'w'),
          ensure_ascii=False, indent=2)

### Extract Features

In [None]:
import h5py
import json
import re
import pickle as pkl

from extract_neuron_activations import get_representations_across_layers_llama


def extract_ravel_entity_features(entity_to_split, attribute_to_prompt_and_split,
                                  layer, output_path, batch_size=128, placeholder='%s'):
  print(output_path)
  f_out = h5py.File(output_path, "a")
  # Generate prompts.
  splits = {'train': ('train', 'train'),
                 'val_entity': ('val', 'train'),
                 'val_context': ('train', 'val'),}
  for split_name, (entity_split, prompt_split) in splits.items():
    for attr, prompt_to_split in attribute_to_prompt_and_split.items():
      inputs, entities, templates = zip(*[(p[:p.index(placeholder)] + e, e, p)
          for p in prompt_to_split if prompt_to_split[p] == prompt_split
          for e in entity_to_split if entity_to_split[e] == entity_split])
      all_features = []
      for b_i in range(0, len(inputs), batch_size):
        input_batch = inputs[b_i:b_i+batch_size]
        encoded_input = tokenizer(
            input_batch, padding="max_length", max_length=INPUT_MAX_LEN,
            return_tensors="pt", truncation=True)
        with torch.no_grad():
          outputs = get_representations_across_layers_llama(
              model.model, encoded_input, layer_index=layer)[f'layer_{layer}-block_output']
          for i in range(len(input_batch)):
            all_features.append(outputs[i:i+1, -1, :].to(torch.float16).cpu().numpy())
      print(attr, split_name, np.concatenate(all_features).shape)
      f_out[f'{attr}-{split_name}'] = np.concatenate(all_features)
      f_out[f'{attr}-{split_name}' + '_input'] = np.void(pkl.dumps(inputs))
      f_out[f'{attr}-{split_name}' + '_template'] = np.void(pkl.dumps(templates))
      f_out[f'{attr}-{split_name}' + '_entity'] = np.void(pkl.dumps(entities))
  f_out.flush()
  f_out.close()


INPUT_MAX_LEN = 48


for layer in [14]:
  output_path = os.path.join(DATA_DIR, f'ravel_city_tinyllama_layer{layer}_representation.hdf5')
  extract_ravel_entity_features(
      KEPT_ENTITY_SPLITS, KEPT_ATTR_TO_PROMPT_AND_SPLIT,
      layer, output_path, batch_size=64)