In [None]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

# Purity Correlation Experiment

## Setup

In [None]:
import sys
import matplotlib
import concepts_xai
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
from tensorflow.keras.models import load_model
from joblib import dump, load
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import accuracy_score, roc_auc_score

import utils
import concepts_xai.evaluation.metrics.niching as niching
from collections import defaultdict

In [None]:
################################################################################
## Set seeds up for reproducibility
################################################################################
utils.reseed(87)


In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################
NUM_CONCEPTS = 5
INPUT_SHAPE = [12]
FROM_CACHE = True
_LATEX_SYMBOL = ""
BASE_DIR = '.'
RESULTS_DIR = os.path.join(BASE_DIR, "results/toy_tabular")
NICHING_RESULTS_DIR = os.path.join(BASE_DIR, "results_concept_niching_integrated/toy_tabular")
rc('text', usetex=(_LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')

## Utility Functions

In [None]:
def bold_text(x):
    if _LATEX_SYMBOL == "$":
        return r"$\textbf{" + x + "}$"
    return x

## Dataset Construction

In [None]:
############################################################################
## Generate Data
############################################################################
def produce_data(samples, cov=0.0, num_concepts=NUM_CONCEPTS):
    x = np.zeros((samples, INPUT_SHAPE[0]), dtype=np.float32)
    y = np.zeros((samples,), dtype=np.float32)
    
    # Sample the x, y, and z variables
    vars = np.random.multivariate_normal(
        mean=[0, 0, 0, 0, 0],
        cov=[
            [1, cov, cov, 0, 0],
            [cov, 1, cov, 0, 0],
            [cov, cov, 1, 0, 0],
            [0, 0, 0, 1, cov],
            [0, 0, 0, cov, 1],
        ],
        size=(samples,),
    )
    x_vars = vars[:, :1]
    y_vars = vars[:, 1:2]
    z_vars = vars[:, 2:3]
    a_vars = vars[:, 3:4]
    b_vars = vars[:, 4:]
    
    # The features are just non-linear functions applied to each
    # variable
    features = [
        np.sin(x_vars) + x_vars,
        np.cos(x_vars) + x_vars,
        np.sin(y_vars) + y_vars,
        np.cos(y_vars) + y_vars,
        np.sin(z_vars) + z_vars,
        np.cos(z_vars) + z_vars,
        x_vars**2 + y_vars**2 + z_vars**2,
        np.sin(a_vars) + a_vars,
        np.cos(a_vars) + a_vars,
        np.sin(b_vars) + b_vars,
        np.cos(b_vars) + b_vars,
        a_vars**2 + b_vars**2
    ]
    features = np.stack(features, axis=1)

    # The concepts just check if the variables are positive
    x_pos = (x_vars > 0).astype(np.int32)
    y_pos = (y_vars > 0).astype(np.int32)
    z_pos = (z_vars > 0).astype(np.int32)
    a_pos = (a_vars > 0).astype(np.int32)
    b_pos = (b_vars > 0).astype(np.int32)
    concepts = np.squeeze(
        np.stack([x_pos, y_pos, z_pos, a_pos, b_pos][:num_concepts], axis=1)
    ).astype(np.float32)
    
    # The labels are generated by checking if at least two of the
    # latent concepts are greater than zero
    labels = np.zeros((samples, 2))
    labels[:, 0] = (x_pos + y_pos + z_pos).squeeze()
    labels[:, 1] = (a_pos + b_pos).squeeze()
    labels = (labels > 1).astype(np.int32)
    
    # And that's it buds
    return features.squeeze(), labels.argmax(axis=1), concepts

## Model Construction

In [None]:
# Construct the encoder model
def construct_encoder(
    input_shape,
    units,
    num_concepts,
    end_activation="sigmoid",
    latent_dims=0,
    output_logits=False,
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    # And finally map this to the number of concepts we have in our set
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    if latent_dims:
        bypass = tf.keras.layers.Dense(
            latent_dims,
            activation=end_activation,
            name="encoder_bypass_channel",
        )(encoder_compute_graph)
    else:
        bypass = None
    encoder_compute_graph = tf.keras.layers.Dense(
        num_concepts,
        activation=None if output_logits else "sigmoid",
        name="encoder_concept_outputs",
    )(encoder_compute_graph)

    # Now time to collapse all the concepts again back into a single vector
    encoder_model = tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph if bypass is None else [encoder_compute_graph, bypass],
        name="encoder",
    )
    return encoder_model


In [None]:
############################################################################
## Build concepts-to-labels model
############################################################################

def construct_decoder(units, num_outputs=1,):
    decoder_layers = [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"decoder_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    decoder_model = tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_outputs,
            activation=None,
            name="decoder_model_output",
        )
    ])
    return decoder_model


# CBM Benchmark

In [None]:
import concepts_xai.methods.CBM.CBModel as CBM
reload(CBM)

############################################################################
## Build CBM
############################################################################

