# Synthetic Dataset Figures

This colab is intended to provide code for reproducing all synthetic-dataset related figures presented in the main paper.

Note: This colab was built to support the data config used in the main paper. This means there are a couple of assumptions:
* The dataset has been saved as a .pkl file (this is how the provided dataset generating script saves datasets).
* There are two labels
* There are two concepts, and their names are 'SINE0' and 'SINE1' (this is not a critical assumption - only certain color mappings are affected)

If your dataset violates any of these assumptions, you will have to update the code to fit your use case.

In [1]:
#@title Imports and helper functions

import os
import collections
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib import gridspec
import sklearn.metrics

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import dataset_utils
import model_utils

%matplotlib inline
%config InlineBackend.figure_format='retina'

SMALL_SIZE = 14
MEDIUM_SIZE = 18
LARGE_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title

# Helper functions

def set_axis_style_bar(ax, labels=None, spineon=False):
  """Set certain axis parameters for bar plots."""
  ax.get_xaxis().set_tick_params(direction='out')
  ax.get_yaxis().set_tick_params(direction='out')
  if labels:
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, fontsize="large")
    ax.set_xlim(-0.5, len(labels) - 0.5)
  if not spineon:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
  ax.yaxis.grid(True)
  ax.xaxis.set_ticks_position('bottom')
  ax.yaxis.set_ticks_position('left')

def set_axis_style(ax, labels=None):
  """Set certain axis parameters for scatter/line plots."""
  ax.get_xaxis().set_tick_params(direction='out')
  ax.get_yaxis().set_tick_params(direction='out')
  if labels:
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, fontsize="large")
    ax.set_xlim(-0.5, len(labels) - 0.5)
  ax.yaxis.grid(True)
  ax.hlines(0,0,100,linewidth=2,color=[0.5,0.5,0.5,0.8])
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.xaxis.set_ticks_position('bottom')
  ax.yaxis.set_ticks_position('left')

def shift_array(xs, n):
  """Shifts a 1D array by n elements"""
  if n == 0:
    return xs
  elif n > 0:
    return np.concatenate((np.full(n, np.nan), xs[:-n]))
  else:
    return np.concatenate((xs[-n:], np.full(-n, np.nan)))

def plot_with_concept_mask_over_time(array, concept_mask, ax, color_pair, concept_name):
  """Plot mean array values and error over time, conditioned on concept presence and absence.
  
  Args:
    array: Array of values of shape [batch_size, time, num_bootstraps]
    concept_mask: Array of boolean values of shape [batch_size,] indicating the presence of a concept.
    ax: Axis to plot on.
    color_pair: A tuple pair of colors for plotting presence (second) and absence (first) lines.
    concept_name: Name of concept for labeling plotted object.
  """
  for presence in [False, True]:
    masked_array = array[concept_mask==presence, :]
    mean_masked = np.nanmean(masked_array, axis = 0)
    mean = np.nanmean(mean_masked, axis=1)
    std = np.nanstd(mean_masked, axis=1)
    ax.plot(
        list(range(masked_array.shape[1])),
        mean,
        label=f"{concept_name} {'present' if presence else 'absent'}",
        color=color_pair[1] if presence else color_pair[0],
        alpha=1 if presence else 0.6,
        linestyle="-" if presence else "--",
        linewidth=2)
    ax.fill_between(
        list(range(masked_array.shape[1])),
        mean-std, mean+std, 
        color=color_pair[1] if presence else color_pair[0],
        alpha=0.5 if presence else 0.2)
    ax.set_xlim(0, masked_array.shape[1])

In [2]:
#@title Load data

dataset_filename = "Enter .pkl dataset filename here" #@param {type: 'string'}
dataset = dataset_utils.load_pickled_data(dataset_filename)

# Dataset concept names have a specific order in the dataset that we must be aware of when plotting concept-related arrays.
dataset_concept_names = [c["name"] for c in dataset["config"]["concept_specs"]]

#@markdown Provide the synthetic_n_eval command line argument used for run_cs_tca.py below. We will collect the first n values of the loaded test dataset where n=synthetic_n_eval. 
#@markdown This is the dataset that was used to compute cs and tca metrics, and we will need these arrays for various figures.
synthetic_n_eval = 100 #@param {type: 'integer'}

eval_dataset_sequence = dataset["test_split"]["sequence"][:, :synthetic_n_eval, :]
eval_dataset_label = dataset["test_split"]["label"][:, :synthetic_n_eval, :]
eval_dataset_concept_changes = dataset["test_split"]["changes"][:synthetic_n_eval, :]
eval_dataset_concepts = dataset["test_split"]["concept"][:synthetic_n_eval, :]
eval_dataset_concept_sequences = dataset["test_split"]["concept_sequence"][:, :synthetic_n_eval, :]

