In [None]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

In [None]:
import matplotlib
import concepts_xai
import numpy as np
import os
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
from tensorflow.keras.models import load_model
from joblib import dump, load
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from concepts_xai.evaluation.metrics.niching import niche_completeness
from concepts_xai.evaluation.metrics.niching import niche_completeness_ratio
from concepts_xai.evaluation.metrics.niching import niche_impurity
from concepts_xai.evaluation.metrics.niching import niche_finding

In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

LATEX_SYMBOL = "$"
RESULTS_DIR = "results/"
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
rc('text', usetex=(LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')

In [None]:
from collections import defaultdict
from joblib import dump, load

def load_results_dir(dirname):
    result = {}
    means_vars = {}
    stds_vars = {}
    for filename in os.listdir(dirname):
        if filename.endswith("_means.npz"):
            means_vars[filename[:-len("_means.npz")]] = np.load(os.path.join(dirname, filename))
            means_vars[filename[:-len("_means.npz")]] = list(map(
                lambda x: means_vars[filename[:-len("_means.npz")]][f"arr_{x}"],
                range(len(means_vars[filename[:-len("_means.npz")]])),
            ))
        elif filename.endswith("_stds.npz"):
            stds_vars[filename[:-len("_stds.npz")]] = np.load(os.path.join(dirname, filename))
            stds_vars[filename[:-len("_stds.npz")]] = list(map(
                lambda x: stds_vars[filename[:-len("_stds.npz")]][f"arr_{x}"],
                range(len(stds_vars[filename[:-len("_stds.npz")]])),
            ))
    if set(means_vars.keys()) != set(stds_vars.keys()):
        raise ValueError(
            f"Found mean variables {set(means_vars.keys())} vs std variables {set(stds_vars.keys())}."
        )
    for var, means in means_vars.items():
        result[var] = list(zip(means, stds_vars[var]))
    
    return result

def trials_mean_and_std(values, num_trials=5):
    values = np.array(values)
    means = []
    stds = []
    for i in range(values.shape[0]//num_trials):
        next_trial_list = values[i*num_trials:(i+1)*num_trials, ...]
        means.append(np.mean(next_trial_list, axis=0))
        stds.append(np.std(next_trial_list, axis=0))
    return means, stds
        
def load_from_joblib(path):
    results = load(path)
    for key, vals in results.items():
        if isinstance(vals, (list, np.ndarray)):
            # Then split it up accross different trials
            try:
                results[key] = list(zip(*trials_mean_and_std(vals)))
            except:
                print("Could not reduce to mean and std with entry", key, "at", path)
                continue
    return results

def bold_text(x):
    if LATEX_SYMBOL == "$":
        return r"$\textbf{" + x + "}$"
    return x

# Load Results (NOTE: ALL EXPERIMENTS MUST BE COMPLETED BEFORE THIS STEP)

In [None]:
################################################################################
## Toy Tabular Oracle Imputer Scores Results Loading
################################################################################

toy_tabular_oracle_dir = os.path.join(RESULTS_DIR, "toy_tabular")
oracle_toy_tabular_results = defaultdict(dict)


oracle_toy_tabular_results["cbm"]["base"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "cbm/scratch_training_purity"
))

oracle_toy_tabular_results["cbm"]["logits"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "cbm/from_logits"
))


oracle_toy_tabular_results["cw"]["max_pool_mean"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "cw/purity_new"
))
oracle_toy_tabular_results["cw"]["mean"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "cw/purity_new"
))

oracle_toy_tabular_results["ccd"]["base"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "ccd/purity"
))
oracle_toy_tabular_results["ccd"]["extended"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "ccd/purity_double_concepts"
))


oracle_toy_tabular_results["senn"]["base"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "senn/purity"
))
oracle_toy_tabular_results["senn"]["extended"] = load_results_dir(os.path.join(
    toy_tabular_oracle_dir,
    "senn/purity_extended"
))

