### Summarize and aggregate susceptibility scores across all queries

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import os
import sys
import math
import random
from itertools import product
from tqdm import tqdm
from typing import Dict, List, Set, Union, Tuple
import yaml
from ast import literal_eval

import json
import pandas as pd
import seaborn as sns

from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch
from typing import List
from matplotlib import pyplot as plt
import numpy as np
import wandb

import statsmodels.api as sm

from preprocessing.datasets import CountryCapital
from susceptibility_scores import construct_paths_and_dataset_kwargs

In [3]:
##################
### Parameters ###
##################

# Data parameters
DATASET_NAME = "YagoECQ"
RAW_DATA_PATH = "data/YagoECQ/yago_qec.json"
SEED = 0
MODEL_ID = "EleutherAI/pythia-6.9b-deduped"
LOAD_IN_8BIT = True
MAX_CONTEXTS = 500
MAX_ENTITIES = 100
CAP_PER_TYPE = False
ABLATE_OUT_RELEVANT_CONTEXTS = False
UNIFORM_CONTEXTS = True
DEDUPLICATE_ENTITIES = True
# ENTITY_SELECTION_FUNC_NAME = "random_sample"
# ENTITY_SELECTION_FUNC_NAME = "top_entity_uri_degree"
ENTITY_SELECTION_FUNC_NAME = "top_entity_namesake_degree"

OVERWRITE = True
ENTITY_TYPES = ["entities", "gpt_fake_entities"]
# ENTITY_TYPES = ["entities", "fake_entities"]
QUERY_TYPES = ["closed", "open"]
# QUERY_TYPES = ["closed"]
# ANSWER_MAP = {
#     0: [1621, 642, 7651, 2302, 2369, 7716],
#     1: [6279, 4754, 22487, 4374, 9820, 24239],
# }
ANSWER_MAP = None

# Model parameters
BATCH_SZ = 16

# wandb stuff
PROJECT_NAME = "context-vs-bias"
GROUP_NAME = None
TAGS = ["yago", "analysis"]
LOG_DATASETS = True

In [4]:
with open(RAW_DATA_PATH) as f:
    yago_qec = json.load(f)

In [5]:
query_ids = list(yago_qec.keys())

In [6]:
dataset_names_and_rdps = [("YagoECQ", RAW_DATA_PATH)]
seeds = [0]  # val every 10
model_id_and_quantize_tuples = [("EleutherAI/pythia-6.9b-deduped", True)]
max_contexts = [500]
max_entities = [100]
query_ids = list(yago_qec.keys())
# query_ids = ["http://schema.org/founder"]

ent_selection_fns = [
    "top_entity_uri_degree",
    "top_entity_namesake_degree",
    "random_sample",
]

# entity_types = json.dumps(
#     ["entities", "fake_entities"], separators=(",", ":")
# )  # separators is important to remove spaces from the string. This is important downstream for bash to be able to read the whole list.
entity_types = json.dumps(
    ["entities", "gpt_fake_entities"], separators=(",", ":")
)  # separators is important to remove spaces from the string. This is important downstream for bash to be able to read the whole list.
# query_types = json.dumps(
#     ["closed", "open"], separators=(",", ":")
# )  # separators is important to remove spaces from the string. This is important downstream for bash to be able to read the whole list.
query_types = json.dumps(
    ["closed", "open"], separators=(",", ":")
)  # separators is important to remove spaces from the string. This is important downstream for bash to be able to read the whole list.

answer_map = dict()
# answer_map = {0: [" No", " no", " NO", "No", "no", "NO"], 1: [" Yes", " yes", " YES", "Yes", "yes", "YES"]}

cap_per_type = False
ablate = False
deduplicate_entities = True
uniform_contexts = True
overwrite = True