def construct_cbm(
    encoder,
    decoder,
    latent_dims=0,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_output_logits=False,
):
    model_factory = CBM.BypassJointCBM if latent_dims else CBM.JointConceptBottleneckModel
    cbm_model = model_factory(
        encoder=encoder,
        decoder=decoder,
        task_loss=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True)
        ),
        name="joint_cbm",
        metrics=[tf.keras.metrics.BinaryAccuracy()],
        alpha=alpha,
        pass_concept_logits=encoder_output_logits,
    )

    ############################################################################
    ## Compile CBM Model
    ############################################################################

    cbm_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return cbm_model

# Construct the complete model
def construct_end_to_end_model(
    encoder,
    decoder,
    input_shape,
    learning_rate=1e-3,
):
    model_inputs = tf.keras.Input(shape=input_shape)
    encoder_out = encoder(model_inputs)
    if isinstance(encoder_out, list):
        encoder_out = tf.concat(encoder_out, axis=-1)
    model_compute_graph = decoder(encoder_out)
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        model_compute_graph,
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["binary_accuracy"],
    )
    return model, encoder, decoder


## CBM-Prob Experiment

In [None]:
import concepts_xai.evaluation.metrics.oracle as oracle

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def cbm_experiment_loop(experiment_config, load_from_cache=False):
    utils.reseed(87)
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    experiment_variables = dict(
        config = experiment_config,
        task_accuracies = [],
        concept_accuracies = [],
        niss=[],
    )
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    
    Path(
        os.path.join(
            experiment_config["niching_results_dir"],
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov}")
            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            
            # Then proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                encoder=construct_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=experiment_config["encoder_units"],
                    num_concepts=experiment_config["num_concepts"],
                    end_activation="sigmoid",
                    latent_dims=experiment_config["latent_dims"],
                    output_logits=experiment_config.get("encoder_output_logits", False),
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                end_to_end_model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tModel pre-training completed")
            
            # Now time to actually construct and train the CBM
            cbm_model = construct_cbm(
                encoder=encoder,
                decoder=decoder,
                alpha=experiment_config["alpha"],
                learning_rate=experiment_config["learning_rate"],
                latent_dims=experiment_config.get("latent_dims", 0),
                encoder_output_logits=experiment_config.get("encoder_output_logits", False),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor="val_concept_accuracy",
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode='max',
            )
            if experiment_config["warmup_epochs"]:
                print("\tWarmup training...")
                cbm_model.fit(
                    x=x_train,
                    y=(
                        y_train,
                        y_train_concepts[:, :experiment_config["num_concepts"]],
                    ),
                    epochs=experiment_config["warmup_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tWarmup training completed")


            print("\tCBM training...")
            cbm_model.fit(
                x=x_train,
                y=(
                    y_train,
                    y_train_concepts[:, :experiment_config["num_concepts"]],
                ),
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tCBM training completed")
            print("\tEvaluating model")
            c_train_pred = cbm_model.encoder(x_train).numpy()
            c_test_pred = cbm_model.encoder(x_test).numpy()
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=y_test_concepts,
                    c_soft_train=c_train_pred,
                    c_true_train=y_train_concepts,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))

    return experiment_variables


In [None]:
reload(niching)
reload(CBM)

############################################################################
## Experiment config
############################################################################

cbm_base_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "cbm/base"
    ),
    input_shape=INPUT_SHAPE,
    num_concepts=NUM_CONCEPTS,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    covariances=np.arange(0, 1, 0.1),
    verbosity=0,
    delta_beta=0.05,
    encoder_output_logits=False,
)

############################################################################
## Experiment run
############################################################################

cbm_base_results = cbm_experiment_loop(
    cbm_base_experiment_config,
    load_from_cache=False,
)

In [None]:
reload(niching)
reload(CBM)

############################################################################
## Experiment config
############################################################################

cbm_from_logits_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    warmup_epochs=0,
    pre_train_epochs=0,
    trials=5,
    alpha=0.1,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "cbm/from_logits"
    ),
    input_shape=INPUT_SHAPE,
    num_concepts=NUM_CONCEPTS,
    latent_dims=0,
    holdout_fraction=0.1,
    train_samples=2000,
    test_samples=1000,
    covariances=np.arange(0, 1, 0.1),
    verbosity=0,
    delta_beta=0.05,
    encoder_output_logits=True,
)

############################################################################
## Experiment run
############################################################################

cbm_from_logits_results = cbm_experiment_loop(
    cbm_from_logits_experiment_config,
    load_from_cache=True,
)

# Concept Whitening Benchmark

In [None]:
import concepts_xai.methods.CW.CWLayer as CW

def construct_cw_model(
    input_shape,
    encoder,
    decoder,
    learning_rate=1e-3,
    activation=tf.keras.activations.relu,
    activation_mode='max_pool_mean',
):
    model_inputs = tf.keras.Input(shape=input_shape)
    cw_layer = CW.ConceptWhiteningLayer(
        activation_mode=activation_mode,
    )
    cw_model = tf.keras.Model(
        model_inputs,
        cw_layer(encoder(model_inputs)),
        name="cw_model",
    )
    
    # Now time to collapse all the concepts again back into a single vector
    model = tf.keras.Model(
        model_inputs,
        decoder(activation(cw_layer(encoder(model_inputs)))),
        name="complete_model",
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["binary_accuracy"],
    )
    return model, cw_model

