In [None]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

# Purity dSprites Benchmarking

## Setup

In [None]:
import matplotlib
import concepts_xai
import numpy as np
import os
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
import scipy
import utils
import model_utils

In [None]:
################################################################################
## Set seeds up for reproducibility
################################################################################

utils.reseed(87)


In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

_LATEX_SYMBOL = "$"
RESULTS_DIR = "results/dsprites"
DATASETS_DIR = os.path.join("results/dsprites", "datasets/")
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True)
rc('text', usetex=(_LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')

## Utility Functions

In [None]:
def bold_text(x):
    if _LATEX_SYMBOL == "$":
        return r"$\textbf{" + x + "}$"
    return x

# Graph Dependency Dataset Construction

In [None]:
import concepts_xai.datasets.dSprites as dsprites
import concepts_xai.datasets.latentFactorData as latentFactorData

def generate_dsprites_dataset(
    label_fn,
    filter_fn=None,
    dataset_path=None,
    concept_map_fn=lambda x: x,
    sample_map_fn=lambda x: x,
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    force_reload=False,
):
    if (not force_reload) and dataset_path and os.path.exists(dataset_path):
        # Them time to load up this dataset!
        ds = np.load(dataset_path)
        return (
            (ds["x_train"], ds["y_train"], ds["c_train"]),
            (ds["x_test"], ds["y_test"], ds["c_test"])
        )
    
    def _task_fn(x_data, c_data):
        return latentFactorData.get_task_data(
            x_data=x_data,
            c_data=c_data,
            label_fn=label_fn,
            filter_fn=filter_fn,
        )

    loaded_dataset = dsprites.dSprites(
        dataset_path=dsprites_path,
        train_size=0.8,
        random_state=42,
        task=_task_fn,
    )
    _, _, _ = loaded_dataset.load_data()

    x_train = sample_map_fn(loaded_dataset.x_train)
    y_train = loaded_dataset.y_train
    c_train = concept_map_fn(loaded_dataset.c_train)
    
    x_test = sample_map_fn(loaded_dataset.x_test)
    y_test = loaded_dataset.y_test
    c_test = concept_map_fn(loaded_dataset.c_test)
    
    if dataset_path:
        # Then serialize it to speed up things next time
        np.savez(
            dataset_path,
            x_train=x_train,
            y_train=y_train,
            c_train=c_train,
            x_test=x_test,
            y_test=y_test,
            c_test=c_test,
        )
    return (x_train, y_train, c_train), (x_test, y_test, c_test),


## Multi-class Task Dataset Construction (Dependent Binary Concepts)

In [None]:
def count_class_balance(y):
    one_hot = tf.keras.utils.to_categorical(y)
    return np.sum(one_hot, axis=0) / one_hot.shape[0]

def multiclass_binary_concepts_map_fn(concepts):
    new_concepts = np.zeros((concepts.shape[0], 5))
    # We will have 5 concepts:
    # (0) "is it ellipse or square?"
    new_concepts[:, 0] = (concepts[:, 0] < 2).astype(np.int)

    # (1) "is_size < 3?"
    num_sizes = len(set(concepts[:, 1]))
    new_concepts[:, 1] = (concepts[:, 1] < num_sizes/2).astype(np.int)

    # (2) "is rotation < PI/2?"
    num_rots = len(set(concepts[:, 2]))
    new_concepts[:, 2] = (concepts[:, 2] < num_rots/2).astype(np.int)

    # (3) "is x <= 16?"
    num_x_coords = len(set(concepts[:, 3]))
    new_concepts[:, 3] = (concepts[:, 3] < num_x_coords // 2).astype(np.int)

    # (4) "is y <= 16?"
    num_y_coords = len(set(concepts[:, 4]))
    new_concepts[:, 4] = (concepts[:, 4] < num_y_coords // 2).astype(np.int)
    
    return new_concepts

def _get_concept_vector(c_data):
    return np.array([
        # First check if it is an ellipse or a square
        int(c_data[0] < 2),
        # Now check that it is "small"
        int(c_data[1] < 3),
        # And it has not been rotated more than PI/2 radians
        int(c_data[2] < 20),
        # Finally, check whether it is in not in the the upper-left quadrant
        int(c_data[3] < 15),
        int(c_data[4] < 15),
    ])

def multiclass_task_label_fn(c_data):
    # Our task will be a binary task where we are interested in determining
    # whether an image is a "small" ellipse not in the upper-left
    # quadrant that has been rotated less than 3*PI/2 radians
    concept_vector = _get_concept_vector(c_data)
    binary_label_encoding = [
        concept_vector[0] or concept_vector[1],
        concept_vector[2] or concept_vector[3],
        concept_vector[4],
    ]
    return int(
        "".join(list(map(str, binary_label_encoding))),
        2
    )

def dep_0_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(0, 6, 2)),
        list(range(0, 40, 4)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    return all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ])



scale_shape_sets_lower = [
    list(np.random.permutation(4))[:3] for i in range(3)
]

scale_shape_sets_upper = [
    list(2 + np.random.permutation(4))[:3] for i in range(3)
]
def dep_1_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 4)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    return True



rotation_scale_sets_lower = [
    list(np.random.permutation(30))[:20] for i in range(6)
]

rotation_scale_sets_upper = [
    list(10 + np.random.permutation(30))[:20] for i in range(6)
]
def dep_2_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
    return True



x_pos_rotation_sets_lower = [
    list(np.random.permutation(20))[:16]
    for i in range(40)
]

x_pos_rotation_sets_upper = [
    list(12 + np.random.permutation(20))[:16]
    for i in range(40)
]
def dep_3_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
        
    if concept_vector[2]:
        if concept[3] not in x_pos_rotation_sets_lower[concept[2]]:
            return False
    else:
        if concept[3] not in x_pos_rotation_sets_upper[concept[2]]:
            return False
    return True

y_pos_x_pos_sets_lower = [
    list(np.random.permutation(20))[:16]
    for i in range(32)
]

y_pos_x_pos_sets_upper = [
    list(12 + np.random.permutation(20))[:16]
    for i in range(32)
]
def dep_4_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32)),
        list(range(0, 32)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
        
    if concept_vector[2]:
        if concept[3] not in x_pos_rotation_sets_lower[concept[2]]:
            return False
    else:
        if concept[3] not in x_pos_rotation_sets_upper[concept[2]]:
            return False
    
    if concept_vector[3]:
        if concept[4] not in y_pos_x_pos_sets_lower[concept[3]]:
            return False
    else:
        if concept[4] not in y_pos_x_pos_sets_upper[concept[3]]:
            return False
    return True

In [None]:
def balanced_multiclass_task_label_fn(c_data):
    # Our task will be a binary task where we are interested in determining
    # whether an image is a "small" ellipse not in the upper-left
    # quadrant that has been rotated less than 3*PI/2 radians
    concept_vector = _get_concept_vector(c_data)
    threshold = 0
    if concept_vector[0] == 1:
        binary_label_encoding = [
            concept_vector[1],
            concept_vector[2],
        ]
    else:
        threshold = 4
        binary_label_encoding = [
            concept_vector[3],
            concept_vector[4],
        ]
    return threshold + int(
        "".join(list(map(str, binary_label_encoding))),
        2
    )

balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_0_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_0_complete_dataset.npz"),
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    concept_map_fn=multiclass_binary_concepts_map_fn,
#     force_reload=True,
)

dep_0_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
    )
)
for c in range(dep_0_corr_mat.shape[0]):
    for l in range(dep_0_corr_mat.shape[1]):
        dep_0_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_0_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_0_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_0_corr_mat),
    [f"$c_{i}$" for i in range(dep_0_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_0_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 0$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_0_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_0_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_1_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_1_complete_dataset.npz"),
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    concept_map_fn=multiclass_binary_concepts_map_fn,
#     force_reload=True,
)

dep_1_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_1_complete_train[1]))
    )
)
for c in range(dep_1_corr_mat.shape[0]):
    for l in range(dep_1_corr_mat.shape[1]):
        dep_1_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_1_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_1_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_1_corr_mat),
    [f"$c_{i}$" for i in range(dep_1_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_1_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 1$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()


print("balanced_multiclass_task_bin_concepts_dep_1_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_1_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_1_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_1_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_1_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_2_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_2_complete_dataset.npz"),
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    concept_map_fn=multiclass_binary_concepts_map_fn,
#     force_reload=True,
)

dep_2_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_2_complete_train[1]))
    )
)
for c in range(dep_2_corr_mat.shape[0]):
    for l in range(dep_2_corr_mat.shape[1]):
        dep_2_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_2_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_2_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_2_corr_mat),
    [f"$c_{i}$" for i in range(dep_2_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_2_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 2$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()


print("balanced_multiclass_task_bin_concepts_dep_2_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_2_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_2_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_2_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_2_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_3_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_3_complete_dataset.npz"),
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    concept_map_fn=multiclass_binary_concepts_map_fn,
#     force_reload=True,
)

dep_3_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_3_complete_train[1]))
    )
)
for c in range(dep_3_corr_mat.shape[0]):
    for l in range(dep_3_corr_mat.shape[1]):
        dep_3_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_3_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_3_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_3_corr_mat),
    [f"$c_{i}$" for i in range(dep_3_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_3_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 3$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_3_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_3_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_3_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_3_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_3_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape[0])
print("")



balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_4_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_4_complete_dataset.npz"),
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    concept_map_fn=multiclass_binary_concepts_map_fn,
#     force_reload=True,
)

dep_4_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_4_complete_train[1]))
    )
)
for c in range(dep_4_corr_mat.shape[0]):
    for l in range(dep_4_corr_mat.shape[1]):
        dep_4_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_4_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_4_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_4_corr_mat),
    [f"$c_{i}$" for i in range(dep_4_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_4_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 4$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_4_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_4_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_4_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_4_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_4_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape[0])

## Model Construction

In [None]:
# Construct the encoder model
def _extract_concepts(activations, concept_cardinality):
    concepts = []
    total_seen = 0
    if all(np.array(concept_cardinality) <= 1):
        # Then nothing to do here as they are all binary concepts
        return activations
    for num_values in concept_cardinality:
        concepts.append(activations[:, total_seen: total_seen + num_values])
        total_seen += num_values
    return concepts
    
def construct_encoder(
    input_shape,
    filter_groups,
    units,
    concept_cardinality,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_dims=0,
    output_logits=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    if latent_dims:
        bypass = tf.keras.layers.Dense(
            latent_dims,
            activation="sigmoid",
            name="encoder_bypass_channel",
        )(encoder_compute_graph)
    else:
        bypass = None
    
    # Map to our output distribution to a flattened
    # vector where we will extract distributions over
    # all concept values
    encoder_compute_graph = tf.keras.layers.Dense(
        sum(concept_cardinality),
        activation=None,
        name="encoder_concept_outputs",
    )(encoder_compute_graph)
        
    # Separate this vector into all of its heads
    concept_outputs = _extract_concepts(
        encoder_compute_graph,
        concept_cardinality,
    )
    if not output_logits:
        if isinstance(concept_outputs, list):
            for i, concept_vec in enumerate(concept_outputs):
                if concept_vec.shape[-1] == 1:
                    # Then this is a binary concept so simply apply sigmoid
                    concept_outputs[i] = tf.keras.activations.sigmoid(concept_vec)
                else:
                    # Else we will apply a softmax layer as we assume that all of these
                    # entries represent a multi-modal probability distribution
                    concept_outputs[i] = tf.keras.activations.softmax(
                        concept_vec,
                        axis=-1,
                    )
        else:
            # Else they are allbinary concepts so let's sigmoid them
            concept_outputs = tf.keras.activations.sigmoid(concept_outputs)
    return tf.keras.Model(
        encoder_inputs,
        [concept_outputs, bypass] if bypass is not None else concept_outputs,
        name="encoder",
    )

In [None]:
############################################################################
## Build concepts-to-labels model
############################################################################

def construct_decoder(units, num_outputs):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs if num_outputs > 2 else 1,
            activation=None,
            name="decoder_model_output",
        )
    ])