In [None]:
#@title Comparing Bootstrapped CAV metrics across strategies
#@markdown There are various CAV building strategies discussed in the main paper, namely 't1_only,' 't0_to_t1,' and 't0_to_t1_diff'.  
#@markdown In this section, we plot the accuracy of the classifiers trained to discriminate between concept and non-concept examples for each strategy.
#@markdown Edit the cell and fill in directories to the cav output files (these should be the output_dir arguments of the create_cavs.py job) 
#@markdown for each of the strategies.

#@markdown Note: Assumes metrics collected are same for all strategies. If this is not the case, the code will need to be updated.

strategy_and_cav_dir_pairs = [
  ("t1 only", "cav directory here"),
  ("t0 to t1", "cav directory here"),
  ("t0 to t1 diff", "cav directory here")
]
metrics_by_strategy = {strat: dataset_utils.load_pickled_data(os.path.join(dir, "cav_metrics.pkl")) for strat, dir in strategy_and_cav_dir_pairs}

# get all relevant strategy, metric, concept, and layers from metrics objects.
strategy_names = set()
metric_names = set()
concept_names = set()
layers = set()

for strat, metrics in metrics_by_strategy.items():
  strategy_names.add(strat)
  for concept_name in metrics:
    concept_names.add(concept_name)
    for layer in metrics[concept_name]:
      layers.add(layer)
      for metric_name in metrics[concept_name][layer]:
        metric_names.add(metric_name)

strategy_names = list(strategy_names)
metric_names = list(metric_names)
concept_names = list(concept_names)
layers = list(layers)

print("All strategies:", strategy_names)
print("All metrics:", metric_names)
print("All concepts:", concept_names)
print("All layers:", layers)

# Create dataframe with reported CAV metrics                              
metric_results_rows = []               
for strat, metrics in metrics_by_strategy.items():
  for concept_name in metrics:
    for layer in metrics[concept_name]:
      metric_results_row = {
          "strategy": strat,
          "concept": concept_name,
          "layer": layer
      }
      for metric_name in metrics[concept_name][layer]:
        metric_result = metrics[concept_name][layer][metric_name]
        pvalue, baseline_score, baseline_score_std, permuted_scores = metric_result
        worst_num_extremes = np.sum(np.array(permuted_scores) >= (baseline_score - baseline_score_std))
        worst_pvalue = (worst_num_extremes + 1) / float(len(permuted_scores) + 1)
        metric_results_row.update({
            f"{metric_name}_avg": baseline_score,
            f"{metric_name}_std": baseline_score_std,
            f"{metric_name}_pvalue": pvalue,
            f"{metric_name}_pvalue_worst": worst_pvalue
        })
      metric_results_rows.append(metric_results_row)
metrics_results_df = pd.DataFrame(metric_results_rows)

# Plot held out bootstrap CAV classifier performance
plt.style.use("classic")
name = "Blues"
cmap = get_cmap(name)
idx_col = np.array(range(len(strategy_and_cav_dir_pairs))) / len(strategy_and_cav_dir_pairs)
colors=[]
for ic in idx_col:
  colors.append(cmap(ic))
alphas = [0.9, 0.9]

bar_width = 1 / (len(strategy_and_cav_dir_pairs)+1)
plt.rcParams.update({'font.size': 14})
strategy_names = list(metrics_by_strategy.keys())

fig, ax = plt.subplots(1, len(concept_names), figsize=(13, 4), dpi=300, sharey=True)
for strat_idx, strat in enumerate(strategy_names):
  for concept_idx, concept_name in enumerate(concept_names):
    results_df = (metrics_results_df[(metrics_results_df["strategy"] == strat) & (metrics_results_df["concept"] == concept_name)]
                  .sort_values("layer"))
    ax[concept_idx].bar(
        np.array(range(len(layers))) + bar_width * strat_idx,
        results_df["balanced_accuracy_avg"]*100,
        width=bar_width,
        color=colors[strat_idx],
        alpha=alphas[1])
    ax[concept_idx].errorbar(
        np.array(range(len(layers))) + bar_width * strat_idx,
        results_df["balanced_accuracy_avg"].to_numpy()*100,
        yerr=results_df["balanced_accuracy_std"].to_numpy()*100,
        fmt="none",
        label="_nolegend_",
        ecolor=[0, 0, 0],
        elinewidth=2)
    ax[concept_idx].set_ylim([0, 105])
    ax[concept_idx].set_xticks(
        np.array(range(len(layers))) +
        (len(strategy_names)+1) * bar_width / len(strategy_names))
    ax[concept_idx].set_xticklabels([f"layer {layer}" for layer in layers],
                                    fontsize=18, fontweight='bold')
    set_axis_style_bar(ax[concept_idx])
    ax[concept_idx].set_title(f"Concept: {concept_name}", fontsize="large", wrap=True, fontweight='bold')