################################################################################
## Toy Tabular Niching Impurity Score Results Loading
################################################################################

toy_tabular_niching_concepts_dir = os.path.join("results_concept_niching_integrated", "toy_tabular")
niching_concepts_toy_tabular_results = defaultdict(dict)


niching_concepts_toy_tabular_results["cbm"]["base"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "cbm/base/results_niching.joblib"
))
niching_concepts_toy_tabular_results["cbm"]["logits"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "cbm/from_logits/results_niching.joblib"
))


niching_concepts_toy_tabular_results["cw"]["max_pool_mean"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "cw/base/results_niching.joblib"
))
niching_concepts_toy_tabular_results["cw"]["max_pool_mean_feat"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "cw/base/results_niching.joblib"
))
niching_concepts_toy_tabular_results["cw"]["mean"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "cw/base/results_niching.joblib"
))

niching_concepts_toy_tabular_results["ccd"]["base"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "ccd/base/results_niching.joblib"
))
niching_concepts_toy_tabular_results["ccd"]["extended"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "ccd/extended/results_niching.joblib"
))


niching_concepts_toy_tabular_results["senn"]["base"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "senn/purity/results_niching.joblib"
))
niching_concepts_toy_tabular_results["senn"]["extended"] = load_from_joblib(os.path.join(
    toy_tabular_niching_concepts_dir,
    "senn/purity_extended/results_niching.joblib"
))

In [None]:
################################################################################
## dSprites Oracle Results Loading
################################################################################

dsprites_oracle_dir = os.path.join(RESULTS_DIR, "dsprites")
oracle_dsprites_results = defaultdict(dict)


oracle_dsprites_results["cbm"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "cbm/graph_dependency_balanced_multiclass_tasks_purity"
))

oracle_dsprites_results["cbm"]["logits"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "cbm/graph_dependency_balanced_multiclass_from_logits_tasks_purity"
))


oracle_dsprites_results["cw"]["max_pool_mean"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "cw/balanced_multiclass_tasks_purity_max_pool_mean"
))
oracle_dsprites_results["cw"]["mean"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "cw/balanced_multiclass_tasks_purity_mean"
))


oracle_dsprites_results["ada_ml_vae"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ada_ml_vae/multilabel_purity_latent_5"
))
oracle_dsprites_results["ada_ml_vae"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ada_ml_vae/multilabel_purity_latent_10"
))


oracle_dsprites_results["ada_g_vae"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ada_g_vae/multilabel_purity_latent_5"
))
oracle_dsprites_results["ada_g_vae"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ada_g_vae/multilabel_purity_latent_10"
))



oracle_dsprites_results["vae"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_5_beta_1"
))
oracle_dsprites_results["vae"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_10_beta_1"
))



oracle_dsprites_results["beta_vae"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_5_beta_10"
))
oracle_dsprites_results["beta_vae"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_10_beta_10"
))


oracle_dsprites_results["ccd"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_5"
))
oracle_dsprites_results["ccd"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_10"
))


oracle_dsprites_results["senn"]["base"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "senn/dependency_multiclass"
))
oracle_dsprites_results["senn"]["extended"] = load_results_dir(os.path.join(
    dsprites_oracle_dir,
    "senn/dependency_multiclass_extended"
))


################################################################################
## dSprites Niching Impurity Score Results Loading
################################################################################

dsprites_niching_concepts_dir = os.path.join("results_concept_niching_integrated", "dsprites")
niching_concepts_dsprites_results = defaultdict(dict)


niching_concepts_dsprites_results["cbm"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "cbm/base/results_niching.joblib"
))
niching_concepts_dsprites_results["cbm"]["logits"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "cbm/from_logits/results_niching.joblib"
))


niching_concepts_dsprites_results["cw"]["max_pool_mean"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "cw/base_feature_False/results_niching.joblib"
))
niching_concepts_dsprites_results["cw"]["max_pool_mean_feat"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "cw/base_feature_True/results_niching.joblib"
))
niching_concepts_dsprites_results["cw"]["mean"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "cw/base_feature_False_mean/results_niching.joblib"
))



