In [10]:
import os, json, glob
import collections
import numpy as np
import pandas as pd
from scipy.stats import weightedtau
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="ticks", context="talk")

from utils.weighted_spearman import weighted_spearman

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [11]:
def recursively_default_dict():
    return collections.defaultdict(recursively_default_dict)

In [12]:
validator_names ={
        "src_train_mse_score": "MSE (Source Train)",
        "src_val_mse_score": "MSE (Source Val)",
        "target_train_mse_score": "MSE (Target Train)",
        "target_val_mse_score": "MSE (Target Val)",
        "target_test_mse_score": "MSE (Target Test)",

        "src_train_mae_score": "MAE (Source Train)",
        "src_val_mae_score": "MAE (Source Val)",
        "target_train_mae_score": "MAE (Target Train)",
        "target_val_mae_score": "MAE (Target Val)",
        "target_test_mae_score": "MAE (Target Test)",

        "src_train_bnm_score": "BNM (Source Train)",
        "src_val_bnm_score": "BNM (Source Val)",
        "target_train_bnm_score": "BNM (Target Train)",
        "target_val_bnm_score": "BNM (Target Val)",

        "src_train_target_train_bnm_score": "BNM (Source Train + Target Train)",
        "src_train_target_val_bnm_score": "BNM (Source Train + Target Val)",
        "src_val_target_train_bnm_score": "BNM (Source Val + Target Train)",
        "src_val_target_val_bnm_score": "BNM (Source Val + Target Val)",

        "src_train_class_ami_score": "ClassAMI (Source Train Features)",
        "src_val_class_ami_score": "ClassAMI (Source Val Features)",
        "target_train_class_ami_score": "ClassAMI (Target Train Features)",
        "target_val_class_ami_score": "ClassAMI (Target Val Features)",

        "src_train_target_train_class_ami_score": "ClassAMI (Source Train + Target Train Features)",
        "src_train_target_val_class_ami_score": "ClassAMI (Source Train + Target Val Features)",
        "src_val_target_train_class_ami_score": "ClassAMI (Source Val + Target Train Features)",
        "src_val_target_val_class_ami_score": "ClassAMI (Source Val + Target Val Features)",

        "src_train_logits_class_ami_score": "ClassAMI (Source Train Logits)",
        "src_val_logits_class_ami_score": "ClassAMI (Source Val Logits)",
        "target_train_logits_class_ami_score": "ClassAMI (Target Train Logits)",
        "target_val_logits_class_ami_score": "ClassAMI (Target Val Logits)",

        "src_train_target_train_logits_class_ami_score": "ClassAMI (Source Train + Target Train Logits)",
        "src_train_target_val_logits_class_ami_score": "ClassAMI (Source Train + Target Val Logits)",
        "src_val_target_train_logits_class_ami_score": "ClassAMI (Source Val + Target Train Logits)",
        "src_val_target_val_logits_class_ami_score": "ClassAMI (Source Val + Target Val Logits)",

        "src_train_snd_score": "SND (Source Train)",
        "src_val_snd_score": "SND (Source Val)",
        "target_train_snd_score": "SND (Target Train)",
        "target_val_snd_score": "SND (Target Val)",

        "src_train_neg_snd_score": "-SND (Source Train)",
        "src_val_neg_snd_score": "-SND (Source Val)",
        "target_train_neg_snd_score": "-SND (Target Train)",
        "target_val_neg_snd_score": "-SND (Target Val)",

        "src_train_target_train_mmd_score": "MMD (Source Train + Target Train)",
        "src_train_target_val_mmd_score": "MMD (Source Train + Target Val)",
        "src_val_target_train_mmd_score": "MMD (Source Val + Target Train)",
        "src_val_target_val_mmd_score": "MMD (Source Val + Target Val)",

        "src_train_target_train_mmd_per_class_score": "MMDPerClass (Source Train + Target Train)",
        "src_train_target_val_mmd_per_class_score": "MMDPerClass (Source Train + Target Val)",
        "src_val_target_train_mmd_per_class_score": "MMDPerClass (Source Val + Target Train)",
        "src_val_target_val_mmd_per_class_score": "MMDPerClass (Source Val + Target Val)",

        "src_train_target_train_logits_mmd_score": "MMD (Source Train + Target Train Logits)",
        "src_train_target_val_logits_mmd_score": "MMD (Source Train + Target Val Logits)",
        "src_val_target_train_logits_mmd_score": "MMD (Source Val + Target Train Logits)",
        "src_val_target_val_logits_mmd_score": "MMD (Source Val + Target Val Logits)",

        "src_train_target_train_logits_mmd_per_class_score": "MMDPerClass (Source Train + Target Train Logits)",
        "src_train_target_val_logits_mmd_per_class_score": "MMDPerClass (Source Train + Target Val Logits)",
        "src_val_target_train_logits_mmd_per_class_score": "MMDPerClass (Source Val + Target Train Logits)",
        "src_val_target_val_logits_mmd_per_class_score": "MMDPerClass (Source Val + Target Val Logits)",

        "src_train_target_train_preds_mmd_score": "MMD (Source Train + Target Train Preds)",
        "src_train_target_val_preds_mmd_score": "MMD (Source Train + Target Val Preds)",
        "src_val_target_train_preds_mmd_score": "MMD (Source Val + Target Train Preds)",
        "src_val_target_val_preds_mmd_score": "MMD (Source Val + Target Val Preds)",

        "src_train_target_train_preds_mmd_per_class_score": "MMDPerClass (Source Train + Target Train Preds)",
        "src_train_target_val_preds_mmd_per_class_score": "MMDPerClass (Source Train + Target Val Preds)",
        "src_val_target_train_preds_mmd_per_class_score": "MMDPerClass (Source Val + Target Train Preds)",
        "src_val_target_val_preds_mmd_per_class_score": "MMDPerClass (Source Val + Target Val Preds)",

        "target_train_henry_score": "H-Score-Simple (Target Train)",
        "target_val_henry_score": "H-Score-Simple (Target Val)",

        "target_train_henry_v3_score": "H-Score-v3 (Target Train)",
        "target_val_henry_v3_score": "H-Score-v3 (Target Val)",
}

