In [None]:
%load_ext autoreload
%matplotlib inline

# Metric Comparison Experiments

## Setup

In [None]:
import matplotlib
import concepts_xai
import numpy as np
import os
import random
import tensorflow as tf
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn

In [None]:
################################################################################
## Set seeds up for reproducibility
################################################################################

os.environ['PYTHONHASHSEED'] = str(87)
tf.random.set_seed(87)
np.random.seed(87)
random.seed(87)


In [None]:
################################################################################
## Global Variables Defining Experiment Flow
################################################################################

LATEX_SYMBOL = ""  # Change to "$" if working out of server
RESULTS_DIR = "results/metric_example_results"
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
rc('text', usetex=(LATEX_SYMBOL == "$"))
plt.style.use('seaborn-whitegrid')

## Utility Functions

In [None]:
import joblib

def serialize_results(results_dict, results_dir):
    joblib.dump(results_dict, os.path.join(results_dir, "raw_results.joblib"))
    for result_name, result_arr in results_dict.items():
        if not isinstance(result_arr, (list, np.ndarray)):
            for key, result_arr in result_arr.items():
                np.savez(
                    os.path.join(results_dir, f"{result_name}_{key}_means.npz"),
                    *list(map(
                        lambda x: x[0] if isinstance(x[0], np.ndarray) else np.array(x[0]),
                        result_arr
                    )),
                )
                np.savez(
                    os.path.join(results_dir, f"{result_name}_{key}_stds.npz"),
                    *list(map(
                        lambda x: x[1] if isinstance(x[1], np.ndarray) else np.array(x[1]),
                        result_arr
                    )),
                )
        else:
            np.savez(
                os.path.join(results_dir, f"{result_name}_means.npz"),
                *list(map(
                    lambda x: x[0] if isinstance(x[0], np.ndarray) else np.array(x[0]),
                    result_arr
                )),
            )
            np.savez(
                os.path.join(results_dir, f"{result_name}_stds.npz"),
                *list(map(
                    lambda x: x[1] if isinstance(x[1], np.ndarray) else np.array(x[1]),
                    result_arr
                )),
            )

def serialize_experiment_config(config, results_dir):
    with open(
        os.path.join(results_dir, "config.yaml"),
        'w',
    ) as f:
        f.write(yaml.dump(config, sort_keys=True))

        
def deserialize_experiment_config(results_dir):
    with open(os.path.join(results_dir, "config.yaml"), 'r') as file:
        return yaml.safe_load(file)

## Dataset Construction

In [None]:
############################################################################
## Generate Data
############################################################################

def produce_data_larger(samples, cov=0.0, num_concepts=3):
    x = np.zeros((samples, 7), dtype=np.float32)
    y = np.zeros((samples,), dtype=np.float32)
    
    # Sample the x, y, and z variables
    mean = np.zeros((num_concepts,))
    cov = np.eye(num_concepts)
    cov += (np.ones((num_concepts, num_concepts)) - np.eye(num_concepts))  * cov 
    vars = np.random.multivariate_normal(
        mean=mean,
        cov=cov,
        size=(samples,),
    )
    x_vars = vars[:, :1]
    y_vars = vars[:, 1:2]
    z_vars = vars[:, 2:3]
    
    # The features are just non-linear functions applied to each
    # variable
    features = [
        np.sin(x_vars) + x_vars,
        np.cos(x_vars) + x_vars,
        np.sin(y_vars) + y_vars,
        np.cos(y_vars) + y_vars,
        np.sin(z_vars) + z_vars,
        np.cos(z_vars) + z_vars,
        x_vars**2 + y_vars**2 + z_vars**2,
    ]
    features = np.stack(features, axis=1)

    # The concepts just check if the variables are positive
    concepts = (vars > 0).astype(np.int32)
    
    # The labels are generated by checking if at least two of the
    # latent concepts are greater than zero
    labels = np.sum(concepts[:, :3], axis=-1)
    labels = (labels > 1).astype(np.int32)
    
    # And that's it buds
    return features, labels, concepts

