# Susceptibility Scores
A notebook for initial exploration.

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [None]:
import os
import sys
import random
from tqdm import tqdm

from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch
import numpy as np
import wandb

In [None]:
from measuring.estimate_probs import estimate_cmi
from preprocessing.datasets import CountryCapital, FriendEnemy, WorldLeaders, YagoECQ
from preprocessing.utils import extract_name_from_yago_uri

### Preamble

In [None]:
##################
### Parameters ###
##################

# Data parameters
SEED = 0
# DATASET_NAME = "CountryCapital"
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=15,
#     max_entities=5,
#     cap_per_type=True,
#     raw_country_capitals_path="data/CountryCapital/real-fake-historical-fictional-famousfictional-country-capital.csv",
#     ablate_out_relevant_contexts=True,
# )
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=15,
#     max_entities=5,
#     cap_per_type=True,
#     raw_country_capitals_path="data/CountryCapital/real-fake-historical-fictional-famousfictional-country-capital.csv",
#     ablate_out_relevant_contexts=True,
# )
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=450,
#     max_entities=90,
#     cap_per_type=True,
#     raw_country_capitals_path="data/CountryCapital/real-fake-historical-fictional-famousfictional-country-capital.csv",
#     ablate_out_relevant_contexts=True,
# )
# DATASET_NAME = "FriendEnemy"
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=15,
#     max_entities=5,
#     cap_per_type=False,
#     raw_data_path="data/FriendEnemy/raw-friend-enemy.csv",
# )
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=657,
#     max_entities=73,
#     cap_per_type=False,
#     raw_data_path="data/FriendEnemy/raw-friend-enemy.csv",
# )
# DATASET_NAME = "WorldLeaders"
# DATASET_KWARGS_IDENTIFIABLE = dict(
#     max_contexts=450,
#     max_entities=90,
#     cap_per_type=False,
#     raw_data_path="data/WorldLeaders/world-leaders-2001-to-2021.csv",
#     ablate_out_relevant_contexts=False,
# )
DATASET_NAME = "YagoECQ"
QUERY_ID = "http://schema.org/leader"
SUBNAME = f"{extract_name_from_yago_uri(QUERY_ID)[0]}_{extract_name_from_yago_uri(QUERY_ID)[1]}"  # TODO: probably need to fix this
DATASET_KWARGS_IDENTIFIABLE = dict(
    query_id=QUERY_ID,
    subname=SUBNAME,
    max_contexts=450,
    max_entities=90,
    cap_per_type=False,
    raw_data_path="data/YagoECQ/yago_qec.json",
    ablate_out_relevant_contexts=False,
)
LOG_DATASETS = True

# Model parameters
MODEL_ID = "EleutherAI/pythia-70m-deduped"
LOAD_IN_8BIT = False
# MODEL_ID = "EleutherAI/pythia-6.9b-deduped"
# LOAD_IN_8BIT = True
BATCH_SZ = 16
OVERWRITE_RESULTS = False

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
# TAGS = ["capitals"]
# TAGS = ["friend-enemy"]
TAGS = ["yago"]

In [None]:
# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
# Paths
# Construct dataset and data ids
# dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)
# data_id = f"{DATASET_NAME}"
data_id = (
    DATASET_NAME
    if "subname" not in DATASET_KWARGS_IDENTIFIABLE
    else f"{DATASET_KWARGS_IDENTIFIABLE['subname']}"
)
data_id += (
    f"-mc{DATASET_KWARGS_IDENTIFIABLE['max_contexts']}"
    if "max_contexts" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_contexts"] is not None
    else ""
)
data_id += (
    f"-me{DATASET_KWARGS_IDENTIFIABLE['max_entities']}"
    if "max_entities" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["max_entities"] is not None
    else ""
)
data_id += (
    "-cappertype"
    if "cap_per_type" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["cap_per_type"]
    else ""
)
data_id += (
    "-ablate"
    if "ablate_out_relevant_contexts" in DATASET_KWARGS_IDENTIFIABLE
    and DATASET_KWARGS_IDENTIFIABLE["ablate_out_relevant_contexts"]
    else ""
)

