In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys 
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)

from codebase.utils.constants import *
from codebase.utils.eval_utils import *

# Settings

In [None]:
# arguments to specify
project_path = Path('/raid/sonali/project_mvs/')
data_set = 'test'

job_ids=['mj3pqeyk_dataaug-v2-flip_split3_selected-snr_no-wt_no-checkerboard']
#job_ids = ['0jpt9ixq_dataaug-v2-flip_split3_pseudo_multiplex_no-wt_no-checkerboard'] #['dn2rwhyl_split3_selectedsnr_otsu3_seed3'] #['skoe17l6_dataaug_05-0-0-0_split3_selected_snr_seed3_no-wt-init']
#job_ids=["1kh90kst_dataaug-v2-flip_split3_pseudo_multiplex_selected-snr-set_no-wt_no-checkerboard"]
#job_ids = ['tzav1irg_dataaug-v2-flip_split3_real-multiplex-pseudoset_no-wt_no-checkerboard']
#job_ids = ['gj5tmvbv_dataaug-v2-flip_split3_TLS-set_no-wt_no-checkerboard']
job_id = job_ids[0]

#sel_epochs=['epoch33-1']
sel_epochs = [get_best_epoch_w_imgs(project_path, x) for x in job_ids]
#sel_epochs = [get_last_epoch_w_imgs(project_path, x) for x in job_ids]

sel_epochs = dict(zip(job_ids, sel_epochs))
print(sel_epochs)

eval_metrics = ['pcorr']

SAVE_PATH = project_path.joinpath('results',job_ids[0],data_set+'_pics', sel_epochs[job_ids[0]], 'level_2')
if not os.path.exists(SAVE_PATH):
    SAVE_PATH.mkdir(parents=True, exist_ok=False)
save_fname = job_ids[0].split('_')[0]


dry_run = True

In [None]:
# load metadata
meta = pd.read_csv(project_path.joinpath('meta', 'melanoma-merged_clinical_data-v8.tsv'), sep='\t')
he_qc = pd.read_csv(project_path.joinpath('meta', 'HE-QC.tsv'), sep='\t')

# Similarity of nuclei density

In [None]:
#nuclei_pcorr = pd.read_csv(project_path.joinpath('meta','nuclei_density', 'nuclei_density-he_imc-split3_'+data_set+'-'+'pcorr'+'-max'+'1024'+'.tsv'), sep='\t', index_col=[0])
nuclei_pcorr = pd.read_csv(project_path.joinpath('meta','nuclei_density', 'nuclei_density-he_imc-'+'all'+'-'+'pcorr'+'-max'+'1024'+'.tsv'), sep='\t', index_col=[0])
# using 32 as avg_kernel of size 32 was used during evaluation
nuclei_pcorr = nuclei_pcorr.loc[:,['pcorr_64']] #pcorr_32
nuclei_pcorr.columns = ['nuclei_sim']
# group ROIs based on pcorr median
nuclei_pcorr['nuclei_slice_sim'] = ['high' if x>=nuclei_pcorr['nuclei_sim'].median() else 'low' for x in nuclei_pcorr['nuclei_sim']]
fig, ax = plt.subplots(figsize=(2,3))
sns.boxplot(y='nuclei_sim', x='nuclei_slice_sim', data=nuclei_pcorr.sort_values(by='nuclei_slice_sim'), palette='Accent', ax=ax)
#plt.ylim(0,1)
plt.ylabel("Pearson's correlation coeff.")
plt.xlabel('Slice similarity')
ax.axhline(0,linestyle='--', color='grey')
if not dry_run:
    plt.savefig(SAVE_PATH.joinpath(save_fname+'_nuclei_stratification.png'), bbox_inches='tight', dpi=300)
    plt.savefig(SAVE_PATH.joinpath(save_fname+'_nuclei_stratification.pdf'), bbox_inches='tight', dpi=300)
plt.show()

# Plot boxplots of metrics per protein

In [None]:
resolutions = ['level_2'] #, 'level_4', 'level_6']
all_evals = dict()
for res in resolutions:
    for job_id in job_ids:
        #job_evals = pd.DataFrame() #dict()
        for i,eval_metric in enumerate(eval_metrics):
            df = pd.read_csv(project_path.joinpath('results', job_id, data_set+'_eval', sel_epochs[job_id], res, 'avgkernel_64',eval_metric+'-eval.csv'), index_col=[0])
            if eval_metric != 'pcorr':
                df.index.name = 'protein'
                df = df.reset_index()
            #job_evals[eval_metric] = df
            if i == 0:
                job_evals = df
            else:
                job_evals = job_evals.merge(df, on=['protein', 'sample_id', 'roi'])
        all_evals[job_id+':'+res] = job_evals
        
