In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import tifffile
import json

import matplotlib.pyplot as plt
import seaborn as sns
import sys 
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)

from codebase.utils.constants import *
from codebase.utils.eval_utils import *

In [None]:
def plot_triangle_heatmap(df_corr, ax=None, title='', cbar_title='Spearman correlation coefficient'):
    mask = np.triu(df_corr)
    for i in range(mask.shape[0]):
        mask[i,i] = 1
    ax = sns.heatmap(df_corr, vmin=-1,vmax=1, cmap='RdBu_r', annot=True, fmt='.2f',
               cbar_kws={'label':cbar_title}, ax=ax, mask=mask)
    ax.set_title(title)
    
def plot_masked_heatmap(df_corr, ax=None, title='', cbar_title='Spearman correlation coefficient'):
    mask = pd.DataFrame(np.zeros(df_corr.shape), index=df_corr.index, columns=df_corr.columns)
    for i in range(mask.shape[0]):
        mask.iloc[i,i] = 1
    mask = mask.to_numpy()
    ax = sns.heatmap(df_corr, vmin=-1,vmax=1, cmap='RdBu_r', annot=True, fmt='.2f',
               cbar_kws={'label':cbar_title}, ax=ax, mask=mask)
    ax.set_title(title)
    

In [None]:
project_path = Path('/raid/sonali/project_mvs/') #Path('/cluster/work/grlab/projects/projects2021-multivstain/')
data_set = 'test'
resolution = 'level_2'
dry_run = True #False

submission_id="mj3pqeyk_dataaug-v2-flip_split3_selected-snr_no-wt_no-checkerboard"

# which epoch to use (best, last, or specified)
epoch = get_best_epoch_w_imgs(project_path, submission_id)
#epoch = get_last_epoch_w_imgs(project_path, submission_id)
#epoch = 'epoch93-1'
print(epoch)

# In case want to plot the current results against another job id (like in MICCAI submission), need to
# 1. Run the notebook for the other job id (comparator), with dry_run=False
# 2. Paste the absolute path to the coexpression_sdplot_preddf.csv file
add_comparison = '/raid/sonali/project_mvs/results/1kh90kst_dataaug-v2-flip_split3_pseudo_multiplex_selected-snr-set_no-wt_no-checkerboard/test_pics/epoch87-1/level_2/1kh90kst-coexpression_sdplot_preddf.csv'
# else set add_comparison to None
# add_comparison = None


# aggregation name for GT
gt_prep = 'agg_masked_data-raw_clip99_arc_otsu3_std_minmax_split3-r5'

In [None]:
DATA_DIR = project_path.joinpath(DATA_DIR)
RESULTS_DIR = project_path.joinpath(RESULTS_DIR)
GT_PATH = DATA_DIR.joinpath('imc_updated', gt_prep)
PRED_PATH = RESULTS_DIR.joinpath(submission_id, data_set+'_scdata', epoch, resolution)
SAVE_PATH = RESULTS_DIR.joinpath(submission_id,data_set+'_pics', epoch, resolution)
if not os.path.exists(SAVE_PATH):
    SAVE_PATH.mkdir(parents=True, exist_ok=False)
save_fname = submission_id.split('_')[0]

# load job args
job_args = json.load(open(RESULTS_DIR.joinpath(submission_id, 'args.txt')))
cv_split = job_args['cv_split']
cv_splits = json.load(open(project_path.joinpath(CV_SPLIT_ROIS_PATH)))
s_rois = cv_splits[cv_split][data_set] #can also specify one ROI by s_rois=['MECADAC_F3']

# Plot internal correlation structure in GT and predictions

