In [None]:
%load_ext autoreload
%matplotlib inline

# Purity Correlation Experiment

## Setup

In [None]:
import matplotlib
import concepts_xai
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
import scipy
import utils
import model_utils

In [None]:
################################################################################
## Set seeds up for reproducibility
################################################################################

utils.reseed(87)

In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

LATEX_SYMBOL = ""
NUM_TRIALS = 5
LOAD_FROM_CACHE = True
RESULTS_DIR = "results/toy_tabular"
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
rc('text', usetex=(LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')


def bold_text(x):
    if LATEX_SYMBOL == "$":
        return r"$\textbf{" + x + "}$"
    return x

## Dataset Construction

In [None]:
############################################################################
## Generate Data
############################################################################
def produce_data(samples, cov=0.0, num_concepts=3):
    x = np.zeros((samples, 7), dtype=np.float32)
    y = np.zeros((samples,), dtype=np.float32)
    
    # Sample the x, y, and z variables
    vars = np.random.multivariate_normal(
        mean=[0, 0, 0],
        cov=[
            [1, cov, cov],
            [cov, 1, cov],
            [cov, cov, 1],
        ],
        size=(samples,),
    )
    x_vars = vars[:, :1]
    y_vars = vars[:, 1:2]
    z_vars = vars[:, 2:]
    
    # The features are just non-linear functions applied to each
    # variable
    features = [
        np.sin(x_vars) + x_vars,
        np.cos(x_vars) + x_vars,
        np.sin(y_vars) + y_vars,
        np.cos(y_vars) + y_vars,
        np.sin(z_vars) + z_vars,
        np.cos(z_vars) + z_vars,
        x_vars**2 + y_vars**2 + z_vars**2,
    ]
    features = np.stack(features, axis=1)

    # The concepts just check if the variables are positive
    x_pos = (x_vars > 0).astype(np.int32)
    y_pos = (y_vars > 0).astype(np.int32)
    z_pos = (z_vars > 0).astype(np.int32)
    concepts = np.squeeze(
        np.stack([x_pos, y_pos, z_pos][:num_concepts], axis=1)
    ).astype(np.float32)
    
    # The labels are generated by checking if at least two of the
    # latent concepts are greater than zero
    labels = x_pos + y_pos + z_pos
    labels = (labels > 1).astype(np.int32)
    
    # And that's it buds
    return features, labels, concepts

## Model Construction

In [None]:
# Construct the encoder model
def construct_encoder(
    input_shape,
    units,
    num_concepts,
    end_activation="sigmoid",
    latent_dims=0,
    output_logits=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    # And finally map this to the number of concepts we have in our set
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    if latent_dims:
        bypass = tf.keras.layers.Dense(
            latent_dims,
            activation=end_activation,
            name="encoder_bypass_channel",
        )(encoder_compute_graph)
    else:
        bypass = None
    encoder_compute_graph = tf.keras.layers.Dense(
        num_concepts,
        activation=None if output_logits else "sigmoid",
        name="encoder_concept_outputs",
    )(encoder_compute_graph)

    # Now time to collapse all the concepts again back into a single vector
    encoder_model = tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph if bypass is None else [encoder_compute_graph, bypass],
        name="encoder",
    )
    return encoder_model


In [None]:
############################################################################
## Build concepts-to-labels model
############################################################################

def construct_decoder(units, num_outputs=1,):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    decoder_model = tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs,
            activation=None,
            name="decoder_model_output",
        )
    ])
    return decoder_model


# CBM Benchmark

In [None]:
import concepts_xai.methods.CBM.CBModel as CBM
reload(CBM)

############################################################################
## Build CBM
############################################################################

def construct_cbm(
    encoder,
    decoder,
    latent_dims=0,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_output_logits=False,
):
    model_factory = CBM.BypassJointCBM if latent_dims else CBM.JointConceptBottleneckModel
    cbm_model = model_factory(
        encoder=encoder,
        decoder=decoder,
        task_loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True)
        ),
        name="joint_cbm",
        metrics=[tf.keras.metrics.BinaryAccuracy()],
        alpha=alpha,
        pass_concept_logits=encoder_output_logits,
    )

    ############################################################################
    ## Compile CBM Model
    ############################################################################

    cbm_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return cbm_model

# Construct the complete model
def construct_end_to_end_model(
    encoder,
    decoder,
    input_shape,
    learning_rate=1e-3,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    encoder_out = encoder(model_inputs)
    if isinstance(encoder_out, list):
        encoder_out = tf.concat(encoder_out, axis=-1)
    model_compute_graph = decoder(encoder_out)
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["binary_accuracy"],
    )
    return model, encoder, decoder


