# Susceptibility Scores
A notebook for initial exploration.

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import os
import sys
import math
import random
from itertools import product
from tqdm import tqdm
import yaml

import pandas as pd
import seaborn as sns
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch
from typing import List
import numpy as np
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import statsmodels.api as sm
from measuring.estimate_probs import (
    estimate_prob_y_given_context_and_entity,
    estimate_prob_x_given_e,
    estimate_prob_next_word_given_x_and_entity,
    estimate_cmi,
    score_model_for_next_word_prob,
    create_position_ids_from_input_ids,
    sharded_score_model,
    estimate_entity_score,
    kl_div,
    difference,
    difference_p_good_only,
    difference_abs_val,
)
from preprocessing.datasets import CountryCapital

### Preamble

In [4]:
##################
### Parameters ###
##################

# Data parameters
SEED = 0
DATASET_NAME = "CountryCapital"
DATASET_KWARGS_IDENTIFIABLE = dict(
    max_contexts=450,
    max_entities=90,
    cap_per_type=True,
    raw_country_capitals_path="data/CountryCapital/real-fake-historical-fictional-famousfictional-country-capital.csv",
)
LOG_DATASETS = True

# Model parameters
# MODEL_ID = "EleutherAI/pythia-70m-deduped"
# LOAD_IN_8BIT = False
MODEL_ID = "EleutherAI/pythia-6.9b-deduped"
LOAD_IN_8BIT = True
BATCH_SZ = 16

# Evaluation switches
COMPUTE_CMI = True
COMPUTE_KL = True
COMPUTE_GOOD_BAD = True
COMPUTE_GOOD_BAD_ABS = True
COMPUTE_GOOD_BAD_P_GOOD_ONLY = True

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
TAGS = ["capitals"]

In [5]:
# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [6]:
# Paths
# Construct dataset and data ids
# dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)
data_id = f"{DATASET_NAME}"
data_id += (
    f"-mc{DATASET_KWARGS_IDENTIFIABLE['max_contexts']}"
    if "max_contexts" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_contexts"] is not None
    else ""
)
data_id += (
    f"-me{DATASET_KWARGS_IDENTIFIABLE['max_entities']}"
    if "max_entities" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_entities"] is not None
    else ""
)
data_id += (
    "-cappertype"
    if "cap_per_type" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["cap_per_type"]
    else ""
)


data_dir = os.path.join("data", DATASET_NAME, data_id, f"{SEED}")
input_dir = os.path.join(data_dir, "inputs")
entities_path = os.path.join(input_dir, "entities.json")
contexts_path = os.path.join(input_dir, "contexts.json")
queries_path = os.path.join(input_dir, "queries.json")
val_data_path = os.path.join(input_dir, "val.csv")
DATASET_KWARGS_IDENTIFIABLE = {
    **DATASET_KWARGS_IDENTIFIABLE,
    **dict(
        entities_path=entities_path,
        contexts_path=contexts_path,
        queries_path=queries_path,
    ),
}

results_dir = os.path.join(data_dir, "results")
val_results_path = os.path.join(results_dir, "val.csv")

# Construct model id
model_id = f"{MODEL_ID}"
model_id += "-8bit" if LOAD_IN_8BIT else ""
model_dir = os.path.join(data_dir, "models", model_id)

print(f"Data dir: {data_dir}")
print(f"Model dir: {model_dir}")

Data dir: data/CountryCapital/CountryCapital-mc450-me90-cappertype/0
Model dir: data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/models/EleutherAI/pythia-6.9b-deduped-8bit


In [7]:
os.makedirs(input_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)

Failed to load entities, contexts, and queries from paths data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/entities.json, data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/contexts.json, and data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/queries.json.
Manually reconstructing dataset and saving to aforementioned paths.


100%|██████████| 2/2 [00:00<00:00, 71.39it/s]


In [8]:
# GPU stuff
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
# wandb stuff
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), "main.ipynb")

params_to_log = {k: v for k, v in locals().items() if k.isupper()}