ax[0].set_ylabel("Bootstrapped accuracy [%]", fontsize="large", wrap=True, fontweight='bold')

legend = ax[1].legend(strategy_names, loc = "upper right", bbox_to_anchor=(2, 1), fontsize=18, frameon=False, title="CAV strategy")
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")

plt.tight_layout()
plt.show()

In [None]:
#@title Comparing CAV metrics on unseen test data across strategies
#@markdown Here we recompute accuracy and rocauc for each strategy on an unseen cohort of test data elements. We use a different slice of the test
#@markdown dataset than is used for final cs and tca evaluation. This code uses the same directories provided in the cell above. Provide
#@markdown a path to a trained model checkpoint and the number of unseen data to compute metrics for below.

#@markdown Note: Assumes metrics collected are same for all strategies. If this is not the case, the code will need to be updated.


# Compute performance on unseen test data
model_checkpoint_path = "Enter model checkpoint path here. It should end with /tfhub" #@param {type: 'string'}
num_unseen_data_to_test = 100 #@param {type: 'integer'}

unseen_test_sequence = dataset["test_split"]["sequence"][:, synthetic_n_eval:synthetic_n_eval+num_unseen_data_to_test, :]
unseen_test_label = dataset["test_split"]["label"][:, synthetic_n_eval:synthetic_n_eval+num_unseen_data_to_test, :]
unseen_test_concept_sequences = dataset["test_split"]["concept_sequence"][:, synthetic_n_eval:synthetic_n_eval+num_unseen_data_to_test, :]
unseen_test_output = model_utils.unroll_model_per_step(
    model_checkpoint_path,
    unseen_test_sequence,
    unseen_test_label)

models_by_strategy = {strat: dataset_utils.load_pickled_data(os.path.join(dir, "cav_classifiers.pkl")) for strat, dir in strategy_and_cav_dir_pairs}

unseen_metrics_results_rows = []
ys_all = np.reshape(unseen_test_concept_sequences, [-1, unseen_test_concept_sequences.shape[-1]])
for strat, models in models_by_strategy.items():
  for dataset_concept_idx, concept_name in enumerate(dataset_concept_names):
    ys = ys_all[:, dataset_concept_idx:dataset_concept_idx+1]
    for layer, acts in enumerate(unseen_test_output["activations"]):
      xs = np.reshape(acts, [-1, acts.shape[-1]])
      balanced_accuracies = []
      rocaucs = []
      for trained_model in models[concept_name][layer]:
        preds = trained_model.predict(xs)
        balanced_accuracy = sklearn.metrics.balanced_accuracy_score(ys, preds)
        balanced_accuracies.append(balanced_accuracy)
        try:
          probs = trained_model.predict_proba(xs)[:, 1]
        except AttributeError:
          probs = trained_model.decision_function(xs)
        rocauc = sklearn.metrics.roc_auc_score(ys, probs)
        rocaucs.append(rocauc)
      unseen_metrics_results_rows.append({
          "strategy": strat,
          "concept": concept_name,
          "layer": layer,
          "balanced_accuracy_avg": np.mean(balanced_accuracies),
          "balanced_accuracy_std": np.std(balanced_accuracies),
          "rocauc_avg": np.mean(rocaucs),
          "rocauc_std": np.std(rocaucs)
      })
unseen_metrics_results_df = pd.DataFrame(unseen_metrics_results_rows)

# Plot held out bootstrap CAV classifier performance
plt.style.use("classic")
name = "Blues"
cmap = get_cmap(name)
idx_col = np.array(range(len(strategy_and_cav_dir_pairs))) / len(strategy_and_cav_dir_pairs)
colors=[]
for ic in idx_col:
  colors.append(cmap(ic))
alphas = [0.9, 0.9]

bar_width = 1 / (len(strategy_and_cav_dir_pairs)+1)
plt.rcParams.update({'font.size': 14})
strategy_names = list(metrics_by_strategy.keys())