def channels_corr_mat(outputs):
    if len(outputs.shape) == 2:
        outputs = np.expand_dims(
            np.expand_dims(outputs, axis=1),
            axis=1,
        )
    # Change (N, H, W, C) to (C, N, H, W)
    outputs = np.transpose(outputs, [3, 0, 1, 2])
    # Change (C, N, H, W) to (C, NxHxW)
    cnhw_shape = outputs.shape
    outputs = np.transpose(np.reshape(outputs, [cnhw_shape[0], -1]))
    outputs -= np.mean(outputs, axis=0, keepdims=True)
    outputs = outputs / np.std(outputs, axis=0, keepdims=True)
    return np.dot(outputs.transpose(), outputs) / outputs.shape[0]

In [None]:
import concepts_xai.evaluation.metrics.leakage as leakage

def concept_scores(
    cw_layer,
    inputs,
    aggregator='max_pool_mean',
    concept_indices=None,
):
    outputs = cw_layer(inputs, training=False)
    if len(tf.shape(outputs)) == 2:
        # Then the scores are already computed by our forward pass
        scores = outputs
    else:
        if cw_layer.data_format == "channels_last":
            # Then we will transpose to make things simpler so that
            # downstream we can always assume it is channels first
            # NHWC -> NCHW
            outputs = tf.transpose(
                outputs,
                perm=[0, 3, 1, 2],
            )

        # Else, we need to do some aggregation
        if aggregator == 'mean':
            # Compute the mean over all channels
            scores = tf.math.reduce_mean(outputs, axis=[2, 3])
        elif aggregator == 'max_pool_mean':
            # First downsample using a max pool and then continue with
            # a mean
            window_size = min(
                2,
                outputs.shape[-1],
                outputs.shape[-2],
            )
            scores = tf.nn.max_pool(
                outputs,
                ksize=window_size,
                strides=window_size,
                padding="SAME",
                data_format="NCHW",
            )
            scores = tf.math.reduce_mean(scores, axis=[2, 3])
        elif aggregator == 'max':
            # Simply select the maximum value across a given channel
            scores = tf.math.reduce_max(outputs, axis=[2, 3])
        else:
            raise ValueError(f'Unsupported aggregator {aggregator}.')

    if concept_indices is not None:
        return scores[:, concept_indices]
    return scores


