In [2]:
%load_ext autoreload
%autoreload 2

In [6]:
from model_selection import top_model_confusion
from utils import serialize, deserialize, serialize_model, deserialize_model
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import utils
import glob
from evaluation import get_label, check_eval_metric

encoder = 'COLLAPSE'
metal = 'ZN'

results_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-eval_results"
model_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-fitted_k2_models"
linearized_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-linearized_data"


## Ranking gridsearch models

In [5]:
key_conf_metrics = ["precision", "correlation", "dice"]
key_cont_metrics = ["ap"]

### Baselines

In [9]:

valid_metrics = ["msd", "specificity", "precision", "fnr", "fdr", "recall", "accuracy", "balanced_acc", "correlation", "threat_score", "prevalence", "dice", "jaccard"]

encoder = 'COLLAPSE'
baseline = 'Attention'
eval_class = 1
metric_str = 'precision'
metric = check_eval_metric(metric_str, valid_metrics)

model_results_dict = utils.deserialize(f'../data/baselines/{encoder}_{baseline}_test_results.pkl')


In [10]:
metric_dict = {}
all_scores = {}
stabilities = []
for path in glob.glob(f'../data/baselines/*_test_results.pkl'):
    model_results_dict = utils.deserialize(path)
    model_str = path.split('/')[-1].split('.')[0]
    N=0
    for graph_id in model_results_dict.keys():
        datum_results_dict = model_results_dict[graph_id]
        if eval_class in [0,1]:
            y = get_label(datum_results_dict)
            if y != eval_class:
                continue
        if metric_str == "msd":
            datum_cms = datum_results_dict["thresh_msd"]
        else:
            datum_cms = datum_results_dict["thresh_cm"]
        
        model_cms = {} 
        for thresh in datum_cms.keys():
            if type(thresh) == tuple:
                new_thresh = thresh[0]
                model_cms[new_thresh] = 0.0
            else:
                model_cms[thresh] = 0.0

        for thresh in datum_cms.keys():
            if type(thresh) == tuple:
                new_thresh = thresh[0]
                model_cms[new_thresh] += metric(datum_cms[thresh])
            else:
                model_cms[thresh] += metric(datum_cms[thresh])
        N += 1
            
    # now average over all graphs
    for thresh in model_cms.keys():
        model_cms[thresh] /= N
    # get top score and threshold
    scores = [(thresh, model_cms[thresh]) for thresh in model_cms.keys()]
    
    max_score = max(scores, key=lambda item: item[1])
    min_score = min(scores, key=lambda item: item[1])
    stability = max_score[1] - min_score[1]
    stabilities.append(stability)
    metric_dict[model_str] = (max_score[0], max_score[1], stability)
    all_scores[model_str] = scores
    
data = []
for model_str, scores in all_scores.items():
    for thresh, score in scores:
        data.append([model_str, thresh, score])
baseline_df = pd.DataFrame(data, columns=['model_name', 'threshold', 'score'])

In [11]:
baseline_df.sort_values('score', ascending=False).groupby('model_name').head(1)

Unnamed: 0,model_name,threshold,score
145,COLLAPSE_Attention_test_results,0.3,0.008197
82,COLLAPSE_GNNExplainer_test_results,0.7,0.002522
221,ESM_test_results,0.7,0.000431
167,COLLAPSE_test_results,0.7,0.000311
50,AA_GNNExplainer_test_results,0.7,0.00026
105,AA_Attention_test_results,0.4,0.000222
124,ESM_Attention_test_results,0.3,0.000187
5,COLLAPSE_Mask_test_results,0.5,0.000126
11,ESM_Mask_test_results,0.0,5.9e-05
194,ESM_GNNExplainer_test_results,0.2,5.9e-05


### Prospectors

In [12]:
import pandas as pd

encoder = 'COLLAPSE'

results_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-eval_results"
model_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-fitted_k2_models"
processor_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-fitted_k2_processors"
linearized_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-linearized_data"

