In [4]:
import plotly.graph_objects as go
import pickle
from transformers import AutoTokenizer
import numpy as np
import shap
from src.plot_text import text, get_grouped_vals

# from src.utils import format_fts_for_plotting, text_ft_index_ends, token_segments
# from src.utils import legacy_get_dataset_info
from datasets import load_dataset

# import matplotlib.pyplot as plt
from src.run_shap import load_shap_vals
from tqdm import tqdm

# import re
# import seaborn as sns
# from scipy import stats
from src.utils import (
    ConfigLoader,
    text_ft_index_ends,
    token_segments,
    format_fts_for_plotting,
    format_text_fts_too,
)
import pandas as pd
import plotly.express as px

In [8]:
configs = [
    "vet_50c_all_text",
    "vet_59c_ensemble_25",
    "vet_59c_ensemble_50",
    "vet_59c_ensemble_75",
    "vet_59c_stack",
    "vet_10c_all_text",
    "vet_19c_ensemble_25",
    "vet_19c_ensemble_50",
    "vet_19c_ensemble_75",
    "vet_19c_stack",
]
configs = [
    "vet_59c_stack",
    "vet_19c_stack",
    "vet_59c_ensemble_25",
    "vet_19c_ensemble_25",
    "vet_59c_ensemble_50",
    "vet_19c_ensemble_50",
    "vet_59c_ensemble_75",
    "vet_19c_ensemble_75",
    "vet_50c_all_text",
    "vet_10c_all_text",
]

configs = [
    "vet_19c_stack",
    "vet_59c_stack",
    "vet_19c_ensemble_25",
    "vet_59c_ensemble_25",
    "vet_19c_ensemble_50",
    "vet_59c_ensemble_50",
    "vet_19c_ensemble_75",
    "vet_59c_ensemble_75",
    "vet_10c_2_all_text",
    "vet_50c_all_text",
]

c2c_dict = {
    "vet_59c_stack": "A: PetBERT Stack",
    "vet_19c_stack": "B: BERT Stack",
    "vet_59c_ensemble_25": "C: PetBERT WE w=.25",
    "vet_19c_ensemble_25": "D: BERT WE w=.25",
    "vet_59c_ensemble_50": "E: PetBERT WE w=.50",
    "vet_19c_ensemble_50": "F: BERT WE w=.50",
    "vet_59c_ensemble_75": "G: PetBERT WE w=.75",
    "vet_19c_ensemble_75": "H: BERT WE w=.75",
    "vet_50c_all_text": "I: PetBERT All Text",
    "vet_10c_all_text": "J: BERT All Text",
    "vet_10c_2_all_text": "J: BERT All Text 2",
}

configs = [
    "vet_19c_stack",
    "vet_59c_stack",
    "vet_19c_ensemble_25",
    "vet_59c_ensemble_25",
    "vet_19c_ensemble_50",
    "vet_59c_ensemble_50",
    "vet_19c_ensemble_75",
    "vet_59c_ensemble_75",
    # "vet_10c_all_text",
    "vet_10c_2_all_text",
    "vet_50c_all_text",
]

c2c_dict = {
    "vet_19c_stack": "A: BERT Stack",
    "vet_59c_stack": "B: PetBERT Stack",
    "vet_19c_ensemble_25": "C: BERT WE w=.25",
    "vet_59c_ensemble_25": "D: PetBERT WE w=.25",
    "vet_19c_ensemble_50": "E: BERT WE w=.50",
    "vet_59c_ensemble_50": "F: PetBERT WE w=.50",
    "vet_19c_ensemble_75": "G: BERT WE w=.75",
    "vet_59c_ensemble_75": "H: PetBERT WE w=.75",
    "vet_50c_all_text": "I: PetBERT All Text",
    "vet_10c_all_text": "J: BERT All Text",
}


val_dict = {}
for config in configs:
    with open(f"../models/shap_vals/summed_{config}.pkl", "rb") as f:
        grouped_shap_vals = pickle.load(f)
    val_dict[config] = (
        abs(grouped_shap_vals).sum(axis=0).mean(axis=0)
        / abs(grouped_shap_vals).sum(axis=0).mean(axis=0).sum()
    )