In [None]:
df_corr_agg = dict()
for s_roi in s_rois:
    print(s_roi)
    pred = pd.read_csv(PRED_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    pred = pred.loc[:,~pred.columns.isin(['sample_roi', 'X', 'Y', 'radius'])]
    print('pred NaNs: '+str(sum(pred.isna().sum(axis=1)>0))+' objects')
    gt = pd.read_csv(GT_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    gt = gt.loc[:,pred.columns]
    print('gt NaNs: '+str(sum(gt.isna().sum(axis=1)>0))+' objects')
    
    # gt internal correlation
    gt_corr = gt.corr('spearman')
    gt_corr = gt_corr.where(np.tril(np.ones(gt_corr.shape)).astype(bool)).fillna(0)
    # pred internal correlation
    pred_corr = pred.corr('spearman')
    pred_corr = pred_corr.where(np.triu(np.ones(pred_corr.shape)).astype(bool)).fillna(0)
    df_corr = gt_corr+pred_corr
    # make sure the diagonal is one
    for i in range(df_corr.shape[0]):
        df_corr.iloc[i,i] = 1
        
    df_corr_agg[s_roi] = df_corr
    # plot the co-expression pattern
    plt.figure(figsize=(8,6))
    plot_masked_heatmap(df_corr, ax=None, title=s_roi+'\nBottom triangle: GT \nTop triangle: Prediction')
    plt.show()

# Plot aggregated metrics: mean and standard deviation

In [None]:
df_corr_mean = pd.DataFrame(np.zeros(df_corr.shape), index=df_corr.index, columns=df_corr.columns)
for k in df_corr_agg.keys():
    df_corr_mean = df_corr_mean+df_corr_agg[k]
df_corr_mean = df_corr_mean/len(df_corr_agg.keys())

df_corr_std = pd.DataFrame(np.zeros(df_corr.shape), index=df_corr.index, columns=df_corr.columns)
for k in df_corr_agg.keys():
    df_corr_std = df_corr_std+(df_corr_agg[k]-df_corr_mean)**2
df_corr_std = df_corr_std/len(df_corr_agg.keys())    

In [None]:
plt.figure(figsize=(8,6))
plot_masked_heatmap(df_corr_mean, ax=None, title='Mean across ROIs\nBottom triangle: GT \nTop triangle: Prediction')
if not dry_run:
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_patterns.png'), bbox_inches='tight', dpi=300)
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_patterns.pdf'), bbox_inches='tight', dpi=300)
plt.show()
plt.figure(figsize=(8,6))
plot_masked_heatmap(df_corr_std, ax=None, title='Std acorss ROIs\nBottom triangle: GT \nTop triangle: Prediction')
plt.show()

# Merge data for plotting

In [None]:
def prep_pointplot_df(df_corr, keep_lower=True):
    if keep_lower:
        plot_df = df_corr.where(np.tril(np.ones(df_corr.shape)).astype(bool))
    else:
        plot_df = df_corr.where(np.triu(np.ones(df_corr.shape)).astype(bool))
    plot_df.index.name = 'protein1'
    plot_df = plot_df.reset_index(drop=False).melt(id_vars='protein1', var_name='protein2', value_name='corr_value')
    plot_df = plot_df.loc[~plot_df.corr_value.isna(),:]
    plot_df = plot_df.loc[plot_df['protein1']!=plot_df['protein2']]
    plot_df['protein_pair'] = [' : '.join(sorted([x,y])) for x,y in zip(plot_df['protein1'],plot_df['protein2'])]
    plot_df = plot_df.set_index('protein_pair')
    return plot_df

# Merge mean values
plot_df_gt = prep_pointplot_df(df_corr_mean)
plot_df_gt['data_type'] = 'GT'
plot_df_gt = plot_df_gt.sort_values(by='corr_value', ascending=True)
plot_df_pred = prep_pointplot_df(df_corr_mean, keep_lower=False)
plot_df_pred['data_type'] = 'Prediction'
plot_df_pred = plot_df_pred.loc[plot_df_gt.index,:]
merged = pd.concat([plot_df_gt, plot_df_pred])
merged = merged.reset_index(drop=False)

In [None]:
# Merge std values
plot_df_gt = prep_pointplot_df(df_corr_std)
plot_df_gt['data_type'] = 'GT'
plot_df_gt = plot_df_gt.sort_values(by='corr_value', ascending=True)
plot_df_pred = prep_pointplot_df(df_corr_std, keep_lower=False)
plot_df_pred['data_type'] = 'Prediction'
plot_df_pred = plot_df_pred.loc[plot_df_gt.index,:]
merged_std = pd.concat([plot_df_gt, plot_df_pred])
merged_std = merged_std.reset_index(drop=False)
merged_std.columns = [x.replace('corr_value', 'corr_std') for x in merged_std.columns]

merged = merged.merge(merged_std, on=['protein_pair', 'data_type'], how='left')

# Dotplot with shadows depicting standard deviation

In [None]:
if add_comparison is not None:
    save_fname = save_fname+'_comp_'+add_comparison.split('/')[-1].split('-')[0]

fig, ax = plt.subplots(figsize=(10,3))
plt.grid(True, color='lightgrey', linestyle='-', linewidth=0.5, alpha=0.6)

merged_gt = merged.loc[merged['data_type']=='GT',:]
ax.scatter(merged_gt['protein_pair'], merged_gt['corr_value'], color='tab:blue', alpha=0.8)
lower = merged_gt['corr_value'] - merged_gt['corr_std']
upper = merged_gt['corr_value'] + merged_gt['corr_std']
ax.fill_between(merged_gt['protein_pair'],lower,upper, alpha=0.2, color='tab:blue')

merged_pred = merged.loc[merged['data_type']=='Prediction',:]
ax.scatter(merged_pred['protein_pair'], merged_pred['corr_value'], color='tab:orange',alpha=0.8)
lower = merged_pred['corr_value'] - merged_pred['corr_std']
upper = merged_pred['corr_value'] + merged_pred['corr_std']
ax.fill_between(merged_pred['protein_pair'],lower,upper, alpha=0.2, color='tab:orange')

if add_comparison is not None:
    merged_comp = pd.read_csv(add_comparison)
    if add_comparison.split('/')[-1].split('-')[0]=='1kh90kst':
        empty_prots = ['CD31', 'CD16', 'CD20']
    merged_comp['prot1'] = [x.split(' : ')[0] for x in merged_comp['protein_pair']]
    merged_comp['prot2'] = [x.split(' : ')[1] for x in merged_comp['protein_pair']]
    merged_comp = merged_comp.loc[~merged_comp.prot1.isin(empty_prots),:]
    merged_comp = merged_comp.loc[~merged_comp.prot2.isin(empty_prots),:]

    
    ax.scatter(merged_comp['protein_pair'], merged_comp['corr_value'], color='tab:green',alpha=0.8)
    lower = merged_comp['corr_value'] - merged_comp['corr_std']
    upper = merged_comp['corr_value'] + merged_comp['corr_std']
    ax.fill_between(merged_comp['protein_pair'],lower,upper, alpha=0.2, color='tab:green')
    

plt.xticks(rotation=90, ha='center', va='top')
plt.ylim(-0.5,1)
plt.xlim(-1,merged_pred.protein_pair.nunique())
plt.axhline(0, linestyle='--', color='grey')
plt.ylabel('Spearman correlation coefficient')
if not dry_run:
    merged_pred.to_csv(SAVE_PATH.joinpath(save_fname+'-coexpression_sdplot_preddf.csv'))
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_sdplot.png'), bbox_inches='tight', dpi=300)
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_sdplot.pdf'), bbox_inches='tight', dpi=300)

plt.show()

# Dotplot with error bars

In [None]:
alpha =0.6
fig, ax = plt.subplots(figsize=(10,3))
plt.grid(True, color='lightgrey', linestyle='-', linewidth=0.5, alpha=0.6)

merged_gt = merged.loc[merged['data_type']=='GT',:]
ax.scatter(merged_gt['protein_pair'], merged_gt['corr_value'], color='tab:blue', alpha=alpha)
ax.errorbar(merged_gt['protein_pair'], merged_gt['corr_value'], color='tab:blue', yerr=merged_gt['corr_std'], fmt='.')

merged_pred = merged.loc[merged['data_type']=='Prediction',:]
ax.scatter(merged_pred['protein_pair'], merged_pred['corr_value'], color='tab:orange', alpha=alpha)
ax.errorbar(merged_pred['protein_pair'], merged_pred['corr_value'], color='tab:orange', yerr=merged_pred['corr_std'], fmt='.')

if add_comparison is not None:
    merged_comp = pd.read_csv(add_comparison)
    if add_comparison.split('/')[-1].split('-')[0]=='1kh90kst':
        empty_prots = ['CD31', 'CD16', 'CD20']
    merged_comp.loc[[(x.split(' : ')[0] in empty_prots or x.split(' : ')[1] in empty_prots) for x in merged_comp['protein_pair']],'corr_value'] = np.nan
    merged_comp.loc[[(x.split(' : ')[0] in empty_prots or x.split(' : ')[1] in empty_prots) for x in merged_comp['protein_pair']],'corr_std'] = np.nan

    
    ax.scatter(merged_comp['protein_pair'], merged_comp['corr_value'], color='tab:green', alpha=alpha)
    ax.errorbar(merged_comp['protein_pair'], merged_comp['corr_value'], color='tab:green', yerr=merged_comp['corr_std'], fmt='.')

plt.xticks(rotation=90, ha='center', va='top')
plt.ylim(-0.5,1)
plt.xlim(-1,merged_pred.protein_pair.nunique())
plt.axhline(0, linestyle='--', color='grey')
plt.ylabel('Spearman correlation coefficient')
if not dry_run:
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_sdbarplot.png'), bbox_inches='tight', dpi=300)
    plt.savefig(SAVE_PATH.joinpath(save_fname+'-coexpression_sdbarplot.pdf'), bbox_inches='tight', dpi=300)

plt.show()

# Plot of the correspondence of internal structure between GT and predictions

In [None]:
fig, ax = plt.subplots(figsize=(4.5,3.5))
merged_wide = merged.pivot(index='protein_pair', columns='data_type', values='corr_value')
sns.scatterplot(x='Prediction', y='GT', data=merged_wide, color='tab:blue', ax=ax)
ax.plot([-0.4, 1], [-0.4,1], linestyle='--', color='lightgrey')
ax.set_title('Correspondence of internal \n correlation structure')
ax.set_ylabel('GT: Spearman correlation')
ax.set_xlabel('Prediction: Spearman correlation')
plt.show()

fig, ax = plt.subplots(figsize=(4.5,3.5))
g = sns.regplot(x='Prediction', y='GT', data=merged_wide, ax=ax)
ax.plot([-0.4, 1], [-0.4,1], linestyle='--', color='lightgrey')
ax.set_title('Correspondence of internal \n correlation structure')
ax.set_ylabel('GT: Spearman correlation')
ax.set_xlabel('Prediction: Spearman correlation')
plt.show()