In [13]:
from model_selection import top_model_confusion, top_model_continuous_avg, top_model_continuous_iid
conf_res = []
for metric in key_conf_metrics:
    print(metric)
    res = top_model_confusion(metric,results_cache_dir, model_cache_dir, eval_class=1, return_all=True)
    res["metric"] = [metric]*len(res)
    conf_res.append(res)
conf_res = pd.concat(conf_res)

cont_res = []
for metric in key_cont_metrics:
    print(metric)
    res = top_model_continuous_avg(metric, results_cache_dir, model_cache_dir, return_all=True)
    res["metric"] = [metric]*len(res)
    cont_res.append(res)
cont_res = pd.concat(cont_res)

# iid_res = []
# for metric in key_cont_metrics:
#     print(metric)
#     res = top_model_continuous_iid(metric, model_cache_dir, linearized_cache_dir, return_all=True)
#     res["metric"] = [metric]*len(res)
#     iid_res.append(res)
# iid_res = pd.concat(iid_res)

precision
correlation
dice
ap


In [14]:
numzero_dict = dict(zip(cont_res.model_name, cont_res.num_zeros))

In [16]:
conf_res['num_zeros'] = conf_res['model_name'].map(numzero_dict)
conf_pvt = conf_res.pivot(index=['model_name', 'threshold', 'num_zeros'], columns='metric', values='score')
for met in key_conf_metrics:
    rank = conf_pvt[met].rank(method='dense', ascending=False)
    conf_pvt[f'rank_{met}'] = rank
# conf_pvt['rank'] = conf_pvt[key_conf_metrics].apply(tuple,axis=1).rank(method='dense',ascending=False)

cont_pvt = cont_res.pivot(index=['model_name', 'num_zeros'], columns='metric', values='score')
cont_pvt['rank_auprc'] = cont_pvt['ap'].rank(method='dense', ascending=False)
# cont_pvt['rank_auroc'] = cont_pvt['auroc'].rank(method='dense', ascending=False)

merged = pd.merge(conf_pvt.reset_index(level=['threshold', 'num_zeros']), cont_pvt, on='model_name', how='left')
merged['rank_zeros'] = merged['num_zeros'].rank(method='dense', ascending=True)
merged['avg_rank'] = merged[[f'rank_{i}' for i in key_conf_metrics] + ['rank_auprc']].mean(axis=1)
merged.sort_values('avg_rank').head(10)

metric,threshold,num_zeros,correlation,dice,precision,rank_precision,rank_correlation,rank_dice,ap,rank_auprc,rank_zeros,avg_rank
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
k30_r0_cutoff6.00_alpha0.0010_tau0.00_lamnan.model,0.9,134,0.266111,0.264567,0.288106,24.0,1.0,2.0,0.263871,6.0,17.0,8.25
k25_r0_cutoff8.00_alpha0.0100_tau0.00_lamnan.model,0.9,114,0.26129,0.265489,0.288582,23.0,4.0,1.0,0.263718,8.0,12.0,9.0
k30_r0_cutoff8.00_alpha0.0010_tau0.00_lamnan.model,0.9,134,0.265969,0.264551,0.288098,25.0,2.0,3.0,0.263862,7.0,17.0,9.25
k30_r0_cutoff8.00_alpha0.0100_tau0.00_lamnan.model,0.9,134,0.265969,0.264551,0.288098,25.0,2.0,3.0,0.263862,7.0,17.0,9.25
k30_r0_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.9,2,0.253023,0.264004,0.289574,22.0,10.0,4.0,0.266219,3.0,3.0,9.75
k30_r0_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.95,2,0.253023,0.264004,0.289574,22.0,10.0,4.0,0.266219,3.0,3.0,9.75
k30_r0_cutoff6.00_alpha0.0100_tau0.00_lamnan.model,0.9,10,0.249237,0.255671,0.283701,38.0,14.0,11.0,0.264128,5.0,9.0,17.0
k30_r0_cutoff8.00_alpha1.0000_tau1.00_lamnan.model,0.9,335,0.258353,0.254677,0.283077,41.0,5.0,14.0,0.263647,9.0,22.0,17.25
k30_r0_cutoff8.00_alpha0.0010_tau4.00_lamnan.model,0.9,335,0.258353,0.254677,0.283077,41.0,5.0,14.0,0.263647,9.0,22.0,17.25
k30_r0_cutoff8.00_alpha0.0001_tau1.00_lamnan.model,0.9,335,0.258353,0.254677,0.283077,41.0,5.0,14.0,0.263647,9.0,22.0,17.25