df1 = pd.DataFrame(columns=["config", "col", "val", "total"])
for idx, config in enumerate(configs):
    vals = val_dict[config]
    args = ConfigLoader(
        config, "../configs/shap_configs.yaml", "../configs/dataset_default.yaml"
    )
    cols = args.categorical_cols + args.numerical_cols + args.text_cols
    for i, col in enumerate(cols):
        df1 = pd.concat(
            [
                df1,
                pd.DataFrame(
                    [
                        {
                            "config": c2c_dict[config],
                            "col": col,
                            "val": vals[i],
                            "total": 1,
                        }
                    ]
                ),
            ],
            ignore_index=True,
        )


# country_index = col_index
# year = config
# country = col
# pop = val
# cols_to_use = set(df1[df1["val"] > 0.05]['col'].values)
cols_to_use = (
    df1.groupby("col")["val"]
    .sum()
    .reset_index()
    .sort_values("val", ascending=False)["col"]
    .values[:6]
)
# if col not in cols_to_use: set col = 'other'
df1.loc[~df1["col"].isin(cols_to_use), "col"] = "other"
# cols_to_use.add('other')
# cols_to_use = list(cols_to_use)
cols_to_use = list(cols_to_use) + ["other"]
# df1 = df1.sort_values(["val"], ignore_index=True, ascending=False)
# df1 = df1.sort_values(["config"], key=lambda x: x.map(dict(zip(configs, range(len(configs))))), ignore_index=False, ascending=True)
# make col index based on cols_to_use
df1["col_index"] = df1["col"].apply(lambda x: cols_to_use.index(x))
# if col == 'other': set col_index = max(col_index) + 1
df1.loc[df1["col"] == "other", "col_index"] = df1["col_index"].max() + 1
# df1 = df1.sort_values(["config"], ignore_index=True, ascending=True)
# df1 = df1.sort_values(["col_index"], ignore_index=True, ascending=True)


# custom_order = dict(zip(range(len(cols_to_use)),configs ))
# df1 = df1.sort_values(by=["col_index", "config"], ignore_index=True, ascending=[True, True], key=lambda x: x.map(custom_order))


# # custom_order = dict(zip(configs, range(len(configs)) ))
# # df1 = df1.sort_values(by=["config", "col_index" ], ignore_index=False, ascending=[True, True], key=lambda x: x.map(custom_order))
# # df1 =
# # # df1 = df1.sort_values(by=["col_index", "config"], ignore_index=True, ascending=[True, True], key=lambda x: x.map(custom_order))

df1 = df1.sort_values(
    ["col_index", "config"], ignore_index=True, ascending=[True, False]
)
#
# colors = px.colors.qualitative.Set3[:len(cols)]


df1["showlegend"] = df1.groupby("col")["config"].cumcount() == 0
# + px.colors.qualitative.Pastel + px.colors.qualitative.Bold + px.colors.qualitative.Light24
colors = px.colors.qualitative.Plotly

fig = go.Figure()
for index, row in df1.iterrows():
    fig.add_trace(
        go.Bar(
            y=[row["config"]],
            # y=[1],
            x=[row["val"]],
            # customdata=[row["gdpPercap"]],
            # hovertemplate="%{x:2f}, GDP per capita: %{customdata:2f}",
            marker=dict(color=colors[row["col_index"]]),
            name=row["col"],
            legendgroup=row["col"],
            showlegend=row["showlegend"],
            orientation="h",
        )
    )

fig.update_layout(barmode="stack", hovermode="y unified")

Updating with:
{'config': 'vet_19c_stack', 'my_text_model': 'james-burton/vet_19c', 'ds_name': 'james-burton/vet_month_1c_ordinal', 'text_model_base': 'bert-base-uncased', 'model_type': 'stack', 'ord_ds_name': 'james-burton/vet_month_1c_ordinal', 'text_cols': ['breed', 'region', 'record']}


{'categorical_cols': ['gender', 'neutered', 'species', 'insured'], 'numerical_cols_long': ['age_at_consult', 'Diseases of the ear or mastoid process', 'Mental, behavioural or neurodevelopmental disorders', 'Diseases of the blood or blood-forming organs', 'Diseases of the circulatory system', 'Dental', 'Developmental anomalies', 'Diseases of the digestive system', 'Endocrine, nutritional or metabolic diseases', 'Diseases of the Immune system', 'Certain infectious or parasitic diseases', 'Diseases of the skin', 'Diseases of the musculoskeletal system or connective tissue', 'Neoplasms', 'Diseases of the nervous system', 'Diseases of the visual system', 'Certain conditions originating in the perinatal 


The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.



In [6]:
sum(val_dict["vet_10c_all_text"][:3])

0.0005154414868679021