def convert_answer_map_to_tokens(
    model_id: str, answer_map: Dict[int, List[str]]
) -> str:
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        padding_side="left",
    )

    answer_map_token_ids = dict()
    for k, v in answer_map.items():
        list_of_token_ids: List[List[str]] = tokenizer(v)["input_ids"]
        valid_token_ids = []
        for token_id in list_of_token_ids:
            if len(token_id) == 1:
                valid_token_ids.append(token_id[0])
            else:
                print(
                    f"tokenizer tokenized an answer map token into multiple tokens ({token_id}), which is invalid input."
                )
        answer_map_token_ids[k] = valid_token_ids
    #     answer_map_token_ids = {
    #         k: [x[0] for x in tokenizer(v)["input_ids"] if len(x) == 1],
    #         for k, v in answer_map.items()
    #     }
    res = json.dumps(answer_map_token_ids, separators=(",", ":"))
    print(res)
    return res


df_dict = []
for ds, rdp in dataset_names_and_rdps:
    for seed in seeds:
        for model_id, do_quantize in model_id_and_quantize_tuples:
            answer_map_in_tokens = convert_answer_map_to_tokens(model_id, answer_map)
            for qid in query_ids:
                for mc in max_contexts:
                    for me in max_entities:
                        for es in ent_selection_fns:
                            dict_vals = dict(
                                DATASET_NAME=ds,
                                RAW_DATA_PATH=rdp,
                                SEED=seed,
                                MODEL_ID=model_id,
                                LOAD_IN_8BIT=do_quantize,
                                QUERY_ID=qid,
                                MAX_CONTEXTS=mc,
                                MAX_ENTITIES=me,
                                CAP_PER_TYPE=cap_per_type,
                                ABLATE_OUT_RELEVANT_CONTEXTS=ablate,
                                DEDUPLICATE_ENTITIES=deduplicate_entities,
                                UNIFORM_CONTEXTS=uniform_contexts,
                                ENTITY_SELECTION_FUNC_NAME=es,
                                OVERWRITE=overwrite,
                                ENTITY_TYPES=json.loads(entity_types),
                                QUERY_TYPES=json.loads(query_types),
                                ANSWER_MAP=json.loads(answer_map_in_tokens)
                                if json.loads(answer_map_in_tokens)
                                else None,
                            )
                            (
                                data_dir,
                                input_dir,
                                entities_path,
                                contexts_path,
                                queries_path,
                                answers_path,
                                val_data_path,
                                model_dir,
                                results_dir,
                                val_results_path,
                                data_id,
                                _,
                                DATASET_KWARGS_IDENTIFIABLE,
                            ) = construct_paths_and_dataset_kwargs(**dict_vals)
                            if os.path.isfile(val_results_path):
                                res = pd.read_csv(
                                    val_results_path,
                                    index_col=0,
                                    converters={
                                        # "contexts": literal_eval,
                                        "entity": literal_eval,
                                    },
                                )
                                closed_qfs = yago_qec[qid]["query_forms"]["closed"]
                                open_qfs = yago_qec[qid]["query_forms"]["open"]

                                res.loc[
                                    res["query_form"].isin(closed_qfs), "query_type"
                                ] = "closed"
                                res.loc[
                                    res["query_form"].isin(open_qfs), "query_type"
                                ] = "open"
                                res.loc[
                                    res["entity"].isin(
                                        [
                                            (x,)
                                            for x in yago_qec[qid]["gpt_fake_entities"]
                                        ]
                                    ),
                                    "entity_type",
                                ] = "gpt_fake_entities"
                                res.loc[
                                    res["entity"].isin(
                                        [(x,) for x in yago_qec[qid]["entities"]]
                                    ),
                                    "entity_type",
                                ] = "entities"
                                res["entity_classes"] = res["q_id"].apply(
                                    lambda x: yago_qec[qid]["entity_types"]
                                )
                                res["entity"] = res["entity"].apply(lambda x: x[0])
                                scores: List[dict] = res[
                                    [
                                        "entity",
                                        "answer",
                                        "query_form",
                                        "entity_type",
                                        "query_type",
                                        "entity_classes",
                                        "susceptibility_score",
                                    ]
                                ].to_dict("records")
                                df_dict += [{**dict_vals, **d} for d in scores]