In [27]:
merged.sort_values(['num_zeros', 'rank_auprc', 'rank_precision']).head(10)

metric,threshold,num_zeros,correlation,dice,precision,rank_precision,rank_correlation,rank_dice,ap,rank_auprc,rank_zeros,avg_rank
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.95,0,0.150881,0.138778,0.277277,66.0,1421.0,1483.0,0.247638,23.0,1.0,598.8
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.9,0,0.171571,0.169508,0.26523,138.0,1000.0,965.0,0.247638,23.0,1.0,425.4
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.85,0,0.183719,0.191359,0.249217,251.0,585.0,341.0,0.247638,23.0,1.0,240.2
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.8,0,0.19768,0.209151,0.230465,379.0,128.0,61.0,0.247638,23.0,1.0,118.4
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.75,0,0.191222,0.206418,0.198232,576.0,287.0,84.0,0.247638,23.0,1.0,194.2
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.7,0,0.187806,0.200922,0.17672,730.0,430.0,139.0,0.247638,23.0,1.0,264.6
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.65,0,0.182604,0.193293,0.161235,881.0,620.0,297.0,0.247638,23.0,1.0,364.4
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.6,0,0.179496,0.187859,0.151348,1049.0,717.0,437.0,0.247638,23.0,1.0,445.4
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.55,0,0.18014,0.185765,0.144932,1140.0,698.0,490.0,0.247638,23.0,1.0,470.4
k25_r4_cutoff8.00_alpha1.0000_tau0.00_lamnan.model,0.5,0,0.180339,0.183493,0.141717,1192.0,690.0,563.0,0.247638,23.0,1.0,493.8


In [22]:
merged[(~merged.index.str.contains('r0'))].sort_values(['avg_rank']).groupby('model_name').head(1).head(20)

metric,threshold,num_zeros,correlation,dice,precision,rank_precision,rank_correlation,rank_dice,ap,rank_auprc,rank_zeros,avg_rank
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
k25_r4_cutoff8.00_alpha1.0000_tau4.00_lamnan.model,0.75,335,0.226671,0.219496,0.281939,44.0,33.0,40.0,0.274555,1.0,22.0,29.5
k25_r4_cutoff8.00_alpha0.0001_tau4.00_lamnan.model,0.8,335,0.21893,0.209035,0.28981,21.0,49.0,62.0,0.274555,1.0,22.0,33.25
k25_r4_cutoff8.00_alpha0.0100_tau4.00_lamnan.model,0.8,335,0.21893,0.209035,0.28981,21.0,49.0,62.0,0.274555,1.0,22.0,33.25
k25_r4_cutoff8.00_alpha0.0010_tau4.00_lamnan.model,0.8,335,0.21893,0.209035,0.28981,21.0,49.0,62.0,0.274555,1.0,22.0,33.25
k30_r4_cutoff8.00_alpha1.0000_tau4.00_lamnan.model,0.75,335,0.225916,0.219324,0.276147,74.0,34.0,42.0,0.271477,2.0,22.0,38.0
k15_r4_cutoff8.00_alpha0.0010_tau4.00_lamnan.model,0.8,335,0.216336,0.208486,0.277856,63.0,54.0,68.0,0.258941,15.0,22.0,50.0
k15_r4_cutoff8.00_alpha0.0100_tau4.00_lamnan.model,0.8,335,0.216336,0.208486,0.277856,63.0,54.0,68.0,0.258941,15.0,22.0,50.0
k15_r4_cutoff8.00_alpha0.0001_tau4.00_lamnan.model,0.8,335,0.216336,0.208486,0.277856,63.0,54.0,68.0,0.258941,15.0,22.0,50.0
k15_r4_cutoff8.00_alpha1.0000_tau4.00_lamnan.model,0.8,335,0.216336,0.208486,0.277856,63.0,54.0,68.0,0.258941,15.0,22.0,50.0
k25_r2_cutoff8.00_alpha1.0000_tau4.00_lamnan.model,0.7,335,0.21684,0.210917,0.271562,104.0,53.0,54.0,0.262593,12.0,22.0,55.75