In [None]:
# Construct the complete model
def construct_end_to_end_model(
    input_shape,
    encoder,
    decoder,
    num_outputs,
    learning_rate=1e-3,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    latent = encoder(model_inputs)
    if isinstance(latent, list):
        if len(latent) > 1:
            compacted_vector = tf.keras.layers.Concatenate(axis=-1)(
                latent
            )
        else:
            compacted_vector = latent[0]
    else:
        compacted_vector = latent
    model_compute_graph = decoder(compacted_vector)
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        metrics=[
            "binary_accuracy" if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    return model, encoder, decoder

# CBM Benchmarking

## Model Definition

In [None]:
############################################################################
## Build CBM
############################################################################
import concepts_xai.methods.CBM.CBModel as CBM

def construct_cbm(
    encoder,
    decoder,
    num_outputs,
    alpha=0.1,
    learning_rate=1e-3,
    latent_dims=0,
    encoder_output_logits=False,
):
    model_factory = CBM.BypassJointCBM if latent_dims else CBM.JointConceptBottleneckModel
    cbm_model = model_factory(
        encoder=encoder,
        decoder=decoder,
        task_loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        name="joint_cbm",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
        alpha=alpha,
        pass_concept_logits=encoder_output_logits,
    )

    ############################################################################
    ## Compile CBM Model
    ############################################################################

    cbm_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return cbm_model


## Experiment Loop

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def cbm_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    oracle_matrix_cache = oracle_matrix_cache or {}
    utils.reseed(87)
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_accuracies=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    
    # Else, let's go ahead and run the whole thing
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset:", ds_name)
        task_accs = []
        concept_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} of dataset {ds_name}")
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    filter_groups=experiment_config["encoder_filter_groups"],
                    units=experiment_config["encoder_units"],
                    concept_cardinality=experiment_config["concept_cardinality"],
                    drop_prob=experiment_config.get("drop_prob", 0.5),
                    max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                    max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                    latent_dims=experiment_config.get("latent_dims", 0),
                    output_logits=experiment_config.get("encoder_output_logits", False),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                end_to_end_model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tModel pre-training completed")
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                num_outputs=experiment_config["num_outputs"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_concept_accuracy",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "max",
                ),
            )
            if experiment_config["warmup_epochs"]:
                print("\tWarmup training...")
                cbm_model.fit(
                    x=x_train,
                    y=(y_train, c_train),
                    epochs=experiment_config["warmup_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tWarmup training completed")


            print("\tCBM training...")
            cbm_model.fit(
                x=x_train,
                y=(y_train, c_train),
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tCBM training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_{ds_name}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_{ds_name}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = cbm_model.evaluate(
                x_test,
                (y_test, c_test),
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )
            concept_accs.append(test_result['concept_accuracy'])
            
            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(cbm_model.predict(x_test)[0], axis=-1)

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    cbm_model.predict(x_test)[0],
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept accuracy = {concept_accs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                encoder_model=cbm_model.encoder,
                features=x_test,
                concepts=c_test,
                output_matrices=True,
                oracle_matrix=oracle_matrix_cache.get(ds_name),
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                encoder_model=cbm_model.encoder,
                features=x_test,
                concepts=c_test,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["num_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_acc_mean, concept_acc_std = np.mean(concept_accs), np.std(concept_accs)
        experiment_variables["concept_accuracies"].append((concept_acc_mean, concept_acc_std))
        print(f"\tTest concept accuracy: {concept_acc_mean:.4f} ± {concept_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cbm_bottleneck_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset", ds_name)
        task_accs = []
        aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_{ds_name}_trial_{trial}"
                )
            )
            
            predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining model")
            train_codes = encoder(x_train)
            if isinstance(train_codes, list):
                train_codes = np.concatenate(list(map(lambda x: x.numpy(), train_codes)), axis=-1)
            else:
                train_codes = train_codes.numpy()
            test_codes = encoder(x_test)
            if isinstance(test_codes, list):
                test_codes = np.concatenate(list(map(lambda x: x.numpy(), test_codes)), axis=-1)
            else:
                test_codes = test_codes.numpy()
            predictive_decoder.fit(
                x=train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\tEvaluating model")
            test_result = predictive_decoder.evaluate(
                test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    predictive_decoder.predict(test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                print(np.sum(preds[:100, :], axis=-1))
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    predictive_decoder.predict(test_codes),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cbm_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset", ds_name)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_{ds_name}_trial_{trial}"
                )
            )
            

            train_codes = encoder(x_train)
            if isinstance(train_codes, list):
                train_codes = np.concatenate(list(map(lambda x: x.numpy(), train_codes)), axis=-1)
            else:
                train_codes = train_codes.numpy()
            test_codes = encoder(x_test)
            if isinstance(test_codes, list):
                test_codes = np.concatenate(list(map(lambda x: x.numpy(), test_codes)), axis=-1)
            else:
                test_codes = test_codes.numpy()
            
            current_accuracies = []
            current_aucs = []
            for concept_idx in range(experiment_config["num_concepts"]):
                print("\tTraining model for concept", concept_idx)
                predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )
                predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )
                predictive_decoder.fit(
                    x=train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["concept_predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tEvaluating model")
                test_result = predictive_decoder.evaluate(
                    test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accuracies.append(test_result['binary_accuracy'])
                
                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    predictive_decoder.predict(test_codes),
                ))
                print(
                f"\t\t\tAverage test concept accuracy = {current_accuracies[-1]:.4f}, "
                f"average test concept AUC = {current_aucs[-1]:.4f}"
            )
            
            avg_concept_accs.append(np.mean(current_accuracies))
            avg_concept_aucs.append(np.mean(current_aucs))
            print(
                f"\t\tAverage test concept accuracy = {avg_concept_accs[-1]:.4f}, "
                f"average test concept AUC = {avg_concept_aucs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest task accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest task AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

## CBM Experiments

In [None]:
reload(CBM)

############################################################################
## Experiment config
############################################################################

graph_dependency_balanced_multiclass_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=10,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))]
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    concept_predictor_max_epochs=100,
    
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    concept_cardinality=[
        2 for _ in range(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
    ],
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cbm/graph_dependency_balanced_multiclass_tasks_purity"),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dims=0,
    holdout_fraction=0.1,
    verbosity=0,
    early_stop_metric="val_concept_accuracy",
    early_stop_mode="max",
    encoder_output_logits=False,
)

# Generate the experiment directory if it does not exist already
Path(graph_dependency_balanced_multiclass_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
graph_dependency_balanced_multiclass_figure_dir = os.path.join(graph_dependency_balanced_multiclass_experiment_config["results_dir"], "figures")
Path(graph_dependency_balanced_multiclass_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

graph_dependency_balanced_multiclass_results = cbm_experiment_loop(
    graph_dependency_balanced_multiclass_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
)
print("task_accuracies:", graph_dependency_balanced_multiclass_results["task_accuracies"])
print("concept_accuracies:", graph_dependency_balanced_multiclass_results["concept_accuracies"])
print("task_aucs:", graph_dependency_balanced_multiclass_results["task_aucs"])
print("purity_scores:", graph_dependency_balanced_multiclass_results["purity_scores"])

# And let's generate an oracle matrix cache to accelerate the experiments that follow
# up with the same datasets
balanced_oracle_matrix_cache = {
    "balanced_multiclass_task_bin_concepts_dep_0_complete": graph_dependency_balanced_multiclass_results["oracle_matrices"][0][0],
    "balanced_multiclass_task_bin_concepts_dep_1_complete": graph_dependency_balanced_multiclass_results["oracle_matrices"][1][0],
    "balanced_multiclass_task_bin_concepts_dep_2_complete": graph_dependency_balanced_multiclass_results["oracle_matrices"][2][0],
    "balanced_multiclass_task_bin_concepts_dep_3_complete": graph_dependency_balanced_multiclass_results["oracle_matrices"][3][0],
    "balanced_multiclass_task_bin_concepts_dep_4_complete": graph_dependency_balanced_multiclass_results["oracle_matrices"][4][0],
}


In [None]:
graph_dependency_balanced_multiclass_results.update(cbm_bottleneck_predict_experiment_loop(
    graph_dependency_balanced_multiclass_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))

In [None]:
graph_dependency_balanced_multiclass_results.update(cbm_bottleneck_concept_predict_experiment_loop(
    graph_dependency_balanced_multiclass_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))

# CBM With Logits

In [None]:
reload(CBM)

############################################################################
## Experiment config
############################################################################

graph_dependency_balanced_multiclass_from_logits_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=10,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))]
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    concept_predictor_max_epochs=100,
    
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    concept_cardinality=[
        2 for _ in range(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
    ],
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cbm/graph_dependency_balanced_multiclass_from_logits_tasks_purity"),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dims=0,
    holdout_fraction=0.1,
    verbosity=0,
    early_stop_metric="val_concept_accuracy",
    early_stop_mode="max",
    encoder_output_logits=True,
)

# Generate the experiment directory if it does not exist already
Path(graph_dependency_balanced_multiclass_from_logits_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
graph_dependency_balanced_multiclass_from_logits_figure_dir = os.path.join(graph_dependency_balanced_multiclass_from_logits_experiment_config["results_dir"], "figures")
Path(graph_dependency_balanced_multiclass_from_logits_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

graph_dependency_balanced_multiclass_from_logits_results = cbm_experiment_loop(
    graph_dependency_balanced_multiclass_from_logits_experiment_config,
    oracle_matrix_cache=balanced_oracle_matrix_cache,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
)
print("task_accuracies:", graph_dependency_balanced_multiclass_from_logits_results["task_accuracies"])
print("concept_accuracies:", graph_dependency_balanced_multiclass_from_logits_results["concept_accuracies"])
print("task_aucs:", graph_dependency_balanced_multiclass_from_logits_results["task_aucs"])

In [None]:
graph_dependency_balanced_multiclass_from_logits_results.update(cbm_bottleneck_predict_experiment_loop(
    graph_dependency_balanced_multiclass_from_logits_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))

In [None]:
graph_dependency_balanced_multiclass_from_logits_results.update(cbm_bottleneck_concept_predict_experiment_loop(
    graph_dependency_balanced_multiclass_from_logits_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))

# CW Benchmark

In [None]:
import concepts_xai.methods.CW.CWLayer as CW

def conv_predictor_model_fn(
    input_concept_classes=1,
    output_concept_classes=2,
):
    estimator = tf.keras.models.Sequential([
         tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            output_concept_classes if output_concept_classes > 2 else 1,
            # We will merge the activation into the loss for numerical
            # stability
            activation=None,
        ),
    ])
    estimator.compile(
        # Use ADAM optimizer by default
        optimizer='adam',
        # Note: we assume labels come without a one-hot-encoding in the
        #       case when the concepts are categorical.
        loss=(
            tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True
            ) if output_concept_classes > 2 else
            tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
            )
        ),
    )
    return estimator


def construct_cw_model(
    input_shape,
    num_outputs,
    filter_groups,
    units,
    activation_mode,
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    T=5,
    eps=1e-5,
    momentum=0.9,
    c1=1e-4,
    c2=0.9,
    max_tau_iterations=500,
    initial_tau=1000.0,
    initial_beta=1e8,
    initial_alpha=0,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    model_compute_graph = model_inputs
    
    # Start with our convolutions
    num_convs = 0
    cw_inputs = []
    cw_ouputs = []
    for filter_group in filter_groups:
        # Add a default "no CW layer" to each filter group
        # if they have not specified this feature
        filter_group = list(map(
            lambda x: x if len(x) == 3 else (x[0], x[1], False),
            filter_group
        ))
        for (num_filters, kernel_size, cw_layer) in filter_group:
            model_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(model_compute_graph)
            num_convs += 1
            if cw_layer:
                cw_inputs.append(model_compute_graph)
                model_compute_graph = CW.ConceptWhiteningLayer(
                    data_format="channels_last",
                    activation_mode=activation_mode,
                    T=T,
                    eps=eps,
                    momentum=momentum,
                    c1=c1,
                    c2=c2,
                    max_tau_iterations=max_tau_iterations,
                    initial_tau=initial_tau,
                    initial_beta=initial_beta,
                    initial_alpha=initial_alpha,
                )(
                    model_compute_graph
                )
                cw_ouputs.append(model_compute_graph)
            else:
                model_compute_graph = tf.keras.layers.BatchNormalization(
                    axis=-1,
                )(model_compute_graph)
            model_compute_graph = tf.keras.activations.relu(model_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        model_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            model_compute_graph
        )
    
    # Flatten this guy
    model_compute_graph = tf.keras.layers.Flatten()(model_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        model_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            model_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        model_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(model_compute_graph)
    
    # Map to our output distribution to a flattened
    # vector where we will extract distributions over
    # all concept values
    model_compute_graph = tf.keras.layers.Dense(
        num_outputs,
        activation=None,
        name="logits",
    )(model_compute_graph)
    
  
    cw_model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="cw_model",
    )
    cw_model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        metrics=[
            "binary_accuracy" if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )

    encoder = tf.keras.Model(
        model_inputs,
        cw_inputs,
        name="encoder_model",
    )

    cw_output_model = tf.keras.Model(
        model_inputs,
        cw_ouputs,
        name="cw_output_model",
    )
    return cw_model, encoder, cw_output_model

In [None]:
import concepts_xai.evaluation.metrics.leakage as leakage

def channels_corr_mat(outputs):
    if len(outputs.shape) == 2:
        outputs = np.expand_dims(
            np.expand_dims(outputs, axis=1),
            axis=1,
        )
    # Change (N, H, W, C) to (C, N, H, W)
    outputs = np.transpose(outputs, [3, 0, 1, 2])
    # Change (C, N, H, W) to (C, NxHxW)
    cnhw_shape = outputs.shape
    outputs = np.transpose(np.reshape(outputs, [cnhw_shape[0], -1]))
    outputs -= np.mean(outputs, axis=0, keepdims=True)
    outputs = outputs / np.std(outputs, axis=0, keepdims=True)
    return np.dot(outputs.transpose(), outputs) / outputs.shape[0]

def conv_predictor_model_fn(
    input_concept_classes=1,
    output_concept_classes=2,
):
    estimator = tf.keras.models.Sequential([
         tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            output_concept_classes if output_concept_classes > 2 else 1,
            # We will merge the activation into the loss for numerical
            # stability
            activation=None,
        ),
    ])
    estimator.compile(
        # Use ADAM optimizer by default
        optimizer='adam',
        # Note: we assume labels come without a one-hot-encoding in the
        #       case when the concepts are categorical.
        loss=(
            tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True
            ) if output_concept_classes > 2 else
            tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
            )
        ),
    )
    return estimator