run = wandb.init(
    project=PROJECT_NAME,
    group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
    mode="online",
)
print(dict(wandb.config))

[34m[1mwandb[0m: Currently logged in as: [33mkdu[0m ([33methz-rycolab[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'SEED': 0, 'DATASET_NAME': 'CountryCapital', 'DATASET_KWARGS_IDENTIFIABLE': {'max_contexts': 450, 'max_entities': 90, 'cap_per_type': True, 'raw_country_capitals_path': 'data/CountryCapital/real-fake-historical-fictional-famousfictional-country-capital.csv', 'entities_path': 'data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/entities.json', 'contexts_path': 'data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/contexts.json', 'queries_path': 'data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs/queries.json'}, 'LOG_DATASETS': True, 'MODEL_ID': 'EleutherAI/pythia-6.9b-deduped', 'LOAD_IN_8BIT': True, 'BATCH_SZ': 16, 'COMPUTE_CMI': True, 'COMPUTE_KL': True, 'COMPUTE_GOOD_BAD': True, 'COMPUTE_GOOD_BAD_ABS': True, 'COMPUTE_GOOD_BAD_P_GOOD_ONLY': True, 'PROJECT_NAME': 'context-vs-bias', 'GROUP_NAME': None, 'TAGS': ['capitals']}


### Load Data

In [10]:
val_df_contexts_per_qe = dataset.get_contexts_per_query_entity_df()

if LOG_DATASETS:
    print(f"Saving datasets to {input_dir}.")
    os.makedirs(input_dir, exist_ok=True)
    val_df_contexts_per_qe.to_csv(val_data_path)

val_df_contexts_per_qe.info()
val_df_contexts_per_qe.head()

Saving datasets to data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 180 entries, 0 to 179
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   q_id        180 non-null    object
 1   query_form  180 non-null    object
 2   entity      180 non-null    object
 3   contexts    180 non-null    object
dtypes: object(4)
memory usage: 5.8+ KB


Unnamed: 0,q_id,query_form,entity,contexts
0,capital_of,Q: What is the capital of {}?\nA:,Zimbabwe,"[The capital of Republic of Acre is Kassel.\n,..."
1,capital_of,Q: What is the capital of {}?\nA:,Paraguay,"[The capital of Republic of Acre is Kassel.\n,..."
2,capital_of,Q: What is the capital of {}?\nA:,Finland,"[The capital of Republic of Acre is Kassel.\n,..."
3,capital_of,Q: What is the capital of {}?\nA:,New Caledonia,"[The capital of Republic of Acre is Kassel.\n,..."
4,capital_of,Q: What is the capital of {}?\nA:,Nagorno-Karabakh Republic,"[The capital of Republic of Acre is Kassel.\n,..."


### Preprocess Data

In [11]:
# Preprocess the data and convert it into inputs for the model (e.g. torch tensors)

In [12]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging datasets to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="dataset")
    artifact.add_dir(local_path=input_dir)
    run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data/CountryCapital/CountryCapital-mc450-me90-cappertype/0/inputs)... Done. 0.0s


Logging datasets to w&b run <wandb.sdk.wandb_run.Run object at 0x2b04ebdaeef0>.


### Score Model

In [13]:
try:
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID, load_in_8bit=LOAD_IN_8BIT, device_map="auto"
    )
except:
    print(f"Failed to load model {MODEL_ID} in 8-bit. Attempting to load normally.")
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID,
        load_in_8bit=False,
    ).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.48s/it]


In [14]:
!nvidia-smi --query-gpu=memory.used --format=csv

memory.used [MiB]
7703 MiB


In [15]:
torch.cuda.empty_cache()
import gc

gc.collect()

6883

In [16]:
# One forward pass
row = val_df_contexts_per_qe.iloc[0]
estimate_cmi(
    row["query_form"],
    entity=row["entity"],
    contexts=row["contexts"][:128],
    model=model,
    tokenizer=tokenizer,
    bs=BATCH_SZ,
)

Using pad_token, but it is not set yet.


Setting model.config.pad_token_id to model.config.eos_token_id


0.10919940431792968

In [17]:
tqdm.pandas()
val_df_contexts_per_qe["susceptibility_score"] = val_df_contexts_per_qe.progress_apply(
    lambda row: estimate_cmi(
        query=row["query_form"],
        entity=row["entity"],
        contexts=row["contexts"],
        model=model,
        tokenizer=tokenizer,
        answer_map=None,
        bs=BATCH_SZ,
    ),
    axis=1,
)
val_df_contexts_per_qe.to_csv(val_results_path)

100%|██████████| 180/180 [18:22<00:00,  6.13s/it]


In [None]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging results to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="results")
    artifact.add_dir(local_path=results_dir)
    run.log_artifact(artifact)

### Evaluate Model

In [18]:
val_df_contexts_per_qe["entity"].value_counts()

entity
Zimbabwe        2
Calisota        2
Saint Marie     2
Arrakis         2
Panem           2
               ..
Serenitaria     2
Ocraita         2
Pelui           2
Baglandia       2
Zhou dynasty    2
Name: count, Length: 90, dtype: int64

In [19]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "The capital of {} is"
].sort_values(by="susceptibility_score")

Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
103,capital_of,The capital of {} is,Republic of China (Taiwan),"[The capital of Republic of Acre is Kassel.\n,...",0.046715
94,capital_of,The capital of {} is,Nagorno-Karabakh Republic,"[The capital of Republic of Acre is Kassel.\n,...",0.051147
107,capital_of,The capital of {} is,Kyrgyzstan,"[The capital of Republic of Acre is Kassel.\n,...",0.081597
101,capital_of,The capital of {} is,Somalia,"[The capital of Republic of Acre is Kassel.\n,...",0.083712
91,capital_of,The capital of {} is,Paraguay,"[The capital of Republic of Acre is Kassel.\n,...",0.083977
...,...,...,...,...,...
139,capital_of,The capital of {} is,Pasti,"[The capital of Republic of Acre is Kassel.\n,...",0.848880
120,capital_of,The capital of {} is,Cadasa,"[The capital of Republic of Acre is Kassel.\n,...",0.851561
148,capital_of,The capital of {} is,Fictional Country,"[The capital of Republic of Acre is Kassel.\n,...",0.914755
136,capital_of,The capital of {} is,Wula,"[The capital of Republic of Acre is Kassel.\n,...",0.918664