In [None]:
merged.reset_index().sort_values('ap', ascending=False).groupby('model_name').head(1)

In [None]:
sns.relplot(data=merged.sort_values('ap').groupby('model_name').head(1), x='correlation', y='ap', hue='num_zeros')

In [None]:
from model_selection import extract_params
copy = merged.copy().reset_index()
copy[['k', 'r', 'cutoff', 'alpha', 'tau', 'lambda']] = copy.apply(lambda x: extract_params(x.model_name), axis='columns', result_type='expand')
copy

In [None]:
copy[['num_zeros', 'k', 'r', 'cutoff', 'alpha', 'tau', 'lambda']].corr()

In [None]:
sns.relplot(data=copy.sort_values('ap').groupby('model_name').head(1), x='correlation', y='ap', hue='r', size='num_zeros')

In [None]:
from evaluation import gridsearch_iteration, fetch_model, fetch_processor
from model_selection import extract_params

k, r, cutoff, alpha, tau, lamb = extract_params('k15_r4_cutoff8.00_alpha0.0010_tau4.00_lamnan.model')

process_args = {"datatype": "protein",
            "k": None,
            "dataset": 'ZN',
            "quantizer_type": "AA" if encoder == "AA" else "kmeans",
            "embeddings_path": f"../data/{encoder}_ZN_{cutoff}_train_embeddings_2.pkl",
            "embeddings_type": "dict",
            "mapping_path": None,
            "sample_size": None,
            "sample_scheme": None,
            "dataset_path": None,
            "verbosity": "low", # change this to low!
            "so_dict_path": None}

proc, processor_name = fetch_processor(k, processor_cache_dir, process_args, cutoff=cutoff)
hparams = {"alpha": alpha, "tau": tau, "lambda": lamb}
model_args = {"modality":"graph",
        "processor": proc,
        "r": r,
        "variant": "inferential",
        "hparams": hparams,
        "train_graph_path": f"../data/{encoder}_ZN_{cutoff}_train_graphs_2",
        "train_label_dict": None}
model, model_str = fetch_model(proc, r, model_cache_dir, model_args, cutoff=cutoff, alpha=alpha, tau=tau)

In [None]:
model_results_dict, data_linearized_dict = gridsearch_iteration(model, model_args, gt_dir=None, thresh="all", arm="train")

In [None]:
metric_str = 'precision'
eval_class = 1
metric = check_eval_metric(metric_str, valid_metrics)
N=0
for graph_id in model_results_dict.keys():
    datum_results_dict = model_results_dict[graph_id]
    print
    if eval_class in [0,1]:
        y = get_label(datum_results_dict)
        if y != eval_class:
            continue
    if metric_str == "msd":
        datum_cms = datum_results_dict["thresh_msd"]
    else:
        datum_cms = datum_results_dict["thresh_cm"]
    
    model_cms = {} 
    for thresh in datum_cms.keys():
        if type(thresh) == tuple:
            new_thresh = thresh[0]
            model_cms[new_thresh] = 0.0
        else:
            model_cms[thresh] = 0.0

    for thresh in datum_cms.keys():
        if type(thresh) == tuple:
            new_thresh = thresh[0]
            model_cms[new_thresh] += metric(datum_cms[thresh])
        else:
            model_cms[thresh] += metric(datum_cms[thresh])
    N += 1
        
