In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import json

import matplotlib.pyplot as plt
import seaborn as sns

import sys 
root_code = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
sys.path.insert(0, root_code)


from codebase.utils.constants import *
from codebase.utils.eval_utils import *
from codebase.utils.metrics import get_density_bins

# Settings

In [None]:
project_path = Path('/raid/sonali/project_mvs/') #Path('/cluster/work/grlab/projects/projects2021-multivstain/')
cv_split = 'split3'
data_set = 'test'
level = 2
submission_id="mj3pqeyk_dataaug-v2-flip_split3_selected-snr_no-wt_no-checkerboard"

# which epoch to use (best, last, can also be a list of epochs)
epoch = get_best_epoch_w_imgs(project_path, submission_id)
#epoch = get_last_epoch_w_imgs(project_path, submission_id)
#epoch = 'epoch93-1'

In [None]:
PROJECT_PATH = Path(project_path)

# Get job args
job_args = json.load(open(PROJECT_PATH.joinpath('results',submission_id, 'args.txt')))
cv_split = job_args['cv_split']
# Get sample_roi list for s given split and data_set
cv = json.load(open(PROJECT_PATH.joinpath(CV_SPLIT_ROIS_PATH)))
sample_rois = cv[cv_split][data_set]


GT_PATH = PROJECT_PATH.joinpath('meta', 'imc_gt-celltype-predictions','rf-cell_type-selected_snr-raw_clip99_arc_otsu3_std_minmax_split3-r5-ntrees100-maxdepth30', cv_split)
#GT_PATH = PROJECT_PATH.joinpath('meta', 'imc_gt-celltype-predictions','rf-cell_type-prots_pseudo_multiplex-raw_clip99_arc_otsu3_std_minmax_split3-r5-ntrees100-maxdepth30', cv_split)

PRED_PATH = PROJECT_PATH.joinpath('results', submission_id, data_set+'_ct', epoch, 'level_'+str(level))


In [None]:
SAVE_PATH = project_path.joinpath('results',submission_id,data_set+'_pics', epoch, 'level_'+str(level))
if not os.path.exists(SAVE_PATH):
    SAVE_PATH.mkdir(parents=True, exist_ok=False)
dry_run = False
save_fname = submission_id.split('_')[0]

# Compare number of detected cells per cell type