def produce_data(samples, cov=0.0, num_concepts=3):
    x = np.zeros((samples, 7), dtype=np.float32)
    y = np.zeros((samples,), dtype=np.float32)
    
    # Sample the x, y, and z variables
    vars = np.random.multivariate_normal(
        mean=[0, 0, 0],
        cov=[
            [1, cov, cov],
            [cov, 1, cov],
            [cov, cov, 1],
        ],
        size=(samples,),
    )
    x_vars = vars[:, :1]
    y_vars = vars[:, 1:2]
    z_vars = vars[:, 2:]
    
    # The features are just non-linear functions applied to each
    # variable
    features = [
        np.sin(x_vars) + x_vars,
        np.cos(x_vars) + x_vars,
        np.sin(y_vars) + y_vars,
        np.cos(y_vars) + y_vars,
        np.sin(z_vars) + z_vars,
        np.cos(z_vars) + z_vars,
        x_vars**2 + y_vars**2 + z_vars**2,
    ]
    features = np.stack(features, axis=1)

    # The concepts just check if the variables are positive
    x_pos = (x_vars > 0).astype(np.int32)
    y_pos = (y_vars > 0).astype(np.int32)
    z_pos = (z_vars > 0).astype(np.int32)
    concepts = np.squeeze(
        np.stack([x_pos, y_pos, z_pos][:num_concepts], axis=1)
    )
    
    # The labels are generated by checking if at least two of the
    # latent concepts are greater than zero
    labels = x_pos + y_pos + z_pos
    labels = (labels > 1).astype(np.int32)
    
    # And that's it buds
    return features, labels, concepts

## Metric methods

In [None]:
import utils
from sklearn import svm

########
## NIS
########

def construct_trivial_auc_mat(num_concepts):
    result = np.ones((num_concepts, num_concepts), dtype=np.float32) * 0.5
    return result + np.eye(num_concepts, dtype=np.float32) * 0.5

##############
## SAP. Code taken from Locatello et al. https://github.com/google-research/disentanglement_lib
##############

def _compute_sap(mus, ys, mus_test, ys_test, continuous_factors):
  """Computes score based on both training and testing codes and factors."""
  score_matrix = compute_score_matrix(mus, ys, mus_test,
                                      ys_test, continuous_factors)
  # Score matrix should have shape [num_latents, num_factors].
  assert score_matrix.shape[0] == mus.shape[0]
  assert score_matrix.shape[1] == ys.shape[0]
  scores_dict = {}
  scores_dict["SAP_score"] = compute_avg_diff_top_two(score_matrix)
  return scores_dict


def compute_sap_on_fixed_data(observations, labels, representation_function,
                              train_percentage=0.2,
                              continuous_factors=False,
                              batch_size=100):
  """Computes the SAP score on the fixed set of observations and labels.
  Args:
    observations: Observations on which to compute the score. Observations have
      shape (num_observations, 64, 64, num_channels).
    labels: Observed factors of variations.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    train_percentage: Percentage of observations used for training.
    continuous_factors: Whether factors should be considered continuous or
      discrete.
    batch_size: Batch size used to compute the representation.
  Returns:
    SAP computed on the provided observations and labels.
  """
  labels = np.transpose(labels)
  mus = utils.obtain_representation(observations, representation_function,
                                    batch_size)
  assert labels.shape[1] == observations.shape[0], "Wrong labels shape."
  assert mus.shape[1] == observations.shape[0], "Wrong representation shape."
  mus_train, mus_test = utils.split_train_test(
      mus,
      train_percentage)
  ys_train, ys_test = utils.split_train_test(
      labels,
      train_percentage)
  return _compute_sap(mus_train, ys_train, mus_test, ys_test,
                      continuous_factors)["SAP_score"]


def compute_score_matrix(mus, ys, mus_test, ys_test, continuous_factors):
  """Compute score matrix as described in Section 3."""
  num_latents = mus.shape[0]
  num_factors = ys.shape[0]
  score_matrix = np.zeros([num_latents, num_factors])
  for i in range(num_latents):
    for j in range(num_factors):
      mu_i = mus[i, :]
      y_j = ys[j, :]
      if continuous_factors:
        # Attribute is considered continuous.
        cov_mu_i_y_j = np.cov(mu_i, y_j, ddof=1)
        cov_mu_y = cov_mu_i_y_j[0, 1]**2
        var_mu = cov_mu_i_y_j[0, 0]
        var_y = cov_mu_i_y_j[1, 1]
        if var_mu > 1e-12:
          score_matrix[i, j] = cov_mu_y * 1. / (var_mu * var_y)
        else:
          score_matrix[i, j] = 0.
      else:
        # Attribute is considered discrete.
        mu_i_test = mus_test[i, :]
        y_j_test = ys_test[j, :]
        classifier = svm.LinearSVC(C=0.01, class_weight="balanced")
        classifier.fit(mu_i[:, np.newaxis], y_j)
        pred = classifier.predict(mu_i_test[:, np.newaxis])
        score_matrix[i, j] = np.mean(pred == y_j_test)
  return score_matrix