def cw_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    oracle_matrix_cache = oracle_matrix_cache or {}
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_aucs=[],
        purity_scores=[],
        repr_purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        repr_purity_matrices=[],
        oracle_matrices=[],
        similarity_ratio_matrices=[],
        correlation_matrices=[],
    )
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Let's save our config here either way
    serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset:", ds_name)
        task_accs = []
        c_aucs = []
        aucs = []
        purity_mats = []
        repr_purity_mats = []
        oracle_mats = []
        purities = []
        repr_purities = []
        non_oracle_purities = []
        similarities = []
        correlations = []
        
        if not experiment_config.get("exclusive_concepts", False):
            concept_groups = [
                x_train[c_train[:, i] == 1, :, :, :]
                for i in range(c_train.shape[-1])
            ]
        else:
            concept_groups = [
                x_train[np.logical_and(c_train[:, i] == 1, np.sum(c_train, axis=-1) == 1), :, :, :]
                for i in range(c_train.shape[-1])
            ]
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} of {ds_name}")
            if not experiment_config.get("exclusive_test_concepts", False):
                test_concept_groups = [
                    x_train[c_train[:, i] == 1, :, :, :]
                    for i in range(c_test.shape[-1])
                ]
            else:
                test_concept_groups = [
                    x_test[np.logical_and(c_test[:, i] == 1, np.sum(c_test, axis=-1) == 1), :, :, :]
                    for i in range(c_test.shape[-1])
                ]
            print("Sizes of test_concept_groups:", list(map(lambda x: x.shape, test_concept_groups)))
            for i, group in enumerate(test_concept_groups):
                max_test_size = experiment_config.get("max_concept_group_size", group.shape[-1])
                test_concept_groups[i] = (
                    group[np.random.choice(group.shape[0], max_test_size), : :, :]
                    if group.shape[0] > max_test_size else group
                )


            # Construct our CW model
            model, encoder, cw_model = construct_cw_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                filter_groups=experiment_config["filter_groups"],
                units=experiment_config["units"],
                drop_prob=experiment_config.get("drop_prob", 0),
                max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                T=experiment_config.get("T", 5),
                eps=experiment_config.get("eps", 1e-5),
                momentum=experiment_config.get("momentum", 0.9),
                activation_mode=experiment_config["activation_mode"],
                c1=experiment_config.get("c1", 1e-4),
                c2=experiment_config.get("c2", 0.9),
                max_tau_iterations=experiment_config.get("max_tau_iterations", 500),
                initial_tau=experiment_config.get("initial_tau", 1000),
                initial_beta=experiment_config.get("initial_beta", 1e8),
                initial_alpha=experiment_config.get("initial_alpha", 0),
            )
            # First do some pretraining for warming up the estimates if needed
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                if experiment_config.get("heatmap_display"):
                    fig, ax = plt.subplots(1, figsize=(8, 6))
                    similarity_ratio = oracle.concept_similarity_matrix(
                        concept_representations=list(map(
                            lambda x: np.mean(cw_model(x).numpy(), axis=(1, 2)),
                            concept_groups
                        )),
                        compute_ratios=True,
                    )
                    im, cbar = utils.heatmap(
                        similarity_ratio,
                        [f"$c_{i}$" for i in range(len(concept_groups))],
                        [f"$c_{i}$" for i in range(len(concept_groups))],
                        ax=ax,
                        cmap="magma",
                        cbarlabel=f"Similarity Ratio",
                        vmin=0,
                        vmax=1,
                    )
                    texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
                    fig.tight_layout()

                    fig.suptitle(f"Baseline Concept Axis Separability", fontsize=25)
                    fig.subplots_adjust(top=0.85)
                    plt.show()
                    
                    fig, ax = plt.subplots(1, figsize=(8, 6))
                    corr_mat = channels_corr_mat(cw_model(x_test).numpy())[:len(test_concept_groups), :len(test_concept_groups)]
                    im, cbar = utils.heatmap(
                        np.abs(corr_mat),
                        [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(corr_mat.shape[-1])],
                        [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(corr_mat.shape[-1])],
                        ax=ax,
                        cmap="magma",
                        cbarlabel=f"Mean Correlation Coef",
                        vmin=0,
                        vmax=1,
                    )
                    texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
                    fig.tight_layout()

                    fig.suptitle(f"Latent Dimension Correlation", fontsize=25)
                    fig.subplots_adjust(top=0.85)
                    plt.show()
                print("\t\tModel pre-training completed")
            
            # Set up the dataset in a nice usable form for unrolling the training
            # loop
            main_dataset_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
            main_dataset_loader = main_dataset_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )
            min_size = min(list(map(lambda x: x.shape[0], concept_groups)))
            print("Minimum size is", min_size, "given concept datasets", list(map(lambda x: x.shape[0], concept_groups)))
            loader_concept_groups = list(map(lambda x: x[:min_size, :, :, :], concept_groups))
            concept_group_loader = tf.data.Dataset.from_tensor_slices(tuple(loader_concept_groups))
            concept_group_loader = concept_group_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )

            @tf.function
            def _train_step(model, x_batch_train, y_batch_train):
                # Update the other model parameters
                with tf.GradientTape() as tape:
                    logits = model(x_batch_train, training=True)
                    loss_value = model.loss(y_batch_train, logits)

                grads = tape.gradient(loss_value, model.trainable_weights)
                model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
                return loss_value

            total_steps = 0
            for epoch in range(experiment_config["max_epochs"]):
                num_batches = len(main_dataset_loader)
                for current_step, (x_batch_train, y_batch_train) in enumerate(main_dataset_loader):
                    print(
                        f'Epoch {epoch + 1} and step {current_step}/{num_batches}         ',
                        end="\r",
                    )
                    # Need to update the rotation matrix
                    if (total_steps + 1) % experiment_config["cw_train_freq"] == 0:
                        for _ in range(experiment_config.get("cw_train_iterations", 1)):
                            cw_batch_steps = 0
                            for concept_groups_batch in concept_group_loader:
                                if cw_batch_steps > experiment_config.get("cw_train_batch_steps", float("inf")):
                                    break
                                model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                                    concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                                )
                                cw_batch_steps += 1
                    if experiment_config.get("concept_auc_freq"):
                        if (total_steps % experiment_config["concept_auc_freq"]) == 0:
                            concept_aucs = leakage.compute_concept_aucs(
                                cw_model=model,
                                encoder=encoder,
                                cw_layer=experiment_config["cw_layer"],
                                x_test=x_test,
                                c_test=c_test,
                                num_concepts=experiment_config["num_concepts"],
                                aggregator=experiment_config['aggregator'],
                            )
                            print(
                                f'Concept AUC at step {total_steps}:',
                                concept_aucs
                            )
                            print("Similarity ratios...")
                            print(oracle.concept_similarity_matrix(
                                concept_representations=list(map(
                                    lambda x: np.mean(cw_model(x).numpy(), axis=(1, 2)),
                                    test_concept_groups,
                                )),
                                compute_ratios=True,
                            ))
                            
                            print("Correlation matrix...")
                            print(np.abs(channels_corr_mat(cw_model(x_test).numpy())[:len(test_concept_groups), :len(test_concept_groups)]))
                    _train_step(model, x_batch_train, y_batch_train)
                    total_steps += 1
            
            print("\tBegining post-training of CW module")
            for epoch in range(experiment_config["post_train_epochs"]):
                # Need to update the rotation matrix
                num_batches = len(concept_group_loader)
                for i, concept_groups_batch in enumerate(concept_group_loader):
                    print(
                        f'Epoch {epoch + 1} and step {i}/{num_batches}         ',
                        end="\r",
                    )
                    model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                        concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                    )
            
            print("\t\tCW training completed")
            print("\tSerializing model")
            model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_{ds_name}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(test_result[
                'sparse_top_k_categorical_accuracy' if experiment_config["num_outputs"] > 1
                else 'binary_accuracy'
            ])
            c_aucs.append(leakage.compute_concept_aucs(
                cw_model=model,
                encoder=encoder,
                cw_layer=experiment_config["cw_layer"],
                x_test=x_test,
                c_test=c_test,
                num_concepts=experiment_config["num_concepts"],
                aggregator=experiment_config['aggregator'],
            ))
            
            if experiment_config["num_outputs"] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(model.predict(x_test), axis=-1)

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    model.predict(x_test),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept AUCs = {c_aucs[-1]}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            
            
            print("\t\tComputing purity score...")
            concept_scores = cw_model.layers[-1].concept_scores(
                encoder(x_test),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                output_matrices=True,
                oracle_matrix=oracle_matrix_cache.get(ds_name),
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            
            print(f"\t\t\tDone {purity_score:.4f}")
            
            if len(experiment_config["input_shape"]) > 2:
                print("\t\tComputing full representation purity score...")
                repr_purity_score, repr_purity_mat, _ = oracle.oracle_impurity_score(
                    c_soft=cw_model(x_test).numpy()[:, :, :, :experiment_config["num_concepts"]],
                    c_true=c_test,
                    output_matrices=True,
                    oracle_matrix=oracle_mat,
                    predictor_model_fn=conv_predictor_model_fn,
                )
                repr_purity_mats.append(repr_purity_mat)
                repr_purities.append(repr_purity_score)
                print(f"\t\t\tDone {repr_purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["data_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            
            # Compute similarity matrices
            print("\t\tComputing similarity ratios...")
            similarity_ratio = oracle.concept_similarity_matrix(
                concept_representations=list(map(
                    lambda x: np.mean(cw_model(x).numpy(), axis=(1, 2)),
                    test_concept_groups,
                )),
                compute_ratios=True,
            )
            if experiment_config.get("heatmap_display"):
                fig, ax = plt.subplots(1, figsize=(8, 6))
                im, cbar = utils.heatmap(
                    similarity_ratio,
                    [f"{_LATEX_SYMBOL}c_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                    [f"{_LATEX_SYMBOL}c_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                    ax=ax,
                    cmap="magma",
                    cbarlabel=f"Similarity Ratio",
                    vmin=0,
                    vmax=1,
                )
                texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
                fig.tight_layout()

                fig.suptitle(f"Concept Axis Separability", fontsize=25)
                fig.subplots_adjust(top=0.85)
                plt.show()
            similarities.append(similarity_ratio)

            
            # Compute correlation matrices
            print("\t\tComputing correlation matrix...")
            corr_mat = channels_corr_mat(cw_model(x_test).numpy())
            correlations.append(corr_mat)
            
            if experiment_config.get("heatmap_display"):
                corr_mat = corr_mat[:len(test_concept_groups), :len(test_concept_groups)]
                fig, ax = plt.subplots(1, figsize=(8, 6))
                im, cbar = utils.heatmap(
                    np.abs(corr_mat),
                    [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(corr_mat.shape[-1])],
                    [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(corr_mat.shape[-1])],
                    ax=ax,
                    cmap="magma",
                    cbarlabel=f"Correlation Coef",
                    vmin=0,
                    vmax=1,
                )
                texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
                fig.tight_layout()

                fig.suptitle(f"Latent Dimension Correlation", fontsize=25)
                fig.subplots_adjust(top=0.85)
                plt.show()
                
            # Compute representation purity score
            
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_aucs_mean = np.mean(np.stack(c_aucs, axis=0), axis=0)
        concept_aucs_std = np.std(np.stack(c_aucs, axis=0), axis=0)
        experiment_variables["concept_aucs"].append((concept_aucs_mean, concept_aucs_std))
        print(f"\tConcept AUCS:")
        line = "\t\t"
        for i in range(concept_aucs_mean.shape[0]):
            line += f'{concept_aucs_mean[i]:.4f} ± {concept_aucs_std[i]:.4f}    '
        print(line)


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)
        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))
        
        repr_purity_mats = np.stack(repr_purity_mats, axis=0)
        repr_purity_mat_mean = np.mean(repr_purity_mats, axis=0)
        repr_purity_mat_std = np.std(repr_purity_mats, axis=0)
        print("\tRepresentation Purity matrix:")
        for i in range(repr_purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(repr_purity_mat_mean.shape[1]):
                line += f'{repr_purity_mat_mean[i, j]:.4f} ± {repr_purity_mat_std[i, j]:.4f}    '
            print(line)
        experiment_variables["repr_purity_matrices"].append((repr_purity_mat_mean, repr_purity_mat_std))
        
        similarities = np.stack(similarities, axis=0)
        similarities_mean = np.mean(similarities, axis=0)
        similarities_std = np.std(similarities, axis=0)
        print("\tSimilarity ratio matrix:")
        for i in range(similarities_mean.shape[0]):
            line = "\t\t"
            for j in range(similarities_mean.shape[1]):
                line += f'{similarities_mean[i, j]:.4f} ± {similarities_std[i, j]:.4f}    '
            print(line)
        experiment_variables["similarity_ratio_matrices"].append(
            (similarities_mean, similarities_std)
        )
        
        if experiment_config.get("heatmap_display"):
            fig, ax = plt.subplots(1, figsize=(8, 6))
            im, cbar = utils.heatmap(
                similarities_mean,
                [f"{_LATEX_SYMBOL}c_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                [f"{_LATEX_SYMBOL}c_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                ax=ax,
                cmap="magma",
                cbarlabel=f"Similarity Ratio",
                vmin=0,
                vmax=1,
            )
            texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
            fig.tight_layout()

            fig.suptitle(f"Mean Concept Axis Separability", fontsize=25)
            fig.subplots_adjust(top=0.85)
            plt.show()
        
        
        correlations = np.stack(correlations, axis=0)
        correlations_mean = np.mean(correlations, axis=0)
        correlations_std = np.std(correlations, axis=0)
        print("\tCorrelation ratio matrix:")
        for i in range(len(concept_groups)):
            line = "\t\t"
            for j in range(len(concept_groups)):
                line += f'{correlations_mean[i, j]:.4f} ± {correlations_std[i, j]:.4f}    '
            print(line)
        experiment_variables["correlation_matrices"].append(
            (correlations_mean, correlations_std)
        )
        if experiment_config.get("heatmap_display"):
            fig, ax = plt.subplots(1, figsize=(8, 6))
            im, cbar = utils.heatmap(
                np.abs(correlations_mean[:len(concept_groups), :len(concept_groups)]),
                [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                [f"{_LATEX_SYMBOL}f_{i}{_LATEX_SYMBOL}" for i in range(len(concept_groups))],
                ax=ax,
                cmap="magma",
                cbarlabel=f"Mean Correlation Coef",
                vmin=0,
                vmax=1,
            )
            texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
            fig.tight_layout()

            fig.suptitle(f"Latent Dimension Correlation", fontsize=25)
            fig.subplots_adjust(top=0.85)
            plt.show()


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")
        
        repr_purity_mean, repr_purity_std = np.mean(repr_purities), np.std(repr_purities)
        experiment_variables["repr_purity_scores"].append((repr_purity_mean, repr_purity_std))
        print(f"\tRepresentation Purity score: {repr_purity_mean:.4f} ± {repr_purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

# HACK: deserialization messes up with custome methods so reusing this here
def concept_scores(
    self,
    inputs,
    aggregator='max_pool_mean',
    concept_indices=None,
    data_format="channels_last",
):
    outputs = self(inputs, training=False)
    if len(tf.shape(outputs)) == 2:
        # Then the scores are already computed by our forward pass
        scores = outputs
    else:
        if data_format == "channels_last":
            # Then we will transpose to make things simpler so that
            # downstream we can always assume it is channels first
            # NHWC -> NCHW
            outputs = tf.transpose(
                outputs,
                perm=[0, 3, 1, 2],
            )

        # Else, we need to do some aggregation
        if aggregator == 'mean':
            # Compute the mean over all channels
            scores = tf.math.reduce_mean(outputs, axis=[2, 3])
        elif aggregator == 'max_pool_mean':
            # First downsample using a max pool and then continue with
            # a mean
            window_size = min(
                2,
                outputs.shape[-1],
                outputs.shape[-2],
            )
            scores = tf.nn.max_pool(
                outputs,
                ksize=window_size,
                strides=window_size,
                padding="SAME",
                data_format="NCHW",
            )
            scores = tf.math.reduce_mean(scores, axis=[2, 3])
        elif aggregator == 'max':
            # Simply select the maximum value across a given channel
            scores = tf.math.reduce_max(outputs, axis=[2, 3])
        else:
            raise ValueError(f'Unsupported aggregator {aggregator}.')

    if concept_indices is not None:
        return scores[:, concept_indices]
    return scores

def cw_bottleneck_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
        latent_feature_predictive_accuracies=[],
        latent_feature_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset", ds_name)
        task_accs = []
        aucs = []
        feat_task_accs = []
        feat_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            complete_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_{ds_name}_trial_{trial}"
                ),
#                 custom_objects={"ConceptWhiteningLayer": CW.ConceptWhiteningLayer},
            )
            cw_output_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"]].output],
                name="cw_output_model",
            )
            
            feature_predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            feature_predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining full representation model")
            train_codes = cw_output_model(x_train).numpy()
            print("train_codes.shape =", train_codes.shape)
            print("y_train.shape =", y_train.shape)
            test_codes = cw_output_model(x_test).numpy()
            feature_predictive_decoder.fit(
                x=train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            
            
            print("\tEvaluating feature model")
            test_result = feature_predictive_decoder.evaluate(
                test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            feat_task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    feature_predictive_decoder.predict(test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                feat_aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                feat_aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    feature_predictive_decoder.predict(test_codes),
                ))
            
            print(
                f"\t\tFeature Test auc = {feat_aucs[-1]:.4f}, "
                f"feature task accuracy = {feat_task_accs[-1]:.4f}"
            )
            
            
            encoder_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"] - 1].output],
                name="cw_output_model",
            )
            
            score_predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            score_predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining score model")
            score_train_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_train),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            score_test_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_test),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            score_predictive_decoder.fit(
                x=score_train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            
            
            print("\tEvaluating score model")
            test_result = score_predictive_decoder.evaluate(
                score_test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    score_predictive_decoder.predict(score_test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    score_predictive_decoder.predict(score_test_codes),
                ))
            
            print(
                f"\t\tFeature Test auc = {aucs[-1]:.4f}, "
                f"feature task accuracy = {task_accs[-1]:.4f}"
            )
            
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")
        
        feat_task_acc_mean, feat_task_acc_std = np.mean(feat_task_accs), np.std(feat_task_accs)
        experiment_variables["latent_feature_predictive_accuracies"].append((feat_task_acc_mean, feat_task_acc_std))
        print(f"\tTest feature task accuracy: {feat_task_acc_mean:.4f} ± {feat_task_acc_std:.4f}")

        feat_task_auc_mean, feat_task_auc_std = np.mean(feat_aucs), np.std(feat_aucs)
        experiment_variables["latent_feature_predictive_aucs"].append((feat_task_auc_mean, feat_task_auc_std))
        print(f"\tTest feature task AUC: {feat_task_auc_mean:.4f} ± {feat_task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cw_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
        latent_feature_avg_concept_predictive_accuracies=[],
        latent_feature_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset", ds_name)
        avg_concept_accs = []
        avg_concept_aucs = []
        feat_avg_concept_accs = []
        feat_avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            complete_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_{ds_name}_trial_{trial}"
                ),
#                 custom_objects={"ConceptWhiteningLayer": CW.ConceptWhiteningLayer},
            )
            cw_output_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"]].output],
                name="cw_output_model",
            )
            
            current_accs = []
            current_aucs = []
            current_feat_accs = []
            current_feat_aucs = []
            for concept_idx in range(experiment_config["num_concepts"]):
                feature_predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )

                feature_predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\tTraining full representation model for concept", concept_idx)
                train_codes = cw_output_model(x_train).numpy()
                test_codes = cw_output_model(x_test).numpy()
                feature_predictive_decoder.fit(
                    x=train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )


                print("\t\tEvaluating feature model")
                test_result = feature_predictive_decoder.evaluate(
                    test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_feat_accs.append(
                    test_result['binary_accuracy']
                )


                current_feat_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    feature_predictive_decoder.predict(test_codes),
                ))

                print(
                    f"\t\tFeature Test concept auc = {current_feat_aucs[-1]:.4f}, "
                    f"feature concept accuracy = {current_feat_accs[-1]:.4f}"
                )


                encoder_model = tf.keras.Model(
                    complete_model.inputs,
                    [complete_model.layers[experiment_config["cw_layer"] - 1].output],
                    name="cw_output_model",
                )

                score_predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )

                score_predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\tTraining score model for concept", concept_idx)
                score_train_codes = concept_scores(
                    cw_output_model.layers[-1],
                    encoder_model(x_train),
                    aggregator=experiment_config['aggregator'],
                ).numpy()[:, list(range(experiment_config["num_concepts"]))]
                score_test_codes = concept_scores(
                    cw_output_model.layers[-1],
                    encoder_model(x_test),
                    aggregator=experiment_config['aggregator'],
                ).numpy()[:, list(range(experiment_config["num_concepts"]))]
                score_predictive_decoder.fit(
                    x=score_train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )


                print("\t\tEvaluating score model")
                test_result = score_predictive_decoder.evaluate(
                    score_test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accs.append(
                    test_result['binary_accuracy']
                )
                
                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    score_predictive_decoder.predict(score_test_codes),
                ))

                print(
                    f"\t\tTest concept auc = {current_aucs[-1]:.4f}, "
                    f"feature concept accuracy = {current_accs[-1]:.4f}"
                )
            
            avg_concept_accs.append(np.mean(current_accs))
            avg_concept_aucs.append(np.mean(current_aucs))
            feat_avg_concept_accs.append(np.mean(current_feat_accs))
            feat_avg_concept_aucs.append(np.mean(current_feat_aucs))
            print("\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest avg concept accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest avg concept AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")
        
        feat_avg_concept_acc_mean, feat_avg_concept_acc_std = np.mean(feat_avg_concept_accs), np.std(feat_avg_concept_accs)
        experiment_variables["latent_feature_avg_concept_predictive_accuracies"].append((feat_avg_concept_acc_mean, feat_avg_concept_acc_std))
        print(f"\tTest feature avg concept accuracy: {feat_avg_concept_acc_mean:.4f} ± {feat_avg_concept_acc_std:.4f}")

        feat_avg_concept_auc_mean, feat_avg_concept_auc_std = np.mean(feat_avg_concept_aucs), np.std(feat_avg_concept_aucs)
        experiment_variables["latent_feature_avg_concept_predictive_aucs"].append((feat_avg_concept_auc_mean, feat_avg_concept_auc_std))
        print(f"\tTest feature avg concept AUC: {feat_avg_concept_auc_mean:.4f} ± {feat_avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

## Experiments

In [None]:
reload(CW)
reload(leakage)

############################################################################
## Experiment config
############################################################################

cw_binary_balanced_multiclass_experiment_config = dict(
    batch_size=128,
    max_epochs=50,
    pre_train_epochs=50,
    post_train_epochs=0,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    cw_train_iterations=1,
    exclusive_concepts=False,
    exclusive_test_concepts=False,
    trials=5,
    learning_rate=1e-3,
    filter_groups=[
        [(8, (7, 7), False)],
        [(16, (5, 5), False)],
        [(32, (3, 3), False)],
        [(64, (3, 3), True)],
    ],
    units=[64, 64, 64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cw/balanced_multiclass_tasks_purity_max_pool_mean"),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    cw_layer=14,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    concept_auc_freq=0,
    holdout_fraction=0.1,
    heatmap_display=True,
    T=5,
    eps=1e-5,
    momentum=0.9,
    initial_tau=1000,
    initial_beta=1e8,
    initial_alpha=0,
    drop_prob=0,    
)


# Generate the experiment directory if it does not exist already
Path(cw_binary_balanced_multiclass_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
cw_binary_balanced_multiclass_figure_dir = os.path.join(cw_binary_balanced_multiclass_experiment_config["results_dir"], "figures")
Path(cw_binary_balanced_multiclass_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

cw_binary_balanced_multiclass_results = cw_experiment_loop(
    cw_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("task_accuracies:", cw_binary_balanced_multiclass_results["task_accuracies"])
print("purity_scores:", cw_binary_balanced_multiclass_results["purity_scores"])
print("concept_aucs:", cw_binary_balanced_multiclass_results["concept_aucs"])
print("task_aucs:", cw_binary_balanced_multiclass_results["task_aucs"])

In [None]:
cw_binary_balanced_multiclass_results.update(cw_bottleneck_predict_experiment_loop(
    cw_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))

In [None]:
cw_binary_balanced_multiclass_results.update(cw_bottleneck_concept_predict_experiment_loop(
    cw_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))

In [None]:
reload(CW)
reload(leakage)

############################################################################
## Experiment config
############################################################################

cw_mean_binary_balanced_multiclass_experiment_config = dict(
    batch_size=128,
    max_epochs=50,
    pre_train_epochs=50,
    post_train_epochs=0,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    cw_train_iterations=1,
    exclusive_concepts=False,
    exclusive_test_concepts=False,
    trials=5,
    learning_rate=1e-3,
    filter_groups=[
        [(8, (7, 7), False)],
        [(16, (5, 5), False)],
        [(32, (3, 3), False)],
        [(64, (3, 3), True)],
    ],
    units=[64, 64, 64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cw/balanced_multiclass_tasks_purity_mean"),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    cw_layer=14,
    aggregator='mean',
    activation_mode='mean',
    concept_auc_freq=0,
    holdout_fraction=0.1,
    heatmap_display=True,
    T=5,
    eps=1e-5,
    momentum=0.9,
    initial_tau=1000,
    initial_beta=1e8,
    initial_alpha=0,
    drop_prob=0,    
)


# Generate the experiment directory if it does not exist already
Path(cw_mean_binary_balanced_multiclass_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
cw_mean_binary_balanced_multiclass_figure_dir = os.path.join(cw_mean_binary_balanced_multiclass_experiment_config["results_dir"], "figures")
Path(cw_mean_binary_balanced_multiclass_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

cw_mean_binary_balanced_multiclass_results = cw_experiment_loop(
    cw_mean_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("task_accuracies:", cw_mean_binary_balanced_multiclass_results["task_accuracies"])
print("purity_scores:", cw_mean_binary_balanced_multiclass_results["purity_scores"])
print("concept_aucs:", cw_mean_binary_balanced_multiclass_results["concept_aucs"])
print("task_aucs:", cw_mean_binary_balanced_multiclass_results["task_aucs"])

In [None]:
cw_mean_binary_balanced_multiclass_results.update(cw_bottleneck_predict_experiment_loop(
    cw_mean_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))

In [None]:
cw_mean_binary_balanced_multiclass_results.update(cw_bottleneck_concept_predict_experiment_loop(
    cw_mean_binary_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))

# Ada-ML-VAE Benchmarking

## Weakly Labelled Dataset Construction

In [None]:
def generate_weak_sample_pairs(x_train, y_train, c_train, size):
    result_samples = []
    result_labels = []
    result_concepts = []
    used_cache = set()
    num_samples = x_train.shape[0]
    while (len(used_cache) < (num_samples * num_samples)) and (
        len(result_samples) < size
    ):
        next_to_try_1 = np.random.choice(x_train.shape[0], 1)[0]
        next_to_try_2 = np.random.choice(x_train.shape[0], 1)[0]
        if (next_to_try_1, next_to_try_2) in used_cache:
            continue
        num_equal = np.sum(c_train[next_to_try_1, :] == c_train[next_to_try_2, :])
        if (num_equal == 0) or (num_equal == c_train.shape[-1]):
            # If they are all different or all the same, then we cannot use this
            used_cache.add((next_to_try_1, next_to_try_2))
        else:
            # Then this is something we can use
            result_samples.append((x_train[next_to_try_1], x_train[next_to_try_2]))
            result_labels.append((y_train[next_to_try_1], y_train[next_to_try_2]))
            result_concepts.append((c_train[next_to_try_1], c_train[next_to_try_2]))
            used_cache.add((next_to_try_1, next_to_try_2))
    
    return result_samples, result_labels, result_concepts

def _join_pair_samples(pairs, axis=1):
    return np.concatenate(
        [
            np.stack(
                list(map(lambda x: x[0], pairs)),
                axis=0,
            ),
            np.stack(
                list(map(lambda x: x[1], pairs)),
                axis=0,
            ),
        ],
        axis=axis,
    )

def generate_wvae_dataset(
    train_data,
    test_data,
    train_size,
    test_size,
):
    train_sample_pairs, train_label_pairs, train_concept_pairs = generate_weak_sample_pairs(
        train_data[0],
        train_data[1],
        train_data[2],
        train_size,
    )
    train_pairs = _join_pair_samples(train_sample_pairs)
    train_label_pairs = _join_pair_samples(train_label_pairs, axis=0)
    train_concept_pairs = _join_pair_samples(train_concept_pairs)
    
    test_sample_pairs, test_label_pairs, test_concept_pairs = generate_weak_sample_pairs(
        test_data[0],
        test_data[1],
        test_data[2],
        test_size,
    )
    test_pairs = _join_pair_samples(test_sample_pairs)
    test_label_pairs = _join_pair_samples(test_label_pairs, axis=0)
    test_concept_pairs = _join_pair_samples(test_concept_pairs)
    return train_pairs, train_label_pairs, train_concept_pairs, test_pairs, test_label_pairs, test_concept_pairs

balanced_multiclass_wvae_datasets = [
    generate_wvae_dataset(
        train,
        test,
        int(train[0].shape[0])/1.5,
        int(test[0].shape[0])/1.5,
    ) for (train, test) in [
        (balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        (balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        (balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        (balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        (balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ]
]

## Model Setup

In [None]:
import concepts_xai.methods.VAE.weak_vae as weak_vae
import concepts_xai.methods.VAE.baseVAE as base_vae
import concepts_xai.methods.VAE.losses as vae_losses
reload(vae_losses)
reload(weak_vae)

def construct_vae_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    include_norm=False,
    include_pool=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for filter_args in filter_group:
            if len(filter_args) == 2:
                filter_args = (*filter_args, 1)
            (num_filters, kernel_size, stride) = filter_args
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                strides=stride,
                padding="SAME",
                activation=None if include_norm else "relu",
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            if include_norm:
                encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                    encoder_compute_graph
                )
                encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        if include_pool:
            # Then do a max pool here to control the parameter count of the model
            # at the end of each group
            encoder_compute_graph = tf.keras.layers.MaxPooling2D(
                pool_size=max_pool_window,
                strides=max_pool_stride,
            )(
                encoder_compute_graph
            )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    return tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="encoder",
    )

def construct_vae_decoder(
    units,
    output_shape,
    latent_dims,
):
    """CNN decoder architecture used in the 'Challenging Common Assumptions in the Unsupervised Learning
       of Disentangled Representations' paper (https://arxiv.org/abs/1811.12359)

       Note: model is uncompiled
    """

    latent_inputs = tf.keras.Input(shape=(latent_dims,))
    model_out = latent_inputs
    for unit in units:
        model_out = tf.keras.layers.Dense(
            unit,
            activation='relu',
        )(model_out)
    model_out = tf.keras.layers.Reshape([4, 4, 32])(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=64,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same"
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=32,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same"
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=32,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same",
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=output_shape[-1],
        kernel_size=4,
        strides=2,
        padding="same",
        activation=None,
    )(model_out)
    model_out = tf.keras.layers.Reshape(output_shape)(model_out)

    return tf.keras.Model(
        inputs=latent_inputs,
        outputs=[model_out],
    )

def construct_wvae(
    input_shape,
    latent_dims,
    filter_groups,
    encoder_units,
    decoder_units,
    drop_prob=0.5,
    include_pool=False,
    max_pool_window=(2,2),
    max_pool_stride=2,
    learning_rate=1e-3,
    beta=1.0,
    vae_model=weak_vae.MLVaeArgmax,
):
    wvae_encoder = construct_vae_encoder(
        input_shape=input_shape,
        filter_groups=filter_groups,
        units=encoder_units,
        drop_prob=drop_prob,
        include_pool=include_pool,
        max_pool_window=max_pool_window,
        max_pool_stride=max_pool_stride,
        latent_dims=latent_dims,
    )
    wvae_decoder = construct_vae_decoder(
        output_shape=input_shape,
        units=decoder_units,
        latent_dims=latent_dims,
    )

    wvae_model = vae_model(
        encoder=wvae_encoder,
        decoder=wvae_decoder,
        loss_fn=vae_losses.bernoulli_fn_wrapper(),
        beta=beta,
    )
    wvae_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return wvae_model

def construct_pretrained_wvae(
    encoder,
    decoder,
    learning_rate=1e-3,
    beta=1.0,
    vae_model=weak_vae.MLVaeArgmax,
):
    wvae_model = vae_model(
        encoder=encoder,
        decoder=decoder,
        loss_fn=vae_losses.bernoulli_fn_wrapper(),
        beta=beta,
    )
    wvae_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return wvae_model

## Experiment Loop

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle
import concepts_xai.methods.VAE.betaVAE as beta_vae

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def wvae_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    include_all_losses=True,
    vae_model=weak_vae.MLVaeArgmax,
    oracle_matrix_cache=None,
    model_cache=None,
    include_encoder_purity=False,
):
    utils.reseed(87)
    if vae_model != beta_vae.BetaVAE:
        split_fn = lambda x: x[:, :x.shape[1]//2, ...] if len(x.shape) > 1 else x[:x.shape[0]//2]
    else:
        split_fn = lambda x: x
        
    model_cache = model_cache or {}
    oracle_matrix_cache = oracle_matrix_cache or None
    experiment_variables = dict(
        total_losses=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        aligned_purity_matrices=[],
        oracle_matrices=[],
    )
    if include_all_losses:
        experiment_variables['elbo_losses'] = []
        experiment_variables['reconstruction_losses'] = []
    if include_encoder_purity:
        experiment_variables['encoder_purity_scores'] = []
        experiment_variables['encoder_purity_matrices'] = []
        experiment_variables['encoder_aligned_purity_matrices'] = []

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, _, c_train), (x_test, _, c_test)) in datasets[start_ind:]:
        latent_dim = experiment_config['latent_dim']
        print("Training with latent dimensions", latent_dim, "in dataset", ds_name)
        tot_losses = []
        recon_losses = []
        el_losses = []
        purity_mats = []
        aligned_purity_mats = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []
        if include_encoder_purity:
            encoder_purity_mats = []
            encoder_aligned_purity_mats = []
            encoder_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            if (ds_name, trial) in model_cache:
                print("Found trial", trial, "model for dataset", ds_name, "in model cache")
                print("\tDeserializing it...")
                encoder_model_dir, decoder_model_dir = model_cache[(ds_name, trial)]
                encoder = tf.keras.models.load_model(encoder_model_dir)
                decoder = tf.keras.models.load_model(decoder_model_dir)
                wvae_model = construct_pretrained_wvae(
                    encoder=encoder,
                    decoder=decoder,
                    beta=experiment_config['beta'],
                    vae_model=vae_model,
                )
            else:
                # Time to actually construct and train the WVAE
                wvae_model = construct_wvae(
                    input_shape=experiment_config["input_shape"],
                    latent_dims=latent_dim,
                    filter_groups=experiment_config["filter_groups"],
                    encoder_units=experiment_config["encoder_units"],
                    decoder_units=experiment_config["decoder_units"],
                    drop_prob=experiment_config['drop_prob'],
                    max_pool_window=experiment_config['max_pool_window'],
                    max_pool_stride=experiment_config['max_pool_stride'],
                    include_pool=experiment_config['include_pool'],
                    beta=experiment_config['beta'],
                    vae_model=vae_model,
                )

                early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                    monitor=experiment_config.get(
                        "early_stop_metric",
                        "val_loss",
                    ),
                    min_delta=experiment_config["min_delta"],
                    patience=experiment_config["patience"],
                    restore_best_weights=True,
                    verbose=2,
                    mode=experiment_config.get(
                        "early_stop_mode",
                        "min",
                    ),
                )

                print("\tVAE training...")
                epochs = 0
                viz_freq = experiment_config.get(
                    "visualization_frequency",
                    experiment_config["max_epochs"]
                )
                num_samples = experiment_config.get(
                    "visualization_samples",
                    4,
                )
                while epochs < experiment_config['max_epochs']:
                    
                    fig, axs = plt.subplots(1, num_samples, figsize=(12, 8))
                    for j in range(num_samples):
                        axs[j].imshow(
                            x_train[np.random.randint(x_train.shape[0]), ...]
                        )
                        axs[j].get_xaxis().set_visible(False)
                        axs[j].get_yaxis().set_visible(False)
                    plt.title(f"Random training sample")
                    plt.show()
                    
                    wvae_model.fit(
                        x=x_train,
                        epochs=min(viz_freq, experiment_config['max_epochs'] - epochs),
                        batch_size=experiment_config["batch_size"],
                        callbacks=[
                            early_stopping_monitor,
                        ],
                        validation_split=experiment_config["holdout_fraction"],
                        verbose=verbosity,
                    )

                    epochs += viz_freq

                    fig, axs = plt.subplots(1, num_samples, figsize=(12, 8))
                    for j in range(num_samples):
                        if experiment_config["input_shape"][-1] > 1:
                            axs[j].imshow(
                                sigmoid(wvae_model.generate_random_sample()[0, :, :, :])
                            )
                        else:
                            axs[j].imshow(
                                wvae_model.generate_random_sample()[0, :, :, :] >= 0
                            )
                        axs[j].get_xaxis().set_visible(False)
                        axs[j].get_yaxis().set_visible(False)
                    plt.title(f"Random sample at epoch {epochs}")
                    plt.show()

                print("\t\tWVAE training completed")
            print("\tSerializing model")
            wvae_model.encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_trial_{trial}"
                )
            )
            wvae_model.decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = wvae_model.evaluate(
                x_test,
                verbose=0,
                return_dict=True,
            )
            tot_losses.append(test_result['loss'])
            recon_losses.append(test_result['reconstruction_loss'])
            el_losses.append(test_result['elbo'])
            
            print(
                f"\t\tTest loss = {tot_losses[-1]:.4f}, "
                f"test reconstruction loss = {recon_losses[-1]:.4f}, "
                f"task elbo = {el_losses[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            latent_codes = wvae_model.sample_from_latent_distribution(
                *wvae_model.encoder(split_fn(x_test))
            ).numpy()
            purity_score, (purity_mat, aligned_purity_mat), oracle_mat = oracle.oracle_impurity_score(
                c_soft=latent_codes,
                c_true=split_fn(c_test),
                output_matrices=True,
                alignment_function=oracle.max_alignment_matrix,
                oracle_matrix=oracle_matrix_cache.get(ds_name),
            )
            
            purity_mats.append(purity_mat)
            aligned_purity_mats.append(aligned_purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")
            print("\t\tPurity matrix:")
            print(purity_mat)
            print("\t\tAligned purity matrix:")
            print(aligned_purity_mat)
            

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=latent_codes,
                c_true=split_fn(c_test),
                alignment_function=oracle.max_alignment_matrix,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["num_concepts"]
                ),
                purity_matrix=aligned_purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            
            if include_encoder_purity:
                # Then compute the purity when the concept representations are the (mean, std) vectors generated
                # by the encoder
                print("\t\tComputing encoder purity score...")
                encoder_means, encoder_logvars = wvae_model.encoder(
                    split_fn(x_test)
                )
                encoder_codes = np.stack([encoder_means.numpy(), encoder_logvars.numpy()], axis=1)
                encoder_purity_score, (encoder_purity_mat, encoder_aligned_purity_mat), _ = oracle.oracle_impurity_score(
                    c_soft=encoder_codes,
                    c_true=split_fn(c_test),
                    output_matrices=True,
                    alignment_function=oracle.max_alignment_matrix,
                    oracle_matrix=oracle_mat,
                )
                print(f"\t\t\tDone {encoder_purity_score:.4f}")
                print("\t\tEncoder purity matrix:")
                print(encoder_purity_mat)
                print("\t\tEncoder rligned purity matrix:")
                print(encoder_aligned_purity_mat)
                encoder_purities.append(encoder_purity_score)
                encoder_purity_mats.append(encoder_purity_mat)
                encoder_aligned_purity_mats.append(encoder_aligned_purity_mat)
                
            print("\t\tDone with trial", trial + 1)

        tot_loss_mean, tot_loss_std = np.mean(tot_losses), np.std(tot_losses)
        experiment_variables["total_losses"].append((tot_loss_mean, tot_loss_std))
        print(f"\tTest total loss: {tot_loss_mean:.4f} ± {tot_loss_std:.4f}")
        
        recon_loss_mean, recon_loss_std = np.mean(recon_losses), np.std(recon_losses)
        if include_all_losses:
            experiment_variables["reconstruction_losses"].append((recon_loss_mean, recon_loss_std))
        print(f"\tTest reconstruction loss: {recon_loss_mean:.4f} ± {recon_loss_std:.4f}")
        
        el_loss_mean, el_loss_std = np.mean(el_losses), np.std(el_losses)
        if include_all_losses:
            experiment_variables["elbo_losses"].append((el_loss_mean, el_loss_std))
        print(f"\tTest elbo loss: {el_loss_mean:.4f} ± {el_loss_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))
        
        aligned_purity_mats = np.stack(aligned_purity_mats, axis=0)
        aligned_purity_mat_mean = np.mean(aligned_purity_mats, axis=0)
        aligned_purity_mat_std = np.std(aligned_purity_mats, axis=0)
        print("\tAligned purity matrix:")
        for i in range(aligned_purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(aligned_purity_mat_mean.shape[1]):
                line += f'{aligned_purity_mat_mean[i, j]:.4f} ± {aligned_purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["aligned_purity_matrices"].append((aligned_purity_mat_mean, aligned_purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")
        
        if include_encoder_purity:
            encoder_purity_mats = np.stack(encoder_purity_mats, axis=0)
            encoder_purity_mat_mean = np.mean(encoder_purity_mats, axis=0)
            encoder_purity_mat_std = np.std(encoder_purity_mats, axis=0)
            print("\tEncoder purity matrix:")
            for i in range(encoder_purity_mat_mean.shape[0]):
                line = "\t\t"
                for j in range(encoder_purity_mat_mean.shape[1]):
                    line += f'{encoder_purity_mat_mean[i, j]:.4f} ± {encoder_purity_mat_std[i, j]:.4f}    '
                print(line)

            experiment_variables["encoder_purity_matrices"].append((encoder_purity_mat_mean, encoder_purity_mat_std))

            encoder_aligned_purity_mats = np.stack(encoder_aligned_purity_mats, axis=0)
            encoder_aligned_purity_mat_mean = np.mean(encoder_aligned_purity_mats, axis=0)
            encoder_aligned_purity_mat_std = np.std(encoder_aligned_purity_mats, axis=0)
            print("\tEncoder aligned purity matrix:")
            for i in range(encoder_aligned_purity_mat_mean.shape[0]):
                line = "\t\t"
                for j in range(encoder_aligned_purity_mat_mean.shape[1]):
                    line += f'{encoder_aligned_purity_mat_mean[i, j]:.4f} ± {encoder_aligned_purity_mat_std[i, j]:.4f}    '
                print(line)

            experiment_variables["encoder_aligned_purity_matrices"].append((encoder_aligned_purity_mat_mean, encoder_aligned_purity_mat_std))
            
            encoder_purity_mean, encoder_purity_std = np.mean(encoder_purities), np.std(encoder_purities)
            experiment_variables["encoder_purity_scores"].append((encoder_purity_mean, encoder_purity_std))
            print(f"\tEncoder purity score: {encoder_purity_mean:.4f} ± {encoder_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables


def wvae_bottleneck_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    vae_model=weak_vae.MLVaeArgmax,
):
    utils.reseed(87)
    if vae_model != beta_vae.BetaVAE:
        split_fn = lambda x: x[:, :x.shape[1]//2, ...] if len(x.shape) > 1 else x[:x.shape[0]//2]
    else:
        split_fn = lambda x: x
        
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        latent_dim = experiment_config['latent_dim']
        print("Training with latent dimensions", latent_dim, "in dataset", ds_name)
        task_accs = []
        aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_trial_{trial}"
                )
            )
            decoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_trial_{trial}"
                )
            )
            
            wvae_model = vae_model(
                encoder=encoder,
                decoder=decoder,
                loss_fn=vae_losses.l2_loss_wrapper(),
                beta=experiment_config["beta"],
            )
            
            predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining model")
            train_codes = wvae_model.sample_from_latent_distribution(
                *wvae_model.encoder(split_fn(x_train))
            ).numpy()
            test_codes = wvae_model.sample_from_latent_distribution(
                *wvae_model.encoder(split_fn(x_test))
            ).numpy()
            predictive_decoder.fit(
                x=train_codes,
                y=split_fn(y_train),
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\tEvaluating model")
            test_result = predictive_decoder.evaluate(
                test_codes,
                split_fn(y_test),
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    predictive_decoder.predict(test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                aucs.append(sklearn.metrics.roc_auc_score(
                    split_fn(y_test),
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    predictive_decoder.predict(test_codes),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def wvae_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    vae_model=weak_vae.MLVaeArgmax,
):
    utils.reseed(87)
    if vae_model != beta_vae.BetaVAE:
        split_fn = lambda x: x[:, :x.shape[1]//2, ...] if len(x.shape) > 1 else x[:x.shape[0]//2]
    else:
        split_fn = lambda x: x
        
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        latent_dim = experiment_config['latent_dim']
        print("Training with latent dimensions", latent_dim, "in dataset", ds_name)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_trial_{trial}"
                )
            )
            decoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_trial_{trial}"
                )
            )
            
            wvae_model = vae_model(
                encoder=encoder,
                decoder=decoder,
                loss_fn=vae_losses.l2_loss_wrapper(),
                beta=experiment_config["beta"],
            )
            
            current_accs = []
            current_aucs = []
            
            for concept_idx in range(experiment_config["num_concepts"]):
                predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )
                predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\tTraining model for concept", concept_idx)
                train_codes = wvae_model.sample_from_latent_distribution(
                    *wvae_model.encoder(split_fn(x_train))
                ).numpy()
                test_codes = wvae_model.sample_from_latent_distribution(
                    *wvae_model.encoder(split_fn(x_test))
                ).numpy()
                predictive_decoder.fit(
                    x=train_codes,
                    y=split_fn(c_train)[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tEvaluating model")
                test_result = predictive_decoder.evaluate(
                    test_codes,
                    split_fn(c_test)[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accs.append(
                    test_result['binary_accuracy']
                )

                current_aucs.append(sklearn.metrics.roc_auc_score(
                    split_fn(c_test)[:, concept_idx],
                    predictive_decoder.predict(test_codes),
                ))
                print(
                f"\t\tTestconcept AUC = {current_aucs[-1]:.4f}, "
                f"concept accuracy = {current_accs[-1]:.4f}"
            )
            avg_concept_accs.append(np.mean(current_accs))
            avg_concept_aucs.append(np.mean(current_aucs))
            print(
                f"\t\tTest avg concept AUC = {avg_concept_aucs[-1]:.4f}, "
                f"avg concept accuracy = {avg_concept_accs[-1]:.4f}"
            )
            print("\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest average concept accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest average concept AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

## Experiments

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_mlvae_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ada_ml_vae/balanced_multilabel_purity_latent_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(ada_mlvae_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ada_mlvae_balanced_multilabel_figure_dir = os.path.join(ada_mlvae_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(ada_mlvae_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
ada_mlvae_balanced_multilabel_results = wvae_experiment_loop(
    ada_mlvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", ada_mlvae_balanced_multilabel_results["total_losses"])
print("purity_scores:", ada_mlvae_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_mlvae_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    ada_mlvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_predictive_accuracies:", ada_mlvae_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", ada_mlvae_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_mlvae_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    ada_mlvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_avg_concept_predictive_accuracies:", ada_mlvae_balanced_multilabel_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ada_mlvae_balanced_multilabel_results["latent_avg_concept_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_mlvae_balanced_multilabel_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    
    decoder_units=[256, 512],

    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
       RESULTS_DIR,
       f"ada_ml_vae/balanced_multilabel_purity_latent_{2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(ada_mlvae_balanced_multilabel_extended_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ada_mlvae_balanced_multilabel_extended_figure_dir = os.path.join(ada_mlvae_balanced_multilabel_extended_experiment_config["results_dir"], "figures")
Path(ada_mlvae_balanced_multilabel_extended_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
ada_mlvae_balanced_multilabel_extended_results = wvae_experiment_loop(
    ada_mlvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", ada_mlvae_balanced_multilabel_extended_results["total_losses"])
print("purity_scores:", ada_mlvae_balanced_multilabel_extended_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_mlvae_balanced_multilabel_extended_results.update(wvae_bottleneck_predict_experiment_loop(
    ada_mlvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_predictive_accuracies:", ada_mlvae_balanced_multilabel_extended_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", ada_mlvae_balanced_multilabel_extended_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_mlvae_balanced_multilabel_extended_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    ada_mlvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_avg_concept_predictive_accuracies:", ada_mlvae_balanced_multilabel_extended_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ada_mlvae_balanced_multilabel_extended_results["latent_avg_concept_predictive_aucs"])

# Ada-GVAE Benchmarking

## Multiclass Task Experiment

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_gvae_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
       RESULTS_DIR,
       f"ada_g_vae/balanced_multilabel_purity_latent_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(ada_gvae_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ada_gvae_balanced_multilabel_figure_dir = os.path.join(ada_gvae_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(ada_gvae_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
ada_gvae_balanced_multilabel_results = wvae_experiment_loop(
    ada_gvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1]),
            (balanced_multiclass_wvae_datasets[0][2], balanced_multiclass_wvae_datasets[0][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1]),
            (balanced_multiclass_wvae_datasets[1][2], balanced_multiclass_wvae_datasets[1][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1]),
            (balanced_multiclass_wvae_datasets[2][2], balanced_multiclass_wvae_datasets[2][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1]),
            (balanced_multiclass_wvae_datasets[3][2], balanced_multiclass_wvae_datasets[3][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1]),
            (balanced_multiclass_wvae_datasets[4][2], balanced_multiclass_wvae_datasets[4][3]),
        ),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", ada_gvae_balanced_multilabel_results["total_losses"])
print("purity_scores:", ada_gvae_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_gvae_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    ada_gvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            # Change to "balanced_multiclass_task_bin_concepts_dep_0_complete" on rerun
            "multiclass_task_bin_concepts_dep_0",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            # Change to "balanced_multiclass_task_bin_concepts_dep_1_complete" on rerun
            "multiclass_task_bin_concepts_dep_1",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            # Change to "balanced_multiclass_task_bin_concepts_dep_2_complete" on rerun
            "multiclass_task_bin_concepts_dep_2",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            # Change to "balanced_multiclass_task_bin_concepts_dep_3_complete" on rerun
            "multiclass_task_bin_concepts_dep_3",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            # Change to "balanced_multiclass_task_bin_concepts_dep_4_complete" on rerun
            "multiclass_task_bin_concepts_dep_4",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_predictive_accuracies:", ada_gvae_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", ada_gvae_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_gvae_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    ada_gvae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_avg_concept_predictive_accuracies:", ada_gvae_balanced_multilabel_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ada_gvae_balanced_multilabel_results["latent_avg_concept_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_gvae_balanced_multilabel_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
       RESULTS_DIR,
       f"ada_g_vae/multilabel_purity_latent_{2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(ada_gvae_balanced_multilabel_extended_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ada_gvae_balanced_multilabel_extended_figure_dir = os.path.join(ada_gvae_balanced_multilabel_extended_experiment_config["results_dir"], "figures")
Path(ada_gvae_balanced_multilabel_extended_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
ada_gvae_balanced_multilabel_extended_results = wvae_experiment_loop(
    ada_gvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1]),
            (balanced_multiclass_wvae_datasets[0][2], balanced_multiclass_wvae_datasets[0][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1]),
            (balanced_multiclass_wvae_datasets[1][2], balanced_multiclass_wvae_datasets[1][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1]),
            (balanced_multiclass_wvae_datasets[2][2], balanced_multiclass_wvae_datasets[2][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1]),
            (balanced_multiclass_wvae_datasets[3][2], balanced_multiclass_wvae_datasets[3][3]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1]),
            (balanced_multiclass_wvae_datasets[4][2], balanced_multiclass_wvae_datasets[4][3]),
        ),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", ada_gvae_balanced_multilabel_extended_results["total_losses"])
print("purity_scores:", ada_gvae_balanced_multilabel_extended_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_gvae_balanced_multilabel_extended_results.update(wvae_bottleneck_predict_experiment_loop(
    ada_gvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_predictive_accuracies:", ada_gvae_balanced_multilabel_extended_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", ada_gvae_balanced_multilabel_extended_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Task accuracy experiment run
############################################################################

ada_gvae_balanced_multilabel_extended_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    ada_gvae_balanced_multilabel_extended_experiment_config,
    load_from_cache=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        (
            "balanced_multiclass_task_bin_concepts_dep_0_complete",
            (balanced_multiclass_wvae_datasets[0][0], balanced_multiclass_wvae_datasets[0][1], balanced_multiclass_wvae_datasets[0][2]),
            (balanced_multiclass_wvae_datasets[0][3], balanced_multiclass_wvae_datasets[0][4], balanced_multiclass_wvae_datasets[0][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_1_complete",
            (balanced_multiclass_wvae_datasets[1][0], balanced_multiclass_wvae_datasets[1][1], balanced_multiclass_wvae_datasets[1][2]),
            (balanced_multiclass_wvae_datasets[1][3], balanced_multiclass_wvae_datasets[1][4], balanced_multiclass_wvae_datasets[1][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_2_complete",
            (balanced_multiclass_wvae_datasets[2][0], balanced_multiclass_wvae_datasets[2][1], balanced_multiclass_wvae_datasets[2][2]),
            (balanced_multiclass_wvae_datasets[2][3], balanced_multiclass_wvae_datasets[2][4], balanced_multiclass_wvae_datasets[2][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_3_complete",
            (balanced_multiclass_wvae_datasets[3][0], balanced_multiclass_wvae_datasets[3][1], balanced_multiclass_wvae_datasets[3][2]),
            (balanced_multiclass_wvae_datasets[3][3], balanced_multiclass_wvae_datasets[3][4], balanced_multiclass_wvae_datasets[3][5]),
        ),
        (
            "balanced_multiclass_task_bin_concepts_dep_4_complete",
            (balanced_multiclass_wvae_datasets[4][0], balanced_multiclass_wvae_datasets[4][1], balanced_multiclass_wvae_datasets[4][2]),
            (balanced_multiclass_wvae_datasets[4][3], balanced_multiclass_wvae_datasets[4][4], balanced_multiclass_wvae_datasets[4][5]),
        ),
    ],
))
print("latent_avg_concept_predictive_accuracies:", ada_gvae_balanced_multilabel_extended_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ada_gvae_balanced_multilabel_extended_results["latent_avg_concept_predictive_aucs"])

# Beta-VAE Benchmarking

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)


############################################################################
## Experiment config
############################################################################

beta_vae_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),

    beta=10,
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(beta_vae_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
beta_vae_balanced_multilabel_figure_dir = os.path.join(beta_vae_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(beta_vae_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
beta_vae_balanced_multilabel_results = wvae_experiment_loop(
    beta_vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", beta_vae_balanced_multilabel_results["total_losses"])
print("purity_scores:", beta_vae_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

beta_vae_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    beta_vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", beta_vae_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", beta_vae_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

beta_vae_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    beta_vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_avg_concept_predictive_accuracies:", beta_vae_balanced_multilabel_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", beta_vae_balanced_multilabel_results["latent_avg_concept_predictive_aucs"])

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)


############################################################################
## Experiment config
############################################################################

beta_vae_extended_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),

    beta=10,
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(beta_vae_extended_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
beta_vae_extended_balanced_multilabel_figure_dir = os.path.join(beta_vae_extended_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(beta_vae_extended_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
beta_vae_extended_balanced_multilabel_results = wvae_experiment_loop(
    beta_vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", beta_vae_extended_balanced_multilabel_results["total_losses"])
print("purity_scores:", beta_vae_extended_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

beta_vae_extended_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    beta_vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", beta_vae_extended_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", beta_vae_extended_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

beta_vae_extended_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    beta_vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_avg_concept_predictive_accuracies:", beta_vae_extended_balanced_multilabel_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", beta_vae_extended_balanced_multilabel_results["latent_avg_concept_predictive_aucs"])

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)


############################################################################
## Experiment config
############################################################################

vae_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),

    beta=1,
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(vae_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
vae_balanced_multilabel_figure_dir = os.path.join(vae_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(vae_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
vae_balanced_multilabel_results = wvae_experiment_loop(
    vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", vae_balanced_multilabel_results["total_losses"])
print("purity_scores:", vae_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)
############################################################################
## Task accuracy experiment run
############################################################################

vae_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", vae_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", vae_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

vae_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    vae_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", vae_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", vae_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)


############################################################################
## Experiment config
############################################################################

vae_extended_balanced_multilabel_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),

    beta=1,
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
)


# Generate the experiment directory if it does not exist already
Path(vae_extended_balanced_multilabel_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
vae_extended_balanced_multilabel_figure_dir = os.path.join(vae_extended_balanced_multilabel_experiment_config["results_dir"], "figures")
Path(vae_extended_balanced_multilabel_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################
vae_extended_balanced_multilabel_results = wvae_experiment_loop(
    vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    include_all_losses=True,
    include_encoder_purity=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)
print("total_losses:", vae_extended_balanced_multilabel_results["total_losses"])
print("purity_scores:", vae_extended_balanced_multilabel_results["purity_scores"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

vae_extended_balanced_multilabel_results.update(wvae_bottleneck_predict_experiment_loop(
    vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", vae_extended_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", vae_extended_balanced_multilabel_results["latent_predictive_aucs"])

In [None]:
reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Task accuracy experiment run
############################################################################

vae_extended_balanced_multilabel_results.update(wvae_bottleneck_concept_predict_experiment_loop(
    vae_extended_balanced_multilabel_experiment_config,
    load_from_cache=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
))
print("latent_predictive_accuracies:", vae_extended_balanced_multilabel_results["latent_predictive_accuracies"])
print("latent_predictive_aucs:", vae_extended_balanced_multilabel_results["latent_predictive_aucs"])

# CCD Benchmarking

## Model Setup

In [None]:
def construct_ccd_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_act=None,  # Leaving sigmoid as used in original paper
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    # TIme to generate the latent code here
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)
    
    return tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )

In [None]:
def construct_ccd_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_act=None,  #"sigmoid",  # Leaving sigmoid as used in original paper
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )

    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    # TIme to generate the latent code here
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)
    
    return tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )

def construct_ccd_decoder(units, num_outputs):
    decoder_layers = [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(
        [tf.keras.layers.Flatten()] +
        decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs if num_outputs > 2 else 1,
            activation=None,
            name="decoder_model_output",
        )
    ])

## Experiment Loop

In [None]:
import concepts_xai.methods.OCACE.topicModel as CCD
import concepts_xai.evaluation.metrics.oracle as oracle
import concepts_xai.evaluation.metrics.completeness as completeness

############################################################################
## Experiment loop
############################################################################

def ccd_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    oracle_matrix_cache = oracle_matrix_cache or {}
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        reconstruction_accuracies=[],
        reconstruction_aucs=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        aligned_purity_matrices=[],
        oracle_matrices=[],
        completeness_scores=[],
        direct_completeness_scores=[],
        mean_similarities=[],
    )
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (
        ds_name,
        (x_train, y_train, c_train), 
        (x_test, y_test, c_test),
    ) in datasets[start_ind:]:
        num_concepts = experiment_config["num_concepts"]
        print("Training with concepts", num_concepts, "in dataset", ds_name)
        task_accs = []
        recon_accs = []
        aucs = []
        recon_aucs = []
        purity_mats = []
        aligned_purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []
        compl_scores = []
        dir_compl_scores = []
        mean_sims = []
        
        channels_axis = (
            -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
            else 1
        )
        if experiment_config["num_outputs"] == 1:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                y_true,
                y_pred
            )
        else:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                tf.keras.utils.to_categorical(y_true),
                scipy.special.softmax(y_pred, axis=-1),
                multi_class='ovo',
            )
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for {num_concepts} concepts")
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_ccd_encoder(
                    input_shape=experiment_config["input_shape"],
                    filter_groups=experiment_config["encoder_filter_groups"],
                    latent_dims=experiment_config["latent_dims"],
                    units=experiment_config["encoder_units"],
                    drop_prob=experiment_config.get("drop_prob", 0.5),
                    max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                    max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                    latent_act=experiment_config.get("latent_act", None),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            print("\tModel pre-training...")
            
            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "min",
                ),
            )
            end_to_end_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tModel pre-training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            
            test_result = end_to_end_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            if experiment_config["num_outputs"] > 1:
                task_accs.append(test_result['sparse_top_k_categorical_accuracy'])
                
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(end_to_end_model.predict(x_test), axis=-1)

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                task_accs.append(test_result['binary_accuracy'])
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    end_to_end_model.predict(x_test),
                ))
            
            # Now extract our concept vectors
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=num_concepts,
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=end_to_end_model.loss,
                top_k=experiment_config.get("top_k", 32),
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
            )
            topic_model.compile(
                optimizer=tf.keras.optimizers.Adam(
                    experiment_config.get("learning_rate", 1e-3),
                )
            )
            
            # Train it for a few epochs
            print("\tTopic model training...")
            topic_model.fit(
                x=encoder(x_train),
                y=y_train,
                epochs=experiment_config["topic_model_train_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tTopic model training completed")
            
            print("\tSerializing model")
            topic_model.g_model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_topic_g_model_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            np.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_topic_vector_num_concepts_{num_concepts}_trial_{trial}.npy"
                ),
                topic_model.topic_vector.numpy(),
            )
            print("\tEvaluating model")
            
            topic_result = topic_model.evaluate(
                encoder(x_test),
                y_test,
                verbose=0,
                return_dict=True,
            )
            
            if experiment_config["num_outputs"] > 1:
                recon_accs.append(topic_result['accuracy'])
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    topic_model(encoder(x_test))[0],
                    axis=-1,
                )

                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                recon_aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                recon_accs.append(topic_result['accuracy'])
                recon_aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    topic_model(encoder(x_test))[0],
                ))
            mean_sims.append(topic_result['mean_sim'])
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}, "
                f"task reconstruction accuracy = {recon_accs[-1]:.4f}, "
                f"task reconstruction auc = {recon_aucs[-1]:.4f}, "
                f"mean concept similarity = {mean_sims[-1]:.4f}"
            )
            
                        
            # We start by extracting a completeness score for the extracted
            # concept vectors
            print(f"\t\tComputing completeness scores...")
            compl_score, _ = completeness.completeness_score(
                X=x_test,
                y=y_test,
                features_to_concepts_fn=encoder,
                concepts_to_labels_model=decoder,
                concept_vectors=np.transpose(topic_model.topic_vector.numpy()),
                task_loss=end_to_end_model.loss,
                channels_axis=channels_axis,
                concept_score_fn=lambda f, c: completeness.dot_prod_concept_score(
                    features=f,
                    concept_vectors=c,
                    channels_axis=channels_axis,
                    beta=experiment_config.get("threshold", 0.5),
                ),
                acc_fn=acc_fn,
            )
            compl_scores.append(compl_score)
            
            dir_compl_score, _ = completeness.direct_completeness_score(
                X=x_test,
                y=y_test,
                features_to_concepts_fn=encoder,
                concept_vectors=np.transpose(topic_model.topic_vector.numpy()),
                task_loss=end_to_end_model.loss,
                channels_axis=channels_axis,
                concept_score_fn=lambda f, c: completeness.dot_prod_concept_score(
                    features=f,
                    concept_vectors=c,
                    channels_axis=channels_axis,
                    beta=experiment_config.get("threshold", 0.5),
                ),
                acc_fn=acc_fn,
            )
            dir_compl_scores.append(dir_compl_score)
            
            print(
                f"\t\t\tCompleteness Score: {compl_scores[-1]:.4f} "
                f"and Direct Completeness Score: {dir_compl_scores[-1]:.4f}"
            )
            
            print(f"\t\tComputing purity score...")
            concept_scores = topic_model.concept_scores(encoder(x_test)).numpy()
            purity_score, (purity_mat, aligned_purity_mat), oracle_mat = oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                output_matrices=True,
                alignment_function=oracle.max_alignment_matrix,
                oracle_matrix=oracle_matrix_cache.get(ds_name),
            )
            
            purity_mats.append(purity_mat)
            aligned_purity_mats.append(aligned_purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                alignment_function=oracle.max_alignment_matrix,
                oracle_matrix=construct_trivial_auc_mat(
                    c_test.shape[-1]
                ),
                purity_matrix=aligned_purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        recon_acc_mean, recon_acc_std = np.mean(recon_accs), np.std(recon_accs)
        experiment_variables["reconstruction_accuracies"].append((recon_acc_mean, recon_acc_std))
        print(f"\tTest reconstruction accuracy: {recon_acc_mean:.4f} ± {recon_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")
        
        recon_auc_mean, recon_auc_std = np.mean(recon_aucs), np.std(recon_aucs)
        experiment_variables["reconstruction_aucs"].append((recon_auc_mean, recon_auc_std))
        print(f"\tTest reconstruction accuracy: {recon_auc_mean:.4f} ± {recon_auc_std:.4f}")
        
        mean_sim_mean, mean_sim_std = np.mean(mean_sims), np.std(mean_sims)
        experiment_variables["mean_similarities"].append((mean_sim_mean, mean_sim_std))
        print(f"\tMean concept similarity: {mean_sim_mean:.4f} ± {mean_sim_std:.4f}")
        
        
        compl_score_mean, compl_score_std = np.mean(compl_scores), np.std(compl_scores)
        experiment_variables["completeness_scores"].append((compl_score_mean, compl_score_std))
        print(f"\tCompleteness Score: {compl_score_mean:.4f} ± {compl_score_std:.4f}")
        
        dir_compl_score_mean, dir_compl_score_std = np.mean(dir_compl_scores), np.std(dir_compl_scores)
        experiment_variables["direct_completeness_scores"].append((dir_compl_score_mean, dir_compl_score_std))
        print(f"\tDirect completeness Score: {dir_compl_score_mean:.4f} ± {dir_compl_score_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))
        
        aligned_purity_mats = np.stack(aligned_purity_mats, axis=0)
        aligned_purity_mat_mean = np.mean(aligned_purity_mats, axis=0)
        aligned_purity_mat_std = np.std(aligned_purity_mats, axis=0)
        print("\tAligned purity matrix:")
        for i in range(aligned_purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(aligned_purity_mat_mean.shape[1]):
                line += f'{aligned_purity_mat_mean[i, j]:.4f} ± {aligned_purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["aligned_purity_matrices"].append((aligned_purity_mat_mean, aligned_purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        num_concepts = experiment_config['num_concepts']
        print("Training with num concepts", num_concepts, "in dataset", ds_name)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            
            decoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            
            
            g_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_topic_g_model_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            
            topic_vector = np.load(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_topic_vector_num_concepts_{num_concepts}_trial_{trial}.npy"
                )
            )
            
            
            # Now extract our concept vectors
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=num_concepts,
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                top_k=experiment_config.get("top_k", 32),
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
                initial_topic_vector=topic_vector,
            )
            
            
            concept_scores = topic_model.concept_scores(encoder(x_test)).numpy()
            
            current_accs = []
            current_aucs = []
            
            for concept_idx in range(experiment_config["data_concepts"]):
                predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )
                predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\tTraining model for concept", concept_idx)
                train_codes = topic_model.concept_scores(encoder(x_train)).numpy()
                test_codes = topic_model.concept_scores(encoder(x_test)).numpy()
                predictive_decoder.fit(
                    x=train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tEvaluating model")
                test_result = predictive_decoder.evaluate(
                    test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accs.append(
                    test_result['binary_accuracy']
                )

                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    predictive_decoder.predict(test_codes),
                ))
            avg_concept_accs.append(np.mean(current_accs))
            avg_concept_aucs.append(np.mean(current_aucs))
            print(
                f"\t\tTest avg concept AUC = {avg_concept_aucs[-1]:.4f}, "
                f"avg concept accuracy = {avg_concept_accs[-1]:.4f}"
            )
            print("\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest average concept accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest average concept AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def ccd_compute_k(y, batch_size):
    _, counts = np.unique(y, return_counts=True)
    avg_class_ratio = np.mean(counts) / y.shape[0]
    return int((avg_class_ratio * batch_size) / 2)

## Experiments

In [None]:
reload(completeness)
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_balanced_multiclass_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    topic_model_train_epochs=50,
    num_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    latent_dims=10,
    threshold=0.0,
    top_k=ccd_compute_k(y=balanced_multiclass_task_bin_concepts_dep_0_complete_train[1], batch_size=32),
    lambda1=0.1,
    lambda2=0.1,
    seed=42,
    eps=1e-5,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))],
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    holdout_fraction=0.1,
    trials=5,
    verbosity=0,
    early_stop_metric="val_loss",
    early_stop_mode="min",
)

# Generate the experiment directory if it does not exist already
Path(ccd_balanced_multiclass_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ccd_balanced_multiclass_figure_dir = os.path.join(ccd_balanced_multiclass_experiment_config["results_dir"], "figures")
Path(ccd_balanced_multiclass_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

ccd_balanced_multiclass_results = ccd_experiment_loop(
    experiment_config=ccd_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)


print("task_accuracies:", ccd_balanced_multiclass_results["task_accuracies"])
print("reconstruction_accuracies:", ccd_balanced_multiclass_results["reconstruction_accuracies"])
print("task_aucs:", ccd_balanced_multiclass_results["task_aucs"])
print("reconstruction_aucs:", ccd_balanced_multiclass_results["reconstruction_aucs"])
print("purity_scores:", ccd_balanced_multiclass_results["purity_scores"])
print("non_oracle_purity_scores:", ccd_balanced_multiclass_results["non_oracle_purity_scores"])
print("completeness_scores:", ccd_balanced_multiclass_results["completeness_scores"])
print("direct_completeness_scores:", ccd_balanced_multiclass_results["direct_completeness_scores"])
print("mean_similarities:", ccd_balanced_multiclass_results["mean_similarities"])

In [None]:
ccd_balanced_multiclass_results.update(ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config=ccd_balanced_multiclass_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))
print("latent_avg_concept_predictive_accuracies:", ccd_balanced_multiclass_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ccd_balanced_multiclass_results["latent_avg_concept_predictive_aucs"])

In [None]:
reload(completeness)
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_balanced_multiclass_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    topic_model_train_epochs=50,
    num_concepts=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    latent_dims=10,
    threshold=0.0,
    top_k=ccd_compute_k(y=balanced_multiclass_task_bin_concepts_dep_0_complete_train[1], batch_size=32),
    lambda1=0.1,
    lambda2=0.1,
    seed=42,
    eps=1e-5,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))],
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    
    latent_decoder_units=[64, 64],
    predictor_max_epochs=100,
    data_concepts=balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{2* balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    holdout_fraction=0.1,
    trials=5,
    verbosity=0,
    early_stop_metric="val_loss",
    early_stop_mode="min",
)

# Generate the experiment directory if it does not exist already
Path(ccd_balanced_multiclass_extended_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ccd_balanced_multiclass_extended_figure_dir = os.path.join(ccd_balanced_multiclass_extended_experiment_config["results_dir"], "figures")
Path(ccd_balanced_multiclass_extended_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

ccd_balanced_multiclass_extended_results = ccd_experiment_loop(
    experiment_config=ccd_balanced_multiclass_extended_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
    oracle_matrix_cache=balanced_oracle_matrix_cache,
)


print("task_accuracies:", ccd_balanced_multiclass_extended_results["task_accuracies"])
print("reconstruction_accuracies:", ccd_balanced_multiclass_extended_results["reconstruction_accuracies"])
print("task_aucs:", ccd_balanced_multiclass_extended_results["task_aucs"])
print("reconstruction_aucs:", ccd_balanced_multiclass_extended_results["reconstruction_aucs"])
print("purity_scores:", ccd_balanced_multiclass_extended_results["purity_scores"])
print("non_oracle_purity_scores:", ccd_balanced_multiclass_extended_results["non_oracle_purity_scores"])
print("completeness_scores:", ccd_balanced_multiclass_extended_results["completeness_scores"])
print("direct_completeness_scores:", ccd_balanced_multiclass_extended_results["direct_completeness_scores"])
print("mean_similarities:", ccd_balanced_multiclass_extended_results["mean_similarities"])

In [None]:
ccd_balanced_multiclass_extended_results.update(ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config=ccd_balanced_multiclass_extended_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
    load_from_cache=True,
))
print("latent_avg_concept_predictive_accuracies:", ccd_balanced_multiclass_extended_results["latent_avg_concept_predictive_accuracies"])
print("latent_avg_concept_predictive_aucs:", ccd_balanced_multiclass_extended_results["latent_avg_concept_predictive_aucs"])

# SENN Benchmarking

## Model Construction

In [None]:
import concepts_xai.methods.SENN.base_senn as SENN
import concepts_xai.methods.SENN.aggregators as aggregators
reload(SENN)
reload(aggregators)


def construct_senn_coefficient_model(units, num_concepts, num_outputs):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"coefficient_model_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_concepts * num_outputs,
            activation=None,
            name="coefficient_model_output",
        ),
        tf.keras.layers.Reshape([num_outputs, num_concepts])
    ])

def construct_senn_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    include_norm=False,
    include_pool=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for filter_args in filter_group:
            if len(filter_args) == 2:
                filter_args = (*filter_args, 1)
            (num_filters, kernel_size, stride) = filter_args
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                strides=stride,
                padding="SAME",
                activation=None if include_norm else "relu",
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            if include_norm:
                encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                    encoder_compute_graph
                )
                encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        if include_pool:
            # Then do a max pool here to control the parameter count of the model
            # at the end of each group
            encoder_compute_graph = tf.keras.layers.MaxPooling2D(
                pool_size=max_pool_window,
                strides=max_pool_stride,
            )(
                encoder_compute_graph
            )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    senn_encoder = tf.keras.Model(
        encoder_inputs,
        mean,
        name="senn_encoder",
    )
    vae_encoder = tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="vae_encoder",
    )
    return senn_encoder, vae_encoder

def construct_senn_model(
    concept_encoder,
    concept_decoder,
    coefficient_model,
    num_outputs,
    regularization_strength=0.1,
    learning_rate=1e-3,
    sparsity_strength=2e-5,
):
    def reconstruction_loss_fn(y_true, y_pred):
        return vae_losses.bernoulli_fn_wrapper()(y_true, concept_decoder(y_pred))
    
    senn_model = SENN.SelfExplainingNN(
        encoder_model=concept_encoder,
        coefficient_model=coefficient_model,
        aggregator_fn=(
            aggregators.multiclass_additive_aggregator if (num_outputs > 2)
            else aggregators.scalar_additive_aggregator
        ),
        task_loss_fn=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        reconstruction_loss_fn=reconstruction_loss_fn,
        regularization_strength=regularization_strength,
        sparsity_strength=sparsity_strength,
        name="SENN",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    senn_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return senn_model



## Experiment Loop

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def get_argmax_concept_explanations(preds, class_theta_scores):
    inds = np.argmax(preds, axis=-1)
    return np.take_along_axis(
        class_theta_scores,
        np.expand_dims(np.expand_dims(inds, axis=-1), axis=-1),
        axis=1,
    )

def senn_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    
    # Else, let's go ahead and run the whole thing
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset:", ds_name)
        task_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} of dataset {ds_name}")
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            concept_encoder, vae_encoder = construct_senn_encoder(
                input_shape=experiment_config["input_shape"],
                latent_dims=experiment_config["num_concepts"],
                filter_groups=experiment_config["filter_groups"],
                units=experiment_config["encoder_units"],
                drop_prob=experiment_config['drop_prob'],
                max_pool_window=experiment_config['max_pool_window'],
                max_pool_stride=experiment_config['max_pool_stride'],
                include_pool=experiment_config['include_pool'],
            )
            concept_decoder = construct_vae_decoder(
                output_shape=experiment_config["input_shape"],
                latent_dims=experiment_config["num_concepts"],
                units=experiment_config["decoder_units"],
            )
            coefficient_model = construct_senn_coefficient_model(
                units=experiment_config["coefficient_model_units"],
                num_concepts=experiment_config["num_concepts"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            if experiment_config.get("pretrain_autoencoder_epochs"):
                autoencoder = beta_vae.BetaVAE(
                    encoder=vae_encoder,
                    decoder=concept_decoder,
                    loss_fn=vae_losses.bernoulli_fn_wrapper(),
                    beta=experiment_config.get("beta", 1),
                )
                
                autoencoder.compile(
                    optimizer=tf.keras.optimizers.Adam(
                        experiment_config.get("learning_rate", 1e-3)
                    ),
                )
                
                print("\tAutoencoder pre-training...")
                autoencoder.fit(
                    x=x_train,
                    epochs=experiment_config["pretrain_autoencoder_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tAutoencoder training completed")

            # Now time to actually construct and train the CBM
            senn_model = construct_senn_model(
                concept_encoder=concept_encoder,
                concept_decoder=concept_decoder,
                coefficient_model=coefficient_model,
                num_outputs=experiment_config["num_outputs"],
                regularization_strength=experiment_config.get("regularization_strength", 0.1),
                learning_rate=experiment_config.get("learning_rate", 1e-3),
                sparsity_strength=experiment_config.get("sparsity_strength", 2e-5),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_total_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "max",
                ),
            )
            
            print("\tSENN training...")
            senn_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tSENN training completed")
            print("\tSerializing model")
            concept_encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_encoder_{ds_name}_trial_{trial}"
                )
            )
            concept_decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_decoder_{ds_name}_trial_{trial}"
                )
            )
            coefficient_model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/coefficient_model_{ds_name}_trial_{trial}"
                )
            )
            
            print("\tEvaluating model")
            test_result = senn_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )
            
            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    senn_model.predict(x_test)[0],
                    axis=-1
                )

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    senn_model.predict(x_test)[0],
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            x_test_preds, (_, x_test_theta_class_scores) = senn_model(x_test)
            test_concept_scores = get_argmax_concept_explanations(
                x_test_preds.numpy(),
                x_test_theta_class_scores.numpy(),
            )
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=test_concept_scores,
                c_true=c_test,
                output_matrices=True,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=test_concept_scores,
                c_true=c_test,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["num_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def senn_bottleneck_predict_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(datasets):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(datasets), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset", ds_name)
        task_accs = []
        aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            concept_encoder_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_encoder_{ds_name}_trial_{trial}"
                )
            )
            concept_decoder_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_decoder_{ds_name}_trial_{trial}"
                )
            )
            
            coefficient_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/coefficient_model_{ds_name}_trial_{trial}"
                )
            )
            
            
            predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining model")
            train_codes = encoder(x_train)
            if isinstance(train_codes, list):
                train_codes = np.concatenate(list(map(lambda x: x.numpy(), train_codes)), axis=-1)
            else:
                train_codes = train_codes.numpy()
            test_codes = encoder(x_test)
            if isinstance(test_codes, list):
                test_codes = np.concatenate(list(map(lambda x: x.numpy(), test_codes)), axis=-1)
            else:
                test_codes = test_codes.numpy()
            predictive_decoder.fit(
                x=train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\tEvaluating model")
            test_result = predictive_decoder.evaluate(
                test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    predictive_decoder.predict(test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                print(np.sum(preds[:100, :], axis=-1))
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    predictive_decoder.predict(test_codes),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

## Balanced Multiclass Task Experiments

In [None]:
reload(SENN)
# If convolution is used in encoder, then this may be needed :/
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.config.experimental_run_functions_eagerly(True)

############################################################################
## Experiment config
############################################################################

senn_dependency_multiclass_experiment_config = dict(
    max_epochs=30, #100,
    predictor_max_epochs=100,
    batch_size=32,
    trials=5,
    learning_rate=1e-3,
    
    num_concepts=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    encoder_units=[64, 64],
#     encoder_units=[256, 128, 64, 64, 32],
    include_pool=True,
    decoder_units=[256, 512],
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    drop_prob=0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    concept_cardinality=[2 for _ in range(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])],
    results_dir=os.path.join(RESULTS_DIR, "senn/dependency_multiclass"),
    verbosity=1,

    holdout_fraction=0.1,
    patience=float("inf"),
    early_stop_metric="val_loss",
    early_stop_mode="min",
    min_delta=1e-5,
)

# Generate the experiment directory if it does not exist already
Path(senn_dependency_multiclass_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
senn_dependency_multiclass_figure_dir = os.path.join(senn_dependency_multiclass_experiment_config["results_dir"], "figures")
Path(senn_dependency_multiclass_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

senn_dependency_multiclass_results = senn_experiment_loop(
    senn_dependency_multiclass_experiment_config,
    load_from_cache=False,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test),
    ],
)
print("task_accuracies:", senn_dependency_multiclass_results["task_accuracies"])
print("task_aucs:", senn_dependency_multiclass_results["task_aucs"])
print("purity_scores:", senn_dependency_multiclass_results["purity_scores"])
balanced_oracle_matrix_cache = {
    "balanced_multiclass_task_bin_concepts_dep_0": senn_dependency_multiclass_results["oracle_matrices"][0][0],
}

# Dataset-wide Results

In [None]:
all_vars = [0, 4]
num_concepts = 5
all_models = [
    ("Joint-CBM", graph_dependency_balanced_multiclass_results, "purity_scores"),
    ("CW MaxPool-Mean", cw_binary_balanced_multiclass_results, "purity_scores"),
    ("CW Feature Map", cw_binary_balanced_multiclass_results, "repr_purity_scores"),
    (f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})", ada_mlvae_balanced_multilabel_results, "purity_scores"),
    (f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})", ada_mlvae_balanced_multilabel_extended_results, "purity_scores"),
    (f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})", ada_gvae_balanced_multilabel_results, "purity_scores"),
    (f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})", ada_gvae_balanced_multilabel_extended_results, "purity_scores"),
    (f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})", ccd_balanced_multiclass_results, "purity_scores"),
    (f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})", ccd_balanced_multiclass_extended_results, "purity_scores"),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.25, 7))
