In [None]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

# Purity dSprites Benchmarking

## Setup

In [None]:
import sys
import matplotlib
import concepts_xai
import numpy as np
import os
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
from tensorflow.keras.models import load_model
from joblib import dump, load
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import accuracy_score, roc_auc_score
import utils
import model_utils


import concepts_xai.evaluation.metrics.niching as niching

In [None]:
################################################################################
## Set seeds up for reproducibility
################################################################################
utils.reseed(87)


In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

FROM_CACHE = True
_LATEX_SYMBOL = "$"
BASE_DIR = '.'
RESULTS_DIR = os.path.join(BASE_DIR, "results/dsprites")
DATASETS_DIR = os.path.join(BASE_DIR, "results/dsprites", "datasets/")
NICHING_RESULTS_DIR = os.path.join(BASE_DIR, "results_concept_niching_integrated/dsprites")
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True)
Path(NICHING_RESULTS_DIR).mkdir(parents=True, exist_ok=True)
rc('text', usetex=(_LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')

## Utility Functions

In [None]:
def bold_text(x):
    if _LATEX_SYMBOL == "$":
        return r"$\textbf{" + x + "}$"
    return x

# Dataset Construction

In [None]:
import concepts_xai.datasets.dSprites as dsprites
import concepts_xai.datasets.latentFactorData as latentFactorData

def generate_dsprites_dataset(
    label_fn,
    filter_fn=None,
    dataset_path=None,
    concept_map_fn=lambda x: x,
    sample_map_fn=lambda x: x,
    dsprites_path="dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
    force_reload=False,
):
    if (not force_reload) and dataset_path and os.path.exists(dataset_path):
        # Them time to load up this dataset!
        ds = np.load(dataset_path)
        return (
            (ds["x_train"], ds["y_train"], ds["c_train"]),
            (ds["x_test"], ds["y_test"], ds["c_test"])
        )
    
    def _task_fn(x_data, c_data):
        return latentFactorData.get_task_data(
            x_data=x_data,
            c_data=c_data,
            label_fn=label_fn,
            filter_fn=filter_fn,
        )

    loaded_dataset = dsprites.dSprites(
        dataset_path=dsprites_path,
        train_size=0.8,
        random_state=42,
        task=_task_fn,
    )
    _, _, _ = loaded_dataset.load_data()

    x_train = sample_map_fn(loaded_dataset.x_train)
    y_train = loaded_dataset.y_train
    c_train = concept_map_fn(loaded_dataset.c_train)
    
    x_test = sample_map_fn(loaded_dataset.x_test)
    y_test = loaded_dataset.y_test
    c_test = concept_map_fn(loaded_dataset.c_test)
    
    if dataset_path:
        # Then serialize it to speed up things next time
        np.savez(
            dataset_path,
            x_train=x_train,
            y_train=y_train,
            c_train=c_train,
            x_test=x_test,
            y_test=y_test,
            c_test=c_test,
        )
    return (x_train, y_train, c_train), (x_test, y_test, c_test),


In [None]:
def count_class_balance(y):
    one_hot = tf.keras.utils.to_categorical(y)
    return np.sum(one_hot, axis=0) / one_hot.shape[0]

def multiclass_binary_concepts_map_fn(concepts):
    new_concepts = np.zeros((concepts.shape[0], 5))
    # We will have 5 concepts:
    # (0) "is it ellipse or square?"
    new_concepts[:, 0] = (concepts[:, 0] < 2).astype(np.int)

    # (1) "is_size < 3?"
    num_sizes = len(set(concepts[:, 1]))
    new_concepts[:, 1] = (concepts[:, 1] < num_sizes/2).astype(np.int)

    # (2) "is rotation < PI/2?"
    num_rots = len(set(concepts[:, 2]))
    new_concepts[:, 2] = (concepts[:, 2] < num_rots/2).astype(np.int)

    # (3) "is x <= 16?"
    num_x_coords = len(set(concepts[:, 3]))
    new_concepts[:, 3] = (concepts[:, 3] < num_x_coords // 2).astype(np.int)

    # (4) "is y <= 16?"
    num_y_coords = len(set(concepts[:, 4]))
    new_concepts[:, 4] = (concepts[:, 4] < num_y_coords // 2).astype(np.int)
    
    return new_concepts

def _get_concept_vector(c_data):
    return np.array([
        # First check if it is an ellipse or a square
        int(c_data[0] < 2),
        # Now check that it is "small"
        int(c_data[1] < 3),
        # And it has not been rotated more than PI/2 radians
        int(c_data[2] < 20),
        # Finally, check whether it is in not in the the upper-left quadrant
        int(c_data[3] < 15),
        int(c_data[4] < 15),
    ])

def multiclass_task_label_fn(c_data):
    # Our task will be a binary task where we are interested in determining
    # whether an image is a "small" ellipse not in the upper-left
    # quadrant that has been rotated less than 3*PI/2 radians
    concept_vector = _get_concept_vector(c_data)
    binary_label_encoding = [
        concept_vector[0] or concept_vector[1],
        concept_vector[2] or concept_vector[3],
        concept_vector[4],
    ]
    return int(
        "".join(list(map(str, binary_label_encoding))),
        2
    )

def dep_0_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(0, 6, 2)),
        list(range(0, 40, 4)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    return all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ])


scale_shape_sets_lower = [
    list(np.random.permutation(4))[:3] for i in range(3)
]

scale_shape_sets_upper = [
    list(2 + np.random.permutation(4))[:3] for i in range(3)
]
def dep_1_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 4)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    return True


rotation_scale_sets_lower = [
    list(np.random.permutation(30))[:20] for i in range(6)
]

rotation_scale_sets_upper = [
    list(10 + np.random.permutation(30))[:20] for i in range(6)
]
def dep_2_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32, 2)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)
    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
    return True

x_pos_rotation_sets_lower = [
    list(np.random.permutation(20))[:16]
    for i in range(40)
]

