# Susceptibility Scores
A notebook for initial exploration.

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import os
import sys
import math
import random
from itertools import product
from tqdm import tqdm
import yaml

import pandas as pd
import seaborn as sns
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch
from typing import List
import numpy as np
import wandb

import statsmodels.api as sm
from measuring.estimate_probs import (
    estimate_prob_y_given_context_and_entity,
    estimate_prob_x_given_e,
    estimate_prob_next_word_given_x_and_entity,
    estimate_cmi,
    score_model_for_next_word_prob,
    create_position_ids_from_input_ids,
    sharded_score_model,
    estimate_entity_score,
    kl_div,
    difference,
    difference_p_good_only,
    difference_abs_val,
)
from preprocessing.datasets import CountryCapital

  from .autonotebook import tqdm as notebook_tqdm


### Preamble

In [3]:
##################
### Parameters ###
##################

# Data parameters
SEED = 0
DATASET_NAME = "CountryCapital"
DATASET_KWARGS_IDENTIFIABLE = dict(
    max_contexts=10,
    max_entities=10,
    raw_country_capitals_path="data/CountryCapital/real-fake-country-capital.csv",
)
LOG_DATASETS = True

# Model parameters
MODEL_ID = "EleutherAI/pythia-70m-deduped"
LOAD_IN_8BIT = False

# Evaluation switches
COMPUTE_CMI = True
COMPUTE_KL = True
COMPUTE_GOOD_BAD = True
COMPUTE_GOOD_BAD_ABS = True
COMPUTE_GOOD_BAD_P_GOOD_ONLY = True

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
TAGS = ["capitals"]

In [4]:
# Paths
# Construct dataset and data ids
# dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)
data_id = f"{DATASET_NAME}"
data_id += (
    f"-mc{DATASET_KWARGS_IDENTIFIABLE['max_contexts']}"
    if "max_contexts" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_contexts"] is not None
    else ""
)
data_id += (
    f"-me{DATASET_KWARGS_IDENTIFIABLE['max_entities']}"
    if "max_entities" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_entities"] is not None
    else ""
)

data_dir = os.path.join("data", DATASET_NAME, data_id, f"{SEED}")
input_dir = os.path.join(data_dir, "inputs")
entities_path = os.path.join(input_dir, "entities.json")
contexts_path = os.path.join(input_dir, "contexts.json")
queries_path = os.path.join(input_dir, "queries.json")
val_data_path = os.path.join(input_dir, "val.csv")
DATASET_KWARGS_IDENTIFIABLE = {
    **DATASET_KWARGS_IDENTIFIABLE,
    **dict(
        entities_path=entities_path,
        contexts_path=contexts_path,
        queries_path=queries_path,
    ),
}

results_dir = os.path.join(data_dir, "results")
val_results_path = os.path.join(results_dir, "val.csv")

# Construct model id
model_id = f"{MODEL_ID}"
model_id += "-8bit" if LOAD_IN_8BIT else ""
model_dir = os.path.join(data_dir, "models", model_id)

print(f"Data dir: {data_dir}")
print(f"Model dir: {model_dir}")

Data dir: data/CountryCapital/CountryCapital-mc10-me10/0
Model dir: data/CountryCapital/CountryCapital-mc10-me10/0/models/EleutherAI/pythia-70m-deduped


In [5]:
os.makedirs(input_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)

100%|██████████| 2/2 [00:00<00:00, 26715.31it/s]


In [6]:
# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [7]:
# GPU stuff
device = "cuda" if torch.cuda.is_available() else "cpu"

  return torch._C._cuda_getDeviceCount() > 0


In [8]:
# wandb stuff
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), "main.ipynb")

params_to_log = {k: v for k, v in locals().items() if k.isupper()}

run = wandb.init(
    project=PROJECT_NAME,
    group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
    mode="online",
)
print(dict(wandb.config))