all_evals = pd.concat(all_evals)
all_evals = all_evals.reset_index()
all_evals.columns = [x.replace('level_0', 'job_id').replace('level_1', 'eval_metric') for x in all_evals.columns]
all_evals = all_evals.drop(columns=['eval_metric'])
all_evals['resolution'] = [x.split(':')[1] for x in all_evals['job_id']]
all_evals['job_id'] = [x.split(':')[0] for x in all_evals['job_id']]

In [None]:
# Add metadata
all_evals = all_evals.merge(meta.loc[:,['tupro_id', 'subtype_group', 'cd8_phenotype_revised']], left_on='sample_id', right_on='tupro_id', how='left')
all_evals = all_evals.merge(he_qc.loc[:,['sample', 'status']], left_on='sample_roi', right_on='sample', how='left')
all_evals['roi_loc'] = [x[0] if x[0] in ['C', 'F'] else 'other' for x in all_evals['roi']]
all_evals = all_evals.merge(nuclei_pcorr, left_on='sample_roi', right_index=True, how='left')
all_evals['rel_pcorr'] = all_evals['pcorr']/all_evals['nuclei_sim']

In [None]:
hue_cols = ['sample_id', 'status', 'subtype_group','cd8_phenotype_revised', 'roi_loc', 'nuclei_slice_sim']

#eval_cols = [x for x in all_evals.columns if x not in ['job_id', 'resolution','protein','sample_roi', 'sample_id', 'roi','pval']]
#eval_cols = [x for x in eval_cols if x.split('_')[0] not in ['overlap', 'dice', 'perc_pos', 'pixelsGT','pixelsPred']]
eval_cols = ['pcorr'] #, 'densitycorr_0.8']#, 'rel_pcorr']
min_metric = -1.01
max_metric = 1.01


for hue_col in hue_cols:
    for metric in eval_cols:
        for job_id in sorted(all_evals.job_id.unique()):
            df = all_evals.loc[all_evals.job_id.isin([job_id]),:]
            #df = df.loc[df['roi_loc']=='F',:]
            #min_metric = df[metric].min()-0.1*df[metric].max()
            #max_metric = df[metric].max()+0.1*df[metric].max()
            n_resolutions = df.resolution.nunique()
            fig, axes = plt.subplots(1,n_resolutions, figsize=(3+n_resolutions*5,5))
            for j,res in enumerate(sorted(df.resolution.unique())):
                ax_plot = axes[j] if n_resolutions>1 else axes
                plot_df = df.loc[df['resolution']==res,:].sort_values(by='protein') #['sample_id','protein', metric]
                sns.boxplot(x=metric, y='protein', color='white', data=plot_df, showfliers=False, ax=ax_plot)
                
                cmap = 'Accent' if hue_col=='nuclei_slice_sim' else None
                sns.stripplot(x=metric, y='protein', hue=hue_col, data=plot_df, ax=ax_plot, palette=cmap)
                ax_plot.set_title(job_id+', '+res)
                if j == (n_resolutions-1):
                    ax_plot.legend(bbox_to_anchor=(1,1))
                else:
                    ax_plot.legend([])
                #ax_plot.vline(0, )
                ax_plot.set_ylabel('')
                ax_plot.set_xlim(min_metric, max_metric)
            fig.tight_layout(pad=2.0)
            plt.show()


# Boxplots stratified by slice-slice similarity