In [13]:
figures_root = "figures"
results_root, algorithms = "results", ["source-only", "adda", "coral", "dann", "gan", "mmd", "vada"]
datasets = {
    "mnistmr": {
        "domains": [("mnist", "mnistm")],
        "oracle": "target_test_mse_score"
    },
    "dogs_and_birds": {
        "domains": [
            ("dogs", "birds"),
            ("birds", "dogs")
        ],
        "oracle": "target_test_mse_score"
    }
}
correlations = {
    "pearson": (False, "pearson"),
    #"weighted pearson": lambda x, y: ,
    "spearman": (False, "spearman"),
    "weighted spearman": (False, lambda x, y: weighted_spearman(x, y, 2)), # used by musgrave et al.
    "kendall": (False, "kendall"),
    "weighted kendall": (False, lambda x, y: weightedtau(x, y)[0]) # used by ferrari
}
correlation = "weighted spearman"
post_computed_scores = True
plot_scatter = False

In [14]:
validators_to_check = [
    "target_val_henry_score", "target_val_henry_v3_score",
    "src_val_target_val_bnm_score", "src_val_target_val_class_ami_score",
    "src_val_target_val_mmd_score", "target_val_snd_score",
    "src_val_mse_score"
]

In [15]:
def load_results(results_root, dataset, source, target, algorithms, post=False, merge=True):
    all_scores = {}
    for algorithm in algorithms:
        prefix = "post_scores_" if post else "scores_"
        score_paths = glob.glob(os.path.join(results_root, dataset, source, target, algorithm, f"{prefix}*.json"))
        scores = {f"{algorithm}_{path.split('/')[-1][len(prefix):-5]}": json.load(open(path, "r")) for path in score_paths}
        if merge:
            all_scores.update(scores)
        else:
            all_scores[algorithm] = scores
    return pd.DataFrame().from_dict(all_scores).T

In [16]:
def load_all_results(results_root, datasets, algorithms, post=False):
    all_tables = recursively_default_dict()
    for dataset in datasets:
        for source, target in datasets[dataset]["domains"]:
            if source == target:
                continue
            all_tables[dataset][source][target] = load_results(results_root, dataset, source, target, algorithms, post=post)
    return all_tables

In [17]:
tables = load_all_results(results_root, datasets, algorithms, post=post_computed_scores)

