In [None]:
import os
import json

import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from shapely.geometry import Point
import geopandas as gpd
from geopandas import GeoDataFrame
import geodatasets
import colorcet as cc


In [None]:
DATASETS_DIR = "/Datasets"
PROPERTY_DICTS_DIR = "/Datasets/property_dicts"
QUESTION_LISTS_DIR = "/Datasets/all_questions"

In [None]:
dict_of_dicts = {}
for property_dict in os.listdir(PROPERTY_DICTS_DIR):
    with open(os.path.join(PROPERTY_DICTS_DIR, property_dict), "r") as f:
        distribution = json.load(f)
        dict_of_dicts[property_dict.split("_")[0]] = distribution
        

In [None]:
len(dict_of_dicts)

In [None]:
dict_of_dicts.keys()

In [None]:
encycl_benchmarks = ["BoolQ", "SQuAD", "NaturalQuestions", "HotpotQA", "TriviaQA", "WebQuestions", "COQA"]
commons_benchmarks = ["PIQA", "WinoGrande", "COPA", "CommonsenseQA", "SIQA", "HellaSwag", "TruthfulQA"]
exams_benchmarks = ["OpenBookQA", "RACE", "ScienceQA", "MMLU", "GPQA", "ARC", "GSM8K"]
exclude = ["SIQA", "GSM8K", "COPA", "MMLU"]

In [None]:
cmap = matplotlib.colors.ListedColormap(cc.cm.glasbey.colors[20 : len(dict_of_dicts.keys()) + 20])
palette = lambda n : sns.color_palette(cc.glasbey, n_colors=n)

In [None]:
def map_list_to_color(lst, hex=False):
    offset = 0
    if not hex:
        colors = cc.cm.glasbey.colors[offset : len(lst) + offset]
    else:
        colors = cc.glasbey[offset : len(lst) + offset]
    map = dict(zip(lst, colors))
    return map

In [None]:
benchmarks_color_list = map_list_to_color(dict_of_dicts.keys())
benchmarks_color_list_hex = map_list_to_color(dict_of_dicts.keys(), hex=True)


In [None]:
benchmarks_color_list_hex

# Coordinates

In [None]:
res = lambda x_l: [float(x) for x in x_l.replace("Point(", "").replace(")", "").split(' ')]


In [None]:
dict_of_dicts.keys()

In [None]:

world = gpd.read_file(geodatasets.data.naturalearth.land['url'])
cols, rows = 2, len(dict_of_dicts) // 2
fig, axs = plt.subplots(nrows=rows, ncols=cols)
fig.set_size_inches(cols*3.2, rows*2)

benchmarks = [b for b in dict_of_dicts.keys() if len(dict_of_dicts[b]["coordinates"]["labels"]) > 30 and b not in exclude]

count = 0
for irow in range(axs.shape[0]):
    for icol in range(axs.shape[1]):
        if count<len(benchmarks):
            geometry = [Point(res(x)) for x in dict_of_dicts[benchmarks[count]]["coordinates"]["labels"] if not x is None and x.startswith("Point")]
            gdf = GeoDataFrame(geometry, geometry=geometry)
            gdf.plot(ax=world.plot(ax=axs[irow][icol], color='white', edgecolor='black'), marker='o', color=benchmarks_color_list[benchmarks[count]], markersize=10)
            axs[irow][icol].set_title(benchmarks[count], fontsize=15)
            axs[irow][icol].set_xticks([])
            axs[irow][icol].set_yticks([])
            count +=1
        else:
            # hide extra axes
            axs[irow][icol].set_visible(False)

plt.tight_layout()
plt.savefig(os.path.join(DATASETS_DIR, 'images', 'geomap_combined.pdf'), format='pdf', dpi=500, bbox_inches='tight')


In [None]:
import plotly.express as px
style_dict = {
    'layout.plot_bgcolor': 'rgba(0,0,0,0)',
    'layout.paper_bgcolor': 'white',
    'layout.margin': {'t':40,'l':10,'b':50,'r':10},
    #'layout.font.family': 'Times New Roman',
    'layout.xaxis.linecolor': 'black',
    'layout.xaxis.ticks': 'inside',
    'layout.xaxis.mirror': True,
    'layout.xaxis.showline': True,
    'layout.yaxis.linecolor': 'black',
    'layout.yaxis.ticks': 'inside',
    'layout.yaxis.mirror': True,
    'layout.yaxis.showline': True,
    'layout.autosize': False,
    'layout.showlegend': True,
    'layout.legend.bgcolor': 'rgba(0,0,0,0)',
    #'layout.legend.xanchor': 'right',
    'layout.legend.x': 0,
    'layout.legend.y': -0.5,
    'layout.legend.font.size': 13,
    # Specialized:
    # 'layout.xaxis.range': (2.3, 2.5),
    #'layout.yaxis.range': (-50, +50),
    'layout.xaxis.title': r'$x$',
    'layout.yaxis.title': r'$y$',
    'layout.legend_title': None,
    'layout.title_x': 0.5,
    'layout.title_y': 1,
    'layout.title_xanchor': 'center',
    'layout.title_yanchor': 'top',
    'layout.title_font_color': 'black',
    'layout.legend.font_color': 'black',
}