def compute_avg_diff_top_two(matrix):
  sorted_matrix = np.sort(matrix, axis=0)
  return np.mean(sorted_matrix[-1, :] - sorted_matrix[-2, :])

In [None]:
##############
## Code taken from Ross et al. https://github.com/dtak/hierarchical-disentanglement
##############



##############
## R4
##############
import numpy as np
import scipy
from sklearn.metrics import mutual_info_score
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.svm import LinearSVC
from collections import Counter

###############################################################################
#
# R4 and R4c scores (our contribution)
#
# These metrics quantify the extent to which every dimension of a ground-truth
# representation V can be mapped individually (via an invertible function) to
# dimensions of a learned representation Z. They accomplish this by considering
# the R^2 coefficient of determination in both directions and taking geometric
# means.
#
# The conditional version (R4c) also takes into account the hierarchy, scoping
# comparisons to cases where both learned and ground-truth factors are active,
# and not penalizing minor differences in the distribution of continuous dims.
#
###############################################################################

def activity_mask(v):
    # Slight kludge to detect activity; could pass a separate mask variable
    # instead
    return (np.abs(v) > 1e-10).astype(int)

def is_categorical(v, max_uniq=10):
    # Also kind of a kludge, but assume a variable is categorical if it's
    # integer-valued and there are few possible options. Could use the
    # hierarchy object instead.
    return len(np.unique(v)) <= max_uniq and np.allclose(v.astype(int), v)

def sample_R2_oneway(inputs, targets, reg=GradientBoostingRegressor, kls=GradientBoostingClassifier):
    if len(inputs) < 2:
        # Handle edge case of nearly empty input
        return 0

    x_train, x_test, y_train, y_test = train_test_split(inputs.reshape(-1,1), targets)
    n_uniq = min(len(np.unique(y_train)), len(np.unique(y_test)))

    if n_uniq == 1:
        # Handle edge case of only one target
        return 1 
    elif is_categorical(targets):
        # Use a classifier for categorical data
        y_train = y_train.astype(int)
        y_test = y_test.astype(int)
        model = kls()
    else:
        # Use a regressor otherwise
        model = reg()

    # Return the R^2 (or accuracy) score
    return model.fit(x_train, y_train).score(x_test, y_test)

def R2_oneway(inputs, targets, iters=5, **kw):
    # Repeatedly compute R^2 over random splits
    return np.mean([sample_R2_oneway(inputs, targets, **kw) for _ in range(iters)])

def R2_bothways(x, y):
    # Take the geometric mean of R^2 in both directions
    r1 = max(0, R2_oneway(x,y))
    r2 = max(0, R2_oneway(y,x))
    return np.sqrt(r1*r2)

def R4_scores(V, Z):
    # For each dimension, find the best R2_bothways
    scores = []

    for i in range(V.shape[1]):
        best = 0
        for j in range(Z.shape[1]):
            best = max(best, R2_bothways(V[:,i], Z[:,j]))
        scores.append(best)

    return np.mean(scores)


###############################################################################
#
# Mutual Information Gap (MIG) Baseline
#
# Technically not defined for continuous targets, but we discretize with 20-bin
# histograms.
#
###############################################################################

def estimate_mutual_information(X, Y, bins=20):
  hist = np.histogram2d(X, Y, bins)[0] # approximate joint
  info = mutual_info_score(None, None, contingency=hist)
  return info / np.log(2) # bits

def estimate_entropy(X, **kw):
  return estimate_mutual_information(X, X, **kw)

def MIG(Z_true, Z_learned, **kw):
  K = Z_true.shape[1]
  gap = 0
  for k in range(K):
    H = estimate_entropy(Z_true[:,k], **kw)
    MIs = sorted([
      estimate_mutual_information(Z_learned[:,j], Z_true[:,k], **kw)
      for j in range(Z_learned.shape[1])
    ], reverse=True)
    gap += (MIs[0] - MIs[1]) / (H * K)
  return gap


###############################################################################
#
# DCI (Disentanglement, Completeness, Informativeness) Baseline
#
# Code adapted from https://github.com/google-research/disentanglement_lib,
# original paper at https://openreview.net/forum?id=By-7dz-AZ.
#
###############################################################################