def cw_experiment_loop(experiment_config, load_from_cache=False):
    utils.reseed(87)
    experiment_variables = dict(
        config = experiment_config,
        task_accuracies = [],
        concept_accuracies = [],
        niss=[],
    )
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
        
    count = 0
    start_ind = 0
    Path(
        os.path.join(
            experiment_config["niching_results_dir"],
            "models",
        )
    )#.mkdir(parents=True, exist_ok=True)
    verbosity = experiment_config.get("verbosity", 0)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']}")

            # First construct the dataset
            (x_train, y_train, y_train_concepts) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, y_test_concepts) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            x_true_inds = (y_train_concepts[:, 0] == 1)
            y_true_inds = (y_train_concepts[:, 1] == 1)
            z_true_inds = (y_train_concepts[:, 2] == 1)
            x_group_inds = np.logical_and(
                x_true_inds,
                np.logical_and(
                    np.logical_not(y_true_inds),
                    np.logical_not(z_true_inds),
                )
            )
            y_group_inds = np.logical_and(
                y_true_inds,
                np.logical_and(
                    np.logical_not(x_true_inds),
                    np.logical_not(z_true_inds),
                )
            )
            z_group_inds = np.logical_and(
                z_true_inds,
                np.logical_and(
                    np.logical_not(x_true_inds),
                    np.logical_not(y_true_inds),
                )
            )
            exclusive_concept_groups = [
                x_train[x_group_inds, :],
                x_train[y_group_inds, :],
                x_train[z_group_inds, :],
            ][:experiment_config["data_concepts"]]
            
            if not experiment_config.get("exclusive_concepts", False):
                x_group_inds = x_true_inds
                y_group_inds = y_true_inds
                z_group_inds = z_true_inds
            concept_groups = [
                x_train[x_group_inds, :],
                x_train[y_group_inds, :],
                x_train[z_group_inds, :],
            ][:experiment_config["data_concepts"]]
            
            
            
            # Construct our CW model
            encoder = construct_encoder(
                input_shape=experiment_config["input_shape"],
                units=experiment_config["encoder_units"],
                num_concepts=experiment_config["num_concepts"],
                end_activation=None,
                latent_dims=experiment_config["latent_dims"],
            )
            decoder = construct_decoder(
                units=experiment_config["decoder_units"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            model, cw_model = construct_cw_model(
                input_shape=experiment_config["input_shape"],
                encoder=encoder,
                decoder=decoder,
                learning_rate=experiment_config["learning_rate"],
                activation=tf.keras.activations.relu,
                activation_mode=experiment_config['activation_mode'],
            )
            
            # First do some pretraining for warming up the estimates if needed
            if experiment_config.get("pre_train_epochs"):
                print("\tModel pre-training...")
                model.fit(
                    x=x_train,
                    y=y_train,
                    epochs=experiment_config["pre_train_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                fig, ax = plt.subplots(1, figsize=(8, 6))
                similarity_ratio = oracle.concept_similarity_matrix(
                    concept_representations=list(map(
                        lambda x: cw_model(x).numpy(),
                        concept_groups
                    )),
                    compute_ratios=True,
                )
                im, cbar = heatmap(
                    similarity_ratio,
                    [f"$c_{i}$" for i in range(len(concept_groups))],
                    [f"$c_{i}$" for i in range(len(concept_groups))],
                    ax=ax,
                    cmap="magma",
                    cbarlabel=f"Similarity Ratio",
                    vmin=0,
                    vmax=1,
                )
                texts = annotate_heatmap(im, valfmt="{x:.2f}")
                fig.tight_layout()

                fig.suptitle(f"Baseline Concept Axis Separability", fontsize=25)
                fig.subplots_adjust(top=0.85)
                plt.show()
                print("\t\tModel pre-training completed")
            
            # Set up the dataset in a nice usable form for unrolling the training
            # loop
            main_dataset_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
            main_dataset_loader = main_dataset_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )
            
            min_size = min(list(map(lambda x: x.shape[0], concept_groups)))
            concept_groups = list(map(lambda x: x[:min_size, :], concept_groups))
            concept_group_loader = tf.data.Dataset.from_tensor_slices(tuple(concept_groups))
            concept_group_loader = concept_group_loader.shuffle(buffer_size=1000).batch(
                experiment_config["batch_size"]
            )

            @tf.function
            def _train_step(model, x_batch_train, y_batch_train):
                # Update the other model parameters
                with tf.GradientTape() as tape:
                    logits = model(x_batch_train, training=True)
                    loss_value = model.loss(y_batch_train, logits)

                grads = tape.gradient(loss_value, model.trainable_weights)
                model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
                return loss_value
            
            total_steps = 0
            for epoch in range(experiment_config["max_epochs"]):
                for current_step, (x_batch_train, y_batch_train) in enumerate(main_dataset_loader):
                    print(
                        f'Epoch {epoch + 1} and step {current_step}/{int(np.ceil(x_train.shape[0] / experiment_config["batch_size"]))}         ',
                        end="\r",
                    )
                    # Need to update the rotation matrix
                    if (total_steps + 1) % experiment_config["cw_train_freq"] == 0:
                        for _ in range(experiment_config.get("cw_train_iterations", 1)):
                            cw_batch_steps = 0
                            for concept_groups_batch in concept_group_loader:
                                if cw_batch_steps > experiment_config.get("cw_train_batch_steps", float("inf")):
                                    break
                                model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                                    concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                                )
                                cw_batch_steps += 1
                    if experiment_config.get("concept_auc_freq"):
                        if (total_steps % experiment_config["concept_auc_freq"]) == 0:
                            concept_aucs = leakage.compute_concept_aucs(
                                cw_model=model,
                                encoder=encoder,
                                cw_layer=experiment_config["cw_layer"],
                                x_test=x_test,
                                c_test=y_test_concepts,
                                num_concepts=experiment_config["num_concepts"],
                            )
                            print(
                                f'Concept AUC at step {total_steps}:',
                                concept_aucs
                            )
                    _train_step(model, x_batch_train, y_batch_train)
                    total_steps += 1
            
            if experiment_config.get("post_cw_train_epochs"):
                for post_epoch in range(experiment_config.get("post_cw_train_epochs", 0)):
                    cw_batch_steps = 0
                    steps_in_batch = len(concept_group_loader)
                    for concept_groups_batch in concept_group_loader:
                        print(
                            f'Post epoch {post_epoch + 1} and step {cw_batch_steps}/{steps_in_batch}         ',
                            end="\r",
                        )
                        model.layers[experiment_config["cw_layer"]].update_rotation_matrix(
                            concept_groups=list(map(lambda x: encoder(x), concept_groups_batch)),
                        )
                        cw_batch_steps += 1
            
            print("\tEvaluating model")
            
            # finding niches for several values of beta
            niche_sizes = []
            niche_impurities = []
            # And estimate the area under the curve using the trapezoid method
            total_area_under_curve_map = defaultdict(float)
            prev_value_map = {}
            delta_beta = experiment_config.get("delta_beta", 0.05)
            if not experiment_config['feature_map']:
                c_train_pred = concept_scores(
                    model.layers[2],
                    encoder(x_train),
                    aggregator=experiment_config['aggregator'],
                ).numpy()
                c_train_pred = c_train_pred[:, :experiment_config["num_concepts"]]
                c_test_pred = concept_scores(
                    model.layers[2],
                    encoder(x_test),
                    aggregator=experiment_config['aggregator'],
                ).numpy()[:, :experiment_config["num_concepts"]]

            else:
                c_train_pred = encoder(x_train)
                c_train_pred = c_train_pred[:, :, :, :experiment_config["num_concepts"]]
                c_test_pred = encoder(x_test)
                c_test_pred = c_test_pred[:, :, :, :experiment_config["num_concepts"]]
                out_shape = c_train_pred.shape[1] * c_train_pred.shape[2]
                c_train_pred = c_train_pred.numpy().reshape(-1, out_shape, c_train_pred.shape[3])
                c_test_pred = c_test_pred.numpy().reshape(-1, out_shape, c_test_pred.shape[3])

            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=y_test_concepts,
                    c_soft_train=c_train_pred,
                    c_true_train=y_train_concepts,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')

            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
        
    return experiment_variables

In [None]:
reload(CW)
reload(niching)
reload(leakage)

############################################################################
## Experiment config
############################################################################

cw_base_experiment_config = dict(
    batch_size=128,
    max_epochs=300,
    pre_train_epochs=0,
    post_cw_train_epochs=300,
    cw_train_freq=20,
    cw_train_batch_steps=20,
    learning_rate=1e-3,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "cw/base",
    ),
    input_shape=INPUT_SHAPE,
    num_concepts=NUM_CONCEPTS,
    latent_dims=0,
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    trials=5,
    verbosity=0,
    data_concepts=NUM_CONCEPTS,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=False,
    feature_map=False,
    add='',
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################

cw_base_results = cw_experiment_loop(
    cw_base_experiment_config,
    load_from_cache=FROM_CACHE,
)

# CCD Benchmark

## Model Setup

In [None]:
def construct_ccd_encoder(
    input_shape,
    units,
    end_activation="sigmoid",
    latent_dims=0,
    latent_act=None,  # Original paper used "sigmoid" but this is troublesome in deep architectures
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    # And finally map this to the number of concepts we have in our set
    encoder_compute_graph = tf.keras.layers.Flatten()(encoder_compute_graph)
    encoder_compute_graph = tf.keras.layers.Dense(
        latent_dims,
        activation=latent_act,
        name="encoder_bypass_channel",
    )(encoder_compute_graph)

    # Now time to collapse all the concepts again back into a single vector
    encoder_model = tf.keras.Model(
        encoder_inputs,
        encoder_compute_graph,
        name="encoder",
    )
    return encoder_model

## Experiment Loop

In [None]:
import concepts_xai.methods.OCACE.topicModel as CCD

############################################################################
## Experiment loop
############################################################################

def ccd_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    num_concepts = experiment_config["num_concepts"]
    res_dir = experiment_config["niching_results_dir"]
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    
    # Else, let's go ahead and run the whole thing
    count = 0
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            experiment_config["niching_results_dir"],
            "models",
        )
    )#.mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        
        channels_axis = (
            -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
            else 1
        )
        if experiment_config["num_outputs"] == 1:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                y_true,
                y_pred
            )
        else:
            acc_fn = lambda y_true, y_pred: sklearn.metrics.roc_auc_score(
                tf.keras.utils.to_categorical(y_true),
                scipy.special.softmax(y_pred, axis=-1),
                multi_class='ovo',
            )
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for covariance {cov:.2f}")
            
            count += 1

            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            y_train = np.squeeze(y_train)
            y_test = np.squeeze(y_test)
            
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            end_to_end_model, encoder, decoder = construct_end_to_end_model(
                input_shape=experiment_config["input_shape"],
                learning_rate=experiment_config["learning_rate"],
                encoder=construct_ccd_encoder(
                    input_shape=experiment_config["input_shape"],
                    units=experiment_config["encoder_units"],
                    latent_act=experiment_config.get("latent_act", None),
                    latent_dims=experiment_config["latent_dims"],
                ),
                decoder=construct_decoder(
                    units=experiment_config["decoder_units"],
                    num_outputs=experiment_config["num_outputs"],
                ),
            )
            
            print("\tModel pre-training...")
            
            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "min",
                ),
            )
            end_to_end_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tModel pre-training completed")
            print("\tSerializing model")
            print("\tEvaluating model")
            
            test_result = end_to_end_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            # Now extract our concept vectors
            if "top_k" not in experiment_config:
                top_k = ccd_compute_k(y=y_train, batch_size=experiment_config["batch_size"])
            else:
                top_k = experiment_config["top_k"]
            topic_model = CCD.TopicModel(
                concepts_to_labels_model=decoder,
                n_channels=experiment_config["latent_dims"],
                n_concepts=experiment_config["num_concepts"],
                threshold=experiment_config.get("threshold", 0.5),
                loss_fn=end_to_end_model.loss,
                top_k=top_k,
                lambda1=experiment_config.get("lambda1", 0.1),
                lambda2=experiment_config.get("lambda2", 0.1),
                seed=experiment_config.get("seed", None),
                eps=experiment_config.get("eps", 1e-5),
                data_format=experiment_config.get(
                    "data_format",
                    "channels_last"
                ),
                allow_gradient_flow_to_c2l=experiment_config.get(
                    'allow_gradient_flow_to_c2l',
                    False,
                ),
                acc_metric=(
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
                    if experiment_config["num_outputs"] > 1 else
                    tf.keras.metrics.BinaryAccuracy()
                ),
            )
            topic_model.compile(
                optimizer=tf.keras.optimizers.Adam(
                    experiment_config.get("learning_rate", 1e-3),
                )
            )
            
            # Train it for a few epochs
            print("\tTopic model training...")
            topic_model.fit(
                x=encoder(x_train),
                y=y_train,
                epochs=experiment_config["topic_model_train_epochs"],
                batch_size=experiment_config["batch_size"],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tTopic model training completed")
            
            enc = OneHotEncoder(sparse=False)
            y_train = enc.fit_transform(y_train.reshape(-1, 1))
            y_test = enc.fit_transform(y_test.reshape(-1, 1))
            
            c_train_pred = topic_model.concept_scores(encoder(x_train)).numpy()
            c_test_pred = topic_model.concept_scores(encoder(x_test)).numpy()

            print("\t\tComputing niching scores...")
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\tNIS: {experiment_variables["niss"][-1]:.2f}')
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables

def ccd_compute_k(y, batch_size):
    _, counts = np.unique(y, return_counts=True)
    avg_class_ratio = np.mean(counts) / y.shape[0]
    return int((avg_class_ratio * batch_size) / 2)

## Correlation Experiment

In [None]:
reload(niching)
reload(CBM)
reload(CCD)
reload(niching)

############################################################################
## Experiment config
############################################################################

ccd_base_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=5,
    learning_rate=1e-3,

    num_concepts=NUM_CONCEPTS,
    input_shape=INPUT_SHAPE,
    latent_dims=NUM_CONCEPTS,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,

    threshold=0.0,
    lambda1=0.1,
    lambda2=0.1,
    eps=1e-5,

    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "ccd/base",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=NUM_CONCEPTS,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################

ccd_base_results = ccd_experiment_loop(
    experiment_config=ccd_base_experiment_config,
    load_from_cache=FROM_CACHE,
)

In [None]:
reload(CBM)
reload(CCD)
reload(niching)

############################################################################
## Experiment config
############################################################################

ccd_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=5,
    learning_rate=1e-3,

    num_concepts=2*NUM_CONCEPTS,
    input_shape=INPUT_SHAPE,
    latent_dims=NUM_CONCEPTS,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,

    threshold=0.0,
    lambda1=0.1,
    lambda2=0.1,
    eps=1e-5,

    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "ccd/extended",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=NUM_CONCEPTS,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
    delta_beta=0.05,
)