# now average over all graphs
for thresh in model_cms.keys():
    model_cms[thresh] /= N
# get top score and threshold
scores = [(thresh, model_cms[thresh]) for thresh in model_cms.keys()]

In [None]:
res_precision = top_model_confusion("precision",results_cache_dir, model_cache_dir, eval_class=1, return_all=True)

In [None]:
from model_selection import top_model_continuous_avg
res_auc = top_model_continuous_avg('auprc', results_cache_dir, model_cache_dir, return_all=True)

In [None]:

res_precision[['k', 'r', 'cutoff', 'alpha', 'tau', 'lambda']] = res_precision.apply(lambda x: extract_params(x.model_name), axis='columns', result_type='expand')

In [None]:
res_auc[['k', 'r', 'cutoff', 'alpha', 'tau', 'lambda']] = res_auc.apply(lambda x: extract_params(x.model_name), axis='columns', result_type='expand')

In [None]:
res = pd.merge(res_precision, res_auc, on=['model_name', 'k', 'r', 'cutoff', 'alpha', 'tau', 'lambda'], how='inner', suffixes=('_precision', '_auprc'))

In [None]:
res_inf = res.dropna(subset=['alpha', 'tau'])

In [None]:
res_inf['alpha'] = res_inf['alpha'].astype('str')

In [None]:
res_inf = res_inf.sort_values('score_precision', ascending=False).groupby('model_name').nth(0)

In [None]:
sns.relplot(data=res_inf, x='score_precision', y='score_auprc', hue='tau', size='cutoff', markers='alpha', palette='viridis')
plt.show()

In [None]:
sns.relplot(data=res_inf, x='score_precision', y='score_auprc', hue='k', size='r', markers='alpha', palette='viridis')
plt.show()

In [None]:
sns.relplot(data=res_inf, x='score_precision', y='score_auprc', hue='k', size='alpha', markers='alpha', palette='viridis')
plt.show()

In [None]:
res_lin = res.dropna(subset=['lambda'])
res_lin = res_lin.sort_values('score_precision', ascending=False).groupby('model_name').nth(0)

In [None]:
sns.relplot(data=res_lin, x='score_precision', y='score_auprc', hue='k', size='r', markers='alpha', palette='viridis')
plt.show()

In [None]:
res_lin

## Plotting on test set

In [None]:
# hard coded for now from looking at top of ranked dataframe
encoder_top_models = \
    {'COLLAPSE': ('k15_r1_cutoff8.00_alpha0.500_tau4.00_lamnan.model', 0.9), \
     # {'COLLAPSE': ('k20_r0_cutoff4.00_alpha0.001_tau0.00_lamnan.model', 0.9), \
    'ESM': ('k30_r1_cutoff4.00_alpha0.500_tau1.00_lamnan.model', 0.0), \
    'AA': ('k21_r2_cutoff6.00_alphanan_taunan_lam1.00.model', 0.5)}

baseline_top_models = \
    {'COLLAPSE': ('COLLAPSE-ZN-8.0-0.0005', 0.7), \
    'ESM': ('ESM-ZN-8.0-0.0005', 0.4), \
    'AA': ('AA-ZN-8.0-0.0005', 0.6)}

In [None]:
key_conf_metrics = ["precision", "correlation", "dice"]
key_cont_metrics = ["auprc"]

In [None]:
from evaluation import test_eval, extract_params
import matplotlib.pyplot as plt
import seaborn as sns
import utils

In [None]:
import warnings
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")

In [None]:
def setup_figure(width=6, height=3):
    sns.set(style='white')
    sns.set_context('paper')
    plt.figure(figsize=(width,height))
