In [None]:
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np
import pandas as pd
import pickle

from hate_target import keys
from mpl_lego.labels import bold_text, apply_subplot_labels
from mpl_lego.colorbar import append_colorbar_to_axis
from pyprojroot import here
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
labels = keys.target_labels

In [None]:
exp06a_path = here('experiments/exp05f.pkl')
with open(exp06a_path, 'rb') as file:
    results = pickle.load(file)


In [None]:
y_true = np.array(results['y_true']).squeeze().T
y_hard = (y_true >= 0.5).astype('int')
y_pred = np.array(results['y_pred']).squeeze().T
y_pred_labels = (y_pred >= 0.4).astype('int')
hits = y_pred_labels == y_true
n_groups = y_true.shape[1]

In [None]:
roc_aucs = np.zeros((8, 8))
pr_aucs = np.zeros((8, 8))
f1_scores = np.zeros((8, 8))
xs = []
ys = []

In [None]:
for ii in range(8):
    for jj in range(8):
        if ii != jj:
            multi_idx = np.unique(np.concatenate((
                np.argwhere(np.all(y_hard[:, [ii, jj]], axis=1)).ravel(),
                np.argwhere(np.all(y_pred_labels[:, [ii, jj]], axis=1)).ravel())))
            if np.all(y_hard[multi_idx, ii]) or np.all(y_hard[multi_idx, jj]):
                roc_aucs[ii, jj] = np.nan
                pr_aucs[ii, jj] = np.nan
                f1_scores[ii, jj] = np.nan
                if ii < jj:
                    xs.append(ii)
                    ys.append(jj)
            else:
                roc_aucs[ii, jj] = roc_auc_score(y_hard[:, [ii, jj]], y_pred_labels[:, [ii, jj]], average='weighted')
                pr_aucs[ii, jj] = average_precision_score(y_hard[:, [ii, jj]], y_pred_labels[:, [ii, jj]], average='weighted')
                f1_scores[ii, jj] = f1_score(y_hard[:, [ii, jj]], y_pred_labels[:, [ii, jj]], average='weighted')

In [None]:
cross_counts = y_hard.T @ y_hard
normalized_counts = np.zeros_like(cross_counts, dtype='float')

for ii in range(8):
    for jj in range(8):
        normalized_counts[ii, jj] = cross_counts[ii, jj] / min(cross_counts[ii, ii], cross_counts[jj, jj])


In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 9), sharex=True, sharey=True)


for ax in axes.ravel():
    ax.set_xlim([-0.5, 7.5])
    ax.set_ylim([7.5, 0.5])

plt.subplots_adjust(wspace=0.35, hspace=0.3)
masked = np.ma.array(normalized_counts, mask=np.triu(normalized_counts, k=0))
img = axes[0, 0].imshow(masked, vmin=0, vmax=0.5, interpolation=None, cmap='plasma')
cb, cax = append_colorbar_to_axis(axes[0, 0], img, spacing=-0.08)
cax.tick_params(labelsize=13)
cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb.set_label(bold_text('Sample Overlap'), fontsize=15, rotation=270, labelpad=20)
    
axes[0, 0].set_xticks(np.arange(n_groups - 1))
axes[0, 0].set_xticklabels(bold_text(labels[:-1]), ha='right', rotation=30)
axes[0, 0].set_yticks(1 + np.arange(n_groups - 1))
axes[0, 0].set_yticklabels(bold_text(labels[1:]), ha='right')
axes[0, 0].tick_params(labelsize=13)

masked = np.ma.array(f1_scores, mask=np.triu(normalized_counts, k=0))
img = axes[0, 1].imshow(masked, vmin=0., vmax=1, interpolation=None, cmap='plasma')
cb, cax = append_colorbar_to_axis(axes[0, 1], img, spacing=-0.08)
cax.tick_params(labelsize=13)
cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb.set_label(bold_text('F1 Score'), fontsize=15, rotation=270, labelpad=20)

axes[0, 1].set_xticks(np.arange(n_groups - 1))
axes[0, 1].set_xticklabels(bold_text(labels[:-1]), ha='right', rotation=30)
axes[0, 1].set_yticks(1 + np.arange(n_groups - 1))
axes[0, 1].set_yticklabels(bold_text(labels[1:]), ha='right')
axes[0, 1].tick_params(labelsize=13)

masked = np.ma.array(roc_aucs, mask=np.triu(normalized_counts, k=0))
img = axes[1, 0].imshow(masked, vmin=0., vmax=1, interpolation=None, cmap='plasma')
cb, cax = append_colorbar_to_axis(axes[1, 0], img, spacing=-0.08)
cax.tick_params(labelsize=13)
cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb.set_label(bold_text('ROC AUC'), fontsize=15, rotation=270, labelpad=20)

axes[1, 0].set_xticks(np.arange(n_groups - 1))
axes[1, 0].set_xticklabels(bold_text(labels[:-1]), ha='right', rotation=30)
axes[1, 0].set_yticks(1 + np.arange(n_groups - 1))
axes[1, 0].set_yticklabels(bold_text(labels[1:]), ha='right')
axes[1, 0].tick_params(labelsize=13)


masked = np.ma.array(pr_aucs, mask=np.triu(normalized_counts, k=0))
img = axes[1, 1].imshow(masked, vmin=0., vmax=1, interpolation=None, cmap='plasma')
cb, cax = append_colorbar_to_axis(axes[1, 1], img, spacing=-0.08)
cax.tick_params(labelsize=13)
cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb.set_label(bold_text('PR AUC'), fontsize=15, rotation=270, labelpad=20)

axes[1, 1].set_xticks(np.arange(n_groups - 1))
axes[1, 1].set_xticklabels(bold_text(labels[:-1]), ha='right', rotation=30)
axes[1, 1].set_yticks(1 + np.arange(n_groups - 1))
axes[1, 1].set_yticklabels(bold_text(labels[1:]), ha='right')
axes[1, 1].tick_params(labelsize=13)

for ax in [axes[0, 1], axes[1, 0], axes[1, 1]]:
    ax.scatter(xs, ys, marker='x', color='black', s=200)

for ax in axes.ravel():
    for spine in ax.spines.values():
        spine.set_visible(False)

apply_subplot_labels(axes, bold=True, size=18, y=1.08)
plt.savefig('figure4.pdf', bbox_inches='tight')