In [18]:
validators_to_remove = set()
for dataset in tables.keys():
    for source in tables[dataset].keys():
        for target in tables[dataset][source].keys():
            table = tables[dataset][source][target]
            #for validator in table:
            #    if table[validator].count() < 2400:
            #        validators_to_remove.add(validator)

            if table.empty:
                continue

            try:
                table["src_train_neg_snd_score"] = -table["src_train_snd_score"]
                table["src_val_neg_snd_score"] = -table["src_val_snd_score"]
                table["target_train_neg_snd_score"] = -table["target_train_snd_score"]
                table["target_val_neg_snd_score"] = -table["target_val_snd_score"]
            except:
                print(f"No SND scores logged for {dataset} {source} {target}")

            try:
                del table["epoch"]
            except:
                pass

In [19]:
[print("Removing", v) for v in validators_to_remove];

In [20]:
for d in tables:
    for s in tables[d]:
        for t in tables[d][s]:
            for v in validators_to_remove:
                try:
                    del tables[d][s][t][v]
                except:
                    pass

make a dictionary that maps dataset-source-target-algorithm to its full set of results

In [21]:
algorithm_tables = recursively_default_dict()
for dataset in tables.keys():
    for source in tables[dataset].keys():
        for target in tables[dataset][source].keys():
            table = tables[dataset][source][target]
            for algorithm in algorithms:
                algorithm_table = table[[i.startswith(algorithm) for i in table.index]]
                algorithm_tables[dataset][source][target][algorithm] = algorithm_table

Make a table for the target test accuracies for each validator

In [22]:
model_selection_tables = recursively_default_dict()
model_selection_tables_validator_values = recursively_default_dict()
for validator in tables["mnistmr"]["mnist"]["mnistm"].columns:
    for dataset in tables.keys():
        for source in tables[dataset].keys():
            for target in tables[dataset][source].keys():
                for algorithm in algorithms:
                    a = algorithm_tables[dataset][source][target][algorithm]
                    if dataset.upper()+"-"+source.capitalize()[0]+target.capitalize()[0] not in model_selection_tables[validator]:
                        model_selection_tables[validator][dataset.upper()+"-"+source.capitalize()[0]+target.capitalize()[0]] = {}
                        model_selection_tables_validator_values[validator][dataset.upper()+"-"+source.capitalize()[0]+target.capitalize()[0]] = {}
                    try:
                        model_selection_tables[validator][dataset.upper()+"-"+source.capitalize()[0]+target.capitalize()[0]][algorithm] = a.iloc[a[validator].argmax()][datasets[dataset]["oracle"]]
                        model_selection_tables_validator_values[validator][dataset.upper()+"-"+source.capitalize()[0]+target.capitalize()[0]][algorithm] = a[validator].max()
                    except:
                        print(f"Coundn't do {validator}, {dataset.upper()+'-'+source.capitalize()[0]+target.capitalize()[0]}, {algorithm}")
        model_selection_tables[validator] = pd.DataFrame().from_dict(model_selection_tables[validator])
        model_selection_tables_validator_values[validator] = pd.DataFrame().from_dict(model_selection_tables_validator_values[validator])

Coundn't do src_train_mse_score, DOGS_AND_BIRDS-BD, dann
Coundn't do src_train_mse_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do src_val_mse_score, DOGS_AND_BIRDS-BD, dann
Coundn't do src_val_mse_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do target_train_mse_score, DOGS_AND_BIRDS-BD, dann
Coundn't do target_train_mse_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do target_val_mse_score, DOGS_AND_BIRDS-BD, dann
Coundn't do target_val_mse_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do target_test_mse_score, DOGS_AND_BIRDS-BD, dann
Coundn't do target_test_mse_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do src_train_mae_score, DOGS_AND_BIRDS-BD, dann
Coundn't do src_train_mae_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do src_val_mae_score, DOGS_AND_BIRDS-BD, dann
Coundn't do src_val_mae_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do target_train_mae_score, DOGS_AND_BIRDS-BD, dann
Coundn't do target_train_mae_score, DOGS_AND_BIRDS-BD, mmd
Coundn't do target_val_mae_score, DOGS_AND_BIRDS-BD, dann
Coundn't do target_val_mae_s

Table 5

In [23]:
model_selection_tables["src_val_mse_score"]