{}
Data dir: data/YagoECQ/schema_highestPoint/schema_highestPoint-mc500-me100-uniformcontexts-dedupeentities-ET_entities_gpt_fake_entities-QT_closed_open-ES_top_entity_uri_degree/0
Model dir: data/YagoECQ/schema_highestPoint/schema_highestPoint-mc500-me100-uniformcontexts-dedupeentities-ET_entities_gpt_fake_entities-QT_closed_open-ES_top_entity_uri_degree/0/models/EleutherAI/pythia-6.9b-deduped-8bit
Data dir: data/YagoECQ/schema_highestPoint/schema_highestPoint-mc500-me100-uniformcontexts-dedupeentities-ET_entities_gpt_fake_entities-QT_closed_open-ES_top_entity_namesake_degree/0
Model dir: data/YagoECQ/schema_highestPoint/schema_highestPoint-mc500-me100-uniformcontexts-dedupeentities-ET_entities_gpt_fake_entities-QT_closed_open-ES_top_entity_namesake_degree/0/models/EleutherAI/pythia-6.9b-deduped-8bit
Data dir: data/YagoECQ/schema_highestPoint/schema_highestPoint-mc500-me100-uniformcontexts-dedupeentities-ET_entities_gpt_fake_entities-QT_closed_open-ES_random_sample/0
Model dir: data/Y

In [7]:
"Twilight Seraphim" in yago_qec["http://schema.org/lyricist"]["gpt_fake_entities"]

False

In [10]:
len(scores_df) 

64400