fig, ax = plt.subplots(1, len(concept_names), figsize=(13, 4), dpi=300, sharey=True)
for strat_idx, strat in enumerate(strategy_names):
  for concept_idx, concept_name in enumerate(concept_names):
    results_df = (unseen_metrics_results_df[(unseen_metrics_results_df["strategy"] == strat) & (unseen_metrics_results_df["concept"] == concept_name)]
                  .sort_values("layer"))
    ax[concept_idx].bar(
        np.array(range(len(layers))) + bar_width * strat_idx,
        results_df["balanced_accuracy_avg"]*100,
        width=bar_width,
        color=colors[strat_idx],
        alpha=alphas[1])
    ax[concept_idx].errorbar(
        np.array(range(len(layers))) + bar_width * strat_idx,
        results_df["balanced_accuracy_avg"].to_numpy()*100,
        yerr=results_df["balanced_accuracy_std"].to_numpy()*100,
        fmt="none",
        label="_nolegend_",
        ecolor=[0, 0, 0],
        elinewidth=2)
    ax[concept_idx].set_ylim([0, 105])
    ax[concept_idx].set_xticks(
        np.array(range(len(layers))) +
        (len(strategy_names)+1) * bar_width / len(strategy_names))
    ax[concept_idx].set_xticklabels([f"layer {layer}" for layer in layers],
                                    fontsize=18, fontweight='bold')
    set_axis_style_bar(ax[concept_idx])
    ax[concept_idx].set_title(f"Concept: {concept_name}", fontsize="large", wrap=True, fontweight='bold')
ax[0].set_ylabel("Bootstrapped accuracy [%]", fontsize="large", wrap=True, fontweight='bold')

legend = ax[1].legend(strategy_names, loc = "upper right", bbox_to_anchor=(2, 1), fontsize=18, frameon=False, title="CAV strategy")
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")

plt.tight_layout()
plt.show()

In [None]:
#@title Load CS and tCA results

def load_cs_tca_results(cs_tca_directory):
  cs_results = collections.defaultdict(dict)
  tca_results = collections.defaultdict(dict)
  for target_subdir in os.listdir(cs_tca_directory):
    target = int(target_subdir.replace("target=", ""))
    for concept_subdir in os.listdir(os.path.join(cs_tca_directory, target_subdir)):
      concept_name = concept_subdir.replace("concept=", "")
      cs_results[target][concept_name], tca_results[target][concept_name] = {}, {}
      for layer_subdir in os.listdir(os.path.join(cs_tca_directory, target_subdir, concept_subdir)):
        full_dir = os.path.join(cs_tca_directory, target_subdir, concept_subdir, layer_subdir)
        layer = int(layer_subdir.replace("layer=", ""))
        cs_results[target][concept_name][layer] = dataset_utils.load_pickled_data(os.path.join(full_dir, "cs_results.pkl"))
        tca_results[target][concept_name][layer] = dataset_utils.load_pickled_data(os.path.join(full_dir, "tca_results.pkl"))
  return cs_results, tca_results

cs_tca_directory = "Enter cs tca directory here. It should be the same as the output_dir command line argument used for run_cs_tca.py" #@param {type: 'string'}
cs_results, tca_results = load_cs_tca_results(cs_tca_directory)

targets = set()
concept_names = set()
layers = set()

for target in cs_results:
  targets.add(target)
  for concept_name in cs_results[target]:
    concept_names.add(concept_name)
    for layer in cs_results[target][concept_name]:
      layers.add(layer)

targets = list(targets)
concept_names = list(concept_names)
layers = list(layers)

print("All targets:", targets)
print("All concepts:", concept_names)
print("All layers:", layers)

In [None]:
#@title Global CS scores

name = "tab20c"
cmap = get_cmap(name)
cols = cmap.colors
mycolors = {"SINE0": cols[13], "SINE1": cols[9]}