############################################################################
## Experiment run
############################################################################

ccd_extended_results = ccd_experiment_loop(
    experiment_config=ccd_extended_experiment_config,
    load_from_cache=FROM_CACHE,
)

# SENN Benchmarking

In [None]:
import concepts_xai.methods.SENN.base_senn as SENN
import concepts_xai.methods.SENN.aggregators as aggregators
reload(SENN)
reload(aggregators)


def construct_senn_coefficient_model(units, num_concepts, num_outputs):
    decoder_layers = [tf.keras.layers.Flatten()] + [
        tf.keras.layers.Dense(
            units,
            activation=tf.nn.relu,
            name=f"coefficient_model_dense_{i+1}",
        ) for i, units in enumerate(units)
    ]
    return tf.keras.Sequential(decoder_layers + [
        tf.keras.layers.Dense(
            num_concepts * num_outputs,
            activation=None,
            name="coefficient_model_output",
        ),
        tf.keras.layers.Reshape([num_outputs, num_concepts])
    ])

def construct_senn_encoder(
    input_shape,
    units,
    end_activation="sigmoid",
    latent_dims=0,
    latent_act=None,  # Original paper used "sigmoid" but this is troublesome in deep architectures
):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    encoder_compute_graph = encoder_inputs
    for i, units in enumerate(units):
        encoder_compute_graph = tf.keras.layers.Dense(
            units,
            activation='relu',
            name=f"encoder_dense_{i}",
        )(encoder_compute_graph)

    mean = tf.keras.layers.Dense(latent_dims, activation=None, name="means")(encoder_compute_graph)
    log_var = tf.keras.layers.Dense(latent_dims, activation=None, name="log_var")(encoder_compute_graph)
    senn_encoder = tf.keras.Model(
        encoder_inputs,
        mean,
        name="senn_encoder",
    )
    vae_encoder = tf.keras.Model(
        encoder_inputs,
        [mean, log_var],
        name="vae_encoder",
    )
    return senn_encoder, vae_encoder