x_pos_rotation_sets_upper = [
    list(12 + np.random.permutation(20))[:16]
    for i in range(40)
]
def dep_3_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32)),
        list(range(0, 32, 2)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
        
    if concept_vector[2]:
        if concept[3] not in x_pos_rotation_sets_lower[concept[2]]:
            return False
    else:
        if concept[3] not in x_pos_rotation_sets_upper[concept[2]]:
            return False
    return True

y_pos_x_pos_sets_lower = [
    list(np.random.permutation(20))[:16]
    for i in range(32)
]

y_pos_x_pos_sets_upper = [
    list(12 + np.random.permutation(20))[:16]
    for i in range(32)
]
def dep_4_filter_fn(concept):
    ranges = [
        list(range(3)),
        list(range(6)),
        list(range(0, 40, 2)),
        list(range(0, 32)),
        list(range(0, 32)),
    ]
    
    concept_vector = _get_concept_vector(concept)

    # First filter as in small dataset to constraint the size of the data a bit
    if not all([
        (concept[i] in ranges[i]) for i in range(len(ranges))
    ]):
        return False
    if concept_vector[0]:
        if concept[1] not in scale_shape_sets_lower[concept[0]]:
            return False
    else:
        if concept[0] not in scale_shape_sets_upper[concept[0]]:
            return False
    
    if concept_vector[1]:
        if concept[2] not in rotation_scale_sets_lower[concept[1]]:
            return False
    else:
        if concept[2] not in rotation_scale_sets_upper[concept[1]]:
            return False
        
    if concept_vector[2]:
        if concept[3] not in x_pos_rotation_sets_lower[concept[2]]:
            return False
    else:
        if concept[3] not in x_pos_rotation_sets_upper[concept[2]]:
            return False
    
    if concept_vector[3]:
        if concept[4] not in y_pos_x_pos_sets_lower[concept[3]]:
            return False
    else:
        if concept[4] not in y_pos_x_pos_sets_upper[concept[3]]:
            return False
    return True

In [None]:
def balanced_multiclass_task_label_fn(c_data):
    # Our task will be a binary task where we are interested in determining
    # whether an image is a "small" ellipse not in the upper-left
    # quadrant that has been rotated less than 3*PI/2 radians
    concept_vector = _get_concept_vector(c_data)
    threshold = 0
    if concept_vector[0] == 1:
        binary_label_encoding = [
            concept_vector[1],
            concept_vector[2],
        ]
    else:
        threshold = 4
        binary_label_encoding = [
            concept_vector[3],
            concept_vector[4],
        ]
    return threshold + int(
        "".join(list(map(str, binary_label_encoding))),
        2
    )

balanced_multiclass_task_bin_concepts_dep_0_complete_train, balanced_multiclass_task_bin_concepts_dep_0_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_0_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_0_complete_dataset.npz"),
    dsprites_path=os.path.join(BASE_DIR, "dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
    concept_map_fn=multiclass_binary_concepts_map_fn,
    force_reload=True,
)

dep_0_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
    )
)
for c in range(dep_0_corr_mat.shape[0]):
    for l in range(dep_0_corr_mat.shape[1]):
        dep_0_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_0_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_0_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_0_corr_mat),
    [f"$c_{i}$" for i in range(dep_0_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_0_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0, #-1,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 0$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_0_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_0_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_1_complete_train, balanced_multiclass_task_bin_concepts_dep_1_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_1_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_1_complete_dataset.npz"),
    dsprites_path=os.path.join(BASE_DIR, "dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
    concept_map_fn=multiclass_binary_concepts_map_fn,
    force_reload=True,
)

dep_1_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_1_complete_train[1]))
    )
)
for c in range(dep_1_corr_mat.shape[0]):
    for l in range(dep_1_corr_mat.shape[1]):
        dep_1_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_1_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_1_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_1_corr_mat),
    [f"$c_{i}$" for i in range(dep_1_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_1_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0, #-1,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 1$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()


print("balanced_multiclass_task_bin_concepts_dep_1_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_1_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_1_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_1_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_1_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_1_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_2_complete_train, balanced_multiclass_task_bin_concepts_dep_2_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_2_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_2_complete_dataset.npz"),
    dsprites_path=os.path.join(BASE_DIR, "dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
    concept_map_fn=multiclass_binary_concepts_map_fn,
    force_reload=True,
)

dep_2_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_2_complete_train[1]))
    )
)
for c in range(dep_2_corr_mat.shape[0]):
    for l in range(dep_2_corr_mat.shape[1]):
        dep_2_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_2_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_2_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_2_corr_mat),
    [f"$c_{i}$" for i in range(dep_2_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_2_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0, #-1,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 2$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()


print("balanced_multiclass_task_bin_concepts_dep_2_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_2_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_2_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_2_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_2_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_2_complete_train[2].shape[0])
print("")


balanced_multiclass_task_bin_concepts_dep_3_complete_train, balanced_multiclass_task_bin_concepts_dep_3_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_3_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_3_complete_dataset.npz"),
    dsprites_path=os.path.join(BASE_DIR, "dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
    concept_map_fn=multiclass_binary_concepts_map_fn,
    force_reload=True,
)

dep_3_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_3_complete_train[1]))
    )
)
for c in range(dep_3_corr_mat.shape[0]):
    for l in range(dep_3_corr_mat.shape[1]):
        dep_3_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_3_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_3_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_3_corr_mat),
    [f"$c_{i}$" for i in range(dep_3_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_3_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0, #-1,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 3$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_3_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_3_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_3_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_3_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_3_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_3_complete_train[2].shape[0])
print("")



balanced_multiclass_task_bin_concepts_dep_4_complete_train, balanced_multiclass_task_bin_concepts_dep_4_complete_test = generate_dsprites_dataset(
    label_fn=balanced_multiclass_task_label_fn,
    filter_fn=dep_4_filter_fn,
    dataset_path=os.path.join(DATASETS_DIR, "balanced_multiclass_task_bin_concepts_dep_4_complete_dataset.npz"),
    dsprites_path=os.path.join(BASE_DIR, "dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
    concept_map_fn=multiclass_binary_concepts_map_fn,
    force_reload=True,
)

dep_4_corr_mat = np.ones(
    (
        balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape[-1],
        len(set(balanced_multiclass_task_bin_concepts_dep_4_complete_train[1]))
    )
)
for c in range(dep_4_corr_mat.shape[0]):
    for l in range(dep_4_corr_mat.shape[1]):
        dep_4_corr_mat[c][l] = np.corrcoef(
            balanced_multiclass_task_bin_concepts_dep_4_complete_train[2][:, c],
            (balanced_multiclass_task_bin_concepts_dep_4_complete_train[1] == l).astype(np.int32),
        )[0, 1]
fig, ax = plt.subplots(1, figsize=(8, 6))
im, cbar = utils.heatmap(
    np.abs(dep_4_corr_mat),
    [f"$c_{i}$" for i in range(dep_4_corr_mat.shape[0])],
    [f"$l_{i}$" for i in range(dep_4_corr_mat.shape[1])],
    ax=ax,
    cmap="magma",
    cbarlabel=f"Correlation Coef",
    vmin=0, #-1,
    vmax=1,
)
texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
fig.tight_layout()

fig.suptitle(f"dSprites Concept-Label Absolute Correlations ($\lambda = 4$)", fontsize=25)
fig.subplots_adjust(top=0.85)
plt.show()

print("balanced_multiclass_task_bin_concepts_dep_4_complete_dataset train size:", balanced_multiclass_task_bin_concepts_dep_4_complete_train[0].shape[0])
print("balanced_multiclass_task_bin_concepts_dep_4_complete_dataset train concept size:", balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape)
print("\tTrain balance:", count_class_balance(balanced_multiclass_task_bin_concepts_dep_4_complete_train[1]))
print("\tConcept balance:", np.sum(balanced_multiclass_task_bin_concepts_dep_4_complete_train[2], axis=0)/balanced_multiclass_task_bin_concepts_dep_4_complete_train[2].shape[0])

In [None]:
def balanced_niching_dataset(x, y, c):
    enc = OneHotEncoder(sparse=False)
    y = enc.fit_transform(y.reshape(-1, 1))
    n_samples_per_class = y.sum(axis=0, dtype=int)
    n_samples_to_draw = int(n_samples_per_class.min())
    samples_per_class = []
    for i in range(y.shape[1]):
        samples_per_class.append(np.argwhere(y[:, i]==1)[np.random.choice(n_samples_per_class[i], n_samples_to_draw)].squeeze())
    samples_per_class = np.concatenate(samples_per_class)
    return x[samples_per_class], y[samples_per_class], c[samples_per_class]

In [None]:
train_0 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_0_complete_train)
train_1 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_1_complete_train)
train_2 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_2_complete_train)
train_3 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_3_complete_train)
train_4 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_4_complete_train)

test_0 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_0_complete_test)
test_1 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_1_complete_test)
test_2 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_2_complete_test)
test_3 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_3_complete_test)
test_4 = balanced_niching_dataset(*balanced_multiclass_task_bin_concepts_dep_4_complete_test)

## Model Construction

In [None]:
# Construct the encoder model
def _extract_concepts(activations, concept_cardinality):
    concepts = []
    total_seen = 0
    if all(np.array(concept_cardinality) <= 1):
        # Then nothing to do here as they are all binary concepts
        return activations
    for num_values in concept_cardinality:
        concepts.append(activations[:, total_seen: total_seen + num_values])
        total_seen += num_values
    return concepts
    
