## Imports

In [5]:
import pandas as pd
import glob
import datasets
import numpy as np
import matplotlib.pyplot as plt
import os
import json

from scipy.stats import chi2_contingency
from scipy.stats import chi2

## Load Data

In [6]:
wiki_data = pd.read_csv("../../data/wikidata/wikidata-property-list.csv")
wiki_data = wiki_data[["Title", "ID", "Datatype", "Description"]]

In [7]:
code_to_lang_dict = {
    "bg": "Bulgarian",
    "ca": "Catalan",
    "cs": "Czech",
    "da": "Danish",
    "de": "German",
    "en": "English",
    "es": "Spanish",
    "fr": "French",
    "hr": "Croatian",
    "hu": "Hungarian",
    "it": "Italian",
    "nl": "Dutch",
    "pl": "Polish",
    "pt": "Portuguese",
    "ro": "Romanian",
    "ru": "Russian",
    "sl": "Slovenian",
    "sr": "Serbian",
    "sv": "Swedish",
    "uk": "Ukrainian",
}

In [8]:
lang_to_code_dict = {v: k for k, v in code_to_lang_dict.items()}

In [9]:
results_dict = {}
results_dict["language"] = []
results_dict["relation"] = []
results_dict["percentage change"] = []
results_dict["new ratio of rows"] = []
results_dict["old ratio of rows"] = []

hf_df = datasets.load_dataset("CalibraGPT/Fact-Completion")
file_names = glob.glob("../../data/result_logs/llama-30b/error-analysis/*.csv")

# confirm grabbing data correctly against LLaMa figure
# uncomment print statement at end of for loop to see
results_dfs = []
count = 0
for file in file_names:
    language = file.split(".csv")[0].split("-")[-1].capitalize()
    error_df = pd.read_csv(file)
    full_hf_df = hf_df[file.split(".csv")[0].split("-")[-1].capitalize()]
    full_hf_df = full_hf_df.to_pandas()

    # stem is in both
    # dataset id is in both
    # to see if the model got something wrong, see if the dataset id in the full df is in the error
    error_ids = list(error_df["dataset_id"])
    correct = []
    counts = []
    relation_names = []
    for row in full_hf_df.iterrows():
        # track counts
        count += 1
        counts.append(count)
        # track errors
        correct.append(False) if row[1]["dataset_id"] in error_ids else correct.append(
            True
        )
        # track relation titles
        relation_id = int(row[1].relation[1:])
        relation_title = list(wiki_data[wiki_data["ID"] == relation_id]["Title"])[0]
        relation_names.append(relation_title)

    # append result to full df
    full_hf_df["correct"] = correct
    # append language to full df
    full_hf_df["language"] = [language] * full_hf_df.shape[0]
    # append language code to full df
    lang_code = lang_to_code_dict[language]
    full_hf_df["lang_code"] = [lang_code] * full_hf_df.shape[0]
    # append relation title to full df
    full_hf_df["relation_title"] = relation_names
    # also append an arbitrary id to have unique val for each row
    full_hf_df["analysis_id"] = counts

    results_dfs.append(full_hf_df)

results_df = pd.concat(results_dfs)
assert results_df.shape[0] == count