Unnamed: 0,MNISTMR-MM,DOGS_AND_BIRDS-DB,DOGS_AND_BIRDS-BD
adda,-46.265926,-1421.468994,-1261.349731
coral,-47.114094,-1251.896606,-1122.283325
dann,-45.405823,-1263.771729,
gan,-44.104088,-1245.370117,-1085.512695
mmd,-61.472012,-1246.888794,
source-only,-46.955608,-1170.086426,-1018.940857
vada,-48.440907,-1234.859619,-1070.798218


In [24]:
model_selection_tables["target_test_mse_score"]

Unnamed: 0,MNISTMR-MM,DOGS_AND_BIRDS-DB,DOGS_AND_BIRDS-BD
adda,-16.496634,-1394.919556,-1261.349731
coral,-43.59753,-1249.920532,-967.836182
dann,-35.224033,-1203.738647,
gan,-16.285812,-1155.002808,-1042.553711
mmd,-45.124035,-1229.689697,
source-only,-42.277546,-1021.140076,-818.450867
vada,-14.37987,-1095.346802,-970.195984


In [25]:
d = pd.DataFrame()
for validator in validators_to_check:
    d[validator] = model_selection_tables_validator_values[validator].mean(axis=1)
d.index = d.index.str.upper()
display(d)
print(d.to_latex(float_format="%.2f"))

Unnamed: 0,target_val_henry_score,target_val_henry_v3_score,src_val_target_val_bnm_score,src_val_target_val_class_ami_score,src_val_target_val_mmd_score,target_val_snd_score,src_val_mse_score
ADDA,0.240567,-294.273341,0.014765,0.218726,-0.011637,6.768794,-293.274897
CORAL,0.256033,-242.170715,0.013632,0.215105,-0.057445,4.963744,-241.17198
DANN,0.29695,-228.38565,0.013772,0.220088,-0.019587,4.805622,-227.386983
GAN,0.249733,-245.783697,0.013829,0.205697,-0.010252,6.007957,-244.784989
MMD,0.2965,-227.155869,0.013977,0.211332,-0.026309,4.737028,-226.15766
SOURCE-ONLY,,,,,,,-239.813514
VADA,0.249567,-248.787444,0.013665,0.208304,-0.01309,6.55573,-247.78895


\begin{tabular}{lrrrrrrr}
\toprule
{} &  target\_val\_henry\_score &  target\_val\_henry\_v3\_score &  src\_val\_target\_val\_bnm\_score &  src\_val\_target\_val\_class\_ami\_score &  src\_val\_target\_val\_mmd\_score &  target\_val\_snd\_score &  src\_val\_mse\_score \\
\midrule
ADDA        &                    0.24 &                    -294.27 &                          0.01 &                                0.22 &                         -0.01 &                  6.77 &            -293.27 \\
CORAL       &                    0.26 &                    -242.17 &                          0.01 &                                0.22 &                         -0.06 &                  4.96 &            -241.17 \\
DANN        &                    0.30 &                    -228.39 &                          0.01 &                                0.22 &                         -0.02 &                  4.81 &            -227.39 \\
GAN         &                    0.25 &                    -245.78 &

In [26]:
d = pd.DataFrame()
for validator in validators_to_check:
    d[validator] = model_selection_tables[validator].mean(axis=1)
d.index = d.index.str.upper()
avg = pd.DataFrame({"Avg.": d.mean()}).T
avg_rank = pd.DataFrame({"Avg. Rank": d.rank(axis=1, ascending=False).mean()}).T
d = pd.concat([d, avg, avg_rank])
display(d.round(2))
print(d.to_latex(float_format="%.2f"))

Unnamed: 0,target_val_henry_score,target_val_henry_v3_score,src_val_target_val_bnm_score,src_val_target_val_class_ami_score,src_val_target_val_mmd_score,target_val_snd_score,src_val_mse_score
ADDA,-909.69,-909.69,-1394.14,-1442.41,-1745.37,-1503.76,-909.69
CORAL,-792.87,-807.1,-824.1,-854.31,-822.77,-773.81,-807.1
DANN,-670.8,-654.59,-673.95,-687.61,-711.35,-702.85,-654.59
GAN,-835.71,-791.47,-823.09,-809.6,-924.53,-868.39,-791.66
MMD,-681.78,-654.18,-684.34,-670.16,-735.16,-697.48,-654.18
SOURCE-ONLY,-755.7,-755.7,-755.7,-755.7,-755.7,-755.7,-745.33
VADA,-798.74,-784.7,-824.16,-853.55,-841.42,-1094.43,-784.7
Avg.,-777.9,-765.35,-854.21,-867.62,-933.76,-913.77,-763.89
Avg. Rank,3.36,2.21,4.5,4.79,6.07,5.21,1.86