# Gender

In [None]:
gendered_benchmarks = [key for key in dict_of_dicts.keys() if len(dict_of_dicts[key]["genders"]["labels"]) > 30 and key not in exclude]
gendered_benchmarks

In [None]:
gender_labels_all = []
for benchmark in gendered_benchmarks:
    for label in pd.Series(dict_of_dicts[benchmark]["genders"]["labels"]).unique():
        if label not in gender_labels_all and not None:
            gender_labels_all.append(label)

In [None]:
gender_df = pd.DataFrame(columns=["benchmark", "label", "count"])
gender_df["benchmark"] = [benchmark for benchmark in gendered_benchmarks for label in gender_labels_all]
gender_df["label"] = [label for benchmark in gendered_benchmarks for label in gender_labels_all]
for benchmark in gendered_benchmarks:
    value_counts = pd.Series(dict_of_dicts[benchmark]["genders"]["labels"]).value_counts()/len(dict_of_dicts[benchmark]["genders"]["labels"])*100
    for label in gender_labels_all:
        if label in value_counts.index:
            gender_df.loc[(gender_df["label"]==label) & (gender_df["benchmark"]==benchmark), "count"] = value_counts[label]
        else:
            gender_df.loc[(gender_df["label"]==label) & (gender_df["benchmark"]==benchmark), "count"] = 0

In [None]:
gender_df

In [None]:
sns.set_style("white")
data = gender_df.loc[gender_df["label"].isin(["male", "female"]), :].sort_values("count", ascending=False)
g = sns.catplot(
    data=data, kind="bar",
    x="count", y="label", palette=sns.color_palette(palette=benchmarks_color_list_hex, as_cmap=True), hue="benchmark", width=0.8
)
g.refline(x=50, color='black')
g.despine(left=True)
g.set_axis_labels("Occurrence [%]", "")
sns.move_legend(g, "upper left", bbox_to_anchor=(0.9, 0.95))
g.legend.set_title("")
plt.title("Entity genders")
plt.gcf().set_size_inches(4,4)

plt.savefig(os.path.join(DATASETS_DIR, 'images', 'genders_combined.pdf'), format='pdf', dpi=400, bbox_inches='tight')

In fact, small difference for WinoGrande!

# Occupation

In [None]:
occupied_benchmarks = [key for key in exams_benchmarks if len(dict_of_dicts[key]["occupations"]["labels"]) > 30 and key not in exclude]
len(occupied_benchmarks)

In [None]:
ncols = 2
nrows = 2
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True)
count = 0
for c in range(ncols):
    for r in range(nrows):
        series = pd.Series(dict_of_dicts[occupied_benchmarks[count]]["occupations"]["labels"])
        percentages = series.value_counts().values[:10] / len(series) * 100
        indices = series.value_counts().index[:10]
        sns.barplot(x=percentages, y=indices, ax=axs[c, r], color=benchmarks_color_list_hex[occupied_benchmarks[count]])
        axs[c, r].set_ylabel("")
        axs[c, r].set_xlabel("Occurrence [%]")
        axs[c, r].set_title(occupied_benchmarks[count])
        count += 1
fig.suptitle("Most frequent occupations")
fig.tight_layout()
plt.savefig(os.path.join(DATASETS_DIR, 'images', 'occupations_combined.pdf'), format='pdf', dpi=400, bbox_inches='tight')


# Religion

In [None]:
religion_benchmarks = [key for key in dict_of_dicts.keys() if len(dict_of_dicts[key]["religion"]["labels"]) > 30 and key not in exclude]
len(religion_benchmarks)

In [None]:
for b in religion_benchmarks:
    print(b, len(dict_of_dicts[b]["religion"]["labels"]))