def construct_vae_decoder(
    units,
    output_shape,
    latent_dims,
):
    """CNN decoder architecture used in the 'Challenging Common Assumptions in the Unsupervised Learning
       of Disentangled Representations' paper (https://arxiv.org/abs/1811.12359)

       Note: model is uncompiled
    """

    latent_inputs = tf.keras.Input(shape=(latent_dims,))
    model_out = latent_inputs
    for unit in units:
        model_out = tf.keras.layers.Dense(
            unit,
            activation='relu',
        )(model_out)
    model_out = tf.keras.layers.Dense(
        output_shape,
        activation=None,
    )(model_out)

    return tf.keras.Model(
        inputs=latent_inputs,
        outputs=[model_out],
    )


def construct_senn_model(
    concept_encoder,
    concept_decoder,
    coefficient_model,
    num_outputs,
    regularization_strength=0.1,
    learning_rate=1e-3,
    sparsity_strength=2e-5,
):
    def reconstruction_loss_fn(y_true, y_pred):
        return tf.reduce_sum(
            tf.square(y_true - concept_decoder(y_pred)),
            [-1]
        )
    senn_model = SENN.SelfExplainingNN(
        encoder_model=concept_encoder,
        coefficient_model=coefficient_model,
        aggregator_fn=(
            aggregators.multiclass_additive_aggregator if (num_outputs >= 2)
            else aggregators.scalar_additive_aggregator
        ),
        task_loss_fn=(
            tf.keras.losses.BinaryCrossentropy(from_logits=True) if (num_outputs < 2)
            else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        ),
        reconstruction_loss_fn=reconstruction_loss_fn,
        regularization_strength=regularization_strength,
        sparsity_strength=sparsity_strength,
        name="SENN",
        metrics=[
            tf.keras.metrics.BinaryAccuracy() if (num_outputs < 2)
            else tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
        ],
    )
    senn_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
    )
    return senn_model