data_dir = os.path.join(
    "data",
    DATASET_NAME,
    f"{DATASET_KWARGS_IDENTIFIABLE['subname']}"
    if "subname" in DATASET_KWARGS_IDENTIFIABLE
    else "",
    data_id,
    f"{SEED}",
)
input_dir = os.path.join(data_dir, "inputs")
entities_path = os.path.join(input_dir, "entities.json")
contexts_path = os.path.join(input_dir, "contexts.json")
queries_path = os.path.join(input_dir, "queries.json")
val_data_path = os.path.join(input_dir, "val.csv")
DATASET_KWARGS_IDENTIFIABLE = {
    **DATASET_KWARGS_IDENTIFIABLE,
    **dict(
        entities_path=entities_path,
        contexts_path=contexts_path,
        queries_path=queries_path,
    ),
}

# Construct model id
model_id = f"{MODEL_ID}"
model_id += "-8bit" if LOAD_IN_8BIT else ""
model_dir = os.path.join(data_dir, "models", model_id)

# Results path
results_dir = os.path.join(model_dir, "results")
val_results_path = os.path.join(results_dir, "val.csv")

print(f"Data dir: {data_dir}")
print(f"Model dir: {model_dir}")

In [None]:
os.makedirs(input_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
dataset = getattr(sys.modules[__name__], DATASET_NAME)(**DATASET_KWARGS_IDENTIFIABLE)

In [None]:
# GPU stuff
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# wandb stuff
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), "main.ipynb")

params_to_log = {k: v for k, v in locals().items() if k.isupper()}

run = wandb.init(
    project=PROJECT_NAME,
    group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
    mode="online",
)
print(dict(wandb.config))

### Load Data

In [None]:
val_df_contexts_per_qe = dataset.get_contexts_per_query_entity_df()

if LOG_DATASETS:
    print(f"Saving datasets to {input_dir}.")
    os.makedirs(input_dir, exist_ok=True)
    val_df_contexts_per_qe.to_csv(val_data_path)

val_df_contexts_per_qe.info()
val_df_contexts_per_qe.head()

### Preprocess Data

In [None]:
# Preprocess the data and convert it into inputs for the model (e.g. torch tensors)

In [None]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging datasets to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="dataset")
    artifact.add_dir(local_path=data_dir)
    run.log_artifact(artifact)

### Score Model

In [None]:
try:
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID, load_in_8bit=LOAD_IN_8BIT, device_map="auto"
    )
except:
    print(f"Failed to load model {MODEL_ID} in 8-bit. Attempting to load normally.")
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID,
        load_in_8bit=False,
    ).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

In [None]:
!nvidia-smi --query-gpu=memory.used --format=csv

In [None]:
torch.cuda.empty_cache()
import gc

gc.collect()

In [None]:
# One forward pass
row = val_df_contexts_per_qe.iloc[0]
estimate_cmi(
    row["query_form"],
    entity=row["entity"],
    contexts=row["contexts"][:128],
    model=model,
    tokenizer=tokenizer,
    bs=BATCH_SZ,
)

In [None]:
tqdm.pandas()
val_df_contexts_per_qe["susceptibility_score"] = val_df_contexts_per_qe.progress_apply(
    lambda row: estimate_cmi(
        query=row["query_form"],
        entity=row["entity"],
        contexts=row["contexts"],
        model=model,
        tokenizer=tokenizer,
        answer_map=None,
        bs=BATCH_SZ,
    ),
    axis=1,
)
val_df_contexts_per_qe.to_csv(val_results_path)

In [None]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
if LOG_DATASETS:
    print(f"Logging results to w&b run {wandb.run}.")
    artifact = wandb.Artifact(name=data_id, type="dataset")
    artifact.add_dir(local_path=data_dir)
    run.log_artifact(artifact)

### Evaluate Model

In [None]:
val_df_contexts_per_qe["entity"].value_counts()

In [None]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "The capital of {} is"
].sort_values(by="susceptibility_score")

In [None]:
val_df_contexts_per_qe[
    val_df_contexts_per_qe["query_form"] == "Q: What is the capital of {}?\nA:"
].sort_values(by="susceptibility_score")

In [None]:
wandb.finish()