niching_concepts_dsprites_results["ada_ml_vae"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ada_ml_vae/multilabel_purity_latent_5/results_niching.joblib"
))
niching_concepts_dsprites_results["ada_ml_vae"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ada_ml_vae/multilabel_purity_latent_10/results_niching.joblib"
))


niching_concepts_dsprites_results["ada_g_vae"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ada_g_vae/multilabel_purity_latent_5/results_niching.joblib"
))
niching_concepts_dsprites_results["ada_g_vae"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ada_g_vae/multilabel_purity_latent_10/results_niching.joblib"
))



niching_concepts_dsprites_results["vae"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "beta_vae/purity_latent_5_beta_1/results_niching.joblib"
))
niching_concepts_dsprites_results["vae"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "beta_vae/purity_latent_10_beta_1/results_niching.joblib"
))



niching_concepts_dsprites_results["beta_vae"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "beta_vae/purity_latent_5_beta_10/results_niching.joblib"
))
niching_concepts_dsprites_results["beta_vae"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "beta_vae/purity_latent_10_beta_10/results_niching.joblib"
))


niching_concepts_dsprites_results["ccd"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_5/results_niching.joblib"
))
niching_concepts_dsprites_results["ccd"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_10/results_niching.joblib"
))


niching_concepts_dsprites_results["senn"]["base"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "senn/dependency_multiclass/results_niching.joblib"
))
niching_concepts_dsprites_results["senn"]["extended"] = load_from_joblib(os.path.join(
    dsprites_niching_concepts_dir,
    "senn/dependency_multiclass_extended/results_niching.joblib"
))

In [None]:
from collections import defaultdict
from joblib import dump, load

################################################################################
## shapes3d Oracle Results Loading
################################################################################

shapes3d_oracle_dir = os.path.join(RESULTS_DIR, "shapes3d")
oracle_shapes3d_results = defaultdict(dict)


oracle_shapes3d_results["cbm"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "cbm/dependency_balanced_multiclass_tasks"
))
oracle_shapes3d_results["cbm"]["logits"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "cbm/dependency_balanced_multiclass_from_logits_tasks"
))



oracle_shapes3d_results["cw"]["max_pool_mean"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "cw/balanced_multiclass_purity_max_pool_mean"
))
oracle_shapes3d_results["cw"]["mean"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "cw/balanced_multiclass_purity_mean"
))


oracle_shapes3d_results["ada_ml_vae"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ada_ml_vae/multilabel_purity_latent_6"
))
oracle_shapes3d_results["ada_ml_vae"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ada_ml_vae/multilabel_extended_purity_latent_12"
))


oracle_shapes3d_results["ada_g_vae"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ada_g_vae/multilabel_purity_latent_6"
))
oracle_shapes3d_results["ada_g_vae"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ada_g_vae/multilabel_extended_purity_latent_12"
))



oracle_shapes3d_results["vae"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_6_beta_1"
))
oracle_shapes3d_results["vae"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_12_beta_1"
))



oracle_shapes3d_results["beta_vae"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_6_beta_10"
))
oracle_shapes3d_results["beta_vae"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "beta_vae/balanced_multilabel_purity_latent_12_beta_10"
))


oracle_shapes3d_results["ccd"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_6"
))
oracle_shapes3d_results["ccd"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "ccd/balanced_multiclass_thresh_0_num_concepts_12"
))


oracle_shapes3d_results["senn"]["base"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "senn/dependency_multiclass"
))
oracle_shapes3d_results["senn"]["extended"] = load_results_dir(os.path.join(
    shapes3d_oracle_dir,
    "senn/dependency_multiclass_extended"
))

################################################################################
## shapes3d Niching Impurity Score Results Loading
################################################################################

shapes3d_niching_concepts_dir = os.path.join("results_concept_niching_integrated", "shapes3d")
niching_concepts_shapes3d_results = defaultdict(dict)