def DCI(gen_factors, latents):
  """Computes score based on both training and testing codes and factors."""
  mus_train, mus_test, ys_train, ys_test = train_test_split(gen_factors, latents, test_size=0.1)
  scores = {}
  importance_matrix, train_err, test_err = compute_importance_gbt(mus_train, ys_train, mus_test, ys_test)
  assert importance_matrix.shape[0] == mus_train.shape[1]
  assert importance_matrix.shape[1] == ys_train.shape[1]
  scores["informativeness_train"] = train_err
  scores["informativeness_test"] = test_err
  scores["disentanglement"] = disentanglement(importance_matrix)
  scores["completeness"] = completeness(importance_matrix)
  return scores["disentanglement"], scores["completeness"], scores["informativeness_test"]

def compute_importance_gbt(x_train, y_train, x_test, y_test):
  """Compute importance based on gradient boosted trees."""
  num_factors = y_train.shape[1]
  num_codes = x_train.shape[1]
  importance_matrix = np.zeros(shape=[num_codes, num_factors],
                               dtype=np.float64)
  train_loss = []
  test_loss = []
  for i in range(num_factors):
    model = GradientBoostingRegressor()
    model.fit(x_train, y_train[:,i])
    importance_matrix[:, i] = np.abs(model.feature_importances_)
    train_loss.append(model.score(x_train, y_train[:,i]))
    test_loss.append(model.score(x_test, y_test[:,i]))
  return importance_matrix, np.mean(train_loss), np.mean(test_loss)


def disentanglement_per_code(importance_matrix):
  """Compute disentanglement score of each code."""
  # importance_matrix is of shape [num_codes, num_factors].
  return 1. - scipy.stats.entropy(importance_matrix.T + 1e-11,
                                  base=importance_matrix.shape[1])


def disentanglement(importance_matrix):
  """Compute the disentanglement score of the representation."""
  per_code = disentanglement_per_code(importance_matrix)
  if importance_matrix.sum() == 0.:
    importance_matrix = np.ones_like(importance_matrix)
  code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum()

  return np.sum(per_code*code_importance)

def completeness_per_factor(importance_matrix):
  """Compute completeness of each factor."""
  # importance_matrix is of shape [num_codes, num_factors].
  return 1. - scipy.stats.entropy(importance_matrix + 1e-11,
                                  base=importance_matrix.shape[0])


def completeness(importance_matrix):
  """"Compute completeness of the representation."""
  per_factor = completeness_per_factor(importance_matrix)
  if importance_matrix.sum() == 0.:
    importance_matrix = np.ones_like(importance_matrix)
  factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum()
  return np.sum(per_factor*factor_importance)

# Metric Comparison Experiment

In [None]:
import concepts_xai.evaluation.metrics.purity as purity
import concepts_xai.evaluation.metrics.completeness as xai_completeness
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

import concepts_xai.evaluation.metrics.niching as niching
from collections import defaultdict

############################################################################
## Experiment loop
############################################################################

def generate_soft_activations(
    c_test,
    mode="random", # "random", "localized", "distributed"
    off_start_interval=0.0,
    off_end_interval=0.5,
    on_start_interval=0.5,
    on_end_interval=1.0,
):
    num_partitions = 2**(c_test.shape[-1] - 1)
    soft_activations = c_test.astype(np.float32)
    for sample_idx in range(c_test.shape[0]):
        for concept_idx in range(c_test.shape[-1]):
            if c_test[sample_idx, concept_idx] == 1:
                # Then the concept is ON
                start_interval = on_start_interval
                end_interval = on_end_interval
            else:
                start_interval = off_start_interval
                end_interval = off_end_interval
            if "localized" in mode:
                partition_size = (end_interval - start_interval)/num_partitions
                # Generate an index corresponding to the binary encoding of the remaining
                # concepts
                segment_index = ''
                for j in range(c_test.shape[-1]):
                    if j != concept_idx:
                        segment_index += str(int(c_test[sample_idx, j]))
                segment_index = int(segment_index, 2)
                start_interval += segment_index * partition_size
                end_interval = start_interval + partition_size
            elif "distributed" in mode:
                # Then let's distribute the knowledge of a bit being active across
                # the other bits
                partition_size = (end_interval - start_interval)/num_partitions
                segment_index = ''
                for j in np.random.permutation(c_test.shape[-1]):
                    if j == concept_idx:
                        continue
                    if c_test[sample_idx, j] == 1:
                        # Then we for sure mark its corresponding bit on
                        segment_index += str(int(c_test[sample_idx, j]))
                    else:
                        # Else we randomly select its corresponding bit value
                        segment_index += str(np.random.randint(2))
                segment_index = int(segment_index, 2)
                start_interval += segment_index * partition_size
                end_interval = start_interval + partition_size
            elif "unaligned" == mode:
                # Then let's distribute the knowledge of a bit being active across
                # the other bits
                end_interval = 1.0
                start_interval = 0.0
                partition_size = (end_interval - start_interval)/num_partitions
                segment_index = ''
                for j in range(c_test.shape[-1]):
                    if j != concept_idx:
                        segment_index += str(int(c_test[sample_idx, j]))
                segment_index = int(segment_index, 2)
                start_interval += segment_index * partition_size
                end_interval = start_interval + partition_size
            if "fixed" in mode:
                soft_activations[sample_idx, concept_idx] = (start_interval + end_interval)/2
            else:
                soft_activations[sample_idx, concept_idx] = np.random.uniform(
                    start_interval,
                    end_interval
                )
    return soft_activations
    