\begin{tabular}{lrrrrrrr}
\toprule
{} &  target\_val\_henry\_score &  target\_val\_henry\_v3\_score &  src\_val\_target\_val\_bnm\_score &  src\_val\_target\_val\_class\_ami\_score &  src\_val\_target\_val\_mmd\_score &  target\_val\_snd\_score &  src\_val\_mse\_score \\
\midrule
ADDA        &                 -909.69 &                    -909.69 &                      -1394.14 &                            -1442.41 &                      -1745.37 &              -1503.76 &            -909.69 \\
CORAL       &                 -792.87 &                    -807.10 &                       -824.10 &                             -854.31 &                       -822.77 &               -773.81 &            -807.10 \\
DANN        &                 -670.80 &                    -654.59 &                       -673.95 &                             -687.61 &                       -711.35 &               -702.85 &            -654.59 \\
GAN         &                 -835.71 &                    -791.47 &

Table 6: all algorithms

In [None]:
d = pd.DataFrame()
for validator in validators_to_check:
    train_validator = validator[::-1].replace("lav", "niart", 1)[::-1]
    d[validator] = [model_selection_tables[train_validator].mean().mean(), model_selection_tables[validator].mean().mean()]
d = d.rename({0: "Train", 1: "Val"})
display(d)
print(d.to_latex(float_format="%.2f"))

Table 6: DANN

In [None]:
d = pd.DataFrame()
for validator in validators_to_check:
    train_validator = validator[::-1].replace("lav", "niart", 1)[::-1]
    d[validator] = [model_selection_tables[train_validator].T["dann"].mean(), model_selection_tables[validator].T["dann"].mean()]
d = d.rename({0: "Train", 1: "Val"})
display(d)
print(d.to_latex(float_format="%.2f"))

Gap tables and plots

In [None]:
def compute_and_plot_gaps(v, mode="mean"):
    gaps = []
    for validator in v:
        if v == "target_test_acc_score":
            continue
        if mode == "mean":
            gaps.append({"gap": (v["target_test_acc_score"] - v[validator]).abs().mean().mean(), "validator": validator})
        elif mode == "max":
            gaps.append({"gap": (v["target_test_acc_score"] - v[validator]).abs().max().max(), "validator": validator})
    gaps = pd.DataFrame(gaps).sort_values("gap")

    gaps.drop(gaps[gaps['validator'] == "target_test_acc_score"].index, inplace=True)
    gaps.drop(gaps[gaps['validator'] == "target_val_acc_score"].index, inplace=True)
    gaps.drop(gaps[gaps['validator'] == "target_train_acc_score"].index, inplace=True)

    sns.set(style="ticks", context="paper")
    plt.figure(figsize=(7, 14))
    ax = sns.barplot(data=gaps, x="gap", y="validator", color="cornflowerblue")
    ax.set_yticklabels([validator_names[r] for r in gaps.sort_values("gap", ascending=True)["validator"]])
    labels = ax.get_yticklabels()
    [label.set_fontweight('bold') for label in labels if "-Score" in label.get_text()]
    ax.set_ylabel("")
    ax.set_xlabel(f"{mode.capitalize()} accuracy gap between best models, \n as selected by validator and oracle")
    sns.despine()
    ax.grid(alpha=0.3, axis="x")
    plt.tight_layout()
    plt.savefig(f"figures/performance_gap_{mode}.pdf")
    plt.savefig(f"figures/performance_gap_{mode}.png")

In [None]:
compute_and_plot_gaps(model_selection_tables, "mean")

In [None]:
compute_and_plot_gaps(model_selection_tables, "max")

