# Analyzing Susceptibility Scores
Given a CSV with the columns:

```q_id,	query_form,	entity,	contexts,	susceptibility_score```

Analyze the results for patterns and correlations

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import os
import sys
import math
import random
from itertools import product
from tqdm import tqdm
import yaml

import pandas as pd
import seaborn as sns
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch
from typing import List
import numpy as np
import wandb

import statsmodels.api as sm
from measuring.estimate_probs import (
    estimate_prob_y_given_context_and_entity,
    estimate_prob_x_given_e,
    estimate_prob_next_word_given_x_and_entity,
    estimate_cmi,
    score_model_for_next_word_prob,
    create_position_ids_from_input_ids,
    sharded_score_model,
    estimate_entity_score,
    kl_div,
    difference,
    difference_p_good_only,
    difference_abs_val,
)
from preprocessing.datasets import CountryCapital

  from .autonotebook import tqdm as notebook_tqdm


### Preamble

In [3]:
##################
### Parameters ###
##################

# Data parameters
SEED = 0
DATASET_NAME = "CountryCapital"
DATASET_KWARGS_IDENTIFIABLE = dict(
    max_contexts=10,
    max_entities=10,
    raw_country_capitals_path="data/CountryCapital/real-fake-country-capital.csv",
)
LOG_DATASETS = True

# Model parameters
MODEL_ID = "EleutherAI/pythia-70m-deduped"
LOAD_IN_8BIT = False

# Evaluation switches
COMPUTE_CMI = True
COMPUTE_KL = True
COMPUTE_GOOD_BAD = True
COMPUTE_GOOD_BAD_ABS = True
COMPUTE_GOOD_BAD_P_GOOD_ONLY = True

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
TAGS = ["capitals", "analysis"]

In [4]:
# Paths
# Construct dataset and data ids
# dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)
data_id = f"{DATASET_NAME}"
data_id += (
    f"-mc{DATASET_KWARGS_IDENTIFIABLE['max_contexts']}"
    if "max_contexts" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_contexts"] is not None
    else ""
)
data_id += (
    f"-me{DATASET_KWARGS_IDENTIFIABLE['max_entities']}"
    if "max_entities" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_entities"] is not None
    else ""
)

data_dir = os.path.join("data", DATASET_NAME, data_id, f"{SEED}")
input_dir = os.path.join(data_dir, "inputs")
entities_path = os.path.join(input_dir, "entities.json")
contexts_path = os.path.join(input_dir, "contexts.json")
queries_path = os.path.join(input_dir, "queries.json")
val_data_path = os.path.join(input_dir, "val.csv")
DATASET_KWARGS_IDENTIFIABLE = {
    **DATASET_KWARGS_IDENTIFIABLE,
    **dict(
        entities_path=entities_path,
        contexts_path=contexts_path,
        queries_path=queries_path,
    ),
}

results_dir = os.path.join(data_dir, "results")
val_results_path = os.path.join(results_dir, "val.csv")

# Construct model id
model_id = f"{MODEL_ID}"
model_id += "-8bit" if LOAD_IN_8BIT else ""
model_dir = os.path.join(data_dir, "models", model_id)

print(f"Data dir: {data_dir}")
print(f"Model dir: {model_dir}")

Data dir: data/CountryCapital/CountryCapital-mc10-me10/0
Model dir: data/CountryCapital/CountryCapital-mc10-me10/0/models/EleutherAI/pythia-70m-deduped


In [5]:
# Analysis dir
analysis_dir = os.path.join(data_dir, "analysis")
print(f"Analysis dir: {analysis_dir}")

Analysis dir: data/CountryCapital/CountryCapital-mc10-me10/0/analysis


In [6]:
os.makedirs(input_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
os.makedirs(analysis_dir, exist_ok=True)
dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)

100%|██████████| 2/2 [00:00<00:00, 6457.74it/s]


In [7]:
# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [8]:
# GPU stuff
device = "cuda" if torch.cuda.is_available() else "cpu"

  return torch._C._cuda_getDeviceCount() > 0


In [9]:
# wandb stuff
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), "analysis.ipynb")

params_to_log = {k: v for k, v in locals().items() if k.isupper()}

run = wandb.init(
    project=PROJECT_NAME,
    group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
    mode="online",
)
print(dict(wandb.config))