[34m[1mwandb[0m: Currently logged in as: [33mkdu[0m ([33methz-rycolab[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'SEED': 0, 'DATASET_NAME': 'CountryCapital', 'DATASET_KWARGS_IDENTIFIABLE': {'max_contexts': 10, 'max_entities': 10, 'raw_country_capitals_path': 'data/CountryCapital/real-fake-country-capital.csv', 'entities_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/entities.json', 'contexts_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/contexts.json', 'queries_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/queries.json'}, 'LOG_DATASETS': True, 'MODEL_ID': 'EleutherAI/pythia-70m-deduped', 'LOAD_IN_8BIT': False, 'COMPUTE_CMI': True, 'COMPUTE_KL': True, 'COMPUTE_GOOD_BAD': True, 'COMPUTE_GOOD_BAD_ABS': True, 'COMPUTE_GOOD_BAD_P_GOOD_ONLY': True, 'PROJECT_NAME': 'context-vs-bias', 'GROUP_NAME': None, 'TAGS': ['capitals']}


### Load Data

In [9]:
val_df_contexts_per_qe = dataset.get_contexts_per_query_entity_df()

if LOG_DATASETS:
    print(f"Saving datasets to {input_dir}.")
    os.makedirs(input_dir, exist_ok=True)
    val_df_contexts_per_qe.to_csv(val_data_path)

val_df_contexts_per_qe.info()
val_df_contexts_per_qe.head()

Saving datasets to data/CountryCapital/CountryCapital-mc10-me10/0/inputs.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20 entries, 0 to 19
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   q_id        20 non-null     object
 1   query_form  20 non-null     object
 2   entity      20 non-null     object
 3   contexts    20 non-null     object
dtypes: object(4)
memory usage: 768.0+ bytes


Unnamed: 0,q_id,query_form,entity,contexts
0,capital_of,Q: What is the capital of {}?\nA:,Nepal,"[The capital of Nepal is Chyria.\n, The capita..."
1,capital_of,Q: What is the capital of {}?\nA:,Warstadt,"[The capital of Nepal is Chyria.\n, The capita..."
2,capital_of,Q: What is the capital of {}?\nA:,San Marino,"[The capital of Nepal is Chyria.\n, The capita..."
3,capital_of,Q: What is the capital of {}?\nA:,Côte d'Ivoire,"[The capital of Nepal is Chyria.\n, The capita..."
4,capital_of,Q: What is the capital of {}?\nA:,Lithuania,"[The capital of Nepal is Chyria.\n, The capita..."


### Preprocess Data

In [10]:
# Preprocess the data and convert it into inputs for the model (e.g. torch tensors)

In [11]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging datasets to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="dataset")
    artifact.add_dir(local_path=input_dir)
    run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data/CountryCapital/CountryCapital-mc10-me10/0/inputs)... Done. 0.0s


Logging datasets to w&b run <wandb.sdk.wandb_run.Run object at 0x7f9779a19f60>.


### Score Model

In [12]:
try:
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID, load_in_8bit=LOAD_IN_8BIT, device_map="auto"
    )
except:
    print(f"Failed to load model {MODEL_ID} in 8-bit. Attempting to load normally.")
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID,
        load_in_8bit=False,
    ).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

In [13]:
tqdm.pandas()
val_df_contexts_per_qe["susceptibility_score"] = val_df_contexts_per_qe.progress_apply(
    lambda row: estimate_cmi(
        query=row["query_form"],
        entity=row["entity"],
        contexts=row["contexts"],
        model=model,
        tokenizer=tokenizer,
        answer_map=None,
    ),
    axis=1,
)
val_df_contexts_per_qe

  0%|          | 0/20 [00:00<?, ?it/s]

Setting model.config.pad_token_id to model.config.eos_token_id


  return np.sum(prob_x_y_given_e * np.nan_to_num(np.log(prob_y_given_context_and_entity / prob_y_given_e)))
  return np.sum(prob_x_y_given_e * np.nan_to_num(np.log(prob_y_given_context_and_entity / prob_y_given_e)))
100%|██████████| 20/20 [00:02<00:00,  8.10it/s]


Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
0,capital_of,Q: What is the capital of {}?\nA:,Nepal,"[The capital of Nepal is Chyria.\n, The capita...",0.040032
1,capital_of,Q: What is the capital of {}?\nA:,Warstadt,"[The capital of Nepal is Chyria.\n, The capita...",0.051261
2,capital_of,Q: What is the capital of {}?\nA:,San Marino,"[The capital of Nepal is Chyria.\n, The capita...",0.088006
3,capital_of,Q: What is the capital of {}?\nA:,Côte d'Ivoire,"[The capital of Nepal is Chyria.\n, The capita...",0.060005
4,capital_of,Q: What is the capital of {}?\nA:,Lithuania,"[The capital of Nepal is Chyria.\n, The capita...",0.044034
5,capital_of,Q: What is the capital of {}?\nA:,Rwanda,"[The capital of Nepal is Chyria.\n, The capita...",0.055354
6,capital_of,Q: What is the capital of {}?\nA:,Brunei,"[The capital of Nepal is Chyria.\n, The capita...",0.077649
7,capital_of,Q: What is the capital of {}?\nA:,Qatar,"[The capital of Nepal is Chyria.\n, The capita...",0.038981
8,capital_of,Q: What is the capital of {}?\nA:,Floofern,"[The capital of Nepal is Chyria.\n, The capita...",0.019681
9,capital_of,Q: What is the capital of {}?\nA:,Ecuador,"[The capital of Nepal is Chyria.\n, The capita...",0.026412


### Evaluate Model

In [14]:
val_df_contexts_per_qe["entity"].value_counts()

entity
Nepal            2
Warstadt         2
San Marino       2
Côte d'Ivoire    2
Lithuania        2
Rwanda           2
Brunei           2
Qatar            2
Floofern         2
Ecuador          2
Name: count, dtype: int64

In [18]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "The capital of {} is"
].sort_values(by="susceptibility_score")

Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
14,capital_of,The capital of {} is,Lithuania,"[The capital of Nepal is Chyria.\n, The capita...",0.250445
17,capital_of,The capital of {} is,Qatar,"[The capital of Nepal is Chyria.\n, The capita...",0.261756
16,capital_of,The capital of {} is,Brunei,"[The capital of Nepal is Chyria.\n, The capita...",0.263851
10,capital_of,The capital of {} is,Nepal,"[The capital of Nepal is Chyria.\n, The capita...",0.266866
19,capital_of,The capital of {} is,Ecuador,"[The capital of Nepal is Chyria.\n, The capita...",0.293593
15,capital_of,The capital of {} is,Rwanda,"[The capital of Nepal is Chyria.\n, The capita...",0.313813
11,capital_of,The capital of {} is,Warstadt,"[The capital of Nepal is Chyria.\n, The capita...",0.428296
13,capital_of,The capital of {} is,Côte d'Ivoire,"[The capital of Nepal is Chyria.\n, The capita...",0.44049
18,capital_of,The capital of {} is,Floofern,"[The capital of Nepal is Chyria.\n, The capita...",0.453249
12,capital_of,The capital of {} is,San Marino,"[The capital of Nepal is Chyria.\n, The capita...",0.454748


In [16]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "Q: What is the capital of {}?\nA:"
].sort_values(by="susceptibility_score")

Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
8,capital_of,Q: What is the capital of {}?\nA:,Floofern,"[The capital of Nepal is Chyria.\n, The capita...",0.019681
9,capital_of,Q: What is the capital of {}?\nA:,Ecuador,"[The capital of Nepal is Chyria.\n, The capita...",0.026412
7,capital_of,Q: What is the capital of {}?\nA:,Qatar,"[The capital of Nepal is Chyria.\n, The capita...",0.038981
0,capital_of,Q: What is the capital of {}?\nA:,Nepal,"[The capital of Nepal is Chyria.\n, The capita...",0.040032
4,capital_of,Q: What is the capital of {}?\nA:,Lithuania,"[The capital of Nepal is Chyria.\n, The capita...",0.044034
1,capital_of,Q: What is the capital of {}?\nA:,Warstadt,"[The capital of Nepal is Chyria.\n, The capita...",0.051261
5,capital_of,Q: What is the capital of {}?\nA:,Rwanda,"[The capital of Nepal is Chyria.\n, The capita...",0.055354
3,capital_of,Q: What is the capital of {}?\nA:,Côte d'Ivoire,"[The capital of Nepal is Chyria.\n, The capita...",0.060005
6,capital_of,Q: What is the capital of {}?\nA:,Brunei,"[The capital of Nepal is Chyria.\n, The capita...",0.077649
2,capital_of,Q: What is the capital of {}?\nA:,San Marino,"[The capital of Nepal is Chyria.\n, The capita...",0.088006


In [17]:
wandb.finish()