fig = plt.figure(figsize=(12, 8), constrained_layout=False, dpi=300)
outer = fig.add_gridspec(1, 2, wspace=0.4, hspace=0.5)
alpha = [0.4, 0.9] 
ax = None
for target in targets:
  subplts = outer[target].subgridspec(len(layers), 1)
  plt.rcParams.update({'font.size': 14})
  barWidth=0.3
  ax_count = 0
  for layer in layers:
    ax = fig.add_subplot(subplts[ax_count], sharey=ax)
    xvals_bars = range(1, len(concept_names) + 1)
    for concept_idx, concept_name in enumerate(concept_names):
      eval_dataset_concept_sequence = eval_dataset_concept_sequences[:, :, dataset_concept_names.index(concept_name)]
      for presence in [0, 1]:
        pres_label = " present" if presence else " absent"
        data = np.nanmean(cs_results[target][concept_name][layer]["bootstrap_CS"][eval_dataset_concept_sequence==presence],axis=0)
        ax.bar(xvals_bars[concept_idx] + barWidth * presence,
              np.nanmean(data), width=barWidth, color=mycolors[concept_name],
              alpha=alpha[presence], label=concept_name + pres_label)
        ax.errorbar(xvals_bars[concept_idx] + barWidth * presence,
                    np.nanmean(data),
                    yerr=np.nanstd(data),
                    fmt='none', ecolor=[0,0,0], elinewidth=2)
    if ax_count == 0:
      plt.title(f"target y_{target}", fontsize=24, fontweight="bold")
    ax.set_ylabel("Layer "+str(layer),fontsize=24, fontweight='bold')
    xtick_labels = concept_names
    ax.set_xticks(np.array(xvals_bars) + barWidth/2)
    ax.set_xticklabels("", fontsize=18, fontweight='bold')
    if ax_count == len(layers)-1:
      ax.set_xticklabels(xtick_labels, fontsize=18, fontweight='bold')
    ax_count += 1

legend = ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1.1),title="Score", fontsize=18, frameon=True)
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")
plt.show()

In [None]:
#@title Global CS scores as time series

desired_changepoint = 50
t0, t1 = 0, 100
layer = 2

CS_target_arrays = {}
for target in targets:
  concept_arrays = []
  for concept_name in concept_names:
    cs_array = cs_results[target][concept_name][layer]["bootstrap_CS"]
    cps = eval_dataset_concept_changes[:, dataset_concept_names.index(concept_name)]
    shifted_cs_array = []
    for i in range(len(cps)):
      cp = cps[i]
      shift_x = desired_changepoint - cp
      shifted_boots_results = []
      num_bootstraps = cs_array.shape[-1]
      for bootstrap in range(num_bootstraps):
        cs_boot = cs_array[:, i, bootstrap]
        shifted_cs = np.reshape(shift_array(cs_boot.tolist(), shift_x), [1, -1])
        shifted_boots_results.append(shifted_cs)
      shifted_boots_results = np.reshape(np.array(shifted_boots_results).T, (1, cs_boot.shape[0], num_bootstraps))
      shifted_cs_array.append(shifted_boots_results)
    shifted_cs_array = np.concatenate(shifted_cs_array, axis=0)
    windowed_cs_array = shifted_cs_array[:, t0:t1]
    concept_arrays.append(windowed_cs_array)
  CS_target_arrays[target] = concept_arrays

concept_colors = {
    "SINE0": (cols[13], cols[13]),
    "SINE1": (cols[9], cols[9])
}

fig = plt.figure(figsize=(12, 4), dpi=300)
gs = gridspec.GridSpec(nrows=1, ncols=2, figure=fig)
gs.update(wspace=0.4, hspace=0.5)
for target in targets:
  if target == 0:
    ax = fig.add_subplot(gs[0, target])
    ax1 = ax
  else:
    ax = fig.add_subplot(gs[0,target],sharey=ax1)
  set_axis_style(ax)
  ax.set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
  ax.set_ylabel("CS" if not target else None, fontsize="x-large", wrap=True, fontweight='bold')
  for concept_idx, concept_name in enumerate(concept_names):
    target_array = CS_target_arrays[target][concept_idx]
    plot_with_concept_mask_over_time(
        array=target_array, 
        concept_mask=eval_dataset_concepts[:, dataset_concept_names.index(concept_name)], 
        ax=ax, 
        color_pair=concept_colors[concept_name],
        concept_name=concept_name)
  ax.set_title(f"Target y_{target}", fontweight='bold')
legend = ax.legend(loc="upper right", title="Score", bbox_to_anchor=(1.8,1),frameon=True)
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")
plt.show()

In [None]:
#@title Compute model performance on eval dataset for error type analysis
#@markdown In order to analyze concept-based scores for each error type (true positives, false positives, etc.),
#@markdown we need to compute model predictions for all eval dataset examples.

model_checkpoint_path = "Enter model checkpoint path here. It should end with /tfhub"  #@param {type: 'string'}

eval_dataset_outputs = model_utils.unroll_model(
    model_checkpoint_path,
    eval_dataset_sequence,
    eval_dataset_label)
eval_dataset_predictions = eval_dataset_outputs["predictions"]
eval_dataset_probs = eval_dataset_outputs["probs"]
eval_dataset_acc = sklearn.metrics.accuracy_score(
    eval_dataset_label.reshape(-1), eval_dataset_predictions.reshape(-1))*100