def construct_encoder(
    input_shape,
    filter_groups,
    units,
    concept_cardinality,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_dims=0,
    output_logits=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    if latent_dims:
        bypass = tf.keras.layers.Dense(
            latent_dims,
            activation="sigmoid",
            name="encoder_bypass_channel",
        )(encoder_compute_graph)
    else:
        bypass = None
    
    # Map to our output distribution to a flattened
    # vector where we will extract distributions over
    # all concept values
    encoder_compute_graph = tf.keras.layers.Dense(
        sum(concept_cardinality),
        activation=None,
        name="encoder_concept_outputs",
    )(encoder_compute_graph)
        
    # Separate this vector into all of its heads
    concept_outputs = _extract_concepts(
        encoder_compute_graph,
        concept_cardinality,
    )
    if not output_logits:
        if isinstance(concept_outputs, list):
            for i, concept_vec in enumerate(concept_outputs):
                if concept_vec.shape[-1] == 1:
                    # Then this is a binary concept so simply apply sigmoid
                    concept_outputs[i] = tf.keras.activations.sigmoid(concept_vec)
                else:
                    # Else we will apply a softmax layer as we assume that all of these
                    # entries represent a multi-modal probability distribution
                    concept_outputs[i] = tf.keras.activations.softmax(
                        concept_vec,
                        axis=-1,
                    )
        else:
            # Else they are allbinary concepts so let's sigmoid them
            concept_outputs = tf.keras.activations.sigmoid(concept_outputs)
    return tf.keras.Model(
        encoder_inputs,
        [concept_outputs, bypass] if bypass is not None else concept_outputs,
        name="encoder",
    )

In [None]:
############################################################################
## Build concepts-to-labels model
############################################################################

def construct_decoder(units, num_outputs):
    decoder_layers = [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs if num_outputs > 2 else 1,
            activation=None,
            name="decoder_model_output",
        )
    ])

In [None]:
# Construct the complete model
def construct_end_to_end_model(
    input_shape,
    encoder,
    decoder,
    num_outputs,
    learning_rate=1e-3,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    latent = encoder(model_inputs)
    if isinstance(latent, list):
        if len(latent) > 1:
            compacted_vector = tf.keras.layers.Concatenate(axis=-1)(
                latent
            )
        else:
            compacted_vector = latent[0]
    else:
        compacted_vector = latent
    model_compute_graph = decoder(compacted_vector)
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        metrics=[
            "binary_accuracy" if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    return model, encoder, decoder

# CBM Benchmarking

## Model Definition

In [None]:
############################################################################
## Build CBM
############################################################################
import concepts_xai.methods.CBM.CBModel as CBM

def construct_cbm(
    encoder,
    decoder,
    num_outputs,
    alpha=0.1,
    learning_rate=1e-3,
    latent_dims=0,
    encoder_output_logits=False,
):
    model_factory = CBM.BypassJointCBM if latent_dims else CBM.JointConceptBottleneckModel
    cbm_model = model_factory(
        encoder=encoder,
        decoder=decoder,
        task_loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        name="joint_cbm",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
        alpha=alpha,
        pass_concept_logits=encoder_output_logits,
    )

    ############################################################################
    ## Compile CBM Model
    ############################################################################

    cbm_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return cbm_model


## Experiment Loop

In [None]:
import concepts_xai.evaluation.metrics.purity as purity
from collections import defaultdict

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def cbm_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset:", ds_name)
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} of dataset {ds_name}")
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    filter_groups=experiment_config["encoder_filter_groups"],
                    units=experiment_config["encoder_units"],
                    concept_cardinality=experiment_config["concept_cardinality"],
                    drop_prob=experiment_config.get("drop_prob", 0.5),
                    max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                    max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                    latent_dims=experiment_config.get("latent_dims", 0),
                    output_logits=experiment_config.get("encoder_output_logits", False),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                num_outputs=experiment_config["num_outputs"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            print("\t\tCBM training completed")
            print("\tSerializing model")
            cbm_model.encoder = load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_{ds_name}_trial_{trial}"
                )
            )

            print("\t\tEncode...")
            concept_train_list = cbm_model.encoder(x_train)
            concept_test_list = cbm_model.encoder(x_test)
            
            c_train_pred = []
            for concept in concept_train_list:
                c_train_pred.append(concept[:, 1])
            c_train_pred = tf.stack(c_train_pred, axis=1).numpy()
            
            print("\t\tListed...")
            
            c_test_pred = []
            for concept in concept_test_list:
                c_test_pred.append(concept[:, 1])
            c_test_pred = tf.stack(c_test_pred, axis=1).numpy()
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables


### Experiments

In [None]:
reload(niching)
reload(CBM)

############################################################################
## Experiment config
############################################################################

cbm_base_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=10,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))]
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    concept_cardinality=[
        2 for _ in range(multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
    ],
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cbm/graph_dependency_balanced_multiclass_tasks_purity"),
    niching_results_dir=os.path.join(NICHING_RESULTS_DIR, "cbm/base"),
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dims=0,
    holdout_fraction=0.1,
    verbosity=0,
    early_stop_metric="val_concept_accuracy",
    early_stop_mode="max",
    delta_beta=0.05,
    encoder_output_logits=False,
)

# Generate the experiment directory if it does not exist already
Path(cbm_base_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
cbm_base_figure_dir = os.path.join(cbm_base_experiment_config["results_dir"], "figures")
Path(cbm_base_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

cbm_base_results = cbm_experiment_loop(
    cbm_base_experiment_config,
    load_from_cache=FROM_CACHE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
)

In [None]:
reload(niching)
reload(CBM)

############################################################################
## Experiment config
############################################################################

cbm_from_logits_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=10,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))]
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    concept_cardinality=[
        2 for _ in range(multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])
    ],
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cbm/graph_dependency_balanced_multiclass_from_logits_tasks_purity"),
    niching_results_dir=os.path.join(NICHING_RESULTS_DIR, "cbm/from_logits"),
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dims=0,
    holdout_fraction=0.1,
    verbosity=0,
    early_stop_metric="val_concept_accuracy",
    early_stop_mode="max",
    delta_beta=0.05,
    encoder_output_logits=True,
)

############################################################################
## Experiment run
############################################################################

cbm_from_logits_results = cbm_experiment_loop(
    cbm_from_logits_experiment_config,
    load_from_cache=FROM_CACHE,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
)

# CW Benchmark

In [None]:
import concepts_xai.methods.CW.CWLayer as CW

def conv_predictor_model_fn(
    input_concept_classes=1,
    output_concept_classes=2,
):
    estimator = tf.keras.models.Sequential([
         tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            output_concept_classes if output_concept_classes > 2 else 1,
            # We will merge the activation into the loss for numerical
            # stability
            activation=None,
        ),
    ])
    estimator.compile(
        # Use ADAM optimizer by default
        optimizer='adam',
        # Note: we assume labels come without a one-hot-encoding in the
        #       case when the concepts are categorical.
        loss=(
            tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True
            ) if output_concept_classes > 2 else
            tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
            )
        ),
    )
    return estimator