def metric_comp_experiment_loop(experiment_config):
    experiment_variables = dict(
        sap_scores=[],
        mig_scores=[],
        dci_disentanglement_scores=[],
        dci_completeness_scores=[],
        dci_informativeness_scores=[],
        r4_scores=[],
        
        mean_concept_acc_scores=[],
        
        ois_scores=[],
        nis_scores=[],
        purity_matrices=[],
        oracle_matrices=[],
    )
    experiment_config["data_concepts"] = experiment_config.get(
        "data_concepts",
        experiment_config["num_concepts"],
    )
    
    # Let's save our config here either way
    serialize_experiment_config(
        experiment_config,
        experiment_config["results_dir"],
    )
    
    verbosity = experiment_config.get("verbosity", 0)
    cov = experiment_config["covariance"]
    for mode in experiment_config["soft_mode"]:
        saps = []
        migs = []
        dci_completenesss = []
        dci_disentanglements = []
        dci_informativenesss = []
        r4s = []
        
        mean_concept_accs = []
        concept_completenesss = []
        
        purity_mats = []
        oracle_mats = []
        oiss = []
        niss = []
        oracle_mat = None
        for trial in range(experiment_config["trials"]):
            print(f"\tTrial {trial + 1}/{experiment_config['trials']} with covariance {cov} and mode {mode}")
            # First construct the dataset
            if experiment_config["data_concepts"] > 3:
                (x_test, y_test, c_test) = produce_data_larger(
                    experiment_config["test_samples"],
                    cov=cov,
                    num_concepts=experiment_config["data_concepts"],
                )
            else:
                (x_test, y_test, c_test) = produce_data(
                    experiment_config["test_samples"],
                    cov=cov,
                    num_concepts=experiment_config["data_concepts"],
                )

            print("\tComputing soft activations...")
            soft_acts = generate_soft_activations(
                c_test,
                mode=mode,
                on_start_interval=experiment_config.get('on_start_interval', 0.5),
                on_end_interval=experiment_config.get('on_end_interval', 1.0),
                off_start_interval=experiment_config.get('off_start_interval', 0.0),
                off_end_interval=experiment_config.get('off_end_interval', 0.5),
            )


            print(f"\t\tComputing MIG...")
            mig = MIG(Z_true=c_test, Z_learned=soft_acts, bins=experiment_config.get('num_bins', 10))
            migs.append(mig)
            print(f"\t\t\tDone MIG = {mig:.4f}")


            print(f"\t\tComputing SAP...")
            sap = compute_sap_on_fixed_data(
                observations=soft_acts,
                labels=c_test,
                representation_function=lambda x: x,
                batch_size=experiment_config.get('batch_size', 64),
            )
            saps.append(sap)
            print(f"\t\t\tDone SAP = {sap:.4f}")


            print(f"\t\tComputing DCI...")
            dci_disentanglement, dci_completeness, dci_informativeness = DCI(
                gen_factors=c_test,
                latents=soft_acts,
            )
            dci_disentanglements.append(dci_disentanglement)
            print(f"\t\t\tDone DCI disentanglement = {dci_disentanglement:.4f}")
            dci_completenesss.append(dci_completeness)
            print(f"\t\t\tDone DCI completeness = {dci_completeness:.4f}")
            dci_informativenesss.append(dci_informativeness)
            print(f"\t\t\tDone DCI informativeness = {dci_informativeness:.4f}")
            
            print(f"\t\tComputing R4...")
            r4 = R4_scores(V=c_test, Z=soft_acts)
            r4s.append(r4)
            print(f"\t\t\tDone R4 = {r4:.4f}")
            
            
            
            print(f"\t\tComputing mean concept accuracy...")
            mean_concept_acc = accuracy_score(
                y_true=c_test,
                y_pred=(soft_acts > 0.5).astype(np.int32),
            )
            mean_concept_accs.append(mean_concept_acc)
            print(f"\t\t\tDone mean concept acc = {mean_concept_acc:.4f}")
            
            
            print(f"\t\tComputing NIS...")
            nis = niching.niche_impurity_score(
                c_soft=soft_acts,
                c_true=c_test,
                delta_beta=0.05,
            )
            niss.append(nis)
            print(f"\t\t\tDone NIS = {nis:.4f}")
            
            print(f"\t\tComputing OIS...")
            ois, purity_mat, oracle_mat = oracle.oracle_impurity_score(
                c_soft=soft_acts,
                c_true=c_test,
                output_matrices=True,
                oracle_matrix=oracle_mat,
            )
            purity_mats.append(purity_mat)
            oracle_mats.append(oracle_mat)
            oiss.append(ois)
            print(f"\t\t\tDone OIS = {ois:.4f}")


            print("\t\tDone with trial", trial + 1)

        purity_mats = np.stack(purity_mats, axis=0)
        purity_mat_mean = np.mean(purity_mats, axis=0)
        purity_mat_std = np.std(purity_mats, axis=0)
        print("\tPurity matrix:")
        for i in range(purity_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(purity_mat_mean.shape[1]):
                line += f'{purity_mat_mean[i, j]:.4f} ± {purity_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["purity_matrices"].append((purity_mat_mean, purity_mat_std, purity_mats))


        oracle_mats = np.stack(oracle_mats, axis=0)
        oracle_mat_mean = np.mean(oracle_mats, axis=0)
        oracle_mat_std = np.std(oracle_mats, axis=0)
        print("\tOracle matrix:")
        for i in range(oracle_mat_mean.shape[0]):
            line = "\t\t"
            for j in range(oracle_mat_mean.shape[1]):
                line += f'{oracle_mat_mean[i, j]:.4f} ± {oracle_mat_std[i, j]:.4f}    '
            print(line)

        experiment_variables["oracle_matrices"].append((oracle_mat_mean, oracle_mat_std, oracle_mats))

        sap_mean, sap_std = np.mean(saps), np.std(saps)
        experiment_variables["sap_scores"].append((sap_mean, sap_std, saps))
        print(f"\tSAP score: {sap_mean:.4f} ± {sap_std:.4f}")

        mig_mean, mig_std = np.mean(migs), np.std(migs)
        experiment_variables["mig_scores"].append((mig_mean, mig_std, migs))
        print(f"\tMIG score: {mig_mean:.4f} ± {mig_std:.4f}")

        dci_disentanglement_mean, dci_disentanglement_std = np.mean(dci_disentanglements), np.std(dci_disentanglements)
        experiment_variables["dci_disentanglement_scores"].append((dci_disentanglement_mean, dci_disentanglement_std, dci_disentanglements))
        print(f"\tDCI disentanglement score: {dci_disentanglement_mean:.4f} ± {dci_disentanglement_std:.4f}")

        dci_completeness_mean, dci_completeness_std = np.mean(dci_completenesss), np.std(dci_completenesss)
        experiment_variables["dci_completeness_scores"].append((dci_completeness_mean, dci_completeness_std, dci_completenesss))
        print(f"\tDCI completeness score: {dci_completeness_mean:.4f} ± {dci_completeness_std:.4f}")

        dci_informativeness_mean, dci_informativeness_std = np.mean(dci_informativenesss), np.std(dci_informativenesss)
        experiment_variables["dci_informativeness_scores"].append((dci_informativeness_mean, dci_informativeness_std, dci_informativenesss))
        print(f"\tDCI informativeness score: {dci_informativeness_mean:.4f} ± {dci_informativeness_std:.4f}")
        
        r4_mean, r4_std = np.mean(r4s), np.std(r4s)
        experiment_variables["r4_scores"].append((r4_mean, r4_std, r4s))
        print(f"\tR4 score: {r4_mean:.4f} ± {r4_std:.4f}")

        
        mean_concept_acc_mean, mean_concept_acc_std = np.mean(mean_concept_accs), np.std(mean_concept_accs)
        experiment_variables["mean_concept_acc_scores"].append((mean_concept_acc_mean, mean_concept_acc_std, mean_concept_accs))
        print(f"\tMean concept accuracy score: {mean_concept_acc_mean:.4f} ± {mean_concept_acc_std:.4f}")
        
        ois_mean, ois_std = np.mean(oiss), np.std(oiss)
        experiment_variables["ois_scores"].append((ois_mean, ois_std, oiss))
        print(f"\tOIS score: {ois_mean:.4f} ± {ois_std:.4f}")
        
        nis_mean, nis_std = np.mean(niss), np.std(niss)
        experiment_variables["nis_scores"].append((nis_mean, nis_std, niss))
        print(f"\tNIS score: {nis_mean:.4f} ± {nis_std:.4f}")
        
        # And serialize the results
        serialize_results(
            results_dict=experiment_variables,
            results_dir=experiment_config["results_dir"],
        )
    return experiment_variables

## Experiment runs

In [None]:
############################################################################
## Experiment config
############################################################################

experiment_config = dict(
    trials=5,
    batch_size=128,
    num_outputs=1,
    results_dir=os.path.join(
        RESULTS_DIR,
        "covariance_0.25"
    ),
    input_shape=[7],
    num_bins=20,
    test_samples=3000,
    covariance=0.25,
    soft_mode=["localized", "random"],
    on_start_interval=0.95,
    on_end_interval=1.0,
    off_start_interval=0.0,
    off_end_interval=0.05,
    verbosity=0,
    num_concepts=5,
    data_concepts=5,
)

# Generate the experiment directory if it does not exist already
Path(experiment_config["results_dir"]).mkdir(parents=True, exist_ok=True)
figure_dir = os.path.join(experiment_config["results_dir"], "figures")
Path(figure_dir).mkdir(parents=True, exist_ok=True)

############################################################################
## Experiment run
############################################################################

results = metric_comp_experiment_loop(
    experiment_config,
)
print("ois_scores:", list(map(lambda x: x[:2], results["ois_scores"])))
print("nis_scores:", list(map(lambda x: x[:2], results["nis_scores"])))
print("sap_scores:", list(map(lambda x: x[:2], results["sap_scores"])))
print("mig_scores:", list(map(lambda x: x[:2], results["mig_scores"])))
print("r4_scores:", list(map(lambda x: x[:2], results["r4_scores"])))
print("dci_disentanglement_scores:", list(map(lambda x: x[:2], results["dci_disentanglement_scores"])))
print("dci_completeness_scores:", list(map(lambda x: x[:2], results["dci_completeness_scores"])))
print("dci_informativeness_scores:", list(map(lambda x: x[:2], results["dci_informativeness_scores"])))


### Result Plots

In [None]:
from prettytable import PrettyTable
import scipy

metrics = [
    ("ois_scores", "OIS (ours)"),
    ("nis_scores", "NIS (ours)"),
    ("sap_scores", "SAP"),
    ("mig_scores", "MIG"),
    ("r4_scores", "R4"),
    ("dci_disentanglement_scores", "DCI Disentanglement"),
    ("dci_completeness_scores", "DCI Completeness"),
    ("dci_informativeness_scores", "DCI Informativeness"),
]

tab = PrettyTable()
tab.field_names = [""] + list(map(lambda x: x[1], metrics))

tab.add_row(["Baseline"] + list(map(lambda x: f'{results[x[0]][-1][0] * 100:.2f}% ± {results[x[0]][-1][1] * 100:.2f}%', metrics)))
p_vals = {}
for method_name, _ in metrics:
    vals_null = np.array(results[method_name][-1][2]) * 100
    vals_hyp = np.array(results[method_name][0][2]) * 100
    p_vals[method_name] = scipy.stats.ttest_ind(
        vals_null,
        vals_hyp,
        equal_var=False,
        alternative='two-sided'
    )[1]
tab.add_row(["Impure"] + list(map(lambda x: f'{results[x[0]][0][0] * 100:.2f}% ± {results[x[0]][0][1] * 100:.2f}% (p = {p_vals[x[0]]:.2e})', metrics)))
print(tab)    