Found cached dataset parquet (/Users/tim/.cache/huggingface/datasets/CalibraGPT___parquet/CalibraGPT--Fact-Completion-24a24a1e4bf6e4a8/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/20 [00:00<?, ?it/s]

## More Cleanup to Ensure that we have access to Subjects across langs

In [10]:
# mapping between dataset id and the english form of a subject
dataset_id_to_eng_subject = {}
for row in results_df.iterrows():
    if row[1].language == "English":
        if row[1].dataset_id not in dataset_id_to_eng_subject:
            dataset_id_to_eng_subject[row[1].dataset_id] = row[1].subject

In [11]:
entities = {}
for row in results_df.iterrows():
    # gather helpful row level data
    # the subject
    subject = row[1].subject
    # whether the model got it right
    val = row[1].correct
    # the dataset id
    dataset_id = row[1].dataset_id
    # the english version of the subject
    english_subject = dataset_id_to_eng_subject[dataset_id]

    # commit it to our tracking dict
    if english_subject not in entities:
        entities[english_subject] = {
            "correct": 0,
            "incorrect": 0,
            "langs": {},
            "alternate_forms": {},
            "dataset_ids": set(),
        }

    # counter of correct/incorrect for that subject
    if val:
        entities[english_subject]["correct"] += 1
    else:
        entities[english_subject]["incorrect"] += 1

    # track language
    lang = row[1].lang_code

    if lang not in entities[english_subject]["langs"]:
        entities[english_subject]["langs"][lang] = 1

    else:
        entities[english_subject]["langs"][lang] += 1

    # track any alternate forms
    entities[english_subject]["alternate_forms"][lang] = subject

    entities[english_subject]["dataset_ids"].add(dataset_id)

In [12]:
entity_names = []
correct = []
incorrect = []
total = []
pct = []
langs = []
num_langs = []
alternate_forms = []
dataset_ids = []
for k, v in entities.items():
    entity_names.append(k)
    # track # of times entity is used in a correct statement, incorrect, and pct accuracy
    correct.append(v["correct"])
    incorrect.append(v["incorrect"])
    total.append(int(v["correct"]) + int(v["incorrect"]))
    pct.append(int(v["correct"]) / (int(v["correct"]) + int(v["incorrect"])))
    # track # of languages the entity is used in
    langs.append(v["langs"])
    num_langs.append(len(v["langs"]))
    alternate_forms.append(v["alternate_forms"])
    # track dataset ids its used in
    dataset_ids.append(list(v["dataset_ids"]))
    # sanity check
    assert int(v["correct"]) + int(v["incorrect"]) == sum(v["langs"].values())

In [13]:
# the average entity appears in ~12 langs
# (remember that this will max out at 20.)
np.mean(num_langs)

11.867738745323988

In [14]:
entity_analysis_df = pd.DataFrame(
    {
        "entity": entity_names,
        "num_correct": correct,
        "num_incorrect": incorrect,
        "total_usages": total,
        "percent_accuracy": pct,
        "languages": langs,
        "num_languages": num_langs,
        "alternate_forms": alternate_forms,
        "dataset_ids": dataset_ids,
    }
)

In [98]:
# we don't see a strong trend of the model doing worse with entities as a function of how
# many questions it receives about it
# which is good because it implies its picks aren't haphazard
entity_analysis_df['total_usages'].corr(entity_analysis_df['percent_accuracy'])

-0.011757719194057775

## Significance Testing

* use chi squared to see: 
 
    * if the number of correctly and incorrectly classified statements for each language scrip is statistically significant
       * yes
    * if the number of correctly and incorrectly classified statements for each language group is statistically significant
        * yes
        
    * the number of correctly and incorrectly classified locations for western/eastern locales is statistically significant
        * diff bt Europe vs Asia, yes (Asia performs better)
        * diff bt Oceania vs Asia, yes (Asia performs better)
        * diff bt Antarctic vs rest of world, yes (Ant. performs worse)
        * diff bt South vs North America, yes (South America performs better)
        * diff bt Europe + North America vs Asia, yes (Asia performs better
    * the number of correctly and incorrectly classified entities for women/men is statistically significant

In [82]:
def chi_squared(category_dicts, flag_one, flag_two, category_explainer):
    table = []
    top_vals = []
    bottom_vals = []
    for i in range(len(category_dicts)):
        val = category_dicts[i][flag_one]
        top_vals.append(category_dicts[i][flag_one])
        bottom_vals.append(category_dicts[i][flag_two])

    table = np.array([top_vals, bottom_vals])
    # print(f"contingency_table for {category_explainer}\n(top row is # correct, bottom row is # incorrect)\n{table}")

    stat, p, dof, expected = chi2_contingency(table)
    reject = "REJECT" if p <= 0.05 else "ACCEPT"

    if p < 0.001:
        p = "< .001"

    print(
        f"For {category_explainer}, we see a chi-squared value of {stat} and a p-value of {p}."
    )
    print(
        f"We can {reject} the null hypothesis that {category_explainer} is independent from performance on the CKA assessment."
    )

### language groups and script

* Romance languages: Catalan, French, Italian, Portuguese, Romanian, Spanish
* Germanic languages: Danish, Dutch, German, Swedish
* Slavic languages: Bulgarian, Czech, Croatian, Polish, Russian, Serbian, Slovenian, Ukrainian
* Hungarian: a Uralic language, not related to any of the other languages.

In [None]:
# 2 x 2
lang_to_script_dict = {
    "bg": "Cyrillic",
    "ca": "Latin",
    "cs": "Latin",
    "da": "Latin",
    "de": "Latin",
    "en": "Latin",
    "es": "Latin",
    "fr": "Latin",
    "hr": "Latin",
    "hu": "Latin",
    "it": "Latin",
    "nl": "Latin",
    "pl": "Latin",
    "pt": "Latin",
    "ro": "Latin",
    "ru": "Cyrillic",
    "sl": "Latin",
    "sr": "Cyrillic",
    "sv": "Latin",
    "uk": "Cyrillic",
}

# 2 x 4
lang_to_group_dict = {
    "bg": "Slavic",
    "ca": "Romance",
    "cs": "Slavic",
    "da": "Germanic",
    "de": "Germanic",
    "en": "Germanic",
    "es": "Romance",
    "fr": "Romance",
    "hr": "Slavic",
    "hu": "Uralic",
    "it": "Romance",
    "nl": "Germanic",
    "pl": "Slavic",
    "pt": "Romance",
    "ro": "Romance",
    "ru": "Slavic",
    "sl": "Slavic",
    "sr": "Slavic",
    "sv": "Germanic",
    "uk": "Slavic",
}

In [None]:
# now, for each of these levels, we need:
# number correct
# number incorrect
# total..
results_df.head()

In [None]:
# scripts
cyrillic = {"correct": 0, "incorrect": 0}
latin = {"correct": 0, "incorrect": 0}

# language groups
germanic = {"correct": 0, "incorrect": 0}
romance = {"correct": 0, "incorrect": 0}
slavic = {"correct": 0, "incorrect": 0}
uralic = {"correct": 0, "incorrect": 0}

for row in results_df.iterrows():
    lang_code = row[1].lang_code
    result = row[1].correct
    mapping = ""

    script = lang_to_script_dict[lang_code]
    group = lang_to_group_dict[lang_code]

    if result:
        mapping = "correct"
    else:
        mapping = "incorrect"

    # language scripts
    if script == "Cyrillic":
        cyrillic[mapping] += 1
    elif script == "Latin":
        latin[mapping] += 1

    # language groups
    if group == "Germanic":
        germanic[mapping] += 1
    elif group == "Romance":
        romance[mapping] += 1
    elif group == "Slavic":
        slavic[mapping] += 1
    elif group == "Uralic":
        uralic[mapping] += 1

print(f"cyrllic: {cyrillic}")
print(f"latin: {latin}")

print(f"germanic: {germanic}")
print(f"romance: {romance}")
print(f"slavic: {slavic}")
print(
    f"uralic: {uralic}"
)  # sanity check -> this is 75.7% correct which is the same as the llama graph result for HU.

# sanity check identical output sizes
assert sum(cyrillic.values()) + sum(latin.values()) == results_df.shape[0]
assert (
    sum(germanic.values())
    + sum(romance.values())
    + sum(slavic.values())
    + sum(uralic.values())
    == results_df.shape[0]
)

In [None]:
chi_squared([cyrillic, latin], "correct", "incorrect", "language script")

In [None]:
chi_squared(
    [germanic, romance, slavic, uralic], "correct", "incorrect", "language family"
)

### Western vs Eastern locations

#### get geo entities

In [15]:
geo_relations = {
'capital': 'P36',
'country': 'P17',
'continent': 'P30',
'capital of': 'P1376',
'is in the administrative territorial entity': 'P131',
'shares border with': 'P47'}

In [16]:
geo_df = results_df[results_df['relation'].isin(list(geo_relations.values()))]

In [68]:
geo_entities = {}
for row in entity_analysis_df.iterrows():
    dataset_ids = list(row[1].dataset_ids)
    for d in dataset_ids:
        if d in list(geo_df['dataset_id']):
            entity = row[1].entity
            if entity not in geo_entities:
                geo_entities[row[1].entity] = [row[1].num_correct, row[1].num_incorrect, [d]]
            else:
                geo_entities[row[1].entity][2].append(d)

In [69]:
len(geo_entities)

3247

In [None]:
with open("../../data/wikidata/full_geo_entities.json", "w") as outfile:
    json.dump(geo_entities, outfile)

#### read in 'tagged' geo entities

In [70]:
geo_entities_tagged = pd.read_csv("../../data/wikidata/geo_entities_tagged.txt", sep="\t")

In [71]:
geo_entities_tagged['location'].value_counts()

EU               995
AS               629
AN               586
North America    583
AF               213
OC               129
SA                78
unsure            34
Name: location, dtype: int64

In [73]:
# count accuracies, first, for EU vs. AS
for k, v in geo_entities.items():
    entity = k
    entity_info = v
    entity_location = geo_entities_tagged[geo_entities_tagged['entity'] == entity]['location']
    
    if entity_location.empty:
        print(entity_info)
        print(f"Couldn't retrieve entity location for entity -- {entity}")
    entity_location = list(entity_location)[0]

    entity_info.append(entity_location)
    # print(f"{entity} is located in {entity_location}")    

[10, 1, ['rome_18230']]
North America


In [74]:
# europe
europe = {"correct": 0, "incorrect": 0}
# asia
asia = {"correct": 0, "incorrect": 0}
# oceania
oceania = {"correct": 0, "incorrect": 0}
# north america
north_america =  {"correct": 0, "incorrect": 0}
# south america
south_america = {"correct": 0, "incorrect": 0}
# africa
africa = {"correct": 0, "incorrect": 0}
# antarctica
antarctica = {"correct": 0, "incorrect": 0}
for k, v in geo_entities.items():
    entity = k
    entity_info = v
    
    num_correct = entity_info[0]
    num_incorrect = entity_info[1]
    location = entity_info[3]
    
    if location == 'EU':
        europe["correct"] += num_correct
        europe["incorrect"] += num_incorrect
    elif location == 'AS':
        asia["correct"] += num_correct
        asia["incorrect"] += num_incorrect
    elif location == 'OC':
        oceania["correct"] += num_correct
        oceania["incorrect"] += num_incorrect
    elif location == 'North America':
        north_america["correct"] += num_correct
        north_america["incorrect"] += num_incorrect
    elif location == 'SA':
        south_america["correct"] += num_correct
        south_america["incorrect"] += num_incorrect
    elif location == 'AF':
        africa["correct"] += num_correct
        africa["incorrect"] += num_incorrect
    elif location == 'AN':
        antarctica["correct"] += num_correct
        antarctica["incorrect"] += num_incorrect
    else:
        continue

#### assess performance on CKA per continent

In [118]:
sum(europe.values())

16201

In [90]:
europe['correct']/sum(europe.values())

0.9048824146657614

In [119]:
sum(asia.values())

10729

In [91]:
asia['correct']/sum(asia.values())

0.9330785720943238

In [120]:
sum(oceania.values())

1761

In [92]:
oceania['correct']/sum(oceania.values())

0.889267461669506

In [121]:
sum(north_america.values())

8656

In [93]:
north_america['correct']/sum(north_america.values())

0.8925600739371534

In [122]:
sum(south_america.values())

1726

In [94]:
south_america['correct']/sum(south_america.values())

0.9159907300115875

In [123]:
sum(africa.values())

4366

In [99]:
africa['correct']/sum(africa.values())

0.9177737059092991

In [125]:
sum(antarctica.values())

5167

In [95]:
antarctica['correct']/sum(antarctica.values())

0.8064640990903813

In [128]:
# rest of world vs. antarctica
ds = [europe, asia, oceania, north_america, south_america, africa]
rest_of_world = {}
for k in ["correct", "incorrect"]:
    rest_of_world[k] = sum(tuple(d[k] for d in ds))

In [129]:
rest_of_world

{'correct': 39551, 'incorrect': 3888}

In [134]:
rest_of_world['correct']/sum(rest_of_world.values())

0.9104951771449619

In [130]:
# rest of north america + europe vs. asia
ds = [europe, north_america]
europe_and_na = {}
for k in ["correct", "incorrect"]:
    europe_and_na[k] = sum(tuple(d[k] for d in ds))

In [131]:
europe_and_na

{'correct': 22386, 'incorrect': 2471}

In [133]:
europe_and_na['correct']/sum(europe_and_na.values())

0.900591382709096

In [135]:
# overall
# rest of world vs. antarctica
ds = [europe, asia, oceania, north_america, south_america, africa, antarctica]
overall = {}
for k in ["correct", "incorrect"]:
    overall[k] = sum(tuple(d[k] for d in ds))

In [138]:
# how many geo CKA questions were asked for our 3213 location-tagged entities
sum(overall.values())

48606

In [136]:
# overall CKA geo perf
overall['correct']/sum(overall.values())

0.8994362835863885

#### run chi squared

In [111]:
chi_squared(
    [europe, asia], "correct", "incorrect", "europe vs. asia geo entities"
)

For europe vs. asia geo entities, we see a chi-squared value of 66.40865439517813 and a p-value of < .001.
We can REJECT the null hypothesis that europe vs. asia geo entities is independent from performance on the CKA assessment.


In [112]:
chi_squared(
    [antarctica, rest_of_world], "correct", "incorrect", "antarctica vs. rest of world geo entities"
)

For antarctica vs. rest of world geo entities, we see a chi-squared value of 551.3639365665833 and a p-value of < .001.
We can REJECT the null hypothesis that antarctica vs. rest of world geo entities is independent from performance on the CKA assessment.


In [115]:
chi_squared(
    [asia, oceania], "correct", "incorrect", "oceania vs. asia geo entities"
)

For oceania vs. asia geo entities, we see a chi-squared value of 42.20897924571737 and a p-value of < .001.
We can REJECT the null hypothesis that oceania vs. asia geo entities is independent from performance on the CKA assessment.


In [116]:
chi_squared(
    [north_america, south_america], "correct", "incorrect", "north vs. south america geo entities"
)

For north vs. south america geo entities, we see a chi-squared value of 8.260629096452938 and a p-value of 0.004051408610312914.
We can REJECT the null hypothesis that north vs. south america geo entities is independent from performance on the CKA assessment.


In [132]:
chi_squared(
    [europe_and_na, asia], "correct", "incorrect", "europe and na vs. asia"
)

For europe and na vs. asia, we see a chi-squared value of 96.55315610667571 and a p-value of < .001.
We can REJECT the null hypothesis that europe and na vs. asia is independent from performance on the CKA assessment.


### Male vs Female names

#### get people entities

In [None]:
people_relations = {'P20': 'place of death',
'P1303': 'instrument',
'P108': 'employer',
'P103': 'native language',
'P39': 'position held',
'P413': 'position played on team',
'P937': 'work location',
'P641': 'sport',
'P106': 'occupation',
'P101': 'field of work'}

In [None]:
people_df = results_df[results_df['relation'].isin(list(people_relations.keys()))]

In [None]:
people_entities = {}
for row in entity_analysis_df.iterrows():
    dataset_ids = list(row[1].dataset_ids)
    for d in dataset_ids:
        if d in list(people_df['dataset_id']):
            entity = row[1].entity
            if entity not in people_entities:
                people_entities[row[1].entity] = [row[1].num_correct, row[1].num_incorrect, [d]]
            else:
                people_entities[row[1].entity][2].append(d)

In [None]:
len(people_entities)

In [None]:
with open("../../data/wikidata/full_people_entities.json", "w") as outfile:
    json.dump(people_entities, outfile)

#### read in 'tagged' people entities