for i, (method_name, results, kword) in enumerate(all_models):
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        list(map(
            lambda x: normalize_purity(x[0], num_concepts),
            np.array(results[kword])[all_vars]
        )),
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=list(map(
            lambda x: 2*normalize_purity(x[1], num_concepts),
            np.array(results[kword])[all_vars]
        )),
        capsize=10,
        edgecolor="black",
    )


lgd = ax.legend(prop={"size":20}, loc='upper center', bbox_to_anchor=(0.5,-0.15), ncol=(num_models - 1)//3)

plt.ylabel("Oracle Impurity Score", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Oracle Impurity (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Oracle Impurity (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
# ax.set_yticks(np.arange(0, 1.1, 0.1))
ax.set_xticklabels(all_vars, fontsize=20)
plt.yticks(fontsize=20)

ax.grid(False)
plt.tight_layout()
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
fig, ax = plt.subplots(figsize=(14, 10))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "concept_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
]


with sns.axes_style("darkgrid"):
    for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
        accs_means = list(map(
            lambda x: x[0],
            np.array(results[kword])[all_vars],
        ))
        accs_means = np.array(list(map(transform_fn, accs_means)))
        accs_stds = list(map(
            lambda x: x[1],
            np.array(results[kword])[all_vars],
        ))
        accs_stds = np.array(list(map(transform_fn, accs_stds)))
        ax.plot(
            np.arange(0, len(all_vars)),
            accs_means,
            c=clrs[i*2],
            zorder=1,
        )
        plt.scatter(
            np.arange(0, len(all_vars)),
            accs_means,
            label=method_name,
            s=150,
            color=clrs[i*2],
            zorder=2,
        )
        ax.fill_between(
            np.arange(0, len(all_vars)),
            accs_means - accs_stds,
            accs_means + accs_stds,
            alpha=0.3,
            facecolor=clrs[i*2],
        )
    
lgd = ax.legend(prop={"size":20}, loc='upper center', bbox_to_anchor=(0.5,-0.15), ncol=(num_models - 1)//3)


plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "concept_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "repr_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x:100 *  np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.75, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.15), ncol=(num_models - 1)//3)

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.array([0, 4])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "concept_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "repr_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "aligned_purity_matrices",
        lambda x:100 *  np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.15), ncol=(num_models - 1)//3)

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Mean Concept AUC (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=10,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.14), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = [0, 4]
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
#     (
#         "CW Mean",
#         cw_mean_binary_balanced_multiclass_results,
#         "task_aucs",
#         lambda x: 100 * x,
#     ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "task_aucs",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=10,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.14), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=10,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.14), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Downstream Task Accuracy (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task Accuracy (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = [0, 4]
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "task_accuracies",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=10,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.14), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=20)
    plt.title(bold_text("Downstream Task Accuracy (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task Accuracy (dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )


lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=25)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=25)
    plt.title(bold_text("Downstream Task AUC from Concepts (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC from Concepts(dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = [0, 4]
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "latent_feature_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "latent_predictive_aucs",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )


lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=25)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=25)
    plt.title(bold_text("Downstream Task AUC from Concepts (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC from Concepts(dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = np.arange(0, balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "latent_feature_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )


lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=25)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=25)
    plt.title(bold_text("Downstream Task Accuracy from Concepts (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task Accuracy from Concepts(dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()

In [None]:
# Plot task accuracies for all methods
# Set up our figure
all_vars = [0, 4]
num_concepts = 5
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        graph_dependency_balanced_multiclass_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW MaxPool-Mean",
        cw_binary_balanced_multiclass_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        "CW Feature Map",
        cw_binary_balanced_multiclass_results,
        "latent_feature_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-MLVAE (n\_latent = {ada_mlvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_mlvae_balanced_multilabel_extended_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"Ada-GVAE (n\_latent = {ada_gvae_balanced_multilabel_extended_experiment_config['latent_dim']})",
        ada_gvae_balanced_multilabel_extended_results,
        "latent_predictive_accuracies",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
    (
        f"CCD (n\_concepts = {ccd_balanced_multiclass_extended_experiment_config['num_concepts']})",
        ccd_balanced_multiclass_extended_results,
        "direct_completeness_scores",
        lambda x: 100 * x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )


lgd = ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=25)
if len(all_vars) > 1:
    plt.xlabel("Number of Dependency Edges ($\lambda$)", fontsize=25)
    plt.title(bold_text("Downstream Task Accuracy from Concepts (dSprites($\lambda$))"), fontsize=30)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(all_vars, fontsize=15)
else:
    plt.title(bold_text("Downstream Task Accuracy from Concepts(dSprites($\lambda = " + str(all_vars[0]) + "$))"), fontsize=30)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=20)
ax.grid(False)
plt.show()