eval_dataset_prauc = sklearn.metrics.average_precision_score(
    eval_dataset_label.reshape(-1), eval_dataset_predictions.reshape(-1))
eval_dataset_rocauc = sklearn.metrics.roc_auc_score(
    eval_dataset_label.reshape(-1), eval_dataset_predictions.reshape(-1))
print("eval dataset accuracy: %.2f" % eval_dataset_acc)
print("eval dataset prauc: %.4f" % eval_dataset_prauc)
print("eval dataset rocauc: %.4f" % eval_dataset_rocauc)

In [None]:
#@title Global CS scores as time series, by error type
concept_colors = {
    "SINE0": (cols[13], cols[13]),
    "SINE1": (cols[9], cols[9])
}

gs_kw = dict(wspace=0.4, hspace=0.5)
fig, ax = plt.subplots(ncols=2, nrows=4, gridspec_kw=gs_kw, figsize=(10,16), sharey=True, sharex=True, dpi=300)

for target in targets:
  pos_pred_mask = (np.max(eval_dataset_predictions[:, :, target], axis=0) == 1)
  pos_label_mask = (np.max(eval_dataset_label[:, :, target], axis=0) == 1)
  for concept_idx in [0, 1]:
    concept_name = concept_names[concept_idx]
    # TP
    target_array = CS_target_arrays[target][concept_idx]
    filtered_target_array = target_array[pos_pred_mask * pos_label_mask, :]
    filtered_concept_mask = eval_dataset_concepts[:, dataset_concept_names.index(concept_name)][pos_pred_mask * pos_label_mask]
    plot_with_concept_mask_over_time(
        array=filtered_target_array,
        concept_mask=filtered_concept_mask, 
        ax=ax[0][target], 
        color_pair=concept_colors[concept_name],
        concept_name=concept_name)
    set_axis_style(ax[0][target])
    ax[0][target].set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
    ax[0][target].set_ylabel("CS" if not target else None, fontsize="x-large", wrap=True, fontweight='bold')
    ax[0][target].set_title(f"Target y_{target} TP, n={np.size(filtered_concept_mask)}", fontweight='bold')

    # FP
    filtered_target_array = target_array[pos_pred_mask * np.logical_not(pos_label_mask), :]
    filtered_concept_mask = eval_dataset_concepts[:, dataset_concept_names.index(concept_name)][pos_pred_mask * np.logical_not(pos_label_mask)]
    plot_with_concept_mask_over_time(
        array=filtered_target_array, 
        concept_mask=filtered_concept_mask, 
        ax=ax[1][target], 
        color_pair=concept_colors[concept_name],
        concept_name=concept_name)
    set_axis_style(ax[1][target])
    ax[1][target].set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
    ax[1][target].set_ylabel("CS" if not target else None, fontsize="x-large", wrap=True, fontweight='bold') 
    ax[1][target].set_title(f"Target y_{target} FP, n={np.size(filtered_concept_mask)}", fontweight='bold')

    # TN
    filtered_target_array = target_array[np.logical_not(pos_pred_mask) * np.logical_not(pos_label_mask), :]
    filtered_concept_mask = eval_dataset_concepts[:, dataset_concept_names.index(concept_name)][np.logical_not(pos_pred_mask) * np.logical_not(pos_label_mask)]
    plot_with_concept_mask_over_time(
        array=filtered_target_array, 
        concept_mask=filtered_concept_mask, 
        ax=ax[2][target], 
        color_pair=concept_colors[concept_name],
        concept_name=concept_name)
    set_axis_style(ax[2][target])
    ax[2][target].set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
    ax[2][target].set_ylabel("CS" if not target else None, fontsize="x-large", wrap=True, fontweight='bold')
    ax[2][target].set_title(f"Target y_{target} TN, n={np.size(filtered_concept_mask)}", fontweight='bold')

    # FN
    filtered_target_array = target_array[np.logical_not(pos_pred_mask) * pos_label_mask, :]
    filtered_concept_mask = eval_dataset_concepts[:, dataset_concept_names.index(concept_name)][np.logical_not(pos_pred_mask) * pos_label_mask]
    plot_with_concept_mask_over_time(
        array=filtered_target_array, 
        concept_mask=filtered_concept_mask, 
        ax=ax[3][target], 
        color_pair=concept_colors[concept_name],
        concept_name=concept_name)
    set_axis_style(ax[3][target])
    ax[3][target].set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
    ax[3][target].set_ylabel("CS" if not target else None, fontsize="x-large", wrap=True, fontweight='bold')
    ax[3][target].set_title(f"Target y_{target} FN, n={np.size(filtered_concept_mask)}", fontweight='bold')