In [None]:
all_corrs_tables = []
for dataset in tables.keys():
    for source in tables[dataset].keys():
        for target in tables[dataset][source].keys():
            table = tables[dataset][source][target]
            if table.empty:
                continue

            rank, corr_fn = correlations[correlation]
            corrs_table = table.rank().corr(corr_fn) if rank else table.corr(corr_fn)
            corrs_table = corrs_table.sort_values(datasets[dataset]["oracle"], ascending=False)
            corrs_table = corrs_table.drop(["target_train_mse_score", "target_val_mse_score", "target_test_mse_score"])
            corrs_table = corrs_table.drop(["target_train_mae_score", "target_val_mae_score", "target_test_mae_score"])
            all_corrs_tables.append(corrs_table)

            # plot correlation
            sns.set(style="ticks", context="paper")
            plt.figure(figsize=(8, 14))
            ax = sns.barplot(x=corrs_table[datasets[dataset]["oracle"]], y=[validator_names[r] for r in corrs_table.index], color='cornflowerblue')
            algorithms_str = ", ".join([a.upper() for a in algorithms])
            ax.set(
                title=f"{dataset.capitalize()} ({source[0].capitalize()}{target[0].capitalize()}) - [{algorithms_str}] ({len(table)} checkpoints)",
                xlabel=f"{correlation.capitalize()} correlation with oracle",
                xlim=(-1, 1)
            )
            ax.bar_label(ax.containers[0], fmt="%.3f", label_type="edge", **{"fontsize": 8})
            labels = ax.get_yticklabels()
            [label.set_fontweight('bold') for label in labels if "-Score" in label.get_text()]
            sns.despine()
            ax.grid(alpha=0.3, axis="x")
            plt.tight_layout()
            plt.savefig(os.path.join(figures_root, f"{dataset}_{source}_{target}_{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}.pdf"))
            plt.savefig(os.path.join(figures_root, f"{dataset}_{source}_{target}_{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}.png"))

            if plot_scatter:
                # plot oracle performances
                sns.set(style="ticks", context="talk")
                ax = plt.figure(figsize=(4, 3))
                sns.histplot(y=datasets[dataset]["oracle"], data=table, binwidth=0.01, color="cornflowerblue")
                plt.ylabel("Oracle performance")
                sns.despine()
                plt.grid(alpha=0.3)

                # plot scores vs oracle scatter plots
                sns.set(style="ticks", context="talk")
                for validator in corrs_table.index:
                    ax = plt.figure(figsize=(4, 3))
                    sns.scatterplot(x=validator, y=datasets[dataset]["oracle"], data=table, s=5, color="cornflowerblue")
                    plt.ylabel("Oracle performance")
                    plt.xlabel(validator_names[validator])
                    ax.text(0.15, 0.85, r"$\rho = {corr:.3f}$".format(corr=corrs_table[datasets[dataset]["oracle"]][validator]), fontsize=10)
                    sns.despine()
                    plt.grid(alpha=0.3)

In [None]:
# make joint table
mean_corrs_table = pd.DataFrame(columns=all_corrs_tables[0].columns, index=all_corrs_tables[0].index)
std_corrs_table = pd.DataFrame(columns=all_corrs_tables[0].columns, index=all_corrs_tables[0].index)
for col in mean_corrs_table.columns:
    for row in mean_corrs_table.index:
        values = []
        for t in all_corrs_tables:
            if not t.empty:
                if col in t and row in t[col]:
                    values.append(t[col][row])
        mean_corrs_table[col][row] = np.nanmean(values)
        std_corrs_table[col][row] = np.nanstd(values)

In [None]:
mean_corrs_table["target_test_mse_score"]["target_val_henry_v3_score"]

In [None]:
a = pd.DataFrame([
        (score, name) for i in range(len(all_corrs_tables)) for score, name in zip(all_corrs_tables[i][datasets["office31"]["oracle"]], all_corrs_tables[i][datasets["office31"]["oracle"]].index)
    ], columns=["score", "validator"])
a = a.sort_values("score")

In [None]:
# plot correlation
sns.set(style="ticks", context="paper")
plt.figure(figsize=(8, 14))
ax = sns.barplot(
    data=a, x="score", y="validator",
    order=mean_corrs_table.sort_values(datasets["office31"]["oracle"], ascending=False).index,
    color="cornflowerblue",
    errorbar="se", errwidth=1,
)
ax.set_yticklabels([validator_names[r] for r in mean_corrs_table.sort_values(datasets["office31"]["oracle"],
    ascending=False).index])
algorithms_str = ", ".join([a.upper() for a in algorithms])
ax.set(
    xlabel=f"{correlation.capitalize()} correlation with oracle",
    xlim=(-1, 1),
    ylabel="",
)
#ax.bar_label(ax.containers[0], fmt="%.3f", label_type="edge", **{"fontsize": 8})
labels = ax.get_yticklabels()
[label.set_fontweight('bold') for label in labels if "-Score" in label.get_text()]
sns.despine()
ax.grid(alpha=0.3, axis="x")
plt.tight_layout()
plt.savefig(os.path.join(figures_root,
    f"{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}.pdf"))