In [None]:
import concepts_xai.evaluation.metrics.purity as purity
import scipy

############################################################################
## Experiment loop
############################################################################

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5
    
def get_argmax_concept_explanations(preds, class_theta_scores):
    if len(preds.shape) == 1:
        # Then we will always pick the same set of concept explanations as there is
        # one or two classes only
        inds = np.zeros(preds.shape, dtype=np.int32)
    else:
        inds = np.argmax(preds, axis=-1)
    result = np.take_along_axis(
        class_theta_scores,
        np.expand_dims(np.expand_dims(inds, axis=-1), axis=-1),
        axis=1,
    )
    return np.squeeze(result, axis=1)

def senn_experiment_loop(
    experiment_config,
    load_from_cache=False,
):
    experiment_variables = dict(
        config = experiment_config,
        niss=[],
    )
    num_concepts = experiment_config["num_concepts"]
    res_dir = experiment_config['niching_results_dir']
    if load_from_cache:
        if os.path.exists(os.path.join(res_dir, 'results_niching.joblib')):
            experiment_variables = load(os.path.join(res_dir, 'results_niching.joblib'))
            return experiment_variables
    
    start_ind = 0
    # Else, let's go ahead and run the whole thing
    verbosity = experiment_config.get("verbosity", 0)
    Path(
        os.path.join(
            res_dir,
            "models",
        )
    ).mkdir(parents=True, exist_ok=True)
    for cov in experiment_config["covariances"][start_ind:]:
        print("Training with covariance:", cov)
        
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} for covariance {cov:.2f}")
            
            (x_train, y_train, c_train) = produce_data(
                experiment_config["train_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            (x_test, y_test, c_test) = produce_data(
                experiment_config["test_samples"],
                cov=cov,
                num_concepts=experiment_config["data_concepts"],
            )
            y_train = np.squeeze(y_train)
            y_test = np.squeeze(y_test)
        
            channels_axis = (
                -1 if experiment_config.get("data_format", "channels_last") == "channels_last"
                else 1
            )
        
            # Proceed to do and end-to-end model in case we want to
            # do some task-specific pretraining
            concept_encoder, vae_encoder = construct_senn_encoder(
                input_shape=experiment_config["input_shape"],
                units=experiment_config["encoder_units"],
                latent_act=experiment_config.get("latent_act", None),
                latent_dims=experiment_config["latent_dims"],
            )
            concept_decoder = construct_vae_decoder(
                units=experiment_config["decoder_units"],
                output_shape=experiment_config["input_shape"][-1],
                latent_dims=experiment_config["latent_dims"],
            )
            coefficient_model = construct_senn_coefficient_model(
                units=experiment_config["coefficient_model_units"],
                num_concepts=experiment_config["latent_dims"],
                num_outputs=experiment_config["num_outputs"],
            )
            
            if experiment_config.get("pretrain_autoencoder_epochs"):
                autoencoder = beta_vae.BetaVAE(
                    encoder=vae_encoder,
                    decoder=concept_decoder,
                    loss_fn=vae_losses.bernoulli_fn_wrapper(),
                    beta=experiment_config.get("beta", 1),
                )

                autoencoder.compile(
                    optimizer=tf.keras.optimizers.Adam(
                        experiment_config.get("learning_rate", 1e-3)
                    ),
                )

                print("\tAutoencoder pre-training...")
                autoencoder.fit(
                    x=x_train,
                    epochs=experiment_config["pretrain_autoencoder_epochs"],
                    batch_size=experiment_config["batch_size"],
                    validation_split=experiment_config["holdout_fraction"],
                    verbose=verbosity,
                )
                print("\t\tAutoencoder training completed")

            # Now time to actually construct and train the CBM
            senn_model = construct_senn_model(
                concept_encoder=concept_encoder,
                concept_decoder=concept_decoder,
                coefficient_model=coefficient_model,
                num_outputs=experiment_config["num_outputs"],
                regularization_strength=experiment_config.get("regularization_strength", 0.1),
                learning_rate=experiment_config.get("learning_rate", 1e-3),
                sparsity_strength=experiment_config.get("sparsity_strength", 2e-5),
            )

            early_stopping_monitor = tf.keras.callbacks.EarlyStopping(
                monitor=experiment_config.get(
                    "early_stop_metric",
                    "val_loss",
                ),
                min_delta=experiment_config["min_delta"],
                patience=experiment_config["patience"],
                restore_best_weights=True,
                verbose=2,
                mode=experiment_config.get(
                    "early_stop_mode",
                    "max",
                ),
            )

            print("\tSENN training...")
            senn_model.fit(
                x=x_train,
                y=y_train,
                epochs=experiment_config["max_epochs"],
                batch_size=experiment_config["batch_size"],
                callbacks=[
                    early_stopping_monitor,
                ],
                validation_split=experiment_config["holdout_fraction"],
                verbose=verbosity,
            )
            print("\t\tSENN training completed")

            print("\tEvaluating model")
            test_result = senn_model.evaluate(
                x_test,
                y_test,
                verbose=0,
                return_dict=True,
            )
            task_acc = (
                test_result['sparse_top_k_categorical_accuracy']
                if experiment_config['num_outputs'] > 1 else
                test_result['binary_accuracy']
            )

            if experiment_config['num_outputs'] > 1:
                # Then lets apply a softmax activation over all the probability
                # classes
                preds = scipy.special.softmax(
                    senn_model.predict(x_test)[0],
                    axis=-1
                )

                # And select just the labels that are in fact being used
                one_hot_labels = tf.keras.utils.to_categorical(y_test)
                auc = (sklearn.metrics.roc_auc_score(
                    one_hot_labels,
                    preds,
                    multi_class='ovo',
                ))
            else:
                auc = (sklearn.metrics.roc_auc_score(
                    y_test,
                    senn_model.predict(x_test)[0],
                ))

            print(
                f"\t\tTest auc = {auc:.4f}, "
                f"task accuracy = {task_acc:.4f}"
            )
            
            
            
            x_train_preds, (_, x_train_theta_class_scores) = senn_model(x_train)
            c_train_pred = get_argmax_concept_explanations(
                x_train_preds.numpy(),
                x_train_theta_class_scores.numpy(),
            )
            
            x_test_preds, (_, x_test_theta_class_scores) = senn_model(x_test)
            c_test_pred = get_argmax_concept_explanations(
                x_test_preds.numpy(),
                x_test_theta_class_scores.numpy(),
            )
            
            print("\t\tComputing niching scores...")
            
            experiment_variables['niss'].append(
                niching.niche_impurity_score(
                    c_soft=c_test_pred,
                    c_true=c_test,
                    c_soft_train=c_train_pred,
                    c_true_train=c_train,
                )
            )
            print(f'\t\NIS: {experiment_variables["niss"][-1]:.2f}')
            
            os.makedirs(res_dir, exist_ok=True)
            dump(experiment_variables, os.path.join(res_dir, 'results_niching.joblib'))
            
    return experiment_variables

In [None]:
reload(niching)
reload(CBM)
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_base_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=5,
    learning_rate=1e-3,
    
    num_concepts=NUM_CONCEPTS,
    input_shape=INPUT_SHAPE,
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,
    
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    
    predictor_max_epochs=300,
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    num_outputs=2,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "senn/purity",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=NUM_CONCEPTS,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

############################################################################
## Experiment run
############################################################################

senn_base_results = senn_experiment_loop(
    experiment_config=senn_base_experiment_config,
    load_from_cache=True,
)

In [None]:
reload(niching)
reload(CBM)
reload(SENN)

############################################################################
## Experiment config
############################################################################

senn_extended_experiment_config = dict(
    batch_size=32,
    max_epochs=300,
    topic_model_train_epochs=50,
    trials=5,
    learning_rate=1e-3,
    
    num_concepts=NUM_CONCEPTS*2,
    input_shape=INPUT_SHAPE,
    latent_dims=10,
    encoder_units=[128, 64],
    decoder_units=[128, 64],
    latent_act=None,
    
    coefficient_model_units=[64, 64],
    latent_decoder_units=[64, 64],
    
    predictor_max_epochs=300,
    
    regularization_strength=0.1,
    sparsity_strength=2e-5,
    
    num_outputs=1,
    patience=float("inf"),
    min_delta=1e-5,
    niching_results_dir=os.path.join(
        NICHING_RESULTS_DIR,
        "senn/purity_extended",
    ),
    covariances=np.arange(0, 1, 0.1),
    train_samples=2000,
    test_samples=1000,
    verbosity=0,
    data_concepts=NUM_CONCEPTS,
    cw_layer=2,
    aggregator='max_pool_mean',
    activation_mode='max_pool_mean',
    cw_train_freq=20,
    concept_auc_freq=0,
    cw_train_iterations=1,
    holdout_fraction=0.1,
    exclusive_concepts=True,
)

############################################################################
## Experiment run
############################################################################

senn_extended_results = senn_experiment_loop(
    experiment_config=senn_extended_experiment_config,
    load_from_cache=True,
)