legend = ax[0][1].legend(loc="upper right", title="Score", bbox_to_anchor=(1.8, 1), frameon=True)
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")
   
plt.show()

In [None]:
#@title Global tCA Scores

name = "tab20c"
cmap = get_cmap(name)
cols = cmap.colors
mycolors = {"SINE0": cols[13], "SINE1": cols[9]}

fig = plt.figure(figsize=(6, 7), constrained_layout=False, dpi=300)
outer = fig.add_gridspec(1, 1, wspace=0.3)
alpha = [0.4, 0.9] 

subplts = outer[0].subgridspec(len(layers), 1)
plt.rcParams.update({'font.size': 14})
barWidth=0.3
ax_count = 0
ax = None
for layer in layers:
  ax = fig.add_subplot(subplts[ax_count], sharey=ax)
  xvals_bars = range(1, len(concept_names) + 1)
  for concept_idx, concept_name in enumerate(concept_names):
    eval_dataset_concept_sequence = eval_dataset_concept_sequences[:, :, dataset_concept_names.index(concept_name)]
    for presence in [0, 1]:
      pres_label = " present" if presence else " absent"
      data = np.nanmean(tca_results[target][concept_name][layer]["bootstrap_tCA"][eval_dataset_concept_sequence==presence],axis=0)
      ax.bar(xvals_bars[concept_idx] + barWidth * presence,
            np.nanmean(data), width=barWidth, color=mycolors[concept_name],
            alpha=alpha[presence], label=concept_name + pres_label)
      ax.errorbar(xvals_bars[concept_idx] + barWidth * presence,
                  np.nanmean(data),
                  yerr=np.nanstd(data),
                  fmt='none', ecolor=[0,0,0], elinewidth=2)
  ax.set_ylabel("Layer "+str(layer),fontsize=24, fontweight='bold')
  xtick_labels = concept_names
  ax.set_xticks(np.array(xvals_bars) + barWidth/2)
  ax.set_xticklabels("", fontsize=18, fontweight='bold')
  if ax_count == len(layers)-1:
    ax.set_xticklabels(xtick_labels, fontsize=18, fontweight='bold')
  ax_count += 1

legend = ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1.1), title="Score", fontsize=18, frameon=True)
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")

plt.show()

In [None]:
#@title Global tCA scores as time series

desired_changepoint = 50
t0, t1 = 0, 100
layer = 2
dummy_target = 0

concept_arrays = []
for concept_name in concept_names:
  tca_array = tca_results[dummy_target][concept_name][layer]["bootstrap_tCA"]
  cps = eval_dataset_concept_changes[:, dataset_concept_names.index(concept_name)]
  shifted_tca_array = []
  for i in range(len(cps)):
    cp = cps[i]
    shift_x = desired_changepoint - cp
    shifted_boots_results = []
    num_bootstraps = tca_array.shape[-1]
    for bootstrap in range(num_bootstraps):
      tca_boot = tca_array[:, i, bootstrap]
      shifted_tca = np.reshape(shift_array(tca_boot.tolist(), shift_x), [1, -1])
      shifted_boots_results.append(shifted_tca)
    shifted_boots_results = np.reshape(np.array(shifted_boots_results).T, (1, tca_boot.shape[0], num_bootstraps))
    shifted_tca_array.append(shifted_boots_results)
  shifted_tca_array = np.concatenate(shifted_tca_array, axis=0)
  windowed_tca_array = shifted_tca_array[:, t0:t1]
  concept_arrays.append(windowed_tca_array)

concept_colors = {
    "SINE0": (cols[13], cols[13]),
    "SINE1": (cols[9], cols[9])
}

fig = plt.figure(figsize=(6, 4), dpi=300)
gs = gridspec.GridSpec(nrows=1, ncols=1, figure=fig)
gs.update(wspace=0.4, hspace=0.2)
ax = fig.add_subplot(gs[0, 0])
set_axis_style(ax)
ax.set_xlabel("Time", fontsize="x-large", wrap=True, fontweight='bold')
ax.set_ylabel("tCA", fontsize="x-large", wrap=True, fontweight='bold')
for concept_idx, concept_name in enumerate(concept_names):
  array = concept_arrays[concept_idx]
  plot_with_concept_mask_over_time(
      array=array, 
      concept_mask=eval_dataset_concepts[:, dataset_concept_names.index(concept_name)], 
      ax=ax, 
      color_pair=concept_colors[concept_name],
      concept_name=concept_name)