plt.savefig(os.path.join(figures_root,
    f"{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}.png"))

In [None]:
[a for a in mean_corrs_table.sort_values(datasets["mnistmr"]["oracle"],
    ascending=False).index]

In [None]:
corrs_val_train = mean_corrs_table.sort_values(datasets["mnistmr"]["oracle"],
    ascending=False)[datasets["mnistmr"]["oracle"]]

In [None]:
corrs_val_train

In [None]:
val_validators_to_check = [
    "target_val_henry_v3_score", "src_val_gouk_v3_score", "src_val_improved_gouk_score",
    "target_val_henry_score", "src_val_gouk_score",
    "src_val_target_val_bnm_score", "src_val_target_val_class_ami_score",
    "src_val_target_val_mmd_score", "target_val_snd_score",
    "src_val_target_val_im_score", "src_val_target_val_entropy_score",
    "src_val_acc_score"
]
train_validators_to_check = [
    "target_train_henry_v3_score", "src_train_gouk_v3_score", "src_train_improved_gouk_score",
    "target_train_henry_score", "src_train_gouk_score",
    "src_val_target_train_bnm_score", "src_val_target_train_class_ami_score",
    "src_val_target_train_mmd_score", "target_train_snd_score",
    "src_val_target_train_im_score", "src_val_target_train_entropy_score",
    "src_train_acc_score"
]

In [None]:
for v in corrs_val_train.index:
    if v not in val_validators_to_check + train_validators_to_check:
        del corrs_val_train[v]

In [None]:
a = pd.DataFrame()
a["name"] = corrs_val_train.index
a["validator"] = [v.replace("target_train_", "", 1).replace("target_val_", "") if "target" in v else v.replace("src_train_", "", 1).replace("src_val_", "") for v in corrs_val_train.index]
a["score"] = corrs_val_train.values
a["split"] = ["train" if "target_train_" in v else None for v in corrs_val_train.index]
a["split"] = ["val" if "target_val_" in v and s is None else s for v, s in zip(corrs_val_train.index, a["split"])]
a["split"] = ["train" if "src_train_" in v and s is None else s for v, s in zip(corrs_val_train.index, a["split"])]
a["split"] = ["val" if "src_val_" in v and s is None else s for v, s in zip(corrs_val_train.index, a["split"])]

In [None]:
a

In [None]:
simple_validator_names = {
    "henry_v3_score": "H-Score-v3",
    "gouk_v3_score": "G-Score-v3",
    "henry_score": "H-Score-Simple",
    "gouk_score": "G-Score-Simple",
    "src_val_class_ami_score": "ClassAMI",
    "src_val_im_score": "IM",
    "src_val_bnm_score": "BNM",
    "src_val_mmd_score": "MMD",
    "src_val_entropy_score": "Entropy",
    "snd_score": "SND",
    "acc_score": "Source Accuracy",
}

In [None]:
b = []
for v in a["validator"]:
    if v not in b:
        b.append(v)

In [None]:
# plot correlation
sns.set(style="ticks", context="paper")
plt.figure(figsize=(8, 6))
g = sns.catplot(
    data=a, x="score", y="validator",
    kind="bar",
    hue="split",
)
ax = g.axes[0, 0]
print(ax.get_yticklabels())
g.set_yticklabels([simple_validator_names[r] for r in b])
algorithms_str = ", ".join([a.upper() for a in algorithms])
ax.set(
    xlabel=f"{correlation.capitalize()} correlation with oracle",
    xlim=(-1, 1),
    ylabel="",
)
# iterate through the axes containers
[ax.bar_label(c, fmt="%.3f", label_type="edge", **{"fontsize": 8}) for c in ax.containers]
labels = ax.get_yticklabels()
[label.set_fontweight('bold') for label in labels if "-Score" in label.get_text()]
sns.despine()
ax.grid(alpha=0.3, axis="x")
plt.tight_layout()
plt.savefig(os.path.join(figures_root,
    f"{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}_val_train.pdf"))
plt.savefig(os.path.join(figures_root,
    f"{algorithms_str.lower().replace(', ', '+')}_{correlation.lower().replace(' ', '_')}_val_train.png"))