niching_concepts_shapes3d_results["cbm"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "cbm/base/results_niching.joblib"
))
niching_concepts_shapes3d_results["cbm"]["logits"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "cbm/from_logits/results_niching.joblib"
))


niching_concepts_shapes3d_results["cw"]["max_pool_mean"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "cw/base_feature_False/results_niching.joblib"
))
niching_concepts_shapes3d_results["cw"]["max_pool_mean_feat"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "cw/base_feature_True/results_niching.joblib"
))
niching_concepts_shapes3d_results["cw"]["mean"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "cw/base_feature_False_mean/results_niching.joblib"
))



niching_concepts_shapes3d_results["ada_ml_vae"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ada_ml_vae/graph_dependency_multiclass_tasks_purity_latent_6/results_niching.joblib"
))
niching_concepts_shapes3d_results["ada_ml_vae"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ada_ml_vae/graph_dependency_multiclass_tasks_purity_latent_12/results_niching.joblib"
))


niching_concepts_shapes3d_results["ada_g_vae"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ada_g_vae/graph_dependency_multiclass_tasks_purity_latent_6/results_niching.joblib"
))
niching_concepts_shapes3d_results["ada_g_vae"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ada_g_vae/graph_dependency_multiclass_tasks_purity_latent_12/results_niching.joblib"
))



niching_concepts_shapes3d_results["vae"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "beta_vae/purity_latent_6_beta_1/results_niching.joblib"
))
niching_concepts_shapes3d_results["vae"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "beta_vae/purity_latent_12_beta_1/results_niching.joblib"
))



niching_concepts_shapes3d_results["beta_vae"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "beta_vae/purity_latent_6_beta_10/results_niching.joblib"
))
niching_concepts_shapes3d_results["beta_vae"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "beta_vae/purity_latent_12_beta_10/results_niching.joblib"
))


niching_concepts_shapes3d_results["ccd"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ccd/balanced_multiclass_num_concepts_6/results_niching.joblib"
))
niching_concepts_shapes3d_results["ccd"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "ccd/balanced_multiclass_num_concepts_12/results_niching.joblib"
))


niching_concepts_shapes3d_results["senn"]["base"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "senn/dependency_multiclass/results_niching.joblib"
))
niching_concepts_shapes3d_results["senn"]["extended"] = load_from_joblib(os.path.join(
    shapes3d_niching_concepts_dir,
    "senn/dependency_multiclass_extended/results_niching.joblib"
))

# OIS and NIS Benchmarking Results

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
clrs = sns.color_palette("tab20", 30)
fig_width = 20
fig_height = 3
# color_map = {}
scale = 1
trials = 5
fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_height))


all_vars = [0, -1]
real_values = [0, 0.9]
num_concepts = 3