In [None]:
fig, ax = plt.subplots(figsize=(6,4))
plot_df = all_evals.loc[:,['protein','pcorr', 'nuclei_slice_sim', 'sample_roi']].drop_duplicates(['protein', 'sample_roi'])
plot_df = plot_df.sort_values(by='nuclei_slice_sim')
sns.boxplot(x='pcorr', y='protein', data=plot_df, showfliers=True, ax=ax, hue='nuclei_slice_sim', palette='Accent', fliersize=3)
#sns.stripplot(x='pcorr', y='protein', data=plot_df, ax=ax, hue='nuclei_slice_sim', palette='Accent', dodge=True)
plt.ylabel('')
plt.xlabel('Pearson correlation')
ax.axvline(0,linestyle='--', color='grey')
plt.legend(bbox_to_anchor=(1,1))
if not dry_run:
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-pcorr_by_nuclei.png'), bbox_inches='tight', dpi=300)
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-pcorr_by_nuclei.pdf'), bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# overall correlation by metadata (here by roi_loc)
col = 'nuclei_slice_sim'#'roi_loc'
metric = 'pcorr'
agg_df = all_evals.groupby([col,'protein'])[metric].median().to_frame('agg_'+metric).reset_index()
agg_df = agg_df.pivot(index='protein', columns=col, values='agg_'+metric)
agg_df

# Growing window analysis

In [None]:
df_all = pd.DataFrame()
for avg_kernel in [0,4,16,32,64,75,128,256]:
    try:
        df = pd.read_csv(project_path.joinpath('results', job_id, data_set+'_eval', sel_epochs[job_id], 'level_2', 'avgkernel_'+str(avg_kernel),'pcorr-eval.csv'), index_col=[0])
    except FileNotFoundError:
        continue
    #df.index.name = 'protein'
    #df = df.reset_index(drop=False)
    df['avgkernel'] = avg_kernel
    df_all = pd.concat([df_all, df])
    # plot 
    fig, ax = plt.subplots(figsize=(5,4))
    sns.boxplot(x='pcorr', y='protein', data=df, ax=ax, palette='tab10')#, showfliers=False)#, hue='nuclei_slice_sim', palette='Accent')
    #sns.stripplot(x='pcorr', y='protein', data=plot_df, ax=ax, hue='nuclei_slice_sim', palette='Accent', dodge=True)
    plt.ylabel('')
    plt.xlabel('Pearson correlation')
    ax.axvline(0, linestyle='--', color='grey')
    plt.xlim(-1,1)
    plt.title(avg_kernel)
    #plt.legend(bbox_to_anchor=(1,1))

In [None]:
# # Mean of means +/- std
# plot_df = df_all.groupby(['avgkernel', 'protein', 'sample_id']).mean().reset_index()
# plot_df_means = plot_df.groupby(['avgkernel', 'protein']).mean().reset_index()
# # to make xticks equidistant
# plot_df = plot_df.sort_values(by=['protein', 'avgkernel'])
# plot_df['avgkernel'] = [str(x) for x in plot_df['avgkernel']]
# plot_df_means = plot_df_means.sort_values(by=['protein', 'avgkernel'])
# plot_df_means['avgkernel'] = [str(x) for x in plot_df_means['avgkernel']]

# fig, ax = plt.subplots(figsize=(5,4))
# sns.scatterplot(x='avgkernel', y='pcorr', data=plot_df_means, hue='protein', palette='tab10')
# sns.lineplot(x='avgkernel', y='pcorr', data=plot_df, hue='protein', palette='tab10', legend=False,
#             estimator='mean')
# ax.axhline(0,linestyle='--', color='grey')
# plt.ylabel('Mean Pearson correlation')
# plt.xlabel('Averaging kernel size')
# plt.legend(bbox_to_anchor=(1,1))

In [None]:
for agg_method in ['mean', 'median']:
    fig, ax = plt.subplots(figsize=(5,4))
    if agg_method == 'mean':
        plot_df = df_all.groupby(['avgkernel', 'protein']).pcorr.mean()
    elif agg_method == 'median':
        plot_df = df_all.groupby(['avgkernel', 'protein']).pcorr.median()
    plot_df = plot_df.to_frame(agg_method+'_pcorr').reset_index()
    # to make xticks equidistant
    plot_df = plot_df.sort_values(by=['protein', 'avgkernel'])
    plot_df['avgkernel'] = [str(x) for x in plot_df['avgkernel']]
    sns.scatterplot(x='avgkernel', y=agg_method+'_pcorr', data=plot_df, hue='protein', palette='tab10')
    sns.lineplot(x='avgkernel', y=agg_method+'_pcorr', data=plot_df, hue='protein', palette='tab10', legend=False)
    ax.axhline(0,linestyle='--', color='grey')
    plt.ylabel(agg_method.capitalize()+' Pearson correlation')
    plt.xlabel('Averaging kernel size')
    plt.legend(bbox_to_anchor=(1,1))
    if not dry_run:
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-pcorr_growing_'+agg_method+'.png'), bbox_inches='tight', dpi=300)
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-pcorr_growing_'+agg_method+'.pdf'), bbox_inches='tight', dpi=300)
    plt.show()