[34m[1mwandb[0m: Currently logged in as: [33mkdu[0m ([33methz-rycolab[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'SEED': 0, 'DATASET_NAME': 'CountryCapital', 'DATASET_KWARGS_IDENTIFIABLE': {'max_contexts': 10, 'max_entities': 10, 'raw_country_capitals_path': 'data/CountryCapital/real-fake-country-capital.csv', 'entities_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/entities.json', 'contexts_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/contexts.json', 'queries_path': 'data/CountryCapital/CountryCapital-mc10-me10/0/inputs/queries.json'}, 'LOG_DATASETS': True, 'MODEL_ID': 'EleutherAI/pythia-70m-deduped', 'LOAD_IN_8BIT': False, 'COMPUTE_CMI': True, 'COMPUTE_KL': True, 'COMPUTE_GOOD_BAD': True, 'COMPUTE_GOOD_BAD_ABS': True, 'COMPUTE_GOOD_BAD_P_GOOD_ONLY': True, 'PROJECT_NAME': 'context-vs-bias', 'GROUP_NAME': None, 'TAGS': ['capitals', 'analysis']}


### Load Data

In [29]:
from ast import literal_eval

val_df_contexts_per_qe = pd.read_csv(
    val_results_path, index_col=0, converters={"contexts": literal_eval}
)

In [30]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging datasets to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="dataset")
    artifact.add_dir(local_path=input_dir)
    run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data/CountryCapital/CountryCapital-mc10-me10/0/inputs)... Done. 0.0s


Logging datasets to w&b run None.


UsageError: Run (8lpepnul) is finished. The call to `log_artifact` will be ignored. Please make sure that you are using an active run.

### Analyze Data

In [31]:
val_df_contexts_per_qe.head()

Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
0,capital_of,Q: What is the capital of {}?\nA:,Cuba,"[The capital of Easter Island is Jakarta.\n, T...",0.036369
1,capital_of,Q: What is the capital of {}?\nA:,Sachan,"[The capital of Easter Island is Jakarta.\n, T...",0.0573
2,capital_of,Q: What is the capital of {}?\nA:,Lebanon,"[The capital of Easter Island is Jakarta.\n, T...",0.066166
3,capital_of,Q: What is the capital of {}?\nA:,Paraguay,"[The capital of Easter Island is Jakarta.\n, T...",0.092626
4,capital_of,Q: What is the capital of {}?\nA:,Wales,"[The capital of Easter Island is Jakarta.\n, T...",0.051006


In [32]:
val_df_contexts_per_qe["entity"].value_counts()

entity
Cuba                         2
Sachan                       2
Lebanon                      2
Paraguay                     2
Wales                        2
United Kingdom; England      2
Ecuador                      2
Easter Island                2
France                       2
Nagorno-Karabakh Republic    2
Name: count, dtype: int64

In [33]:
val_df_contexts_per_qe.iloc[0]["contexts"]

['The capital of Easter Island is Jakarta.\n',
 'The capital of Easter Island is Ryan.\n',
 'The capital of Lebanon is Tashkent.\n',
 'The capital of Nagorno-Karabakh Republic is Luanda.\n',
 'The capital of Wales is Santo Domingo.\n',
 'The capital of Wales is Ottawa.\n',
 'The capital of Easter Island is Luxembourg.\n',
 'The capital of Ecuador is Baku.\n',
 'The capital of Nagorno-Karabakh Republic is Banjul.\n',
 'The capital of Easter Island is Doha.\n']

In [14]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "The capital of {} is"
].sort_values(by="susceptibility_score")

Unnamed: 0.1,Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
15,15,capital_of,The capital of {} is,United Kingdom; England,"['The capital of Easter Island is Jakarta.\n',...",0.118859
11,11,capital_of,The capital of {} is,Sachan,"['The capital of Easter Island is Jakarta.\n',...",0.494262
12,12,capital_of,The capital of {} is,Lebanon,"['The capital of Easter Island is Jakarta.\n',...",0.553292
10,10,capital_of,The capital of {} is,Cuba,"['The capital of Easter Island is Jakarta.\n',...",0.578758
16,16,capital_of,The capital of {} is,Ecuador,"['The capital of Easter Island is Jakarta.\n',...",0.592763
18,18,capital_of,The capital of {} is,France,"['The capital of Easter Island is Jakarta.\n',...",0.607013
14,14,capital_of,The capital of {} is,Wales,"['The capital of Easter Island is Jakarta.\n',...",0.679343
13,13,capital_of,The capital of {} is,Paraguay,"['The capital of Easter Island is Jakarta.\n',...",0.717503
17,17,capital_of,The capital of {} is,Easter Island,"['The capital of Easter Island is Jakarta.\n',...",0.884176
19,19,capital_of,The capital of {} is,Nagorno-Karabakh Republic,"['The capital of Easter Island is Jakarta.\n',...",0.921168


In [15]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "Q: What is the capital of {}?\nA:"
].sort_values(by="susceptibility_score")

Unnamed: 0.1,Unnamed: 0,q_id,query_form,entity,contexts,susceptibility_score
6,6,capital_of,Q: What is the capital of {}?\nA:,Ecuador,"['The capital of Easter Island is Jakarta.\n',...",0.030048
0,0,capital_of,Q: What is the capital of {}?\nA:,Cuba,"['The capital of Easter Island is Jakarta.\n',...",0.036369
8,8,capital_of,Q: What is the capital of {}?\nA:,France,"['The capital of Easter Island is Jakarta.\n',...",0.045785
4,4,capital_of,Q: What is the capital of {}?\nA:,Wales,"['The capital of Easter Island is Jakarta.\n',...",0.051006
5,5,capital_of,Q: What is the capital of {}?\nA:,United Kingdom; England,"['The capital of Easter Island is Jakarta.\n',...",0.051659
1,1,capital_of,Q: What is the capital of {}?\nA:,Sachan,"['The capital of Easter Island is Jakarta.\n',...",0.0573
2,2,capital_of,Q: What is the capital of {}?\nA:,Lebanon,"['The capital of Easter Island is Jakarta.\n',...",0.066166
9,9,capital_of,Q: What is the capital of {}?\nA:,Nagorno-Karabakh Republic,"['The capital of Easter Island is Jakarta.\n',...",0.082448
3,3,capital_of,Q: What is the capital of {}?\nA:,Paraguay,"['The capital of Easter Island is Jakarta.\n',...",0.092626
7,7,capital_of,Q: What is the capital of {}?\nA:,Easter Island,"['The capital of Easter Island is Jakarta.\n',...",0.095616


In [16]:
wandb.finish()