## Purity experiment with concept logits

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def cbm_experiment_loop(experiment_config, load_from_cache=False):
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_accuracies=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
        correlation_matrices=[],
    )
    utils.reseed(87)
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Let's save our config here either way
    utils.serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        concept_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []
        corr_mats = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            
            # Then proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=experiment_config["encoder_units"],
                    num_concepts=experiment_config["num_concepts"],
                    end_activation="sigmoid",
                    latent_dims=experiment_config["latent_dims"],
                    output_logits=experiment_config.get("encoder_output_logits", False),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                end_to_end_model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tModel pre-training completed")
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor="val_concept_accuracy",
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode='max',
            )
            if experiment_config["warmup_epochs"]:
                print("\tWarmup training...")
                cbm_model.fit(
                    x=x_train,
                    y=(
                        y_train,
                        y_train_concepts[:, :experiment_config["num_concepts"]],
                    ),
                    epochs=experiment_config["warmup_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tWarmup training completed")


            print("\tCBM training...")
            cbm_model.fit(
                x=x_train,
                y=(
                    y_train,
                    y_train_concepts[:, :experiment_config["num_concepts"]],
                ),
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tCBM training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_cov_{cov:.1f}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_cov_{cov:.1f}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            
            test_result = cbm_model.evaluate(
                x_test,
                (
                    y_test,
                    y_test_concepts[:, :experiment_config["num_concepts"]],
                ),
                verbose=0,
                return_dict=True,
            )
            task_accs.append(test_result['binary_accuracy'])
            concept_accs.append(test_result['concept_accuracy'])
            aucs.append(sklearn.metrics.roc_auc_score(
                y_test,
                cbm_model.predict(x_test)[0],
            ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept accuracy = {concept_accs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            
            print("\tComputing linear correlations...")
            soft_acts = (
                np.concatenate(cbm_model.encoder(x_test), axis=-1)
                if experiment_config["latent_dims"] else encoder(x_test).numpy()
            )
            corr_mat = np.ones((soft_acts.shape[-1], y_test_concepts.shape[-1]))
            for c in range(corr_mat.shape[0]):
                for l in range(corr_mat.shape[1]):
                    corr_mat[c][l] = np.corrcoef(
                        soft_acts[:, c],
                        y_test_concepts[:, l],
                    )[0, 1]

            corr_mats.append(corr_mat)

            print(f"\t\tComputing OIS...")
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                output_matrices=True,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
        
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["data_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_acc_mean, concept_acc_std = np.mean(concept_accs), np.std(concept_accs)
        experiment_variables["concept_accuracies"].append((concept_acc_mean, concept_acc_std))
        print(f"\tTest concept accuracy: {concept_acc_mean:.4f} ± {concept_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")
        
        
        corr_mats = np.stack(corr_mats, axis=0)
        corr_mat_mean = np.mean(corr_mats, axis=0)
        corr_mat_std = np.std(corr_mats, axis=0)
        print("\tCorrelation matrix:")
        for i in range(corr_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(corr_mat_mean.shape[1]):
                line += f'{corr_mat_mean[i, j]:.4f} ± {corr_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["correlation_matrices"].append((corr_mat_mean, corr_mat_std))

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tOIS: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cbm_bottleneck_predict_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )

            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_cov_{cov:.1f}_trial_{trial}"
                )
            )
            
            predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining model")
            train_codes = encoder(x_train)
            if isinstance(train_codes, list):
                train_codes = np.concatenate(list(map(lambda x: x.numpy(), train_codes)), axis=-1)
            else:
                train_codes = train_codes.numpy()
            test_codes = encoder(x_test)
            if isinstance(test_codes, list):
                test_codes = np.concatenate(list(map(lambda x: x.numpy(), test_codes)), axis=-1)
            else:
                test_codes = test_codes.numpy()
            predictive_decoder.fit(
                x=train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\tEvaluating model")
            test_result = predictive_decoder.evaluate(
                test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    predictive_decoder.predict(test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                print(np.sum(preds[:100, :], axis=-1))
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    predictive_decoder.predict(test_codes),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cbm_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_cov_{cov:.1f}_trial_{trial}"
                )
            )
            

            train_codes = encoder(x_train)
            if isinstance(train_codes, list):
                train_codes = np.concatenate(list(map(lambda x: x.numpy(), train_codes)), axis=-1)
            else:
                train_codes = train_codes.numpy()
            test_codes = encoder(x_test)
            if isinstance(test_codes, list):
                test_codes = np.concatenate(list(map(lambda x: x.numpy(), test_codes)), axis=-1)
            else:
                test_codes = test_codes.numpy()
            
            current_accuracies = []
            current_aucs = []
            for concept_idx in range(experiment_config["num_concepts"]):
                print("\tTraining model for concept", concept_idx)
                predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )
                predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )
                predictive_decoder.fit(
                    x=train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["concept_predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tEvaluating model")
                test_result = predictive_decoder.evaluate(
                    test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accuracies.append(test_result['binary_accuracy'])
                
                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    predictive_decoder.predict(test_codes),
                ))
                print(
                f"\t\t\tAverage test concept accuracy = {current_accuracies[-1]:.4f}, "
                f"average test concept AUC = {current_aucs[-1]:.4f}"
            )
            
            avg_concept_accs.append(np.mean(current_accuracies))
            avg_concept_aucs.append(np.mean(current_aucs))
            print(
                f"\t\tAverage test concept accuracy = {avg_concept_accs[-1]:.4f}, "
                f"average test concept AUC = {avg_concept_aucs[-1]:.4f}"
            )
            print("\t\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest task accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest task AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

In [None]:
reload(CBM)

############################################################################
## Experiment config
############################################################################

from_logits_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=NUM_TRIALS,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    
    latent_decoder_units=[128, 64],
    predictor_max_epochs=300,
    concept_predictor_max_epochs=300,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cbm/from_logits"
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    covariances=np.arange(0, 1, 0.1),
    verbosity=0,
    encoder_output_logits=True,
)

# Generate the experiment directory if it does not exist already
Path(from_logits_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
from_logits_figure_dir = os.path.join(from_logits_experiment_config["results_dir"], "figures")
Path(from_logits_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

from_logits_results = cbm_experiment_loop(
    from_logits_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", from_logits_results["task_accuracies"])
print("concept_accuracies:", from_logits_results["concept_accuracies"])
print("task_aucs:", from_logits_results["task_aucs"])

In [None]:
from_logits_results.update(cbm_bottleneck_predict_experiment_loop(
    from_logits_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

In [None]:
from_logits_results.update(cbm_bottleneck_concept_predict_experiment_loop(
    from_logits_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

In [None]:
reload(CBM)

############################################################################
## Experiment config
############################################################################

base_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=NUM_TRIALS,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    
    latent_decoder_units=[128, 64],
    predictor_max_epochs=300,
    concept_predictor_max_epochs=300,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cbm/base"
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    covariances=np.arange(0, 1, 0.1),
    verbosity=0,
    encoder_output_logits=False,
)

# Generate the experiment directory if it does not exist already
Path(base_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
base_figure_dir = os.path.join(base_experiment_config["results_dir"], "figures")
Path(base_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

base_results = cbm_experiment_loop(
    base_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", base_results["task_accuracies"])
print("concept_accuracies:", base_results["concept_accuracies"])
print("task_aucs:", base_results["task_aucs"])

In [None]:
base_results.update(cbm_bottleneck_predict_experiment_loop(
    base_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

In [None]:
base_results.update(cbm_bottleneck_concept_predict_experiment_loop(
    base_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

## Capacity Experiment

In [None]:
def cbm_capacity_experiment_loop(experiment_config, load_from_cache=False):
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_accuracies=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    utils.reseed(87)
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["model_units"]):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Let's save our config here either way
    utils.serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for units in experiment_config["model_units"][start_ind:]:
        print("Training with units:", [units, units//2])
        task_accs = []
        concept_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=0,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=0,
                num_concepts=experiment_config["data_concepts"],
            )
            
            # Then proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=[units, units//2],
                    num_concepts=experiment_config["num_concepts"],
                    end_activation="sigmoid",
                    latent_dims=experiment_config["latent_dims"],
                ),
                decoder=construct_decoder(
                    units=[units, units//2],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                end_to_end_model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tModel pre-training completed")
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor="val_concept_accuracy",
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode='max',
            )
            if experiment_config["warmup_epochs"]:
                print("\tWarmup training...")
                cbm_model.fit(
                    x=x_train,
                    y=(
                        y_train,
                        y_train_concepts[:, :experiment_config["num_concepts"]],
                    ),
                    epochs=experiment_config["warmup_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tWarmup training completed")


            print("\tCBM training...")
            cbm_model.fit(
                x=x_train,
                y=(
                    y_train,
                    y_train_concepts[:, :experiment_config["num_concepts"]],
                ),
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tCBM training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_capacity_{units}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_capacity_{units}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = cbm_model.evaluate(
                x_test,
                (
                    y_test,
                    y_test_concepts[:, :experiment_config["num_concepts"]],
                ),
                verbose=0,
                return_dict=True,
            )
            task_accs.append(test_result['binary_accuracy'])
            concept_accs.append(test_result['concept_accuracy'])
            aucs.append(sklearn.metrics.roc_auc_score(
                y_test,
                cbm_model.predict(x_test)[0],
            ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept accuracy = {concept_accs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            soft_acts = (
                np.concatenate(cbm_model.encoder(x_test), axis=-1)
                if experiment_config["latent_dims"] else encoder(x_test).numpy()
            )
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                output_matrices=True,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
        
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["data_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_acc_mean, concept_acc_std = np.mean(concept_accs), np.std(concept_accs)
        experiment_variables["concept_accuracies"].append((concept_acc_mean, concept_acc_std))
        print(f"\tTest concept accuracy: {concept_acc_mean:.4f} ± {concept_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

In [None]:
############################################################################
## Experiment config
############################################################################

capacity_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=NUM_TRIALS,
    alpha=0.1,
    learning_rate=1e-3,
    model_units=[256, 128, 64, 32, 16, 8, 4],
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cbm/capacity_purity",
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,  # But we still use three concepts in the data
)

# Generate the experiment directory if it does not exist already
Path(capacity_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
capacity_figure_dir = os.path.join(capacity_experiment_config["results_dir"], "figures")
Path(capacity_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

capacity_results = cbm_capacity_experiment_loop(
    capacity_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", capacity_results["task_accuracies"])
print("concept_accuracies:", capacity_results["concept_accuracies"])
print("task_aucs:", capacity_results["task_aucs"])

# Mixed Capacity Experiments

In [None]:
def cbm_mixed_capacity_experiment_loop(experiment_config, load_from_cache=False):
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_accuracies=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    utils.reseed(87)
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["model_units"]):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Let's save our config here either way
    utils.serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for units in experiment_config["model_units"][start_ind:]:
        print("Training with units:", [units, units//2])
        task_accs = []
        concept_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=0,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=0,
                num_concepts=experiment_config["data_concepts"],
            )
            
            # Then proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            encoder_units = (
                [units, units//2] if experiment_config["encoder_experiment"]
                else experiment_config["encoder_units"]
            )
            decoder_units = (
                [units, units//2] if (not experiment_config["encoder_experiment"])
                else experiment_config["decoder_units"]
            )
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=encoder_units,
                    num_concepts=experiment_config["num_concepts"],
                    end_activation="sigmoid",
                    latent_dims=experiment_config["latent_dims"],
                    output_logits=experiment_config.get("encoder_output_logits", False),
                ),
                decoder=construct_decoder(
                    units=decoder_units,
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                end_to_end_model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tModel pre-training completed")
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor="val_concept_accuracy",
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode='max',
            )
            if experiment_config["warmup_epochs"]:
                print("\tWarmup training...")
                cbm_model.fit(
                    x=x_train,
                    y=(
                        y_train,
                        y_train_concepts[:, :experiment_config["num_concepts"]],
                    ),
                    epochs=experiment_config["warmup_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tWarmup training completed")


            print("\tCBM training...")
            cbm_model.fit(
                x=x_train,
                y=(
                    y_train,
                    y_train_concepts[:, :experiment_config["num_concepts"]],
                ),
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tCBM training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_capacity_{encoder_units[0]}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_capacity_{decoder_units[0]}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = cbm_model.evaluate(
                x_test,
                (
                    y_test,
                    y_test_concepts[:, :experiment_config["num_concepts"]],
                ),
                verbose=0,
                return_dict=True,
            )
            task_accs.append(test_result['binary_accuracy'])
            concept_accs.append(test_result['concept_accuracy'])
            aucs.append(sklearn.metrics.roc_auc_score(
                y_test,
                cbm_model.predict(x_test)[0],
            ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept accuracy = {concept_accs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            soft_acts = (
                np.concatenate(cbm_model.encoder(x_test), axis=-1)
                if experiment_config["latent_dims"] else encoder(x_test).numpy()
            )
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                output_matrices=True,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
        
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=y_test_concepts,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["data_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_acc_mean, concept_acc_std = np.mean(concept_accs), np.std(concept_accs)
        experiment_variables["concept_accuracies"].append((concept_acc_mean, concept_acc_std))
        print(f"\tTest concept accuracy: {concept_acc_mean:.4f} ± {concept_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

In [None]:
############################################################################
## Experiment config
############################################################################

encoder_capacity_logits_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=NUM_TRIALS,
    alpha=0.1,
    learning_rate=1e-3,
    model_units=[256, 128, 64, 32, 16, 8, 4],
    decoder_units=[128, 64],
    encoder_experiment=True,
    
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cbm/encoder_capacity_logits_purity",
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,  # But we still use three concepts in the data
    encoder_output_logits=True,
)

# Generate the experiment directory if it does not exist already
Path(encoder_capacity_logits_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
encoder_capacity_logits_figure_dir = os.path.join(encoder_capacity_logits_experiment_config["results_dir"], "figures")
Path(encoder_capacity_logits_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

encoder_capacity_logits_results = cbm_mixed_capacity_experiment_loop(
    encoder_capacity_logits_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", encoder_capacity_logits_results["task_accuracies"])
print("concept_accuracies:", encoder_capacity_logits_results["concept_accuracies"])
print("task_aucs:", encoder_capacity_logits_results["task_aucs"])

In [None]:
############################################################################
## Experiment config
############################################################################

decoder_capacity_logits_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=NUM_TRIALS,
    alpha=0.1,
    learning_rate=1e-3,
    model_units=[256, 128, 64, 32, 16, 8, 4],
    encoder_units=[128, 64],
    encoder_experiment=False,
    
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cbm/decoder_capacity_logits_purity",
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,  # But we still use three concepts in the data
    encoder_output_logits=True,
)

# Generate the experiment directory if it does not exist already
Path(decoder_capacity_logits_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
decoder_capacity_logits_figure_dir = os.path.join(decoder_capacity_logits_experiment_config["results_dir"], "figures")
Path(decoder_capacity_logits_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

decoder_capacity_logits_results = cbm_mixed_capacity_experiment_loop(
    decoder_capacity_logits_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", decoder_capacity_logits_results["task_accuracies"])
print("concept_accuracies:", decoder_capacity_logits_results["concept_accuracies"])
print("task_aucs:", decoder_capacity_logits_results["task_aucs"])

# Concept Whitening Benchmark

In [None]:
import concepts_xai.methods.CW.CWLayer as CW

def construct_cw_model(
    input_shape,
    encoder,
    decoder,
    learning_rate=1e-3,
    activation=tf.keras.activations.relu,
    activation_mode='max_pool_mean',
):
    model_inputs = tf.keras.Input(shape=input_shape)
    cw_layer = CW.ConceptWhiteningLayer(
        activation_mode=activation_mode,
    )
    cw_model = tf.keras.Model(
        model_inputs,
        cw_layer(encoder(model_inputs)),
        name="cw_model",
    )
    
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        decoder(activation(cw_layer(encoder(model_inputs)))),
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["binary_accuracy"],
    )
    return model, cw_model

def channels_corr_mat(outputs):
    if len(outputs.shape) == 2:
        outputs = np.expand_dims(
            np.expand_dims(outputs, axis=1),
            axis=1,
        )
    # Change (N, H, W, C) to (C, N, H, W)
    outputs = np.transpose(outputs, [3, 0, 1, 2])
    # Change (C, N, H, W) to (C, NxHxW)
    cnhw_shape = outputs.shape
    outputs = np.transpose(np.reshape(outputs, [cnhw_shape[0], -1]))
    outputs -= np.mean(outputs, axis=0, keepdims=True)
    outputs = outputs / np.std(outputs, axis=0, keepdims=True)
    return np.dot(outputs.transpose(), outputs) / outputs.shape[0]

In [None]:
import concepts_xai.evaluation.metrics.leakage as leakage

def cw_experiment_loop(experiment_config, load_from_cache=False):
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        concept_aucs=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
        similarity_ratio_matrices=[],
        correlation_matrices=[],
    )
    utils.reseed(87)
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Let's save our config here either way
    utils.serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        c_aucs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []
        similarities = []
        correlations = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")

            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            x_true_inds = (y_train_concepts[:, 0] == 1)
            y_true_inds = (y_train_concepts[:, 1] == 1)
            z_true_inds = (y_train_concepts[:, 2] == 1)
            x_group_inds = np.logical_and(
                x_true_inds,
                np.logical_and(
                    np.logical_not(y_true_inds),
                    np.logical_not(z_true_inds),
                )
            )
            y_group_inds = np.logical_and(
                y_true_inds,
                np.logical_and(
                    np.logical_not(x_true_inds),
                    np.logical_not(z_true_inds),
                )
            )
            z_group_inds = np.logical_and(
                z_true_inds,
                np.logical_and(
                    np.logical_not(x_true_inds),
                    np.logical_not(y_true_inds),
                )
            )
            exclusive_concept_groups = [
                x_train[x_group_inds, :],
                x_train[y_group_inds, :],
                x_train[z_group_inds, :],
            ][:experiment_config["data_concepts"]]
            
            if not experiment_config.get("exclusive_concepts", False):
                x_group_inds = x_true_inds
                y_group_inds = y_true_inds
                z_group_inds = z_true_inds
            concept_groups = [
                x_train[x_group_inds, :],
                x_train[y_group_inds, :],
                x_train[z_group_inds, :],
            ][:experiment_config["data_concepts"]]
            
            
            
            # Construct our CW model
            encoder = construct_encoder(
                input_shape=experiment_config["input_shape"],
                units=experiment_config["encoder_units"],
                num_concepts=experiment_config["num_concepts"],
                end_activation=None,
                latent_dims=experiment_config["latent_dims"],
            )
            decoder = construct_decoder(
                units=experiment_config["decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            model, cw_model = construct_cw_model(
                input_shape=experiment_config["input_shape"],
                encoder=encoder,
                decoder=decoder,
                learning_rate=experiment_config["learning_rate"],
                activation=tf.keras.activations.relu,
                activation_mode=experiment_config['activation_mode'],
            )
            
            # First do some pretraining for warming up the estimates if needed
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                fig, ax = plt.subplots(1, figsize=(8, 6))
                similarity_ratio = oracle.concept_similarity_matrix(
                    concept_representations=list(map(
                        lambda x: cw_model(x).numpy(),
                        concept_groups
                    )),
                    compute_ratios=True,
                )
                im, cbar = utils.heatmap(
                    similarity_ratio,
                    [f"$c_{i}$" for i in range(len(concept_groups))],
                    [f"$c_{i}$" for i in range(len(concept_groups))],
                    ax=ax,
                    cmap="magma",
                    cbarlabel=f"Similarity Ratio",
                    vmin=0,
                    vmax=1,
                )
                texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
                fig.tight_layout()

                fig.suptitle(f"Baseline Concept Axis Separability", fontsize=25)
                fig.subplots_adjust(top=0.85)
                plt.show()
                print("\t\tModel pre-training completed")
            
            # Set up the dataset in a nice usable form for unrolling the training
            # loop
            main_dataset_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
            main_dataset_loader = main_dataset_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )
            
            min_size = min(list(map(lambda x: x.shape[0], concept_groups)))
            print("Minimum size is", min_size, "given concept datasets", list(map(lambda x: x.shape[0], concept_groups)))
            concept_groups = list(map(lambda x: x[:min_size, :], concept_groups))
            concept_group_loader = tf.data.Dataset.from_tensor_slices(tuple(concept_groups))
            concept_group_loader = concept_group_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )

            @tf.function
            def _train_step(model, x_batch_train, y_batch_train):
                # Update the other model parameters
                with tf.GradientTape() as tape:
                    logits = model(x_batch_train, training=True)
                    loss_value = model.loss(y_batch_train, logits)

                grads = tape.gradient(loss_value, model.trainable_weights)
                model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
                return loss_value
            
            total_steps = 0
            for epoch in range(experiment_config["max_epochs"]):
                for current_step, (x_batch_train, y_batch_train) in enumerate(main_dataset_loader):
                    print(
                        f'Epoch {epoch + 1} and step {current_step}/{int(np.ceil(x_train.shape[0] / experiment_config["batch_size"]))}         ',
                        end="\r",
                    )
                    # Need to update the rotation matrix
                    if (total_steps + 1) % experiment_config["cw_train_freq"] == 0:
                        for _ in range(experiment_config.get("cw_train_iterations", 1)):
                            cw_batch_steps = 0
                            for concept_groups_batch in concept_group_loader:
                                if cw_batch_steps > experiment_config.get("cw_train_batch_steps", float("inf")):
                                    break
                                model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                                    concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                                )
                                cw_batch_steps += 1
                    if experiment_config.get("concept_auc_freq"):
                        if (total_steps % experiment_config["concept_auc_freq"]) == 0:
                            concept_aucs = leakage.compute_concept_aucs(
                                cw_model=model,
                                encoder=encoder,
                                cw_layer=experiment_config["cw_layer"],
                                x_test=x_test,
                                c_test=y_test_concepts,
                                num_concepts=experiment_config["num_concepts"],
                            )
                            print(
                                f'Concept AUC at step {total_steps}:',
                                concept_aucs
                            )
                    _train_step(model, x_batch_train, y_batch_train)
                    total_steps += 1
            
            if experiment_config.get("post_cw_train_epochs"):
                for post_epoch in range(experiment_config.get("post_cw_train_epochs", 0)):
                    cw_batch_steps = 0
                    steps_in_batch = len(concept_group_loader)
                    for concept_groups_batch in concept_group_loader:
                        print(
                            f'Post epoch {post_epoch + 1} and step {cw_batch_steps}/{steps_in_batch}         ',
                            end="\r",
                        )
                        model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                            concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                        )
                        cw_batch_steps += 1
            
            print("\t\tCW training completed")
            print("\tSerializing model")
            model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_cov_{cov:.1f}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            test_result = model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(test_result['binary_accuracy'])
            c_aucs.append(leakage.compute_concept_aucs(
                cw_model=model,
                encoder=encoder,
                cw_layer=experiment_config["cw_layer"],
                x_test=x_test,
                c_test=y_test_concepts,
                num_concepts=experiment_config["num_concepts"],
            ))
            
            aucs.append(sklearn.metrics.roc_auc_score(
                y_test,
                model.predict(x_test),
            ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"test concept AUCs = {c_aucs[-1]}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            
            
            print("\t\tComputing purity score...")
            concept_scores = cw_model.layers[-1].concept_scores(
                encoder(x_test),
                aggregator=experiment_config['aggregator'],
            ).numpy()
            purity_score, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=y_test_concepts,
                output_matrices=True,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=y_test_concepts,
                oracle_matrix=construct_trivial_auc_mat(
                    experiment_config["data_concepts"]
                ),
                purity_matrix=purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            
            print("\t\tComputing similarity ratios...")
            similarity_ratio = oracle.concept_similarity_matrix(
                concept_representations=list(map(
                    lambda x: cw_model(x).numpy(), #[:, :len(concept_groups)],
                    concept_groups, #exclusive_concept_groups,  #concept_groups
                )),
                compute_ratios=True,
            )
            fig, ax = plt.subplots(1, figsize=(8, 6))
            im, cbar = utils.heatmap(
                similarity_ratio,
                [f"$x_+$", f"$y_+$", f"$z_+$"][:experiment_config["num_concepts"]],
                [f"$x_+$", f"$y_+$", f"$z_+$"][:experiment_config["num_concepts"]],
                ax=ax,
                cmap="magma",
                cbarlabel=f"Similarity Ratio",
                vmin=0,
                vmax=1,
            )
            texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
            fig.tight_layout()

            fig.suptitle(f"Concept Axis Separability", fontsize=25)
            fig.subplots_adjust(top=0.85)
            plt.show()
            similarities.append(similarity_ratio)
            
            # Compute correlation matrices
            print("\t\tComputing correlation matrix...")
            corr_mat = channels_corr_mat(cw_model(x_test).numpy())
            correlations.append(corr_mat)
            fig, ax = plt.subplots(1, figsize=(8, 6))
            im, cbar = utils.heatmap(
                np.abs(corr_mat),
                [f"$f_{i}$" for i in range(corr_mat.shape[-1])],
                [f"$f_{i}$" for i in range(corr_mat.shape[-1])],
                ax=ax,
                cmap="magma",
                cbarlabel=f"Correlation Coef",
                vmin=0,
                vmax=1,
            )
            texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
            fig.tight_layout()

            fig.suptitle(f"Latent Dimension Correlation", fontsize=25)
            fig.subplots_adjust(top=0.85)
            plt.show()
            
            

            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        concept_aucs_mean = np.mean(np.stack(c_aucs, axis=0), axis=0)
        concept_aucs_std = np.std(np.stack(c_aucs, axis=0), axis=0)
        experiment_variables["concept_aucs"].append((concept_aucs_mean, concept_aucs_std))
        print(f"\tConcept AUCS:")
        line = "\t\t"
        for i in range(concept_aucs_mean.shape[0]):
            line += f'{concept_aucs_mean[i]:.4f} ± {concept_aucs_std[i]:.4f}    '
        print(line)


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))
        
        similarities = np.stack(similarities, axis=0)
        similarities_mean = np.mean(similarities, axis=0)
        similarities_std = np.std(similarities, axis=0)
        print("\tSimilarity ratio matrix:")
        for i in range(similarities_mean.shape[0]):
            line = "\t\t"
            for j in range(similarities_mean.shape[1]):
                line += f'{similarities_mean[i, j]:.4f} ± {similarities_std[i, j]:.4f}    '
            print(line)
        
        fig, ax = plt.subplots(1, figsize=(8, 6))
        im, cbar = utils.heatmap(
            similarities_mean,
            [f"$x_+$", f"$y_+$", f"$z_+$"][:experiment_config["num_concepts"]],
            [f"$x_+$", f"$y_+$", f"$z_+$"][:experiment_config["num_concepts"]],
            ax=ax,
            cmap="magma",
            cbarlabel=f"Similarity Ratio",
            vmin=0,
            vmax=1,
        )
        texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
        fig.tight_layout()

        fig.suptitle(f"Mean Concept Axis Separability", fontsize=25)
        fig.subplots_adjust(top=0.85)
        plt.show()
        
        experiment_variables["similarity_ratio_matrices"].append(
            (similarities_mean, similarities_std)
        )
        
        
        correlations = np.stack(correlations, axis=0)
        correlations_mean = np.mean(correlations, axis=0)
        correlations_std = np.std(correlations, axis=0)
        print("\tCorrelation ratio matrix:")
        for i in range(correlations_mean.shape[0]):
            line = "\t\t"
            for j in range(correlations_mean.shape[1]):
                line += f'{correlations_mean[i, j]:.4f} ± {correlations_std[i, j]:.4f}    '
            print(line)
        
        fig, ax = plt.subplots(1, figsize=(8, 6))
        im, cbar = utils.heatmap(
            np.abs(corr_mat),
            [f"$f_{i}$" for i in range(corr_mat.shape[-1])],
            [f"$f_{i}$" for i in range(corr_mat.shape[-1])],
            ax=ax,
            cmap="magma",
            cbarlabel=f"Mean Correlation Coef",
            vmin=0,
            vmax=1,
        )
        texts = utils.annotate_heatmap(im, valfmt="{x:.2f}")
        fig.tight_layout()

        fig.suptitle(f"Latent Dimension Correlation", fontsize=25)
        fig.subplots_adjust(top=0.85)
        plt.show()

        experiment_variables["correlation_matrices"].append(
            (correlations_mean, correlations_std)
        )


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

# HACK: deserialization messes up with custome methods so reusing this here
def concept_scores(
    self,
    inputs,
    aggregator='max_pool_mean',
    concept_indices=None,
    data_format="channels_last",
):
    outputs = self(inputs, training=False)
    if len(tf.shape(outputs)) == 2:
        # Then the scores are already computed by our forward pass
        scores = outputs
    else:
        if data_format == "channels_last":
            # Then we will transpose to make things simpler so that
            # downstream we can always assume it is channels first
            # NHWC -> NCHW
            outputs = tf.transpose(
                outputs,
                perm=[0, 3, 1, 2],
            )

        # Else, we need to do some aggregation
        if aggregator == 'mean':
            # Compute the mean over all channels
            scores = tf.math.reduce_mean(outputs, axis=[2, 3])
        elif aggregator == 'max_pool_mean':
            # First downsample using a max pool and then continue with
            # a mean
            window_size = min(
                2,
                outputs.shape[-1],
                outputs.shape[-2],
            )
            scores = tf.nn.max_pool(
                outputs,
                ksize=window_size,
                strides=window_size,
                padding="SAME",
                data_format="NCHW",
            )
            scores = tf.math.reduce_mean(scores, axis=[2, 3])
        elif aggregator == 'max':
            # Simply select the maximum value across a given channel
            scores = tf.math.reduce_max(outputs, axis=[2, 3])
        else:
            raise ValueError(f'Unsupported aggregator {aggregator}.')

    if concept_indices is not None:
        return scores[:, concept_indices]
    return scores

def cw_bottleneck_predict_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_predictive_accuracies=[],
        latent_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            complete_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_cov_{cov:.1f}_trial_{trial}"
                ),
            )
            complete_model.summary()
            cw_output_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"]].output],
                name="cw_output_model",
            )
            
            feature_predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            encoder_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"] - 1](complete_model.inputs)],
                name="cw_output_model",
            )
            
            score_predictive_decoder = construct_decoder(
                units=experiment_config["latent_decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            score_predictive_decoder.compile(
                optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                loss=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                metrics=[
                    "binary_accuracy" if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                ],
            )

            print("\tTraining score model")
            score_train_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_train),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            score_test_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_test),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            score_predictive_decoder.fit(
                x=score_train_codes,
                y=y_train,
                epochs=experiment_config["predictor_max_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            
            
            print("\tEvaluating score model")
            test_result = score_predictive_decoder.evaluate(
                score_test_codes,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    score_predictive_decoder.predict(score_test_codes),
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    score_predictive_decoder.predict(score_test_codes),
                ))
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )
            
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["latent_predictive_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["latent_predictive_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def cw_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            complete_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/end_to_end_model_cov_{cov:.1f}_trial_{trial}"
                ),
            )
            complete_model.summary()
            cw_output_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"]].output],
                name="cw_output_model",
            )
            
            
            encoder_model = tf.keras.Model(
                complete_model.inputs,
                [complete_model.layers[experiment_config["cw_layer"] - 1](complete_model.inputs)],
                name="cw_output_model",
            )
            
            current_aucs = []
            current_accs = []
            score_train_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_train),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            score_test_codes = concept_scores(
                cw_output_model.layers[-1],
                encoder_model(x_test),
                aggregator=experiment_config['aggregator'],
            ).numpy()[:, list(range(experiment_config["num_concepts"]))]
            for concept_idx in range(experiment_config["num_concepts"]):
            
                score_predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )

                score_predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\t\tTraining score model for concept", concept_idx)
                score_predictive_decoder.fit(
                    x=score_train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )


                print("\t\tEvaluating score model")
                test_result = score_predictive_decoder.evaluate(
                    score_test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accs.append(
                    test_result['binary_accuracy']
                )

                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    score_predictive_decoder.predict(score_test_codes),
                ))
                print(
                    f"\t\t\tTest concept AUC = {current_aucs[-1]:.4f}, "
                    f"concept accuracy = {current_accs[-1]:.4f}"
                )
            avg_concept_aucs.append(np.mean(current_aucs))
            avg_concept_accs.append(np.mean(current_accs))
            print(
                f"\t\tTest avg concept AUC = {avg_concept_aucs[-1]:.4f}, "
                f"avg concept accuracy = {avg_concept_accs[-1]:.4f}"
            )
            
            print("\t\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest task accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest task AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

In [None]:
reload(CW)
reload(leakage)

############################################################################
## Experiment config
############################################################################

cw_covariance_experiment_config = dict(
    batch_size=128,
    max_epochs=300,
    pre_train_epochs=0,
    post_cw_train_epochs=300,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    
    latent_decoder_units=[128, 64],
    predictor_max_epochs=300,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "cw/purity_new",
    ),
    input_shape=[7],
    num_concepts=3,
    latent_dims=0,
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    trials=NUM_TRIALS,
    verbosity=0,
    data_concepts=3,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=False,
)


# Generate the experiment directory if it does not exist already
Path(cw_covariance_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
cw_covariance_figure_dir = os.path.join(cw_covariance_experiment_config["results_dir"], "figures")
Path(cw_covariance_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

cw_covariance_results = cw_experiment_loop(
    cw_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)
print("task_accuracies:", cw_covariance_results["task_accuracies"])
print("concept_aucs:", cw_covariance_results["task_aucs"])

In [None]:
cw_covariance_results.update(cw_bottleneck_predict_experiment_loop(
    cw_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

In [None]:
cw_covariance_results.update(cw_bottleneck_concept_predict_experiment_loop(
    cw_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

# CCD Benchmark

## Model Setup

In [None]:
def construct_ccd_encoder(
    input_shape,
    units,
    end_activation="sigmoid",
    latent_dims=0,
    latent_act=None,  # Original paper used "sigmoid" but this is troublesome in deep architectures
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    # And finally map this to the number of concepts we have in our set
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)

    # Now time to collapse all the concepts again back into a single vector
    encoder_model = tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )
    return encoder_model

## Experiment Loop

In [None]:
import concepts_xai.methods.OCACE.topicModel as CCD
import concepts_xai.evaluation.metrics.oracle as oracle
import concepts_xai.evaluation.metrics.completeness as completeness

############################################################################
## Experiment loop
############################################################################

def ccd_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        reconstruction_accuracies=[],
        reconstruction_aucs=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        aligned_purity_matrices=[],
        oracle_matrices=[],
        completeness_scores=[],
        direct_completeness_scores=[],
        mean_similarities=[],
    )
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        recon_accs = []
        aucs = []
        recon_aucs = []
        purity_mats = []
        aligned_purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []
        compl_scores = []
        dir_compl_scores = []
        mean_sims = []
        
        channels_axis = (
            -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
            else 1
        )
        if experiment_config["num_outputs"] == 1:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                y_true,
                y_pred
            )
        else:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                tf.keras.utils.to_categorical(y_true),
                scipy.special.softmax(y_pred, axis=-1),
                multi_class='ovo',
            )
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for covariance {cov:.2f}")

            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            y_train = np.squeeze(y_train)
            y_test = np.squeeze(y_test)
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_ccd_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=experiment_config["encoder_units"],
                    latent_act=experiment_config.get("latent_act", None),
                    latent_dims=experiment_config["latent_dims"],
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            print("\tModel pre-training...")
            
            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "min",
                ),
            )
            end_to_end_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tModel pre-training completed")
            print("\tSerializing model")
            encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_cov_{cov:.2f}__num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_cov_{cov:.2f}__num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            print("\tEvaluating model")
            
            test_result = end_to_end_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            if experiment_config["num_outputs"] > 1:
                task_accs.append(test_result['sparse_top_k_categorical_accuracy'])
                
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(end_to_end_model.predict(x_test), axis=-1)

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                task_accs.append(test_result['binary_accuracy'])
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    end_to_end_model.predict(x_test),
                ))
            
            # Now extract our concept vectors
            if "top_k" not in experiment_config:
                top_k = ccd_compute_k(y=y_train, batch_size=experiment_config["batch_size"])
            else:
                top_k = experiment_config["top_k"]
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=experiment_config["num_concepts"],
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=end_to_end_model.loss,
                top_k=top_k,
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
            )
            topic_model.compile(
                optimizer=tf.keras.optimizers.Adam(
                    experiment_config.get("learning_rate", 1e-3),
                )
            )
            
            # Train it for a few epochs
            print("\tTopic model training...")
            topic_model.fit(
                x=encoder(x_train),
                y=y_train,
                epochs=experiment_config["topic_model_train_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tTopic model training completed")
            
            print("\tSerializing model")
            topic_model.g_model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/topic_g_model_cov_{cov:.2f}_num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            np.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/topic_vector_cov_{cov:.2f}_num_concepts_{experiment_config['num_concepts']}_trial_{trial}.npy"
                ),
                topic_model.topic_vector.numpy(),
            )
            print("\tEvaluating model")
            
            topic_result = topic_model.evaluate(
                encoder(x_test),
                y_test,
                verbose=0,
                return_dict=True,
            )
            
            if experiment_config["num_outputs"] > 1:
                recon_accs.append(topic_result['accuracy'])
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    topic_model(encoder(x_test))[0],
                    axis=-1,
                )

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                recon_aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                recon_accs.append(topic_result['accuracy'])
                recon_aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    topic_model(encoder(x_test))[0],
                ))
            mean_sims.append(topic_result['mean_sim'])
            
            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}, "
                f"task reconstruction accuracy = {recon_accs[-1]:.4f}, "
                f"task reconstruction auc = {recon_aucs[-1]:.4f}, "
                f"mean concept similarity = {mean_sims[-1]:.4f}"
            )
            
                        
            # We start by extracting a completeness score for the extracted
            # concept vectors
            print(f"\t\tComputing completeness scores...")
            compl_score, _ = completeness.completeness_score(
                X=x_test,
                y=y_test,
                features_to_concepts_fn=encoder,
                concepts_to_labels_model=decoder,
                concept_vectors=np.transpose(topic_model.topic_vector.numpy()),
                task_loss=end_to_end_model.loss,
                channels_axis=channels_axis,
                concept_score_fn=lambda f, c: completeness.dot_prod_concept_score(
                    features=f,
                    concept_vectors=c,
                    channels_axis=channels_axis,
                    beta=experiment_config.get("threshold", 0.5),
                ),
                acc_fn=acc_fn,
            )
            compl_scores.append(compl_score)
            
            dir_compl_score, _ = completeness.direct_completeness_score(
                X=x_test,
                y=y_test,
                features_to_concepts_fn=encoder,
                concept_vectors=np.transpose(topic_model.topic_vector.numpy()),
                task_loss=end_to_end_model.loss,
                channels_axis=channels_axis,
                concept_score_fn=lambda f, c: completeness.dot_prod_concept_score(
                    features=f,
                    concept_vectors=c,
                    channels_axis=channels_axis,
                    beta=experiment_config.get("threshold", 0.5),
                ),
                acc_fn=acc_fn,
            )
            dir_compl_scores.append(dir_compl_score)
            
            print(
                f"\t\t\tCompleteness Score: {compl_scores[-1]:.4f} "
                f"and Direct Completeness Score: {dir_compl_scores[-1]:.4f}"
            )
            
            print(f"\t\tComputing purity score...")
            concept_scores = topic_model.concept_scores(encoder(x_test)).numpy()
            purity_score, (purity_mat, aligned_purity_mat), oracle_mat = oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                output_matrices=True,
                alignment_function=oracle.max_alignment_matrix,
            )
            
            purity_mats.append(purity_mat)
            aligned_purity_mats.append(aligned_purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=concept_scores,
                c_true=c_test,
                alignment_function=oracle.max_alignment_matrix,
                oracle_matrix=construct_trivial_auc_mat(
                    c_test.shape[-1]
                ),
                purity_matrix=aligned_purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")

        recon_acc_mean, recon_acc_std = np.mean(recon_accs), np.std(recon_accs)
        experiment_variables["reconstruction_accuracies"].append((recon_acc_mean, recon_acc_std))
        print(f"\tTest reconstruction accuracy: {recon_acc_mean:.4f} ± {recon_acc_std:.4f}")

        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")
        
        recon_auc_mean, recon_auc_std = np.mean(recon_aucs), np.std(recon_aucs)
        experiment_variables["reconstruction_aucs"].append((recon_auc_mean, recon_auc_std))
        print(f"\tTest reconstruction accuracy: {recon_auc_mean:.4f} ± {recon_auc_std:.4f}")
        
        mean_sim_mean, mean_sim_std = np.mean(mean_sims), np.std(mean_sims)
        experiment_variables["mean_similarities"].append((mean_sim_mean, mean_sim_std))
        print(f"\tMean concept similarity: {mean_sim_mean:.4f} ± {mean_sim_std:.4f}")
        
        
        compl_score_mean, compl_score_std = np.mean(compl_scores), np.std(compl_scores)
        experiment_variables["completeness_scores"].append((compl_score_mean, compl_score_std))
        print(f"\tCompleteness Score: {compl_score_mean:.4f} ± {compl_score_std:.4f}")
        
        dir_compl_score_mean, dir_compl_score_std = np.mean(dir_compl_scores), np.std(dir_compl_scores)
        experiment_variables["direct_completeness_scores"].append((dir_compl_score_mean, dir_compl_score_std))
        print(f"\tDirect completeness Score: {dir_compl_score_mean:.4f} ± {dir_compl_score_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))
        
        aligned_purity_mats = np.stack(aligned_purity_mats, axis=0)
        aligned_purity_mat_mean = np.mean(aligned_purity_mats, axis=0)
        aligned_purity_mat_std = np.std(aligned_purity_mats, axis=0)
        print("\tAligned purity matrix:")
        for i in range(aligned_purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(aligned_purity_mat_mean.shape[1]):
                line += f'{aligned_purity_mat_mean[i, j]:.4f} ± {aligned_purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["aligned_purity_matrices"].append((aligned_purity_mat_mean, aligned_purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    utils.reseed(87)
    experiment_variables = dict(
        latent_avg_concept_predictive_accuracies=[],
        latent_avg_concept_predictive_aucs=[],
    )

    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                print("Could not find", f'"{file_name}_means.npz" or "{file_name}_stds.npz" in cache.')
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)} ({file_name}).'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            print(experiment_variables)
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        avg_concept_accs = []
        avg_concept_aucs = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            encoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/encoder_cov_{cov:.2f}__num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            
            decoder = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/decoder_cov_{cov:.2f}__num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            
            
            g_model = tf.keras.models.load_model(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/topic_g_model_cov_{cov:.2f}_num_concepts_{experiment_config['num_concepts']}_trial_{trial}"
                )
            )
            
            topic_vector = np.load(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/topic_vector_cov_{cov:.2f}_num_concepts_{experiment_config['num_concepts']}_trial_{trial}.npy"
                )
            )
            
            
            # Now extract our concept vectors
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=experiment_config['num_concepts'],
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=(
                    tf.keras.losses.BinaryCrossentropy(from_logits=True) if (experiment_config["num_outputs"] <= 2)
                    else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
                ),
                top_k=experiment_config.get("top_k", 32),
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
                initial_topic_vector=topic_vector,
            )
            
            
            concept_scores = topic_model.concept_scores(encoder(x_test)).numpy()
            
            current_aucs = []
            current_accs = []
            score_train_codes = topic_model.concept_scores(encoder(x_train)).numpy()
            score_test_codes = topic_model.concept_scores(encoder(x_test)).numpy()
            for concept_idx in range(experiment_config["data_concepts"]):
            
                score_predictive_decoder = construct_decoder(
                    units=experiment_config["latent_decoder_units"],
                    num_outputs=1,
                )

                score_predictive_decoder.compile(
                    optimizer=tf.keras.optimizers.Adam(experiment_config["learning_rate"]),
                    loss=(
                        tf.keras.losses.BinaryCrossentropy(from_logits=True)
                    ),
                    metrics=[
                        "binary_accuracy"
                    ],
                )

                print("\t\tTraining score model for concept", concept_idx)
                score_predictive_decoder.fit(
                    x=score_train_codes,
                    y=c_train[:, concept_idx],
                    epochs=experiment_config["predictor_max_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )


                print("\t\tEvaluating score model")
                test_result = score_predictive_decoder.evaluate(
                    score_test_codes,
                    c_test[:, concept_idx],
                    verbose=0,
                    return_dict=True,
                )
                current_accs.append(
                    test_result['binary_accuracy']
                )

                current_aucs.append(sklearn.metrics.roc_auc_score(
                    c_test[:, concept_idx],
                    score_predictive_decoder.predict(score_test_codes),
                ))
                print(
                    f"\t\t\tTest concept AUC = {current_aucs[-1]:.4f}, "
                    f"concept accuracy = {current_accs[-1]:.4f}"
                )
            avg_concept_aucs.append(np.mean(current_aucs))
            avg_concept_accs.append(np.mean(current_accs))
            print(
                f"\t\tTest avg concept AUC = {avg_concept_aucs[-1]:.4f}, "
                f"avg concept accuracy = {avg_concept_accs[-1]:.4f}"
            )
            
            print("\t\tDone with trial", trial + 1)

        avg_concept_acc_mean, avg_concept_acc_std = np.mean(avg_concept_accs), np.std(avg_concept_accs)
        experiment_variables["latent_avg_concept_predictive_accuracies"].append((avg_concept_acc_mean, avg_concept_acc_std))
        print(f"\tTest task accuracy: {avg_concept_acc_mean:.4f} ± {avg_concept_acc_std:.4f}")

        avg_concept_auc_mean, avg_concept_auc_std = np.mean(avg_concept_aucs), np.std(avg_concept_aucs)
        experiment_variables["latent_avg_concept_predictive_aucs"].append((avg_concept_auc_mean, avg_concept_auc_std))
        print(f"\tTest task AUC: {avg_concept_auc_mean:.4f} ± {avg_concept_auc_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

def ccd_compute_k(y, batch_size):
    _, counts = np.unique(y, return_counts=True)
    avg_class_ratio = np.mean(counts) / y.shape[0]
    return int((avg_class_ratio * batch_size) / 2)

## CCD Experiments

In [None]:
reload(completeness)
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_covariance_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=NUM_TRIALS,
    learning_rate=1e-3,
    
    num_concepts=3,
    input_shape=[7],
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,
    
    threshold=0.0,
    lambda1=0.1,
    lambda2=0.1,
    eps=1e-5,
    
    latent_decoder_units=[128, 64],
    predictor_max_epochs=300,
    
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "ccd/purity",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

# Generate the experiment directory if it does not exist already
Path(ccd_covariance_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ccd_covariance_figure_dir = os.path.join(ccd_covariance_experiment_config["results_dir"], "figures")
Path(ccd_covariance_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

ccd_covariance_results = ccd_experiment_loop(
    experiment_config=ccd_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)


print("task_accuracies:", ccd_covariance_results["task_accuracies"])
print("reconstruction_accuracies:", ccd_covariance_results["reconstruction_accuracies"])
print("task_aucs:", ccd_covariance_results["task_aucs"])
print("reconstruction_aucs:", ccd_covariance_results["reconstruction_aucs"])
print("purity_scores:", ccd_covariance_results["purity_scores"])
print("non_oracle_purity_scores:", ccd_covariance_results["non_oracle_purity_scores"])
print("completeness_scores:", ccd_covariance_results["completeness_scores"])
print("direct_completeness_scores:", ccd_covariance_results["direct_completeness_scores"])
print("mean_similarities:", ccd_covariance_results["mean_similarities"])


In [None]:
ccd_covariance_results.update(ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config=ccd_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

In [None]:
reload(completeness)
reload(CBM)
reload(CCD)

############################################################################
## Experiment config
############################################################################

ccd_covariance_double_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=NUM_TRIALS,
    learning_rate=1e-3,
    
    num_concepts=6,  # Let's extract twice as many concepts in here
    input_shape=[7],
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,
    
    threshold=0.0,
    lambda1=0.1,
    lambda2=0.1,
    eps=1e-5,
    
    num_outputs=1,
    
    latent_decoder_units=[128, 64],
    predictor_max_epochs=300,
    
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "ccd/purity_double_concepts",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

# Generate the experiment directory if it does not exist already
Path(ccd_covariance_double_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
ccd_covariance_double_figure_dir = os.path.join(ccd_covariance_double_experiment_config["results_dir"], "figures")
Path(ccd_covariance_double_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

ccd_covariance_double_results = ccd_experiment_loop(
    experiment_config=ccd_covariance_double_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)


print("task_accuracies:", ccd_covariance_double_results["task_accuracies"])
print("reconstruction_accuracies:", ccd_covariance_double_results["reconstruction_accuracies"])
print("task_aucs:", ccd_covariance_double_results["task_aucs"])
print("reconstruction_aucs:", ccd_covariance_double_results["reconstruction_aucs"])
print("purity_scores:", ccd_covariance_double_results["purity_scores"])
print("non_oracle_purity_scores:", ccd_covariance_double_results["non_oracle_purity_scores"])
print("completeness_scores:", ccd_covariance_double_results["completeness_scores"])
print("direct_completeness_scores:", ccd_covariance_double_results["direct_completeness_scores"])
print("mean_similarities:", ccd_covariance_double_results["mean_similarities"])


In [None]:
ccd_covariance_double_results.update(ccd_bottleneck_concept_predict_experiment_loop(
    experiment_config=ccd_covariance_double_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
))

# SENN Benchmark

In [None]:
import concepts_xai.methods.SENN.base_senn as SENN
import concepts_xai.methods.SENN.aggregators as aggregators
reload(SENN)
reload(aggregators)


def construct_senn_coefficient_model(units, num_concepts, num_outputs):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"coefficient_model_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_concepts * num_outputs,
            activation=None,
            name="coefficient_model_output",
        ),
        tf.keras.layers.Reshape([num_outputs, num_concepts])
    ])

def construct_senn_encoder(
    input_shape,
    units,
    end_activation="sigmoid",
    latent_dims=0,
    latent_act=None,  # Original paper used "sigmoid" but this is troublesome in deep architectures
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    senn_encoder = tf.keras.Model(
        encoder_inputs,
        mean,
        name="senn_encoder",
    )
    vae_encoder = tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="vae_encoder",
    )
    return senn_encoder, vae_encoder


def construct_vae_decoder(
    units,
    output_shape,
    latent_dims,
):
    """CNN decoder architecture used in the 'Challenging Common Assumptions in the Unsupervised Learning
       of Disentangled Representations' paper (https://arxiv.org/abs/1811.12359)

       Note: model is uncompiled
    """

    latent_inputs = tf.keras.Input(shape=(latent_dims,))
    model_out = latent_inputs
    for unit in units:
        model_out = tf.keras.layers.Dense(
            unit,
            activation='relu',
        )(model_out)
    model_out = tf.keras.layers.Dense(
        output_shape,
        activation=None,
    )(model_out)

    return tf.keras.Model(
        inputs=latent_inputs,
        outputs=[model_out],
    )


def construct_senn_model(
    concept_encoder,
    concept_decoder,
    coefficient_model,
    num_outputs,
    regularization_strength=0.1,
    learning_rate=1e-3,
    sparsity_strength=2e-5,
):
    def reconstruction_loss_fn(y_true, y_pred):
#         return vae_losses.bernoulli_fn_wrapper()(y_true, concept_decoder(y_pred))
        return tf.reduce_sum(
            tf.square(y_true - concept_decoder(y_pred)),
            [-1]
        )
    senn_model = SENN.SelfExplainingNN(
        encoder_model=concept_encoder,
        coefficient_model=coefficient_model,
        aggregator_fn=(
            aggregators.multiclass_additive_aggregator if (num_outputs >= 2)
            else aggregators.scalar_additive_aggregator
        ),
        task_loss_fn=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs < 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        reconstruction_loss_fn=reconstruction_loss_fn,
        regularization_strength=regularization_strength,
        sparsity_strength=sparsity_strength,
        name="SENN",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs < 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    senn_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return senn_model

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5

def get_argmax_concept_explanations(preds, class_theta_scores):
    inds = np.argmax(preds, axis=-1)
    result = np.take_along_axis(
        class_theta_scores,
        np.expand_dims(np.expand_dims(inds, axis=-1), axis=-1),
        axis=1,
    )
    return np.squeeze(result, axis=1)

def senn_experiment_loop(
    experiment_config,
    load_from_cache=False,
    oracle_matrix_cache=None,
):
    utils.reseed(87)
    oracle_matrix_cache = oracle_matrix_cache or {}
    experiment_variables = dict(
        task_accuracies=[],
        task_aucs=[],
        purity_scores=[],
        non_oracle_purity_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    start_ind = 0
    if load_from_cache:
        cached = True
        complete_cash = True
        for var_name in experiment_variables:
            file_name = os.path.join(
                experiment_config["results_dir"],
                var_name
            )
            if (
                (not os.path.exists(f'{file_name}_means.npz')) or
                (not os.path.exists(f'{file_name}_stds.npz'))
            ):
                cached = False
                break
            loaded_means = np.load(f'{file_name}_means.npz')
            if len(loaded_means) != len(experiment_config["covariances"]):
                print("Found", len(loaded_means), "means for variable", var_name, "vs", len(experiment_config["covariances"]), "datasets")
                # Then we have a partial run here so let's just run the rest
                if start_ind and start_ind != len(loaded_means):
                    raise ValueError(
                        f'Found inconsistent start indices in cached '
                        f'data {start_ind} vs {len(loaded_means)}.'
                    )
                start_ind = len(loaded_means)
                complete_cash = False
        if cached:
            # Then we have found all of the arrays of interest, so let's
            # load them and use them
            print("Experiment cache was hit")
            print("\tLoading variables from", experiment_config["results_dir"])
            for var_name in experiment_variables:
                file_name = os.path.join(
                    experiment_config["results_dir"],
                    var_name
                )
                experiment_variables[var_name] = list(zip(
                    list(map(
                        lambda x: np.load(f'{file_name}_means.npz')[x],
                        np.load(f'{file_name}_means.npz')
                    )),
                    list(map(
                        lambda x: np.load(f'{file_name}_stds.npz')[x],
                        np.load(f'{file_name}_stds.npz')
                    )),
                ))
            if complete_cash:
                # Then we are good to go
                if not os.path.exists(
                    os.path.join(experiment_config["results_dir"], "config.yaml")
                ):
                    # then serialize the config as this is a different version run
                    utils.serialize_experiment_config(
                        experiment_config,
                        experiment_config["results_dir"],
                    )
                return experiment_variables


    # Else, let's go ahead and run the whole thing
    Path(
        os.path.join(
            experiment_config["results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        task_accs = []
        aucs = []
        purity_mats = []
        oracle_mats = []
        purities = []
        non_oracle_purities = []

        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for covariance {cov:.2f}")
            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            x_train = np.squeeze(x_train)
            x_test = np.squeeze(x_test)
            y_train = np.squeeze(y_train)
            y_test = np.squeeze(y_test)

            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            concept_encoder, vae_encoder = construct_senn_encoder(
                input_shape=experiment_config["input_shape"],
                units=experiment_config["encoder_units"],
                latent_act=experiment_config.get("latent_act", None),
                latent_dims=experiment_config["latent_dims"],
            )
            concept_decoder = construct_vae_decoder(
                units=experiment_config["decoder_units"],
                output_shape=experiment_config["input_shape"][-1],
                latent_dims=experiment_config["latent_dims"],
            )
            coefficient_model = construct_senn_coefficient_model(
                units=experiment_config["coefficient_model_units"],
                num_concepts=experiment_config["latent_dims"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            if experiment_config.get("pretrain_autoencoder_epochs"):
                autoencoder = beta_vae.BetaVAE(
                    encoder=vae_encoder,
                    decoder=concept_decoder,
                    loss_fn=vae_losses.bernoulli_fn_wrapper(),
                    beta=experiment_config.get("beta", 1),
                )

                autoencoder.compile(
                    optimizer=tf.keras.optimizers.Adam(
                        experiment_config.get("learning_rate", 1e-3)
                    ),
                )

                print("\tAutoencoder pre-training...")
                autoencoder.fit(
                    x=x_train,
                    epochs=experiment_config["pretrain_autoencoder_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tAutoencoder training completed")

            # Now time to actually construct and train the CBM
            senn_model = construct_senn_model(
                concept_encoder=concept_encoder,
                concept_decoder=concept_decoder,
                coefficient_model=coefficient_model,
                num_outputs=experiment_config["num_outputs"],
                regularization_strength=experiment_config.get("regularization_strength", 0.1),
                learning_rate=experiment_config.get("learning_rate", 1e-3),
                sparsity_strength=experiment_config.get("sparsity_strength", 2e-5),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "max",
                ),
            )

            print("\tSENN training...")
            senn_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tSENN training completed")
            print("\tSerializing model")
            concept_encoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_encoder_{cov:.2f}_trial_{trial}"
                )
            )
            concept_decoder.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/concept_decoder_{cov:.2f}_trial_{trial}"
                )
            )
            coefficient_model.save(
                os.path.join(
                    experiment_config["results_dir"],
                    f"models/coefficient_model_{cov:.2f}_trial_{trial}"
                )
            )

            print("\tEvaluating model")
            test_result = senn_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_accs.append(
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    senn_model.predict(x_test)[0],
                    axis=-1
                )

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                aucs.append(sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                aucs.append(sklearn.metrics.roc_auc_score(
                    y_test,
                    senn_model.predict(x_test)[0],
                ))

            print(
                f"\t\tTest auc = {aucs[-1]:.4f}, "
                f"task accuracy = {task_accs[-1]:.4f}"
            )

            print(f"\t\tComputing purity score...")
            x_test_preds, (_, x_test_theta_class_scores) = senn_model(x_test)
            test_concept_scores = get_argmax_concept_explanations(
                x_test_preds.numpy(),
                x_test_theta_class_scores.numpy(),
            )

            purity_score, (purity_mat, aligned_purity_mat), oracle_mat = oracle.oracle_impurity_score(
                c_soft=test_concept_scores,
                c_true=c_test,
                output_matrices=True,
                alignment_function=oracle.max_alignment_matrix,
            )
            purity_mats.append(aligned_purity_mat)
            oracle_mats.append(oracle_mat)
            purities.append(purity_score)
            print(f"\t\t\tDone {purity_score:.4f}")

            print("\t\tComputing non-oracle purity score...")
            non_oracle_purities.append(oracle.oracle_impurity_score(
                c_soft=test_concept_scores,
                c_true=c_test,
                oracle_matrix=construct_trivial_auc_mat(
                    c_test.shape[-1]
                ),
                alignment_function=oracle.max_alignment_matrix,
                purity_matrix=aligned_purity_mat,
            ))
            print(f"\t\t\tDone {non_oracle_purities[-1]:.4f}")
            print("\t\tDone with trial", trial + 1)

        task_acc_mean, task_acc_std = np.mean(task_accs), np.std(task_accs)
        experiment_variables["task_accuracies"].append((task_acc_mean, task_acc_std))
        print(f"\tTest task accuracy: {task_acc_mean:.4f} ± {task_acc_std:.4f}")


        task_auc_mean, task_auc_std = np.mean(aucs), np.std(aucs)
        experiment_variables["task_aucs"].append((task_auc_mean, task_auc_std))
        print(f"\tTest task AUC: {task_auc_mean:.4f} ± {task_auc_std:.4f}")

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std))

        purity_mean, purity_std = np.mean(purities), np.std(purities)
        experiment_variables["purity_scores"].append((purity_mean, purity_std))
        print(f"\tPurity score: {purity_mean:.4f} ± {purity_std:.4f}")

        non_oracle_purity_mean, non_oracle_purity_std = np.mean(non_oracle_purities), np.std(non_oracle_purities)
        experiment_variables["non_oracle_purity_scores"].append((non_oracle_purity_mean, non_oracle_purity_std))
        print(f"\tNon-oracle purity score: {non_oracle_purity_mean:.4f} ± {non_oracle_purity_std:.4f}")

        # And serialize the results
        utils.serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

In [None]:
reload(CBM)
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_covariance_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=NUM_TRIALS,
    learning_rate=1e-3,
    
    num_concepts=3,
    input_shape=[7],
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,
    
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    
    predictor_max_epochs=300,
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    num_outputs=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "senn/purity",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

# Generate the experiment directory if it does not exist already
Path(senn_covariance_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
senn_covariance_figure_dir = os.path.join(senn_covariance_experiment_config["results_dir"], "figures")
Path(senn_covariance_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

senn_covariance_results = senn_experiment_loop(
    experiment_config=senn_covariance_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)


print("task_accuracies:", senn_covariance_results["task_accuracies"])
print("task_aucs:", senn_covariance_results["task_aucs"])
print("purity_scores:", senn_covariance_results["purity_scores"])
print("non_oracle_purity_scores:", senn_covariance_results["non_oracle_purity_scores"])


In [None]:
reload(CBM)
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_covariance_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=NUM_TRIALS,
    learning_rate=1e-3,

    num_concepts=2*3,
    input_shape=[7],
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,

    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],

    predictor_max_epochs=300,

    regularization_strength=0.1,
    sparsity_strength=2e-5,

    num_outputs=2,
    patience=float("inf"),
    min_delta=1e-5,
    results_dir=os.path.join(
        RESULTS_DIR,
        "senn/purity_extended",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=3,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

# Generate the experiment directory if it does not exist already
Path(senn_covariance_extended_experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
senn_covariance_extended_figure_dir = os.path.join(senn_covariance_extended_experiment_config["results_dir"], "figures")
Path(senn_covariance_extended_figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

senn_covariance_extended_results = senn_experiment_loop(
    experiment_config=senn_covariance_extended_experiment_config,
    load_from_cache=LOAD_FROM_CACHE,
)


print("task_accuracies:", senn_covariance_extended_results["task_accuracies"])
print("task_aucs:", senn_covariance_extended_results["task_aucs"])
print("purity_scores:", senn_covariance_extended_results["purity_scores"])
print("non_oracle_purity_scores:", senn_covariance_extended_results["non_oracle_purity_scores"])


# Dataset-wide Results

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "purity_scores",
        lambda x: x,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "purity_scores",
        lambda x: x,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "purity_scores",
        lambda x: x,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        results[kword],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        results[kword],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (all_vars - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15) #, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("Oracle Impurity", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Oracle Impurity (TabularToy($\delta$))"), fontsize=28)
plt.xticks(all_vars, fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.array([0, 9])
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models*1.25, 3))
for i, (method_name, results, kword, transform_fn, subsample) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[subsample],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[subsample],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Oracle Impurity", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Oracle Impurity (TabularToy($\delta$))"), fontsize=25)
plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.array([5])
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "purity_scores",
        lambda x: x,
        all_vars
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models, 5))
for i, (method_name, results, kword, transform_fn, subsample) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[subsample],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[subsample],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.1), ncol=(num_models - 1)//2)

plt.ylabel("Oracle Impurity", fontsize=20)
# plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Oracle Impurity (TabularToy($\delta = 0.5$))"), fontsize=28)
plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "concept_accuracies",
        lambda x: 100 * x,
        all_vars,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn, subsample) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[subsample],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[subsample],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("AUC (\%)", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Mean Concept Accuracy (TabularToy($\delta$))"), fontsize=28)
plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.array([0, 9])
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "concept_accuracies",
        lambda x: 100 * x,
        all_vars,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.5, 5))
for i, (method_name, results, kword, transform_fn, subsample) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[subsample],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[subsample],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Mean Concept Accuracy (TabularToy($\delta$))"), fontsize=28)
plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.array([5])
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "concept_accuracies",
        lambda x: 100 * x,
        all_vars,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "concept_aucs",
        lambda x: 100 * np.mean(x),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "aligned_purity_matrices",
        lambda x: 100 * np.mean(np.diagonal(x)),
        all_vars,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models, 5))
for i, (method_name, results, kword, transform_fn, subsample) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[subsample],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[subsample],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.1), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
# plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Mean Concept Accuracy (TabularToy($\delta = 0.5$))"), fontsize=28)
plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_aucs",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [0, 9]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_aucs",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [5]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_aucs",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [0, 9]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 1.5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [5]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "task_accuracies",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1)//2)

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = np.arange(0, 10)
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC from Concepts (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC from Concepts (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [0, 9]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "latent_predictive_aucs",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("AUC (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task AUC from Concepts (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task AUC from Concepts (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()

In [None]:
# Plot concept accuracies for all methods
# Set up our figure
all_vars = [0, 9]
real_values = np.array(list(map(lambda x: x*0.1, all_vars)))
clrs = sns.color_palette("husl", 20)

all_models = [
    (
        "Joint-CBM",
        base_results,
        "latent_predictive_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW MaxPool-Mean",
        cw_covariance_results,
        "latent_predictive_accuracies",
        lambda x: x * 100,
    ),
    (
        "CW Feature Map",
        cw_covariance_results,
        "latent_predictive_accuracies",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_experiment_config['num_concepts']})",
        ccd_covariance_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
    (
        f"CCD (n\_concepts = {ccd_covariance_double_experiment_config['num_concepts']})",
        ccd_covariance_double_results,
        "direct_completeness_scores",
        lambda x: x * 100,
    ),
]

num_models = len(all_models) + 1

clrs = sns.color_palette("husl", num_models * 2)
scale = 1

fig, ax = plt.subplots(figsize=(num_models * 2 + 5, 5))
for i, (method_name, results, kword, transform_fn) in enumerate(all_models):
    accs_means = list(map(
        lambda x: x[0],
        np.array(results[kword])[all_vars],
    ))
    accs_means = np.array(list(map(transform_fn, accs_means)))
    accs_stds = list(map(
        lambda x: x[1],
        np.array(results[kword])[all_vars],
    ))
    accs_stds = np.array(list(map(transform_fn, accs_stds)))
    ax.bar(
        scale * (np.arange(0, len(all_vars)) - (1/2 - 1/(2 * num_models)) + i/num_models),
        accs_means,
        width=scale/num_models,
        color=clrs[i*2],
        align='center',
        label=method_name,
        yerr=2*accs_stds,
        capsize=5,
        edgecolor="black",
    )



lgd = ax.legend(fontsize=15, loc='upper center', bbox_to_anchor=(0.5,-0.2), ncol=(num_models - 1))

plt.ylabel("Accuracy (\%)", fontsize=20)
if len(all_vars) > 1:
    plt.xlabel("Cross-concept Covariance ($\delta$)", fontsize=20)
    plt.title(bold_text("Downstream Task Accuracy from Concepts (TabularToy($\delta$))"), fontsize=28)
    plt.xticks(np.arange(0, len(all_vars)), fontsize=15)
    ax.set_xticklabels(list(map(lambda x: f'{x:.1f}', real_values)), fontsize=15)
else:
    plt.title(bold_text("Downstream Task Accuracy from Concepts (TabularToy($\delta = " + str(all_vars[0] * 0.1) + "$))"), fontsize=28)
    plt.xticks([], fontsize=15)
plt.yticks(fontsize=15)
ax.grid(False)
plt.show()