def construct_cw_model(
    input_shape,
    num_outputs,
    filter_groups,
    units,
    activation_mode,
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    T=5,
    eps=1e-5,
    momentum=0.9,
    c1=1e-4,
    c2=0.9,
    max_tau_iterations=500,
    initial_tau=1000.0,
    initial_beta=1e8,
    initial_alpha=0,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    model_compute_graph = model_inputs
    
    # Start with our convolutions
    num_convs = 0
    cw_inputs = []
    cw_ouputs = []
    for filter_group in filter_groups:
        # Add a default "no CW layer" to each filter group
        # if they have not specified this feature
        filter_group = list(map(
            lambda x: x if len(x) == 3 else (x[0], x[1], False),
            filter_group
        ))
        for (num_filters, kernel_size, cw_layer) in filter_group:
            model_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(model_compute_graph)
            num_convs += 1
            if cw_layer:
                cw_inputs.append(model_compute_graph)
                model_compute_graph = CW.ConceptWhiteningLayer(
                    data_format="channels_last",
                    activation_mode=activation_mode,
                    T=T,
                    eps=eps,
                    momentum=momentum,
                    c1=c1,
                    c2=c2,
                    max_tau_iterations=max_tau_iterations,
                    initial_tau=initial_tau,
                    initial_beta=initial_beta,
                    initial_alpha=initial_alpha,
                )(
                    model_compute_graph
                )
                cw_ouputs.append(model_compute_graph)
            else:
                model_compute_graph = tf.keras.layers.BatchNormalization(
                    axis=-1,
                )(model_compute_graph)
            model_compute_graph = tf.keras.activations.relu(model_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        model_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            model_compute_graph
        )
    
    # Flatten this guy
    model_compute_graph = tf.keras.layers.Flatten()(model_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        model_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            model_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        model_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(model_compute_graph)
    
    # Map to our output distribution to a flattened
    # vector where we will extract distributions over
    # all concept values
    model_compute_graph = tf.keras.layers.Dense(
        num_outputs,
        activation=None,
        name="logits",
    )(model_compute_graph)
    
  
    cw_model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="cw_model",
    )
    cw_model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        metrics=[
            "binary_accuracy" if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )

    encoder = tf.keras.Model(
        model_inputs,
        cw_inputs,
        name="encoder_model",
    )

    cw_output_model = tf.keras.Model(
        model_inputs,
        cw_ouputs,
        name="cw_output_model",
    )
    return cw_model, encoder, cw_output_model

In [None]:
from scipy.special import softmax

def channels_corr_mat(outputs):
    if len(outputs.shape) == 2:
        outputs = np.expand_dims(
            np.expand_dims(outputs, axis=1),
            axis=1,
        )
    # Change (N, H, W, C) to (C, N, H, W)
    outputs = np.transpose(outputs, [3, 0, 1, 2])
    # Change (C, N, H, W) to (C, NxHxW)
    cnhw_shape = outputs.shape
    outputs = np.transpose(np.reshape(outputs, [cnhw_shape[0], -1]))
    outputs -= np.mean(outputs, axis=0, keepdims=True)
    outputs = outputs / np.std(outputs, axis=0, keepdims=True)
    return np.dot(outputs.transpose(), outputs) / outputs.shape[0]

def conv_predictor_model_fn(
    input_concept_classes=1,
    output_concept_classes=2,
):
    estimator = tf.keras.models.Sequential([
         tf.keras.layers.Conv2D(
            filters=16,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=(3,3),
            padding="SAME",
            activation="relu",
            data_format="channels_last",
        ),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            output_concept_classes if output_concept_classes > 2 else 1,
            # We will merge the activation into the loss for numerical
            # stability
            activation=None,
        ),
    ])
    estimator.compile(
        # Use ADAM optimizer by default
        optimizer='adam',
        # Note: we assume labels come without a one-hot-encoding in the
        #       case when the concepts are categorical.
        loss=(
            tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True
            ) if output_concept_classes > 2 else
            tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
            )
        ),
    )
    return estimator



def concept_scores(
    cw_layer,
    inputs,
    aggregator='max_pool_mean',
    concept_indices=None,
):
    outputs = cw_layer(inputs, training=False)
    if len(tf.shape(outputs)) == 2:
        # Then the scores are already computed by our forward pass
        scores = outputs
    else:
        # Else, we need to do some aggregation
        if aggregator == 'mean':
            # Compute the mean over all channels
            scores = tf.math.reduce_mean(outputs, axis=[2, 3])
        elif aggregator == 'max_pool_mean':
            # First downsample using a max pool and then continue with
            # a mean
            window_size = min(
                2,
                outputs.shape[-1],
                outputs.shape[-2],
            )
            scores = tf.nn.max_pool(
                outputs,
                ksize=window_size,
                strides=window_size,
                padding="SAME",
                data_format="NCHW",
            )
            scores = tf.math.reduce_mean(scores, axis=[2, 3])
        elif aggregator == 'max':
            # Simply select the maximum value across a given channel
            scores = tf.math.reduce_max(outputs, axis=[2, 3])
        else:
            raise ValueError(f'Unsupported aggregator {aggregator}.')

    if concept_indices is not None:
        return scores[:, concept_indices]
    return scores



def cw_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
        
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    )
    verbosity = experiment_config.get("verbosity", 0)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with dataset:", ds_name)
        
        if not experiment_config.get("exclusive_concepts", False):
            concept_groups = [
                x_train[c_train[:, i] == 1, :, :, :]
                for i in range(c_train.shape[-1])
            ]
        else:
            concept_groups = [
                x_train[np.logical_and(c_train[:, i] == 1, np.sum(c_train, axis=-1) == 1), :, :, :]
                for i in range(c_train.shape[-1])
            ]
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} of {ds_name}")
            if not experiment_config.get("exclusive_test_concepts", False):
                test_concept_groups = [
                    x_train[c_train[:, i] == 1, :, :, :]
                    for i in range(c_test.shape[-1])
                ]
            else:
                test_concept_groups = [
                    x_test[np.logical_and(c_test[:, i] == 1, np.sum(c_test, axis=-1) == 1), :, :, :]
                    for i in range(c_test.shape[-1])
                ]
            print("Sizes of test_concept_groups:", list(map(lambda x: x.shape, test_concept_groups)))
            for i, group in enumerate(test_concept_groups):
                max_test_size = experiment_config.get("max_concept_group_size", group.shape[-1])
                test_concept_groups[i] = (
                    group[np.random.choice(group.shape[0], max_test_size), : :, :]
                    if group.shape[0] > max_test_size else group
                )


            # Construct our CW model
            model, encoder, cw_model = construct_cw_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                filter_groups=experiment_config["filter_groups"],
                units=experiment_config["units"],
                drop_prob=experiment_config.get("drop_prob", 0),
                max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                T=experiment_config.get("T", 5),
                eps=experiment_config.get("eps", 1e-5),
                momentum=experiment_config.get("momentum", 0.9),
                activation_mode=experiment_config["activation_mode"],
                c1=experiment_config.get("c1", 1e-4),
                c2=experiment_config.get("c2", 0.9),
                max_tau_iterations=experiment_config.get("max_tau_iterations", 500),
                initial_tau=experiment_config.get("initial_tau", 1000),
                initial_beta=experiment_config.get("initial_beta", 1e8),
                initial_alpha=experiment_config.get("initial_alpha", 0),
            )
            
            print("\t\tCW training completed")
            model = load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_{ds_name}_trial_{trial}"
                )
            )
            
             # finding niches for several values of beta
            niche_sizes = []
            niche_impurities = []
            # And estimate the area under the curve using the trapezoid method
            total_area_under_curve_map = defaultdict(float)
            prev_value_map = {}
            delta_beta = experiment_config.get("delta_beta", 0.05)
            if not experiment_config['feature_map']:
                encoder = tf.keras.Model(
                    inputs=model.get_layer(model.layers[0].name).input,
                    outputs=model.get_layer(model.layers[13].name).output,
                )
                c_train_pred = concept_scores(
                    model.layers[14],
                    encoder(x_train),
                    aggregator=experiment_config['aggregator'],
                ).numpy()
                c_train_pred = c_train_pred[:, :experiment_config["num_concepts"]]
                c_test_pred = concept_scores(
                    model.layers[14],
                    encoder(x_test),
                    aggregator=experiment_config['aggregator'],
                ).numpy()[:, :experiment_config["num_concepts"]]
                
            else:
                encoder = tf.keras.Model(
                    inputs=model.get_layer(model.layers[0].name).input,
                    outputs=model.get_layer(model.layers[14].name).output,
                )

                c_train_pred = encoder(x_train)
                c_train_pred = c_train_pred[:, :, :, :experiment_config["num_concepts"]]
                c_test_pred = encoder(x_test)
                c_test_pred = c_test_pred[:, :, :, :experiment_config["num_concepts"]]
                out_shape = c_train_pred.shape[1]*c_train_pred.shape[2]
                c_train_pred = c_train_pred.numpy().reshape(-1, out_shape, c_train_pred.shape[3])
                c_test_pred = c_test_pred.numpy().reshape(-1, out_shape, c_test_pred.shape[3])
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables

## Experiments

In [None]:
reload(CW)
reload(niching)

############################################################################
## Experiment config
############################################################################

cw_features_experiment_config = dict(
    batch_size=256,
    max_epochs=50,
    pre_train_epochs=50,
    post_train_epochs=0,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    cw_train_iterations=1,
    exclusive_concepts=False,
    exclusive_test_concepts=False,
    trials=5,
    learning_rate=1e-3,
    filter_groups=[
        [(8, (7, 7), False)],
        [(16, (5, 5), False)],
        [(32, (3, 3), False)],
        [(64, (3, 3), True)],
    ],
    units=[64, 64, 64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cw/balanced_multiclass_tasks_purity"),
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    cw_layer=14,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    concept_auc_freq=0,
    holdout_fraction=0.1,
    heatmap_display=True,
    T=5,
    eps=1e-5,
    momentum=0.9,
    initial_tau=1000,
    initial_beta=1e8,
    initial_alpha=0,
    drop_prob=0,    
    feature_map=True,
    add='',
    delta_beta=0.05,
)
cw_features_experiment_config['niching_results_dir'] = os.path.join(
    NICHING_RESULTS_DIR,
    f'cw/base_feature_{cw_features_experiment_config["feature_map"]}{cw_features_experiment_config["add"]}',
)

############################################################################
## Experiment run
############################################################################

# model, c_train_pred, y_train, c_test_pred, y_test = cw_experiment_loop(
cw_features_results = cw_experiment_loop(
    cw_features_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=True,
)

In [None]:
reload(CW)
reload(niching)

############################################################################
## Experiment config
############################################################################

cw_base_experiment_config = dict(
    batch_size=256,
    max_epochs=50,
    pre_train_epochs=50,
    post_train_epochs=0,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    cw_train_iterations=1,
    exclusive_concepts=False,
    exclusive_test_concepts=False,
    trials=5,
    learning_rate=1e-3,
    filter_groups=[
        [(8, (7, 7), False)],
        [(16, (5, 5), False)],
        [(32, (3, 3), False)],
        [(64, (3, 3), True)],
    ],
    units=[64, 64, 64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cw/balanced_multiclass_tasks_purity_max_pool_mean"),
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    cw_layer=14,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    concept_auc_freq=0,
    holdout_fraction=0.1,
    heatmap_display=True,
    T=5,
    eps=1e-5,
    momentum=0.9,
    initial_tau=1000,
    initial_beta=1e8,
    initial_alpha=0,
    drop_prob=0,    
    feature_map=False,
    add='',
    delta_beta=0.05,
)
cw_base_experiment_config['niching_results_dir'] = os.path.join(
    NICHING_RESULTS_DIR,
    f'cw/base_feature_{cw_base_experiment_config["feature_map"]}{cw_base_experiment_config["add"]}',
)


############################################################################
## Experiment run
############################################################################

cw_base_results = cw_experiment_loop(
    cw_base_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
reload(CW)

############################################################################
## Experiment config
############################################################################

cw_mean_experiment_config = dict(
    batch_size=256,
    max_epochs=50,
    pre_train_epochs=50,
    post_train_epochs=0,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    cw_train_iterations=1,
    exclusive_concepts=False,
    exclusive_test_concepts=False,
    trials=5,
    learning_rate=1e-3,
    filter_groups=[
        [(8, (7, 7), False)],
        [(16, (5, 5), False)],
        [(32, (3, 3), False)],
        [(64, (3, 3), True)],
    ],
    units=[64, 64, 64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1]))
        if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2 else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(RESULTS_DIR, "cw/balanced_multiclass_tasks_purity_mean"),
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    cw_layer=14,
    aggregator='mean',
    activation_mode='mean',
    concept_auc_freq=0,
    holdout_fraction=0.1,
    heatmap_display=True,
    T=5,
    eps=1e-5,
    momentum=0.9,
    initial_tau=1000,
    initial_beta=1e8,
    initial_alpha=0,
    drop_prob=0,    
    feature_map=False,
    add='_mean',
    delta_beta=0.05,
)
cw_mean_experiment_config['niching_results_dir'] = os.path.join(
    NICHING_RESULTS_DIR,
    f'cw/base_feature_{cw_mean_experiment_config["feature_map"]}{cw_mean_experiment_config["add"]}',
)

# Generate the experiment directory if it does not exist already
Path(cw_mean_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
cw_mean_figure_dir = os.path.join(cw_mean_experiment_config["results_dir"], "figures")
Path(cw_mean_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

cw_mean_results = cw_experiment_loop(
    cw_mean_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

# Ada-ML-VAE Benchmarking

## Weakly Labelled Dataset Construction

In [None]:
cwtrain_0 = [np.concatenate((train_0[0], train_0[0]), axis=1), train_0[2], train_0[1]]
cwtrain_1 = [np.concatenate((train_1[0], train_1[0]), axis=1), train_1[2], train_1[1]]
cwtrain_2 = [np.concatenate((train_2[0], train_2[0]), axis=1), train_2[2], train_2[1]]
cwtrain_3 = [np.concatenate((train_3[0], train_3[0]), axis=1), train_3[2], train_3[1]]
cwtrain_4 = [np.concatenate((train_4[0], train_4[0]), axis=1), train_4[2], train_4[1]]

cwtest_0 = [np.concatenate((test_0[0], test_0[0]), axis=1), test_0[2], test_0[1]]
cwtest_1 = [np.concatenate((test_1[0], test_1[0]), axis=1), test_1[2], test_1[1]]
cwtest_2 = [np.concatenate((test_2[0], test_2[0]), axis=1), test_2[2], test_2[1]]
cwtest_3 = [np.concatenate((test_3[0], test_3[0]), axis=1), test_3[2], test_3[1]]
cwtest_4 = [np.concatenate((test_4[0], test_4[0]), axis=1), test_4[2], test_4[1]]

## Model Setup

In [None]:
import concepts_xai.methods.VAE.weak_vae as weak_vae
import concepts_xai.methods.VAE.baseVAE as base_vae
import concepts_xai.methods.VAE.losses as vae_losses
reload(vae_losses)
reload(weak_vae)

def construct_vae_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    include_norm=False,
    include_pool=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for filter_args in filter_group:
            if len(filter_args) == 2:
                filter_args = (*filter_args, 1)
            (num_filters, kernel_size, stride) = filter_args
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                strides=stride,
                padding="SAME",
                activation=None if include_norm else "relu",
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            if include_norm:
                encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                    encoder_compute_graph
                )
                encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        if include_pool:
            # Then do a max pool here to control the parameter count of the model
            # at the end of each group
            encoder_compute_graph = tf.keras.layers.MaxPooling2D(
                pool_size=max_pool_window,
                strides=max_pool_stride,
            )(
                encoder_compute_graph
            )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    return tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="encoder",
    )

def construct_vae_decoder(
    units,
    output_shape,
    latent_dims,
):
    """CNN decoder architecture used in the 'Challenging Common Assumptions in the Unsupervised Learning
       of Disentangled Representations' paper (https://arxiv.org/abs/1811.12359)

       Note: model is uncompiled
    """

    latent_inputs = tf.keras.Input(shape=(latent_dims,))
    model_out = latent_inputs
    for unit in units:
        model_out = tf.keras.layers.Dense(
            unit,
            activation='relu',
        )(model_out)
    model_out = tf.keras.layers.Reshape([4, 4, 32])(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=64,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same"
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=32,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same"
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=32,
        kernel_size=4,
        strides=2,
        activation='relu',
        padding="same",
    )(model_out)

    model_out = tf.keras.layers.Conv2DTranspose(
        filters=output_shape[-1],
        kernel_size=4,
        strides=2,
        padding="same",
        activation=None,
    )(model_out)
    model_out = tf.keras.layers.Reshape(output_shape)(model_out)

    return tf.keras.Model(
        inputs=latent_inputs,
        outputs=[model_out],
    )

def construct_wvae(
    input_shape,
    latent_dims,
    filter_groups,
    encoder_units,
    decoder_units,
    drop_prob=0.5,
    include_pool=False,
    max_pool_window=(2,2),
    max_pool_stride=2,
    learning_rate=1e-3,
    beta=1.0,
    vae_model=weak_vae.MLVaeArgmax,
):
    wvae_encoder = construct_vae_encoder(
        input_shape=input_shape,
        filter_groups=filter_groups,
        units=encoder_units,
        drop_prob=drop_prob,
        include_pool=include_pool,
        max_pool_window=max_pool_window,
        max_pool_stride=max_pool_stride,
        latent_dims=latent_dims,
    )
    wvae_decoder = construct_vae_decoder(
        output_shape=input_shape,
        units=decoder_units,
        latent_dims=latent_dims,
    )

    wvae_model = vae_model(
        encoder=wvae_encoder,
        decoder=wvae_decoder,
        loss_fn=vae_losses.bernoulli_fn_wrapper(),
        beta=beta,
    )
    wvae_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return wvae_model

## Experiment Loop

In [None]:
# import concepts_xai.evaluation.metrics.purity as purity
import concepts_xai.methods.VAE.betaVAE as beta_vae

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def wvae_experiment_loop(
    experiment_config,
    datasets,
    wvae,
    latent='',
    load_from_cache=False,
    include_all_losses=True,
    vae_model=weak_vae.MLVaeArgmax,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    if vae_model != beta_vae.BetaVAE:
        prefix = ""
        split_fn = lambda x: x[:, :x.shape[1]//2, ...] if len(x.shape) > 1 else x[:x.shape[0]//2]
    else:
        prefix = "balanced_"
        split_fn = lambda x: x
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, c_train, y_train), (x_test, c_test, y_test)) in datasets[start_ind:]:
        latent_dim = experiment_config['latent_dim']
        print("Training with latent dimensions", latent_dim, "in dataset", ds_name)

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            
            # Time to actually construct and train the WVAE
            wvae_model = construct_wvae(
                input_shape=experiment_config["input_shape"],
                latent_dims=latent_dim,
                filter_groups=experiment_config["filter_groups"],
                encoder_units=experiment_config["encoder_units"],
                decoder_units=experiment_config["decoder_units"],
                drop_prob=experiment_config['drop_prob'],
                max_pool_window=experiment_config['max_pool_window'],
                max_pool_stride=experiment_config['max_pool_stride'],
                beta=experiment_config['beta'],
                vae_model=vae_model,
            )
            
                
            print("\t\tWVAE training completed")
            
            wvae_model.encoder = load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{prefix}{ds_name}_encoder_trial_{trial}"
                )
            )

            print(f"\t\tComputing purity score...")
            c_train_pred = wvae_model.sample_from_latent_distribution(
                *wvae_model.encoder(split_fn(x_train))
            ).numpy()
            c_test_pred = wvae_model.sample_from_latent_distribution(
                *wvae_model.encoder(split_fn(x_test))
            ).numpy()
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables

## Experiments

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_mlvae_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ada_ml_vae/multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ada_ml_vae/multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################
ada_mlvae_results = wvae_experiment_loop(
    ada_mlvae_experiment_config,
    wvae="ada_ml_vae",
    latent=f"_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}",
    include_all_losses=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0", cwtrain_0, cwtest_0),
        ("multiclass_task_bin_concepts_dep_1", cwtrain_1, cwtest_1),
        ("multiclass_task_bin_concepts_dep_2", cwtrain_2, cwtest_2),
        ("multiclass_task_bin_concepts_dep_3", cwtrain_3, cwtest_3),
        ("multiclass_task_bin_concepts_dep_4", cwtrain_4, cwtest_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_mlvae_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ada_ml_vae/multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ada_ml_vae/multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################
ada_mlvae_extended_results = wvae_experiment_loop(
    ada_mlvae_extended_experiment_config,
    wvae="ada_ml_vae",
    latent=f"_latent_{2* multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}",
    include_all_losses=True,
    vae_model=weak_vae.MLVaeArgmax,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0", cwtrain_0, cwtest_0),
        ("multiclass_task_bin_concepts_dep_1", cwtrain_1, cwtest_1),
        ("multiclass_task_bin_concepts_dep_2", cwtrain_2, cwtest_2),
        ("multiclass_task_bin_concepts_dep_3", cwtrain_3, cwtest_3),
        ("multiclass_task_bin_concepts_dep_4", cwtrain_4, cwtest_4),
    ],
    load_from_cache=FROM_CACHE,
)

# Ada-GVAE Benchmarking

## Experiments

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_gvae_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ada_g_vae/multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ada_g_vae/multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################
ada_gvae_results = wvae_experiment_loop(
    ada_gvae_experiment_config,
    wvae="ada_g_vae",
    latent=f"_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}",
    load_from_cache=FROM_CACHE,
    include_all_losses=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0", cwtrain_0, cwtest_0),
        ("multiclass_task_bin_concepts_dep_1", cwtrain_1, cwtest_1),
        ("multiclass_task_bin_concepts_dep_2", cwtrain_2, cwtest_2),
        ("multiclass_task_bin_concepts_dep_3", cwtrain_3, cwtest_3),
        ("multiclass_task_bin_concepts_dep_4", cwtrain_4, cwtest_4),
    ],
)

In [None]:
reload(vae_losses)
reload(base_vae)
reload(weak_vae)

############################################################################
## Experiment config
############################################################################

ada_gvae_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    beta=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ada_g_vae/multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ada_g_vae/multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################
ada_gvae_extended_results = wvae_experiment_loop(
    ada_gvae_extended_experiment_config,
    wvae="ada_g_vae",
    latent=f"_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}",
    load_from_cache=FROM_CACHE,
    include_all_losses=True,
    vae_model=weak_vae.GroupVAEArgmax,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0", cwtrain_0, cwtest_0),
        ("multiclass_task_bin_concepts_dep_1", cwtrain_1, cwtest_1),
        ("multiclass_task_bin_concepts_dep_2", cwtrain_2, cwtest_2),
        ("multiclass_task_bin_concepts_dep_3", cwtrain_3, cwtest_3),
        ("multiclass_task_bin_concepts_dep_4", cwtrain_4, cwtest_4),
    ],
)

## Beta-VAE Benchmarking

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Experiment config
############################################################################