In [20]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "Q: What is the capital of {}?\nA:"
].sort_values(by="susceptibility_score")

Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
13,capital_of,Q: What is the capital of {}?\nA:,Republic of China (Taiwan),"[The capital of Republic of Acre is Kassel.\n,...",0.031173
9,capital_of,Q: What is the capital of {}?\nA:,Sweden,"[The capital of Republic of Acre is Kassel.\n,...",0.063413
2,capital_of,Q: What is the capital of {}?\nA:,Finland,"[The capital of Republic of Acre is Kassel.\n,...",0.069502
64,capital_of,Q: What is the capital of {}?\nA:,Arrakis,"[The capital of Republic of Acre is Kassel.\n,...",0.070778
62,capital_of,Q: What is the capital of {}?\nA:,Narnia,"[The capital of Republic of Acre is Kassel.\n,...",0.071606
...,...,...,...,...,...
37,capital_of,Q: What is the capital of {}?\nA:,Manika,"[The capital of Republic of Acre is Kassel.\n,...",0.365732
48,capital_of,Q: What is the capital of {}?\nA:,Du,"[The capital of Republic of Acre is Kassel.\n,...",0.368670
46,capital_of,Q: What is the capital of {}?\nA:,Wula,"[The capital of Republic of Acre is Kassel.\n,...",0.380046
20,capital_of,Q: What is the capital of {}?\nA:,Kadersaryina,"[The capital of Republic of Acre is Kassel.\n,...",0.396409


In [21]:
wandb.finish()