In [17]:
pd.read_csv("summarized_scores.csv").info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 64400 entries, 0 to 64399
Data columns (total 24 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   DATASET_NAME                  64400 non-null  object 
 1   RAW_DATA_PATH                 64400 non-null  object 
 2   SEED                          64400 non-null  int64  
 3   MODEL_ID                      64400 non-null  object 
 4   LOAD_IN_8BIT                  64400 non-null  bool   
 5   QUERY_ID                      64400 non-null  object 
 6   MAX_CONTEXTS                  64400 non-null  int64  
 7   MAX_ENTITIES                  64400 non-null  int64  
 8   CAP_PER_TYPE                  64400 non-null  bool   
 9   ABLATE_OUT_RELEVANT_CONTEXTS  64400 non-null  bool   
 10  DEDUPLICATE_ENTITIES          64400 non-null  bool   
 11  UNIFORM_CONTEXTS              64400 non-null  bool   
 12  ENTITY_SELECTION_FUNC_NAME    64400 non-null  object 
 13  O

In [9]:
scores_df = pd.DataFrame(df_dict)
scores_df.to_csv("summarized_scores.csv", index=False)
scores_df["ENTITY_TYPES"] = scores_df["ENTITY_TYPES"].apply(lambda x: tuple(x))
scores_df["QUERY_TYPES"] = scores_df["QUERY_TYPES"].apply(lambda x: tuple(x))
scores_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 64400 entries, 0 to 64399
Data columns (total 24 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   DATASET_NAME                  64400 non-null  object 
 1   RAW_DATA_PATH                 64400 non-null  object 
 2   SEED                          64400 non-null  int64  
 3   MODEL_ID                      64400 non-null  object 
 4   LOAD_IN_8BIT                  64400 non-null  bool   
 5   QUERY_ID                      64400 non-null  object 
 6   MAX_CONTEXTS                  64400 non-null  int64  
 7   MAX_ENTITIES                  64400 non-null  int64  
 8   CAP_PER_TYPE                  64400 non-null  bool   
 9   ABLATE_OUT_RELEVANT_CONTEXTS  64400 non-null  bool   
 10  DEDUPLICATE_ENTITIES          64400 non-null  bool   
 11  UNIFORM_CONTEXTS              64400 non-null  bool   
 12  ENTITY_SELECTION_FUNC_NAME    64400 non-null  object 
 13  O

In [9]:
res

Unnamed: 0,q_id,query_form,entity,answer,contexts,susceptibility_score,full_query_example,query_type,entity_type
0,http://schema.org/lyricist,Q: Is {answer} a lyricist for '{entity}'?\nA:,The Thistle and Thorn,Paul Cook,"[""A lyricist for '(If Paradise Is) Half as Nic...",0.036604,A lyricist for '(If Paradise Is) Half as Nice'...,closed,
1,http://schema.org/lyricist,Q: Is {answer} a lyricist for '{entity}'?\nA:,Judge Not,Bob Marley,"[""A lyricist for '(If Paradise Is) Half as Nic...",0.032724,A lyricist for '(If Paradise Is) Half as Nice'...,closed,
2,http://schema.org/lyricist,Q: Is {answer} a lyricist for '{entity}'?\nA:,The Dirge of the Drifting Isles,Lorenz Hart,"[""A lyricist for '(If Paradise Is) Half as Nic...",0.033952,A lyricist for '(If Paradise Is) Half as Nice'...,closed,
3,http://schema.org/lyricist,Q: Is {answer} a lyricist for '{entity}'?\nA:,Twilight Seraphim,Al Jackson Jr.,"[""A lyricist for '(If Paradise Is) Half as Nic...",0.039045,A lyricist for '(If Paradise Is) Half as Nice'...,closed,
4,http://schema.org/lyricist,Q: Is {answer} a lyricist for '{entity}'?\nA:,Long Promised Road,Jack Rieley,"[""A lyricist for '(If Paradise Is) Half as Nic...",0.038116,A lyricist for '(If Paradise Is) Half as Nice'...,closed,
...,...,...,...,...,...,...,...,...,...
395,http://schema.org/lyricist,A lyricist for '{entity}' is a,Matchbox,Carl Perkins,"[""A lyricist for '(If Paradise Is) Half as Nic...",1.966967,A lyricist for '(If Paradise Is) Half as Nice'...,open,
396,http://schema.org/lyricist,A lyricist for '{entity}' is a,BreatheEz Air Masks,Jennifer Peña,"[""A lyricist for '(If Paradise Is) Half as Nic...",1.731975,A lyricist for '(If Paradise Is) Half as Nice'...,open,
397,http://schema.org/lyricist,A lyricist for '{entity}' is a,Bike,Syd Barrett,"[""A lyricist for '(If Paradise Is) Half as Nic...",1.815772,A lyricist for '(If Paradise Is) Half as Nice'...,open,
398,http://schema.org/lyricist,A lyricist for '{entity}' is a,Honeysuckle Rose,Andy Razaf,"[""A lyricist for '(If Paradise Is) Half as Nic...",1.689596,A lyricist for '(If Paradise Is) Half as Nice'...,open,


In [16]:
pd.DataFrame(
    [
        {
            "query_id": "http://schema.org/alumniOf",
            "entity": "Harvard",
            "model-id": "my-model-id",
            "query_form": "my-query-form",
            "query_type": "closed",
            "entity_type": "entities",
            "max_contexts": 500,
            "cap_per_type": False,
            "ABLATE_OUT_RELEVANT_CONTEXTS": False,
            "DEDUPLICATE_ENTITIES": True,
            "UNIFORM_CONTEXTS": True,
            "ENTITY_SELECTION_FUNC_NAME": "random_sample",
            "SEED": 0,
            "ANSWER_MAP": None,
            "sus_score": 0.4,
        }
    ]
    * int(1e7)
)

Unnamed: 0,query_id,entity,model-id,query_form,query_type,entity_type,max_contexts,cap_per_type,ABLATE_OUT_RELEVANT_CONTEXTS,DEDUPLICATE_ENTITIES,UNIFORM_CONTEXTS,ENTITY_SELECTION_FUNC_NAME,SEED,ANSWER_MAP,sus_score
0,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
1,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
2,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
3,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
4,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9999995,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
9999996,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
9999997,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
9999998,http://schema.org/alumniOf,Harvard,my-model-id,my-query-form,closed,entities,500,False,False,True,True,random_sample,0,,0.4