ax.set_title(f"Target y_{target}", fontweight='bold')
legend = ax.legend(loc="upper right", title="Score", bbox_to_anchor=(1.65,1.1),frameon=True)
legend.get_title().set_fontsize("20")
legend.get_title().set_fontweight("bold")
plt.show()

In [None]:
#@title Local CS & tCA score time series example

example_idx =  2#@param {type: 'integer'}
layer = 2 #@param {type: 'integer'}

label_idx_to_color = {
    0: cols[13],
    1: cols[9]
}
concept_to_color = {
    "SINE0": cols[13],
    "SINE1": cols[9]
}

fig = plt.figure(figsize=(12,20), dpi=300)
gs = gridspec.GridSpec(nrows=7, ncols=1, figure=fig)
gs.update(wspace=0.4, hspace=0.5)

ax0 = fig.add_subplot(gs[0])
for label_idx in [0, 1]:
  ax0.plot(range(100), eval_dataset_label[:, example_idx, label_idx], label=f"Label y_{label_idx}", color=label_idx_to_color[label_idx], linewidth=2)
  ax0.plot(range(100), eval_dataset_probs[:, example_idx, label_idx], color=label_idx_to_color[label_idx],marker="o",linewidth=2, label=f"Prediction y_{label_idx}",
           markersize=4, linestyle="--")
plt.ylim(-0.1, 1.1)
legend = ax0.legend(loc="upper right", bbox_to_anchor=(1.36, 1))

ax0.set_ylabel("Predictions", fontsize="x-large", wrap=True, fontweight='bold')

ax = None
for target in [0, 1]:
  for concept_idx, concept_name in enumerate(concept_names):
    ax = fig.add_subplot(gs[concept_idx + (2 * target) + 1], sharey=ax)
    score = cs_results[target][concept_name][layer]["bootstrap_CS"][:, example_idx]
    mean = np.nanmean(score, axis=1)
    std = np.nanstd(score, axis=1)
    ax.plot(list(range(t0,t1)), mean, label=concept_name, color=concept_to_color[concept_name], linewidth=2)
    ax.fill_between(list(range(t0,t1)), mean-std, mean+std, color=concept_to_color[concept_name],alpha=0.4)
    score = cs_results[target][concept_name][layer]["permuted_CS"][:, example_idx]
    mean = np.nanmean(score, axis=1)
    std = np.nanstd(score, axis=1)
    ax.plot(list(range(t0,t1)), mean, label=f"Permuted", color=[0.6, 0.6, 0.6, 0.8])
    ax.fill_between(list(range(t0,t1)), mean-std, mean+std, color=[0.6, 0.6, 0.6, 0.8],alpha=0.1)
    ax.legend(loc="upper right",bbox_to_anchor=(1.3,1))
    set_axis_style(ax)
    ax.set_title(f"CS: label y_{target}, concept={concept_name}", fontweight="bold")
    ax.set_ylabel("CS", fontsize="x-large", wrap=True, fontweight="bold")

ax = None
for concept_idx, concept_name in enumerate(concept_names):
  ax = fig.add_subplot(gs[concept_idx + 5], sharey=ax)
  score = tca_results[dummy_target][concept_name][layer]["bootstrap_tCA"][:, example_idx]
  mean = np.nanmean(score, axis=1)
  std = np.nanstd(score, axis=1)
  ax.plot(list(range(t0,t1)), mean, label=concept_name, color=concept_to_color[concept_name], linewidth=2)
  ax.fill_between(list(range(t0,t1)), mean-std, mean+std, color=concept_to_color[concept_name],alpha=0.4)
  score = tca_results[dummy_target][concept_name][layer]["permuted_tCA"][:, example_idx]
  mean = np.nanmean(score, axis=1)
  std = np.nanstd(score, axis=1)
  ax.plot(list(range(t0,t1)), mean, label=f"Permuted", color=[0.6, 0.6, 0.6, 0.8])
  ax.fill_between(list(range(t0,t1)), mean-std, mean+std, color=[0.6, 0.6, 0.6, 0.8],alpha=0.1)
  ax.legend(loc="upper right",bbox_to_anchor=(1.3,1))
  set_axis_style(ax)
  ax.set_ylabel("tCA", fontsize="x-large", wrap=True, fontweight="bold")
  ax.set_title(f"tCA: concept={concept_name}", fontweight="bold")

plt.show()