# Poor ROIs
Identify problematic ROIs (on which the model performs poorly)

In [None]:
# based on which metric to perform selection
sel_metric = 'pcorr'
# based on which proteins to perform selection (can also be a selected one, eg ['MelanA'])
sel_proteins = sorted(all_evals.protein.unique())
# aggregation method (across proteins) {mean, median, perc_proteins_above_thrs}
agg_method = 'median'#'perc_proteins_above_thrs'#'mean'
# threshold for counting "bad" proteins (only used if perc_proteins_above_thrs)
thrs = 0
# how many ROI names to return
bottom_n = 5
# whether higher values are better (True for all current metrics)
higher_is_better = True

poor_df = all_evals.loc[all_evals.protein.isin(sel_proteins),['sample_roi', sel_metric]]
if agg_method == 'mean':
    poor_df = poor_df.groupby('sample_roi')[sel_metric].mean().to_frame('agg_'+sel_metric)
elif agg_method == 'median':
    poor_df = poor_df.groupby('sample_roi')[sel_metric].median().to_frame('agg_'+sel_metric)
elif agg_method == 'perc_proteins_above_thrs':
    poor_df = poor_df.groupby('sample_roi')[sel_metric].apply(lambda x: round(sum(x>thrs)/len(x),2)).to_frame('agg_'+sel_metric)
else:
    print('Selected aggregation method not supported!')
poor_df = poor_df.sort_values(by='agg_'+sel_metric, ascending=1-higher_is_better)
poor_df.tail(bottom_n)

In [None]:
poor_df.head(bottom_n)

# CW-SSIM

In [None]:
cwssim_path = '/raid/sonali/project_mvs/results/'+job_id+'/'+data_set+'_eval/'+sel_epochs[job_id]+'/level_2/cwssim_eval_nb.csv'
if os.path.exists(cwssim_path):
    cwssim_df_all = pd.read_csv(cwssim_path)
    plot_df = cwssim_df_all.reset_index()
    plot_df = plot_df.merge(nuclei_pcorr.reset_index(), on='sample_roi', how='left')

    fig, ax = plt.subplots(figsize=(6,4))
    sns.boxplot(x='cwssim_30', y='protein', data=plot_df.sort_values(by=['nuclei_slice_sim', 'protein']), showfliers=True,
                fliersize=3, ax=ax, hue='nuclei_slice_sim', palette='Accent')
    plt.ylabel('')
    plt.xlabel('CW-SSIM')
    plt.xlim(0-0.05,1+0.05)
    plt.legend(bbox_to_anchor=(1,1))
    if not dry_run:
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cwssim_by_nuclei.png'), bbox_inches='tight', dpi=300)
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cwssim_by_nuclei.pdf'), bbox_inches='tight', dpi=300)
    plt.show()


# Binarization-based score plots (dice, densitycorr, overlap)

In [None]:
metric = 'dice' #'densitycorr' #'dice'
hue_cols = ['sample_id','status', 'subtype_group','cd8_phenotype_revised', 'nuclei_slice_sim']

if metric in eval_cols:
    id_cols = ['protein','sample_id','roi', 'subtype_group', 'cd8_phenotype_revised', 'status','nuclei_slice_sim']
    dice_cols = [x for x in all_evals.columns if metric in x]
    dice_cols.extend(id_cols)
    dice_df = all_evals.loc[:,dice_cols]
    dice_df = dice_df.melt(id_vars=id_cols, var_name='thrs', value_name=metric)
    dice_df['thrs'] = [float(x.split('_')[-1]) for x in dice_df['thrs']]
    display(dice_df.head(2))

    for protein in sorted(dice_df.protein.unique()):
        for hue_col in hue_cols:
            plt.figure(figsize=(8,5))
            sns.boxplot(x='thrs', y=metric, data=dice_df.loc[dice_df.protein==protein,:], color='lightgrey')
            sns.stripplot(x='thrs', y=metric, data=dice_df.loc[dice_df.protein==protein,:], hue=hue_col, alpha=0.6)
            plt.legend(bbox_to_anchor=(1,1))
            plt.xticks(rotation=90)
            plt.title(protein)
            plt.show()