pal = sns.color_palette('tab20')

In [None]:
metal = 'ZN'

In [None]:
test_df = []
test_metrics = key_conf_metrics + key_cont_metrics
for encoder, (model_str, threshold) in encoder_top_models.items():
    results_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-eval_results"
    model_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-fitted_k2_models"
    processor_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-fitted_k2_processors"
    linearized_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results_2/{encoder}-linearized_data"

    _,_,cutoff,_,_,_ = extract_params(model_str)

    if encoder == 'AA':
        g_encoder = 'COLLAPSE'
    else:
        g_encoder = encoder

    G_dir = f"../data/{g_encoder}_{metal}_{cutoff}_test_graphs_2"
    
    df = test_eval(model_str, threshold, test_metrics, model_cache_dir, processor_cache_dir, G_dir, gt_dir=None, label_dict=None, modality="graph")
    test_df.append(df)
test_df = pd.concat(test_df)

In [None]:
# baselines
from evaluation import get_test_metrics

base_df = []
for encoder in ['COLLAPSE', 'ESM', 'AA']:
    best_model, best_thresh = baseline_top_models[encoder]
    results_dict = deserialize(f'../data/baselines/{encoder}_test_results.pkl')
    df = get_test_metrics(results_dict, encoder, best_model, best_thresh, test_metrics)
    base_df.append(df)
base_df = pd.concat(base_df)

In [None]:
combined_df = pd.concat([test_df, base_df])
combined_df['method'] = ['Prospector']*len(test_df) + ['GAT+GNNExplainer']*len(base_df)

In [None]:
test_df.to_csv('../data/results/K2_test_results.csv')

In [None]:
for met in test_metrics:
    subdf = combined_df[combined_df.metric == met].reset_index()

    plt.clf()
    if met == 'auprc':
        setup_figure(6,3)
        ax = sns.barplot(data=subdf[subdf.regime == 'all'], x='value', y='encoder', hue='method', orient='horizontal', errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
        sns.stripplot(data=subdf[subdf.regime == 'all'], x='value', y='encoder',  hue='method', orient='horizontal', dodge=True, alpha=0.1, linewidth=0.5, ax=ax, legend=False)
        plt.title(met, fontsize=12)
        plt.legend(loc=(0.65,1.01))
    else:
        fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(6, 3), sharey=True, gridspec_kw={'wspace': 0})
        sns.barplot(data=subdf[subdf['regime'] == 'class-1'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, ax=ax2, errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
        sns.stripplot(data=subdf[subdf['regime'] == 'class-1'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, alpha=0.1, linewidth=0.5, ax=ax2, legend=False)
        # ax1.yaxis.set_label_position('left')

        ax2.set_title('  '+'class-1', loc='left')
        ax2.set_ylabel('')
        ax2.set_yticklabels([])
        ax2.legend_.remove()
    
        sns.barplot(data=subdf[subdf['regime'] == 'all'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, ax=ax1, errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
        sns.stripplot(data=subdf[subdf['regime'] == 'all'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, alpha=0.1, linewidth=0.5, ax=ax1, legend=False)
        ax1.legend_.remove()
    
        # optionally use the same scale left and right
        xmax = max(ax1.get_xlim()[1], ax2.get_xlim()[1])
        ax1.set_xlim(xmax=xmax)
        ax2.set_xlim(xmax=xmax)

        ax1.invert_xaxis()  # reverse the direction
        ax1.tick_params(axis='y', labelleft=True, left=True, labelright=False, right=False)
        ax1.set_ylabel('')
        ax1.set_title('all data'+'  ', loc='right')

        plt.legend(loc=(-1.01,1.02))

    fig.suptitle(met, fontsize=12)
    
    plt.tight_layout()
    # plt.savefig(f'../data/figures/{met}-k2-vs-baseline.png', dpi=300, format='png')
    plt.show()