In [None]:
ct_counts_all = pd.DataFrame()
for s_roi in sample_rois:
    gt = pd.read_csv(GT_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    pred = pd.read_csv(PRED_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    # Get cell-type counts
    gt_cell_type_counts = gt.pred_cell_type.value_counts().to_dict() 
    pred_cell_type_counts = pred.pred_cell_type.value_counts().to_dict()
    
    ct_counts = pd.DataFrame(index=CELL_TYPES, columns=['GT', 'Prediction'])
    for ct in CELL_TYPES:
        ct_counts.loc[ct, 'GT'] = gt_cell_type_counts[ct] if ct in gt_cell_type_counts.keys() else 0 
        ct_counts.loc[ct, 'Prediction'] = pred_cell_type_counts[ct] if ct in pred_cell_type_counts.keys() else 0 
    ct_counts['sample_roi'] = s_roi
    ct_counts_all = pd.concat([ct_counts_all, ct_counts], axis=0)

ct_counts_all['sample_id'] = [x.split('_')[0] for x in ct_counts_all.sample_roi.to_list()]
ct_counts_all.index.name = 'cell_type'
ct_counts_all = ct_counts_all.reset_index(drop=False)

In [None]:
logscale = True
add_spcorr = True
cts = [x for x in CELL_TYPES if x !='other']
# whether to color by the slice-slice similarity derived from comparing nuclei density between H&E and GT IMC
color_by_nuclei = True #False

if color_by_nuclei:
    nuclei_pcorr = pd.read_csv(project_path.joinpath('meta','nuclei_density', 'nuclei_density-he_imc-'+'all'+'-'+'pcorr'+'-max'+'1024'+'.tsv'), sep='\t', index_col=[0])
    nuclei_pcorr = nuclei_pcorr.loc[:,['pcorr_64']] #pcorr_32
    nuclei_pcorr.columns = ['nuclei_sim']
    # group ROIs based on pcorr median
    nuclei_pcorr['nuclei_slice_sim'] = ['high' if x>=nuclei_pcorr['nuclei_sim'].median() else 'low' for x in nuclei_pcorr['nuclei_sim']]
    ct_counts_all = ct_counts_all.merge(nuclei_pcorr.reset_index(), on='sample_roi', how='left')

fig, axes = plt.subplots(1, len(cts), figsize=(30,4))
spcorr_all = dict()
for i,ct in enumerate(cts):
    plot_df = ct_counts_all.loc[ct_counts_all.cell_type==ct,:]
    # chage to log scale
    if logscale:
        plot_df['GT'] = plot_df['GT'].apply(lambda x: np.log(x+1))
        plot_df['Prediction'] = plot_df['Prediction'].apply(lambda x: np.log(x+1))
    axlim = (min(plot_df['GT'].min(), plot_df['Prediction'].min()), max(plot_df['GT'].max(), plot_df['Prediction'].max()))
    if color_by_nuclei:
        sns.scatterplot(x='Prediction', y='GT', data=plot_df.sort_values('nuclei_slice_sim'), ax=axes[i], hue='nuclei_slice_sim', palette='Accent', legend=(i== len(cts)-1))
    else:
        sns.scatterplot(x='Prediction', y='GT', data=plot_df, ax=axes[i])
    ylab = 'GT cell counts'
    xlab = 'Predicted cell counts'
    if logscale:
        ylab = ylab+' (log(x+1))'
        xlab = xlab+' (log(x+1))'
    axes[i].set_ylabel(ylab)
    axes[i].set_xlabel(xlab)
    title = ct
    if add_spcorr:
        spcorr = plot_df.loc[:,['GT','Prediction']].corr(method='spearman').iloc[0,1]
        spcorr_all[ct] = spcorr
        title = title+'\n  spcorr: '+str(round(spcorr,2))
        if color_by_nuclei:
            for nuc_class in sorted(plot_df['nuclei_slice_sim'].unique()):
                spcorr_all[ct+' | '+nuc_class] = plot_df.loc[plot_df['nuclei_slice_sim']==nuc_class,['GT','Prediction']].corr(method='spearman').iloc[0,1]
        spcorr_all[ct] = spcorr
    axes[i].set_title(title)
    axes[i].set_box_aspect(1)

    minmax = min(plot_df['GT'].max(), plot_df['Prediction'].max())
    minmin = min(plot_df['GT'].min(), plot_df['Prediction'].min())
    axes[i].plot([minmin,minmax], [minmin,minmax], color='lightgrey', linestyle='--')
if not dry_run:
    if color_by_nuclei:
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr_by_nuclei.png'), bbox_inches='tight', dpi=300)
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr_by_nuclei.pdf'), bbox_inches='tight', dpi=300)
    else:
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr.png'), bbox_inches='tight', dpi=300)
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr.pdf'), bbox_inches='tight', dpi=300)
plt.show()

In [None]:
spcorr_all = pd.DataFrame({'cell_type':spcorr_all.keys(), 'spcorr': spcorr_all.values()}).set_index('cell_type')
spcorr_all = round(spcorr_all.loc[cts,:],2)
print(spcorr_all.to_latex(escape=False))

In [None]:
plot_df = ct_counts_all
plot_df = plot_df.loc[:,['cell_type', 'nuclei_slice_sim', 'GT', 'Prediction']].melt(id_vars=['cell_type', 'nuclei_slice_sim'], var_name='datatype')
for ct in [x for x in sorted(plot_df.cell_type.unique()) if x!="other"]:
    sns.kdeplot(x='value', data=plot_df.loc[plot_df.cell_type==ct,:], hue='datatype', cut=0)
    plt.title(ct)
    plt.show()

# Number of detected cells per cell type by metadata

In [None]:
meta = pd.read_csv(PROJECT_PATH.joinpath('meta', 'melanoma-merged_clinical_data-v8.tsv'), sep='\t')
ct_counts_all = ct_counts_all.merge(meta.loc[:,['tupro_id', 'subtype_group', 'cd8_phenotype_revised']], left_on='sample_id', right_on='tupro_id', how='left')

In [None]:
for ct in cts:
    plot_df = ct_counts_all.loc[ct_counts_all.cell_type==ct,:]
    sns.boxplot(x='cd8_phenotype_revised', y='Prediction', data=plot_df)
    plt.title(ct)
    plt.show()

# Compare cell-type density maps at different resolutions

In [None]:
# assign colors to cell-types
from matplotlib.colors import ListedColormap
cmap_sel = plt.get_cmap('Set1')
cmap_sel = ListedColormap(cmap_sel(np.arange(9)))
cts = [x for x in CELL_TYPES if x!='other']
color_palette = dict(zip(cts, cmap_sel.colors))

In [None]:
bin_lim = 1000//(2**(level-2))
axmax = 1024//(2**(level-2))
resolutions = [32,64,128,256]

pcorr_df_all = pd.DataFrame()
for res in resolutions:
    desired_resolution_px = res//(2**(level-2))
    x_bins, y_bins = get_density_bins(desired_resolution_px, bin_lim, axmax)    
    for s_roi in sample_rois:
        gt = pd.read_csv(GT_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
        gt['present'] = 1
        gt = gt.pivot(index=['X', 'Y'], columns='pred_cell_type', values='present').fillna(0).reset_index(drop=False)
        pred = pd.read_csv(PRED_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
        pred['present'] = 1
        pred = pred.pivot(index=['X', 'Y'], columns='pred_cell_type', values='present').fillna(0).reset_index(drop=False)

        pcorr_df = pd.DataFrame(index=cts, columns=['pcorr'])
        for i,ct in enumerate(cts):
            if ((ct in gt.columns) and (ct in pred.columns)):
                density_gt, _, _ = np.histogram2d(gt.loc[gt[ct]==1,'X'], gt.loc[gt[ct]==1,'Y'], [x_bins, y_bins], density=True)   
                density_pred, _, _ = np.histogram2d(pred.loc[pred[ct]==1,'X'], pred.loc[pred[ct]==1,'Y'], [x_bins, y_bins], density=True)   
                pcorr_df.loc[ct,'pcorr'] = pearsonr(density_gt.flatten(), density_pred.flatten())[0]
            else:
                print(ct, 'missing')
                pcorr_df.loc[ct,'pcorr'] = np.nan
        pcorr_df['sample_roi'] = s_roi
        pcorr_df['resolution'] = res
        pcorr_df_all = pd.concat([pcorr_df_all, pcorr_df])
pcorr_df_all.index.name = 'cell_type'
pcorr_df_all = pcorr_df_all.reset_index()

In [None]:
# in how many RIOs a given cell-type was missing (either GT or prediction)
pcorr_df_all.loc[pcorr_df_all.pcorr.isna(),:].groupby('cell_type')['sample_roi'].nunique()

In [None]:
for res in resolutions:
    plot_df = pcorr_df_all.loc[pcorr_df_all['resolution']==res,:]
    fig, ax = plt.subplots(figsize=(6,4))
    sns.boxplot(x='pcorr', y='cell_type', data=plot_df, ax=ax, palette=color_palette)
    ax.axvline(0,linestyle='--', color='lightgrey')
    plt.ylabel('')
    plt.xlabel('Pearson correlation')
    plt.title(res)
    plt.xlim(-1,1)
    plt.show()

In [None]:
pcorr_df_all.loc[pcorr_df_all['resolution']==64,:].groupby('cell_type').pcorr.median()

In [None]:
# aggregated Pearson's correlation as a function of resolution
pcorr_df_all_agg = pcorr_df_all.groupby(['resolution', 'cell_type']).pcorr.median().to_frame('median_pcorr').reset_index()
fig, ax = plt.subplots(figsize=(4.5,3.5))
sns.scatterplot(x='resolution', y='median_pcorr', data=pcorr_df_all_agg, hue='cell_type', ax=ax, palette=color_palette)
sns.lineplot(x='resolution', y='median_pcorr', data=pcorr_df_all_agg, hue='cell_type', ax=ax, palette=color_palette, legend=False)
ax.axhline(0,linestyle='--', color='lightgrey')
ax.set_title('Correspondence of cell-type maps')
ax.set_ylabel('Median Pearson correlation')
ax.set_xlabel('Resolution in px')
plt.legend(bbox_to_anchor=(1,1))
if not dry_run:
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr_growing.png'), bbox_inches='tight', dpi=300)
        plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_corr_growing.pdf'), bbox_inches='tight', dpi=300)
plt.show()


# Overlay cell-type density maps

In [None]:
bin_lim = 1000//(2**(level-2))
axmax = 1024//(2**(level-2))
desired_resolution_px = 32//(2**(level-2))
x_bins, y_bins = get_density_bins(desired_resolution_px, bin_lim, axmax)
max_density = 1/((bin_lim/desired_resolution_px)**2)/10 # 1/(n_bins)
cts = CELL_TYPES #[x for x in CELL_TYPES if x !='other']

bin_lim, axmax, desired_resolution_px, axmax//desired_resolution_px, max_density

In [None]:
def plt_ax_adjust(plt_ax, title=''):
    plt_ax.set_box_aspect(1)
    plt_ax.set_title(title)
    plt_ax.set_xticks([])
    plt_ax.set_yticks([])
    plt_ax.set_ylabel('')
    plt_ax.set_xlabel('')
    plt_ax.set_facecolor('white')

In [None]:
### joint plot of multiple proteins
# offset for setting y-/x-axis limits
offset = 20
# size of the point
marker_size = 10
# transparency of the point
alpha = 0.4
# which cell-types to plot
cts_sel = ['tumor', 'Tcells.CD8', 'Bcells']
# which ROIs to save plots for (only if dry_run=False)
save_rois = ['MYKOKIG_F1', 'MAHEFOG_F3', 'MAHEFOG_F2']

#color_palette = dict(zip(cts_sel, cmap_sel.colors))
for s_roi in sorted(save_rois):#sample_rois):
    print(s_roi)
    he = np.load(project_path.joinpath('data/tupro/binary_he_rois',s_roi+'.npy'))
    gt = pd.read_csv(GT_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    gt = gt.loc[gt['pred_cell_type'].isin(cts_sel),:].sort_values(by='pred_cell_type')
    gt['present'] = 1
    pred = pd.read_csv(PRED_PATH.joinpath(s_roi+'.tsv'), sep='\t', index_col=[0])
    pred = pred.loc[pred['pred_cell_type'].isin(cts_sel),:].sort_values(by='pred_cell_type')
    pred['present'] = 1
    
    fig, axes = plt.subplots(1, 3, figsize=(9,3))
    axes[0].imshow(he, origin='lower')
    plt_ax_adjust(axes[0], title="H&E")
    sns.scatterplot(x='Y', y='X', data=gt, hue='pred_cell_type', ax=axes[1], s=marker_size, legend=False,
                   palette=color_palette, alpha=alpha)
    axes[1].set_ylim(0-offset,1000+offset)
    axes[1].set_xlim(0-offset,1000+offset)
    plt_ax_adjust(axes[1], title='GT cell-type location')#title='GT: '+', '.join(cts_sel))
    sns.scatterplot(x='Y', y='X', data=pred, hue='pred_cell_type', ax=axes[2], s=marker_size, legend=True,
                   palette=color_palette, alpha=alpha)
    axes[2].set_ylim(0-offset,1000+offset)
    axes[2].set_xlim(0-offset,1000+offset)
    plt_ax_adjust(axes[2], title='Predicted cell-type location')#title='Pred: '+', '.join(cts_sel))
    plt.legend(bbox_to_anchor=(1,1))
    fig.subplots_adjust(wspace=0.05, hspace=-0.25)
    if not dry_run:
        if s_roi in save_rois:
            plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_maps-tumor_cd8_bcells-'+s_roi+'.png'), bbox_inches='tight', dpi=300)
            plt.savefig(SAVE_PATH.joinpath(save_fname+'-cts_maps-tumor_cd8_bcells-'+s_roi+'.pdf'), bbox_inches='tight', dpi=300)
    plt.show()