all_models = [
    (
        "Joint-CBM",
        niching_concepts_toy_tabular_results["cbm"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        niching_concepts_toy_tabular_results["cw"]["max_pool_mean"],
        "niss",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        niching_concepts_toy_tabular_results["ccd"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        niching_concepts_toy_tabular_results["senn"]["base"],
        "niss",
        lambda x: x,
    ),
]
num_models = len(all_models) + 1
ax = axs[0]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    means = np.array(list(map(transform_fn, means)))
    stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    stds = np.array(list(map(transform_fn, stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("NIS", fontsize=30)
ax.set_xlabel("$\delta$", fontsize=30)
ax.set_title(bold_text("TabularToy($\delta$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_vars)))
ax.set_xticklabels(list(map(lambda x: f"{x:.1f}", real_values)), fontsize=25)
ax.set_yticks(np.arange(0.1, 1.1, 0.2))
ax.set_ylim((0.1, 1))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0.1, 1.1, 0.2))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 4]
num_concepts = 5
all_models = [
    (
        "Joint-CBM",
        niching_concepts_dsprites_results["cbm"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        niching_concepts_dsprites_results["cw"]["max_pool_mean"],
        "niss",
        lambda x: x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        niching_concepts_dsprites_results["ada_ml_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        niching_concepts_dsprites_results["ada_g_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        niching_concepts_dsprites_results["beta_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"VAE (n\_latent = k)",
        niching_concepts_dsprites_results["vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        niching_concepts_dsprites_results["ccd"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        niching_concepts_dsprites_results["senn"]["base"],
        "niss",
        lambda x: x,
    ),
]
num_models = len(all_models) + 2
ax = axs[1]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    means = np.array(list(map(transform_fn, means)))
    stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    stds = np.array(list(map(transform_fn, stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("")
ax.set_xlabel("$\lambda$", fontsize=30)
ax.set_title(bold_text("dSprites($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(0.1, 1.1, 0.2))
ax.set_ylim((0.1, 1))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0.1, 1.1, 0.2))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()


all_vars = [0, -1]
all_values = [0, 5]
num_concepts = 6
all_models = [
    (
        "Joint-CBM",
        niching_concepts_shapes3d_results["cbm"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        niching_concepts_shapes3d_results["cw"]["max_pool_mean"],
        "niss",
        lambda x: x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        niching_concepts_shapes3d_results["ada_ml_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        niching_concepts_shapes3d_results["ada_g_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        niching_concepts_shapes3d_results["beta_vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"VAE (n\_latent = k)",
        niching_concepts_shapes3d_results["vae"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        niching_concepts_shapes3d_results["ccd"]["base"],
        "niss",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        niching_concepts_shapes3d_results["senn"]["base"],
        "niss",
        lambda x: x,
    ),
]

num_models = len(all_models) + 2
ax = axs[2]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    means = np.array(list(map(transform_fn, means)))
    stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    stds = np.array(list(map(transform_fn, stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("")
ax.set_xlabel("$\lambda$", fontsize=30)
ax.set_title(bold_text("3dshapes($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(0.1, 1.1, 0.2))
ax.set_ylim((0.1, 1))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0.1, 1.1, 0.2))), fontsize=18)
ax.grid(False)



handles, labels = ax.get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.1), ncol=(num_models - 1)//2)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
clrs = sns.color_palette("tab20", 30)
fig_width = 20
fig_height = 3
# color_map = {}
scale = 1
trials = 5
fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_height))

all_vars = [0, -1]
real_values = [0.0, 0.9]
num_concepts = 3
all_models = [
    (
        "Joint-CBM",
        oracle_toy_tabular_results["cbm"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_toy_tabular_results["cw"]["max_pool_mean"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_toy_tabular_results["ccd"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_toy_tabular_results["senn"]["base"],
        "purity_scores",
        lambda x: x,
    ),
]
num_models = len(all_models) + 1
ax = axs[0]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("OIS", fontsize=30)
ax.set_xlabel("$\delta$", fontsize=30)
ax.set_title(bold_text("TabularToy($\delta$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_vars)))
ax.set_xticklabels(list(map(lambda x: f"{x:.1f}", real_values)), fontsize=25)
ax.set_yticks(np.arange(0, 0.9, 0.2))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0, 0.9, 0.2))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 4]
num_concepts = 5
all_models = [
    (
        "Joint-CBM",
        oracle_dsprites_results["cbm"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_dsprites_results["cw"]["max_pool_mean"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_dsprites_results["ada_ml_vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_dsprites_results["ada_g_vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_dsprites_results["beta_vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_dsprites_results["vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_dsprites_results["ccd"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_dsprites_results["senn"]["base"],
        "purity_scores",
        lambda x: x,
    ),
]
num_models = len(all_models) + 2
ax = axs[1]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
    ax.set_title
ax.set_xlabel("$\lambda$", fontsize=30)
ax.set_title(bold_text("dSprites($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(0, 0.9, 0.2))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0, 0.9, 0.2))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 5]
num_concepts = 6
all_models = [
    (
        "Joint-CBM",
        oracle_shapes3d_results["cbm"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_shapes3d_results["cw"]["max_pool_mean"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_ml_vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_g_vae"]["extended"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_shapes3d_results["beta_vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_shapes3d_results["vae"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_shapes3d_results["ccd"]["base"],
        "purity_scores",
        lambda x: x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_shapes3d_results["senn"]["base"],
        "purity_scores",
        lambda x: x,
    ),
]

num_models = len(all_models) + 2
ax = axs[2]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_xlabel("$\lambda$", fontsize=30)
ax.set_title(bold_text("3dshapes($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values))) #, fontsize=20)
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(0, 0.9, 0.2))
ax.set_yticklabels(list(map(lambda x: f"{x:.1f}", np.arange(0, 0.9, 0.2))), fontsize=18)
ax.grid(False)




handles, labels = ax.get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.1), ncol=(num_models - 1)//2)
plt.show()

# Accuracy Plots

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
clrs = sns.color_palette("tab20", 30)
fig_width = 24
fig_height = 5
# color_map = {}
scale = 1
trials = 5
fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_height))

all_vars = [0, -1]
real_values = [0.0, 0.9]
num_concepts = 3
all_models = [
    (
        "Joint-CBM",
        oracle_toy_tabular_results["cbm"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_toy_tabular_results["cbm"]["logits"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_toy_tabular_results["cw"]["max_pool_mean"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_toy_tabular_results["cw"]["max_pool_mean"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_toy_tabular_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_toy_tabular_results["senn"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
]
num_models = len(all_models) + 1
ax = axs[0]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    print(method_name, kword)
    print(method_name, kword)
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("AUC(\%)", fontsize=30)
# ax.set_xlabel("$\delta$", fontsize=25)
ax.set_title(bold_text("TabularToy($\delta$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_vars)))
ax.set_xticklabels(list(map(lambda x: f"{x}", real_values)), fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 103))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 4]
num_concepts = 5
all_models = [
    (
        "Joint-CBM",
        oracle_dsprites_results["cbm"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_dsprites_results["cbm"]["logits"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_dsprites_results["cw"]["max_pool_mean"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_dsprites_results["cw"]["max_pool_mean"],
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_dsprites_results["ada_ml_vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_dsprites_results["ada_g_vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_dsprites_results["beta_vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_dsprites_results["vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_dsprites_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_dsprites_results["senn"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
]
num_models = len(all_models) + 2
ax = axs[1]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    print(method_name, kword)
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
    ax.set_title
# ax.set_ylabel("Oracle Impurity Score", fontsize=25)
# ax.set_xlabel("$\lambda$", fontsize=25)
ax.set_title(bold_text("dSprites($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 103))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 5]
num_concepts = 6
all_models = [
    (
        "Joint-CBM",
        oracle_shapes3d_results["cbm"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_shapes3d_results["cbm"]["logits"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_shapes3d_results["cw"]["max_pool_mean"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_shapes3d_results["cw"]["max_pool_mean"],
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_ml_vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_g_vae"]["extended"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_shapes3d_results["beta_vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_shapes3d_results["vae"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_shapes3d_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_shapes3d_results["senn"]["base"],
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 2
ax = axs[2]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_title(bold_text("3dshapes($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values))) #, fontsize=20)
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 103))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()



handles, labels = ax.get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, loc='upper center', bbox_to_anchor=(0.5,0.03), ncol=(num_models - 1)//2)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
clrs = sns.color_palette("tab20", 30)
fig_width = 16
fig_height = 5
# color_map = {}
scale = 1
trials = 5
fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_height))

all_vars = [0, -1]
real_values = [0.0, 0.9]
num_concepts = 3
all_models = [
    (
        "Joint-CBM",
        oracle_toy_tabular_results["cbm"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_toy_tabular_results["cbm"]["logits"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_toy_tabular_results["cw"]["max_pool_mean"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_toy_tabular_results["cw"]["max_pool_mean"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_toy_tabular_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_toy_tabular_results["senn"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
]
num_models = len(all_models) + 1
ax = axs[0]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    print(method_name, kword)
    print(method_name, kword)
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_ylabel("AUC(\%)", fontsize=30)
# ax.set_xlabel("$\delta$", fontsize=25)
ax.set_title(bold_text("TabularToy($\delta$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_vars)))
ax.set_xticklabels(list(map(lambda x: f"{x}", real_values)), fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 108))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 4]
num_concepts = 5
all_models = [
    (
        "Joint-CBM",
        oracle_dsprites_results["cbm"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_dsprites_results["cbm"]["logits"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_dsprites_results["cw"]["max_pool_mean"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_dsprites_results["cw"]["max_pool_mean"],
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_dsprites_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_dsprites_results["senn"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
]
num_models = len(all_models) + 2
ax = axs[1]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    print(method_name, kword)
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
    ax.set_title
# ax.set_ylabel("Oracle Impurity Score", fontsize=25)
# ax.set_xlabel("$\lambda$", fontsize=25)
ax.set_title(bold_text("dSprites($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 108))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 5]
num_concepts = 6
all_models = [
    (
        "Joint-CBM",
        oracle_shapes3d_results["cbm"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "Joint-CBM-Logits",
        oracle_shapes3d_results["cbm"]["logits"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        oracle_shapes3d_results["cw"]["max_pool_mean"],
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        oracle_shapes3d_results["cw"]["max_pool_mean"],
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = k)",
        oracle_shapes3d_results["ccd"]["base"],
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"SENN (n\_concepts = k)",
        oracle_shapes3d_results["senn"]["base"],
        "task_aucs",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 2
ax = axs[2]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_title(bold_text("3dshapes($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values))) #, fontsize=20)
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 108))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()



handles, labels = ax.get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=23, loc='upper center', bbox_to_anchor=(0.5,0.03), ncol=(num_models - 1)//2)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
clrs = sns.color_palette("tab20", 30)
fig_width = 14
fig_height = 5
# color_map = {}
scale = 1
trials = 5
fig, axs = plt.subplots(1, 2, figsize=(fig_width, fig_height))

all_vars = [0, -1]
all_values = [0, 4]
num_concepts = 5
all_models = [
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_dsprites_results["ada_ml_vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_dsprites_results["ada_g_vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_dsprites_results["beta_vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_dsprites_results["vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
]
num_models = len(all_models) + 2
ax = axs[0]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    print(method_name, kword)
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )


ax.set_ylabel("Avg Concept AUC (\%)", fontsize=30)
ax.set_title(bold_text("dSprites($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values)))
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 103))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()

all_vars = [0, -1]
all_values = [0, 5]
num_concepts = 6
all_models = [
    (
        f"Ada-MLVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_ml_vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = k)",
        oracle_shapes3d_results["ada_g_vae"]["extended"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Beta-VAE (n\_latent = k)",
        oracle_shapes3d_results["beta_vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"VAE (n\_latent = k)",
        oracle_shapes3d_results["vae"]["base"],
        "latent_avg_concept_predictive_accuracies",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 2
ax = axs[1]

for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    if method_name not in color_map:
        color_map[method_name] = clrs[len(color_map)]
    color = color_map[method_name]
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=color,
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )
ax.set_title(bold_text("3dshapes($\lambda$)"), fontsize=30)
ax.set_xticks(np.arange(0, len(all_values))) #, fontsize=20)
ax.set_xticklabels(all_values, fontsize=25)
ax.set_yticks(np.arange(70, 110, 10))
ax.set_ylim((70, 103))
ax.set_yticklabels(list(map(lambda x: f"{x}", np.arange(70, 110, 10))), fontsize=18)
ax.grid(False)
handles, labels = ax.get_legend_handles_labels()



handles, labels = ax.get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, loc='upper center', bbox_to_anchor=(0.5,0.03), ncol=(num_models - 1)//2)
plt.show()