In [None]:
top_n = 10
ncols = 2
nrows = 8
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, sharex=True, figsize=(ncols*3.3, nrows*2.3))
count = 0
for r in range(nrows):
    for c in range(ncols):
        if count < len(religion_benchmarks):
            series = pd.Series(dict_of_dicts[religion_benchmarks[count]]["religion"]["labels"])
            percentages = series.value_counts().values[:top_n] / len(series) * 100
            indices = series.value_counts().index[:top_n]
            name_updates = {"The Church of Jesus Christ of Latter-day Saints": "Mormon Church", 
                            "Seventh-day Adventist Church": "Sev.-day Adv. Ch.", 
                            "Nondenominational Christianity": "Nondenom. Christ.",
                            "Georgian Orthodox Church": "Georgian Orth. Ch.",
                            "United Church of Christ": "United Ch. of Christ"}
            indices = [i.replace(i, name_updates[i]) if i in name_updates else i for i in indices]
            sns.barplot(x=percentages, y=indices, ax=axs[r,c], color=benchmarks_color_list_hex[religion_benchmarks[count]])
            axs[r,c].set_ylabel("")
            axs[r,c].set_xlabel("")
            axs[r,c].tick_params(axis='y', labelsize=14)
            axs[r,c].tick_params(axis='x', labelsize=14)
            axs[r,c].set_title(religion_benchmarks[count], fontsize=14)
            count += 1
            
fig.delaxes(axs[r,c])
fig.suptitle('Top-10 religions [%]',y=.99, fontsize=15)
fig.tight_layout()
plt.savefig(os.path.join(DATASETS_DIR, 'images', 'religions_combined.pdf'), format='pdf', dpi=400, bbox_inches='tight')

# Instance of

In [None]:
instance_labels_all = []
for benchmark in benchmarks:
    for label in pd.Series(dict_of_dicts[benchmark]["instance_of"]["labels"]).unique():
        if label not in instance_labels_all and not None:
            instance_labels_all.append(label)
instance_labels_all[:10]

In [None]:
instance_df = pd.DataFrame(columns=["benchmark", "label", "count", "percentage"])
instance_df["benchmark"] = [benchmark for benchmark in benchmarks for label in instance_labels_all]
instance_df["label"] = [label for benchmark in benchmarks for label in instance_labels_all]
for benchmark in benchmarks:
    value_counts = pd.Series(dict_of_dicts[benchmark]["instance_of"]["labels"]).value_counts()
    value_percentages = value_counts/len(dict_of_dicts[benchmark]["instance_of"]["labels"])*100
    for label in instance_labels_all:
        if label in value_counts.index:
            instance_df.loc[(instance_df["label"]==label) & (instance_df["benchmark"]==benchmark), "count"] = value_counts[label]
            instance_df.loc[(instance_df["label"]==label) & (instance_df["benchmark"]==benchmark), "percentage"] = value_percentages[label]
        else:
            instance_df.loc[(instance_df["label"]==label) & (instance_df["benchmark"]==benchmark), "count"] = 0
            instance_df.loc[(instance_df["label"]==label) & (instance_df["benchmark"]==benchmark), "percentage"] = 0

In [None]:
import numpy as np
instance_sum_df = pd.DataFrame(columns=["label", "cumulative", "cumulative_pcnt_per_bench", "cumulative_pcnt_by_total"])
instance_sum_df["label"] = instance_labels_all
total_sum = np.sum([len(dict_of_dicts[benchmark]["instance_of"]["labels"]) for benchmark in benchmarks])
print(total_sum)
for label in instance_labels_all:
    cumulative_count = instance_df.loc[instance_df["label"] == label, "count"].sum()
    cumulative_pcnt = instance_df.loc[instance_df["label"] == label, "percentage"].sum()
    instance_sum_df.loc[instance_sum_df["label"] == label, "cumulative"] = cumulative_count
    instance_sum_df.loc[instance_sum_df["label"] == label, "cumulative_pcnt_per_bench"] = cumulative_pcnt #(sum / total_sum) * 100
    instance_sum_df.loc[instance_sum_df["label"] == label, "cumulative_pcnt_by_total"] = (cumulative_count / total_sum) * 100

instance_sum_df = instance_sum_df.sort_values(by="cumulative_pcnt_per_bench", ascending=False)
instance_sum_df

In [None]:
sns.set_style("white")
frequent_occupations = instance_sum_df[:10]["label"] #ation_sum_df.loc[occupation_sum_df["cumulative_pcnt"]>1, "label"]
g = sns.catplot(
    data=instance_sum_df[:10], kind="bar", x="cumulative_pcnt_per_bench", y="label", width=1
)
g.despine(left=True)
g.set_axis_labels("Occurrence [%]", "")
g.legend.set_title("")
plt.title("Top-10 instance types summed across benchmark datasets")
#fig.tight_layout()
plt.gcf().set_size_inches(10, 5)

plt.savefig(os.path.join(DATASETS_DIR, 'instance_cumulative.png'), format='png', dpi=300, bbox_inches='tight')

In [None]:
plt.pie(instance_sum_df.loc[:10,"cumulative_pcnt_per_bench"], labels = instance_sum_df.loc[:10,"label"], colors = palette(10), autopct='%.0f%%')
plt.show()