beta_vae_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    beta=10,
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"beta_vae/purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################
beta_vae_results = wvae_experiment_loop(
    beta_vae_experiment_config,
    wvae="beta_vae",
    latent=f"_latent_{beta_vae_experiment_config['latent_dim']}",
    include_all_losses=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Experiment config
############################################################################

beta_vae_extended_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    beta=10,
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"beta_vae/purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{10}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################
beta_vae_extended_results = wvae_experiment_loop(
    beta_vae_extended_experiment_config,
    wvae="beta_vae",
    latent=f"_latent_{2*beta_vae_extended_experiment_config['latent_dim']}",
    include_all_losses=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Experiment config
############################################################################

vae_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    beta=1,
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"beta_vae/purity_latent_{multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################
vae_results = wvae_experiment_loop(
    vae_experiment_config,
    wvae="beta_vae",
    latent=f"_latent_{vae_experiment_config['latent_dim']}",
    include_all_losses=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
import concepts_xai.methods.VAE.betaVAE as beta_vae

reload(vae_losses)
reload(base_vae)
reload(beta_vae)

############################################################################
## Experiment config
############################################################################

vae_extended_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    trials=5,
    learning_rate=1e-3,
    input_shape=multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    latent_dim=2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    beta=1,
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    include_pool=True,
    encoder_units=[64, 64],
    decoder_units=[256, 512],
    
    latent_decoder_units=[64, 64],
    num_outputs=(
        len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    predictor_max_epochs=100,
    
    drop_prob=0.0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"beta_vae/balanced_multilabel_purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"beta_vae/purity_latent_{2*multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}_beta_{1}"
    ),
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    verbosity=0,
    data_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    holdout_fraction=0.1,
    visualization_frequency=25,
    visualization_samples=6,
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################
vae_extended_results = wvae_experiment_loop(
    vae_extended_experiment_config,
    wvae="beta_vae",
    latent=f"_latent_{2*vae_extended_experiment_config['latent_dim']}",
    include_all_losses=True,
    vae_model=beta_vae.BetaVAE,
    datasets=[
        ("multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

# CCD Benchmarking

## Model Setup

In [None]:
def construct_ccd_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_act=None,  # Leaving sigmoid as used in original paper
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    # TIme to generate the latent code here
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)
    
    return tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )

In [None]:
def construct_ccd_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    latent_act=None,  #"sigmoid",  # Leaving sigmoid as used in original paper
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for (num_filters, kernel_size) in filter_group:
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                padding="SAME",
                activation=None,
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                encoder_compute_graph
            )
            encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        # Then do a max pool here to control the parameter count of the model
        # at the end of each group
        encoder_compute_graph = tf.keras.layers.MaxPooling2D(
            pool_size=max_pool_window,
            strides=max_pool_stride,
        )(
            encoder_compute_graph
        )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)
    
    # TIme to generate the latent code here
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)
    
    return tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )

def construct_ccd_decoder(units, num_outputs):
    decoder_layers = [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(
        [tf.keras.layers.Flatten()] +
        decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs if num_outputs > 2 else 1,
            activation=None,
            name="decoder_model_output",
        )
    ])

## Experiment Loop

In [None]:
import concepts_xai.methods.OCACE.topicModel as CCD
import concepts_xai.evaluation.metrics.purity as purity

############################################################################
## Experiment loop
############################################################################

def ccd_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    num_concepts = experiment_config["num_concepts"]
    res_dir = experiment_config["niching_results_dir"]
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with concepts", num_concepts, "in dataset", ds_name)
        
        channels_axis = (
            -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
            else 1
        )
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for {num_concepts} concepts")
            print("x_train.shape =", x_train.shape)
            print("y_train.shape =", y_train.shape)
            print("c_train.shape =", c_train.shape)
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                num_outputs=experiment_config["num_outputs"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_ccd_encoder(
                    input_shape=experiment_config["input_shape"],
                    filter_groups=experiment_config["encoder_filter_groups"],
                    latent_dims=experiment_config["latent_dims"],
                    units=experiment_config["encoder_units"],
                    drop_prob=experiment_config.get("drop_prob", 0.5),
                    max_pool_window=experiment_config.get("max_pool_window", (2, 2)),
                    max_pool_stride=experiment_config.get("max_pool_stride", (2, 2)),
                    latent_act=experiment_config.get("latent_act", None),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            print("\t\tModel pre-training completed")
            print("\tSerializing model")
            
            encoder = load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_encoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            decoder = load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_decoder_num_concepts_{num_concepts}_trial_{trial}"
                )
            )
            
            topic_vector = np.load(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/{ds_name}_topic_vector_num_concepts_{num_concepts}_trial_{trial}.npy"
                ),
            )
            
            # Now extract our concept vectors
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=num_concepts,
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=end_to_end_model.loss,
                top_k=experiment_config.get("top_k", 32),
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
                g_model=load_model(
                    os.path.join(
                        experiment_config["results_dir"],
                        f"models/{ds_name}_topic_g_model_num_concepts_{num_concepts}_trial_{trial}"
                    )
                ),
                initial_topic_vector=topic_vector,
            )
            
            print(f"\t\tComputing purity score...")
            topic_model.compile(
                optimizer=tf.keras.optimizers.Adam(
                    experiment_config.get("learning_rate", 1e-3),
                )
            )
            c_train_pred = topic_model.concept_scores(encoder(x_train)).numpy()
            c_test_pred = topic_model.concept_scores(encoder(x_test)).numpy()
            
            
            print("\t\tComputing niching scores...")
            # finding niches
            print("Topic model evaluation:", sklearn.metrics.accuracy_score(
                np.argmax(y_test, axis=-1),
                np.argmax(topic_model(encoder(x_test))[0], axis=-1),
            ))
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables
            

def ccd_compute_k(y, batch_size):
    _, counts = np.unique(y, return_counts=True)
    avg_class_ratio = np.mean(counts) / y.shape[0]
    return int((avg_class_ratio * batch_size) / 2)

## Experiments

In [None]:
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    topic_model_train_epochs=50,
    num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    latent_dims=10,
    threshold=0.0,
    top_k=ccd_compute_k(y=multiclass_task_bin_concepts_dep_0_complete_train[1], batch_size=32),
    lambda1=0.1,
    lambda2=0.1,
    seed=42,
    eps=1e-5,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))],
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    holdout_fraction=0.1,
    trials=5,
    verbosity=0,
    early_stop_metric="val_loss",
    early_stop_mode="min",
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################

num_concepts=multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]
ccd_results = ccd_experiment_loop(
# encoder, topic_model, x_train, y_train = ccd_experiment_loop(
    experiment_config=ccd_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

In [None]:
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_extended_experiment_config = dict(
    batch_size=64,
    max_epochs=100,
    topic_model_train_epochs=50,
    num_concepts=2*balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1],
    latent_dims=10,
    threshold=0.0,
    top_k=ccd_compute_k(y=balanced_multiclass_task_bin_concepts_dep_0_complete_train[1], batch_size=32),
    lambda1=0.1,
    lambda2=0.1,
    seed=42,
    eps=1e-5,
    learning_rate=1e-3,
    encoder_filter_groups=[
        [(8, (7, 7))],
        [(16, (5, 5))],
        [(32, (3, 3))],
        [(64, (3, 3))],
    ],
    encoder_units=[64, 64],
    decoder_units=[64, 64],
    drop_prob=0.5,
    max_pool_window=(2,2),
    pax_pool_stride=2,
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{2*balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        f"ccd/balanced_multiclass_thresh_0_num_concepts_{2*balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]}"
    ),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    holdout_fraction=0.1,
    trials=5,
    verbosity=0,
    early_stop_metric="val_loss",
    early_stop_mode="min",
    delta_beta=0.05,
)

############################################################################
## Experiment run
############################################################################

ccd_extended_results = ccd_experiment_loop(
# encoder, topic_model, x_train, y_train = ccd_experiment_loop(
    experiment_config=ccd_extended_experiment_config,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
    load_from_cache=FROM_CACHE,
)

# SENN Benchmarking

In [None]:
import concepts_xai.methods.SENN.base_senn as SENN
import concepts_xai.methods.SENN.aggregators as aggregators
reload(SENN)
reload(aggregators)


def construct_senn_coefficient_model(units, num_concepts, num_outputs):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"coefficient_model_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_concepts * num_outputs,
            activation=None,
            name="coefficient_model_output",
        ),
        tf.keras.layers.Reshape([num_outputs, num_concepts])
    ])

def construct_senn_encoder(
    input_shape,
    filter_groups,
    units,
    latent_dims,
    drop_prob=0.5,
    max_pool_window=(2,2),
    max_pool_stride=2,
    include_norm=False,
    include_pool=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    
    # Start with our convolutions
    num_convs = 0
    for filter_group in filter_groups:
        for filter_args in filter_group:
            if len(filter_args) == 2:
                filter_args = (*filter_args, 1)
            (num_filters, kernel_size, stride) = filter_args
            encoder_compute_graph = tf.keras.layers.Conv2D(
                filters=num_filters,
                kernel_size=kernel_size,
                strides=stride,
                padding="SAME",
                activation=None if include_norm else "relu",
                name=f'encoder_conv_{num_convs}',
            )(encoder_compute_graph)
            num_convs += 1
            if include_norm:
                encoder_compute_graph = tf.keras.layers.BatchNormalization()(
                    encoder_compute_graph
                )
                encoder_compute_graph = tf.keras.activations.relu(encoder_compute_graph)
        if include_pool:
            # Then do a max pool here to control the parameter count of the model
            # at the end of each group
            encoder_compute_graph = tf.keras.layers.MaxPooling2D(
                pool_size=max_pool_window,
                strides=max_pool_stride,
            )(
                encoder_compute_graph
            )
    
    # Flatten this guy
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    
    # Add a dropout if requested
    if drop_prob:
        encoder_compute_graph = tf.keras.layers.Dropout(drop_prob)(
            encoder_compute_graph
        )
    
    # Finally, include the fully connected bottleneck here
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    senn_encoder = tf.keras.Model(
        encoder_inputs,
        mean,
        name="senn_encoder",
    )
    vae_encoder = tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="vae_encoder",
    )
    return senn_encoder, vae_encoder

def construct_senn_model(
    concept_encoder,
    concept_decoder,
    coefficient_model,
    num_outputs,
    regularization_strength=0.1,
    learning_rate=1e-3,
    sparsity_strength=2e-5,
):
    def reconstruction_loss_fn(y_true, y_pred):
        return vae_losses.bernoulli_fn_wrapper()(y_true, concept_decoder(y_pred))

    senn_model = SENN.SelfExplainingNN(
        encoder_model=concept_encoder,
        coefficient_model=coefficient_model,
        aggregator_fn=(
            aggregators.multiclass_additive_aggregator if (num_outputs > 2)
            else aggregators.scalar_additive_aggregator
        ),
        task_loss_fn=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs <= 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        reconstruction_loss_fn=reconstruction_loss_fn,
        regularization_strength=regularization_strength,
        sparsity_strength=sparsity_strength,
        name="SENN",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs <= 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    senn_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return senn_model

In [None]:
import concepts_xai.evaluation.metrics.purity as purity

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def get_argmax_concept_explanations(preds, class_theta_scores):
    inds = np.argmax(preds, axis=-1)
    result = np.take_along_axis(
        class_theta_scores,
        np.expand_dims(np.expand_dims(inds, axis=-1), axis=-1),
        axis=1,
    )
    return np.squeeze(result, axis=1)

def senn_experiment_loop(
    experiment_config,
    datasets,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    num_concepts = experiment_config["num_concepts"]
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for (ds_name, (x_train, y_train, c_train), (x_test, y_test, c_test)) in datasets[start_ind:]:
        print("Training with concepts", num_concepts, "in dataset", ds_name)
        
        channels_axis = (
            -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
            else 1
        )
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for {num_concepts} concepts")
            concept_encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_encoder_{ds_name}_trial_{trial}"
                )
            )
            concept_decoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_decoder_{ds_name}_trial_{trial}"
                )
            )
            coefficient_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/coefficient_model_{ds_name}_trial_{trial}"
                )
            )
            senn_model = construct_senn_model(
                concept_encoder=concept_encoder,
                concept_decoder=concept_decoder,
                coefficient_model=coefficient_model,
                num_outputs=experiment_config["num_outputs"],
                regularization_strength=experiment_config.get("regularization_strength", 0.1),
                learning_rate=experiment_config.get("learning_rate", 1e-3),
                sparsity_strength=experiment_config.get("sparsity_strength", 2e-5),
            )
            
            x_train_preds, (_, x_train_theta_class_scores) = senn_model(x_train)
            c_train_pred = get_argmax_concept_explanations(
                x_train_preds.numpy(),
                x_train_theta_class_scores.numpy(),
            )
            
            x_test_preds, (_, x_test_theta_class_scores) = senn_model(x_test)
            c_test_pred = get_argmax_concept_explanations(
                x_test_preds.numpy(),
                x_test_theta_class_scores.numpy(),
            )
            
            print("\t\tComputing niching scores...")
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables

In [None]:
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_experiment_config = dict(
    max_epochs=50,
    pretrain_autoencoder_epochs=50,
    predictor_max_epochs=100,
    batch_size=64,
    trials=5,
    learning_rate=1e-3,
    
    num_concepts=(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
    encoder_units=[64, 64],
    include_pool=True,
    decoder_units=[256, 512],
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    drop_prob=0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    concept_cardinality=[2 for _ in range(balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1])],
    results_dir=os.path.join(RESULTS_DIR, "senn/dependency_multiclass"),
    niching_results_dir=os.path.join(NICHING_RESULTS_DIR, "senn/dependency_multiclass"),
    verbosity=0,

    holdout_fraction=0.1,
    patience=float("inf"),
    early_stop_metric="val_loss",
    early_stop_mode="min",
    min_delta=1e-5,
    delta_beta=0.05,
)

# Generate the experiment directory if it does not exist already
Path(senn_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
senn_figure_dir = os.path.join(senn_experiment_config["results_dir"], "figures")
Path(senn_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

senn_results = senn_experiment_loop(
    senn_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
)

In [None]:
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_extended_experiment_config = dict(
    max_epochs=50,
    pretrain_autoencoder_epochs=50,
    predictor_max_epochs=100,
    batch_size=64,
    trials=5,
    learning_rate=1e-3,
    
    num_concepts=(2 * balanced_multiclass_task_bin_concepts_dep_0_complete_train[2].shape[-1]),
    input_shape=balanced_multiclass_task_bin_concepts_dep_0_complete_train[0].shape[1:],
    num_outputs=(
        len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) if len(set(balanced_multiclass_task_bin_concepts_dep_0_complete_train[1])) > 2
        else 1
    ),
    
    filter_groups=[
       [(8, (7, 7), 1)],
       [(16, (5, 5), 1)],
       [(32, (3, 3), 1)],
       [(64, (3, 3), 1)],
    ],
#     encoder_units=[128, 64, 64, 64, 32],
    encoder_units=[64, 64],
    include_pool=True,
    decoder_units=[256, 512],
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    drop_prob=0,
    max_pool_window=(2,2),
    max_pool_stride=2,
    
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    results_dir=os.path.join(RESULTS_DIR, "senn/dependency_multiclass_extended"),
    niching_results_dir=os.path.join(NICHING_RESULTS_DIR, "senn/dependency_multiclass_extended"),
    verbosity=0,

    holdout_fraction=0.1,
    patience=float("inf"),
    early_stop_metric="val_loss",
    early_stop_mode="min",
    min_delta=1e-5,
    delta_beta=0.05,
)

# Generate the experiment directory if it does not exist already
Path(senn_extended_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
senn_extended_figure_dir = os.path.join(senn_extended_experiment_config["results_dir"], "figures")
Path(senn_extended_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

senn_extended_results = senn_experiment_loop(
    senn_extended_experiment_config,
    load_from_cache=True,
    datasets=[
        ("balanced_multiclass_task_bin_concepts_dep_0_complete", train_0, test_0),
        ("balanced_multiclass_task_bin_concepts_dep_1_complete", train_1, test_1),
        ("balanced_multiclass_task_bin_concepts_dep_2_complete", train_2, test_2),
        ("balanced_multiclass_task_bin_concepts_dep_3_complete", train_3, test_3),
        ("balanced_multiclass_task_bin_concepts_dep_4_complete", train_4, test_4),
    ],
)