In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import tifffile
import json
import torch
import torch.nn as nn
import torchvision.transforms as tt

import matplotlib.pyplot as plt
import seaborn as sns

import sys 
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)

from codebase.utils.constants import *
from codebase.utils.eval_utils import * 
from codebase.utils.metrics import get_density_bins

In [None]:
def plt_ax_adjust(plt_ax, title=''):
    plt_ax.set_box_aspect(1)
    plt_ax.set_title(title)
    plt_ax.set_xticks([])
    plt_ax.set_yticks([])
    plt_ax.set_ylabel('')
    plt_ax.set_xlabel('')
    

In [None]:
%load_ext autoreload
%autoreload 2

# Settings

In [None]:
##### arguments to specify
project_path = Path('/raid/sonali/project_mvs/') #Path('/cluster/work/grlab/projects/projects2021-multivstain/')
dev0="cpu" # cuda:5 corresponds to dgx:5
cv_split = 'split3'
data_set = 'test'
show_he = True
relevant_args = [] #['protein_set', 'seed'] # which args to display in the image title
tumor_prots = ['MelanA', 'gp100', 'S100', 'SOX9', 'SOX10']
proteins =  ['MelanA', 'CD3', 'HLA-DR']#['tumor_prots','CD3', 'CD20', 'HLA-DR', 'HLA-ABC', 'CD16'] #['tumor_prots', 'CD3','CD8a'] #['tumor_prots', 'CD3', 'CD8a'] #['MelanA','CD3','CD8a'] #['tumor_prots', 'CD8a']
max_imgs = 17 # maximum number of images displayed in a row
add_stats = False#True # whether to add min&max values to img title
cohort_scale = True#True # whether to use q95 from the cohort to scale the image (anything above will be clipped to q0.95)

job_ids = ['mj3pqeyk_dataaug-v2-flip_split3_selected-snr_no-wt_no-checkerboard']

# which epoch to use (best, last, can also be a list of epochs)
sel_epochs = [get_best_epoch_w_imgs(project_path, x) for x in job_ids]
#sel_epochs = [get_last_epoch_w_imgs(project_path, x) for x in job_ids]
#sel_epochs = ['epoch'+str(x)+'-1' for x in ['55', '50', '45', '40']]#['10','20','30','40','50','60','70','80', '90']]
#sel_epochs = ['epoch59-1']
# epoch selection based on top scores from valid
sel_epochs = ['epoch39-1']#, 'epoch60-1','epoch45-1']

if len(job_ids)>1:
    full_names = [x+'|'+y for x,y in zip(job_ids, sel_epochs)]
else:
    full_names = [job_ids[0]+'|'+x for x in sel_epochs]
resolution = 'level_2'
level = int(resolution.split('_')[-1])

# Gaussian blur
blur_sigma = 1
kernel_width = 3

# averaging kernel
avg_kernel = 0 #32 # 32 #16 # 32
avg_stride = 1
if avg_kernel > 0:
    avg_pool = nn.AvgPool2d(kernel_size=avg_kernel, padding=int(np.floor(avg_kernel/2)), stride=avg_stride)
    
densitycorr_px = 0 #64
q = 0.5 # for now using mean across training set as quantiles on cohort level not computed yet
if densitycorr_px > 0:
    x_bins, y_bins = get_density_bins(densitycorr_px, 1000, 1000)
    cohort_thrs = pd.read_csv(project_path.joinpath('data','tupro', 'agg_stats_qc', cv_split, 'imc_rois_raw_clip99_arc_otsu3-agg_stats.tsv'), sep='\t', index_col=[0])
    cohort_thrs = cohort_thrs['mean_minmaxed_cohort'].to_dict()
    cohort_thrs['tumor_prots'] = max([cohort_thrs[x] for x in tumor_prots])
print(full_names)

In [None]:
DATA_DIR = project_path.joinpath(DATA_DIR)
RESULTS_DIR = project_path.joinpath(RESULTS_DIR)
cv_splits = json.load(open(project_path.joinpath(CV_SPLIT_ROIS_PATH)))
s_rois = cv_splits[cv_split][data_set] #can also specify one ROI by s_rois=['MECADAC_F3']

gt_imc_prep = 'raw_clip99_arc_otsu3_std_minmax_'+cv_split
GT_IMC_PATH = DATA_DIR.joinpath('binary_imc_rois_'+gt_imc_prep)


# Protein maps

In [None]:
# use cohort level quantiles
if cohort_scale:
    fname = 'agg_masked_data-raw_clip99_arc_otsu3_std_minmax_split3-r5-train_quantiles'
    cohort_quantiles = pd.read_csv(project_path.joinpath('data/tupro/imc_updated/'+fname+'.tsv'), sep='\t', index_col=[0])
    cohort_quantiles = cohort_quantiles.loc[['q0.95'],:] 
    if 'tumor_prots' in proteins:
        cohort_quantiles['tumor_prots'] = cohort_quantiles.loc[:,[x for x in tumor_prots if x in cohort_quantiles.columns]].max(axis=1)

In [None]:
for s_roi in reversed(sorted(s_rois)):
    print(s_roi)
    
    for protein in proteins:
        # create a dictionary of all data to be able to plot any number of images (with max of max_imgs per row)
        all_img = dict()
        if show_he:
            all_img['H&E']= np.load(DATA_DIR.joinpath('binary_he_rois',s_roi+'.npy'))
        img_gt_full = np.load(GT_IMC_PATH.joinpath(s_roi+'.npy'))
        for full_name in full_names:
            # GT IMC
            if protein=='tumor_prots':          
                # take max across all tumor markers (independent of which where used for a given job)
                #protein_gt_idx = [PROTEIN_LIST_MVS.index(prot) for prot in tumor_prots]
                #img_gt = img_gt_full[:,:,protein_gt_idx]
                #img_gt = np.apply_along_axis(np.max, 2, img_gt).reshape(img_gt.shape[0], img_gt.shape[1], 1)
                img_gt = get_tumor_prots_signal(img_gt_full, PROTEIN_LIST_MVS, tumor_prots)
            else:
                protein_gt_idx = PROTEIN_LIST_MVS.index(protein)
                img_gt = img_gt_full[:,:,[protein_gt_idx]]
            img_gt = preprocess_img(img_gt, dev0, downsample_factor=1000//(4000//2**level),
                                    kernel_width=kernel_width, blur_sigma=blur_sigma, avg_kernel=avg_kernel, avg_stride=avg_stride)
            img_title = 'GT_'+protein
            if add_stats:
                minmax_gt = 'min: '+str(round(np.min(img_gt),2))+' max: '+str(round(np.max(img_gt),2))
                img_title = img_title+'\n'+minmax_gt
            if densitycorr_px > 0:
                img_gt_bin = img_gt.reshape(img_gt.shape[0], img_gt.shape[1])
                img_gt_bin[img_gt_bin<cohort_thrs[protein]] = 0
                coords_gt = pd.DataFrame(np.nonzero(img_gt_bin), index=['X', 'Y']).transpose()
                img_gt, _, _ = np.histogram2d(coords_gt['X'], coords_gt['Y'], [x_bins, y_bins], density=True)
            all_img[img_title] = img_gt
        
        
            job_name = full_name.split('|')[0]
            epoch = full_name.split('|')[-1]
            job_args = json.load(open(RESULTS_DIR.joinpath(job_name, 'args.txt')))
            assert job_args['cv_split']==cv_split, 'Prediction cv_split does not correspond to GT!'
            
            if len(relevant_args)>0:
                #img_title = full_name+'\n'+'\n'.join([str(job_args[x]) for x in relevant_args])+' : '+protein
                img_title = epoch.split('-')[0]+'\n'+'\n'.join([str(job_args[x]) for x in relevant_args])+' : '+protein
            else:
                #img_title = full_name+'\n'+protein
                img_title = epoch.split('-')[0]+'\n'+protein
            img_dir = RESULTS_DIR.joinpath(job_name,data_set+'_images', epoch, resolution)
            protein_list = get_protein_list(job_args['protein_set'])
            if protein in protein_list or protein=='tumor_prots':
                if protein=='tumor_prots':
                    # take max across all tumor markers (independent of which where used for a given job)
                    if len([x for x in tumor_prots if x in protein_list]):
                        protein_pred_idx = [protein_list.index(prot) for prot in [x for x in tumor_prots if x in protein_list]]
                    else:
                        protein_pred_idx = []
                else:
                    protein_pred_idx = [protein_list.index(protein)]
                
                if len(protein_pred_idx)>0:
                    img_pred = np.load(img_dir.joinpath(s_roi+'.npy'))[:,:,protein_pred_idx]
                    if protein=='tumor_prots':
                        #img_pred = np.apply_along_axis(np.max, 2, img_pred)
                        #img_pred = img_pred.reshape(1000,1000,1)
                        img_pred = get_tumor_prots_signal(img_pred, [x for x in tumor_prots if x in protein_list], tumor_prots)
                    img_pred = preprocess_img(img_pred, dev0, downsample_factor=1000//(4000//2**level),
                                kernel_width=kernel_width, blur_sigma=blur_sigma, avg_kernel=avg_kernel, avg_stride=avg_stride)
                    if add_stats:
                        minmax_pred = 'min: '+str(round(np.min(img_pred),2))+' max: '+str(round(np.max(img_pred),2))
                        img_title = img_title+'\n'+minmax_pred
                    if densitycorr_px > 0:
                        img_pred_bin = img_pred.reshape(img_pred.shape[0], img_pred.shape[1])
                        img_pred_bin[img_pred_bin<cohort_thrs[protein]] = 0
                        coords_pred = pd.DataFrame(np.nonzero(img_pred_bin), index=['X', 'Y']).transpose()
                        img_pred, _, _ = np.histogram2d(coords_pred['X'], coords_pred['Y'], [x_bins, y_bins], density=True)
                    all_img[img_title] = img_pred
    
        if len(all_img)>max_imgs:
            print('Cannot print more than '+str(max_imgs)+' images per row! Plotting only first '+str(max_imgs)+'.')
            rm_keys = [x for i,x in enumerate(all_img.keys()) if i>max_imgs]
            for k in rm_keys:
                all_img.pop(k)

        #plt.gcf().set_facecolor("white")
        fig, axes = plt.subplots(1,len(all_img), figsize=(20,5))
        for i,k in enumerate(all_img.keys()):
            vmin, vmax = None, None
            if cohort_scale:
                vmin,vmax = 0,1
                prot = [x for x in proteins if x in k]
                if len(prot)>0:
                    vmax = cohort_quantiles.loc[:,prot[0]].values[0]
            axes[i].imshow(all_img[k], origin='lower', vmin=vmin, vmax=vmax) #, interpolation='spline36', vmin=0, vmax=1)
            plt_ax_adjust(axes[i], title=k)        
        #fig.set_facecolor("white")
        #fig.suptitle(s_roi)
        plt.show()


# Protein maps across multiple proteins (for MICCAI)

In [None]:
s_rois = ['MECYGYR_F2']
for s_roi in reversed(sorted(s_rois)):
    print(s_roi)
    # create a dictionary of all data to be able to plot any number of images (with max of max_imgs per row)
    all_img = dict()
    if show_he:
        all_img['H&E']= np.load(DATA_DIR.joinpath('binary_he_rois',s_roi+'.npy'))
    img_gt_full = np.load(GT_IMC_PATH.joinpath(s_roi+'.npy'))
        
    for protein in proteins:
        # GT IMC
        if protein=='tumor_prots':          
            # take max across all tumor markers (independent of which where used for a given job)
            #protein_gt_idx = [PROTEIN_LIST_MVS.index(prot) for prot in tumor_prots]
            #img_gt = img_gt_full[:,:,protein_gt_idx]
            #img_gt = np.apply_along_axis(np.max, 2, img_gt).reshape(img_gt.shape[0], img_gt.shape[1], 1)
            img_gt = get_tumor_prots_signal(img_gt_full, PROTEIN_LIST_MVS, tumor_prots)
        else:
            protein_gt_idx = PROTEIN_LIST_MVS.index(protein)
            img_gt = img_gt_full[:,:,[protein_gt_idx]]
        img_gt = preprocess_img(img_gt, dev0, downsample_factor=1000//(4000//2**level),
                                kernel_width=kernel_width, blur_sigma=blur_sigma, avg_kernel=avg_kernel, avg_stride=avg_stride)
        img_title = 'GT_'+protein
        if add_stats:
            minmax_gt = 'min: '+str(round(np.min(img_gt),2))+' max: '+str(round(np.max(img_gt),2))
            img_title = img_title+'\n'+minmax_gt
        if densitycorr_px > 0:
            img_gt_bin = img_gt.reshape(img_gt.shape[0], img_gt.shape[1])
            img_gt_bin[img_gt_bin<cohort_thrs[protein]] = 0
            coords_gt = pd.DataFrame(np.nonzero(img_gt_bin), index=['X', 'Y']).transpose()
            img_gt, _, _ = np.histogram2d(coords_gt['X'], coords_gt['Y'], [x_bins, y_bins], density=True)
        all_img[img_title] = img_gt
        
        for full_name in full_names:
            job_name = full_name.split('|')[0]
            epoch = full_name.split('|')[-1]
            job_args = json.load(open(RESULTS_DIR.joinpath(job_name, 'args.txt')))
            assert job_args['cv_split']==cv_split, 'Prediction cv_split does not correspond to GT!'
            
            if len(relevant_args)>0:
                #img_title = full_name+'\n'+'\n'.join([str(job_args[x]) for x in relevant_args])+' : '+protein
                img_title = epoch.split('-')[0]+'\n'+'\n'.join([str(job_args[x]) for x in relevant_args])+' : '+protein
            else:
                #img_title = full_name+'\n'+protein
                img_title = epoch.split('-')[0]+'\n'+protein
            img_dir = RESULTS_DIR.joinpath(job_name,data_set+'_images', epoch, resolution)
            protein_list = get_protein_list(job_args['protein_set'])
            if protein in protein_list or protein=='tumor_prots':
                if protein=='tumor_prots':
                    # take max across all tumor markers (independent of which where used for a given job)
                    if len([x for x in tumor_prots if x in protein_list]):
                        protein_pred_idx = [protein_list.index(prot) for prot in [x for x in tumor_prots if x in protein_list]]
                    else:
                        protein_pred_idx = []
                else:
                    protein_pred_idx = [protein_list.index(protein)]
                
                if len(protein_pred_idx)>0:
                    img_pred = np.load(img_dir.joinpath(s_roi+'.npy'))[:,:,protein_pred_idx]
                    if protein=='tumor_prots':
                        #img_pred = np.apply_along_axis(np.max, 2, img_pred)
                        #img_pred = img_pred.reshape(1000,1000,1)
                        img_pred = get_tumor_prots_signal(img_pred, [x for x in tumor_prots if x in protein_list], tumor_prots)
                    img_pred = preprocess_img(img_pred, dev0, downsample_factor=1000//(4000//2**level),
                                kernel_width=kernel_width, blur_sigma=blur_sigma, avg_kernel=avg_kernel, avg_stride=avg_stride)
                    if add_stats:
                        minmax_pred = 'min: '+str(round(np.min(img_pred),2))+' max: '+str(round(np.max(img_pred),2))
                        img_title = img_title+'\n'+minmax_pred
                    if densitycorr_px > 0:
                        img_pred_bin = img_pred.reshape(img_pred.shape[0], img_pred.shape[1])
                        img_pred_bin[img_pred_bin<cohort_thrs[protein]] = 0
                        coords_pred = pd.DataFrame(np.nonzero(img_pred_bin), index=['X', 'Y']).transpose()
                        img_pred, _, _ = np.histogram2d(coords_pred['X'], coords_pred['Y'], [x_bins, y_bins], density=True)
                    all_img[img_title] = img_pred
    
    if len(all_img)>max_imgs:
        print('Cannot print more than '+str(max_imgs)+' images per row! Plotting only first '+str(max_imgs)+'.')
        rm_keys = [x for i,x in enumerate(all_img.keys()) if i>max_imgs]
        for k in rm_keys:
            all_img.pop(k)
    
    #plt.gcf().set_facecolor("white")
    fig, axes = plt.subplots(1,len(all_img), figsize=(10,5))
    for i,k in enumerate(all_img.keys()):
        vmin, vmax = None, None
        if cohort_scale:
            vmin,vmax = 0,1
            prot = [x for x in proteins if x in k]
            if len(prot)>0:
                vmax = cohort_quantiles.loc[:,prot[0]].values[0]
        axes[i].imshow(all_img[k], origin='lower', vmin=vmin, vmax=vmax) #, interpolation='spline36', vmin=0, vmax=1)
        plt_ax_adjust(axes[i], title=k)        
    #fig.set_facecolor("white")
    #fig.suptitle(s_roi)
    plt.savefig('/home/joanna/'+job_ids[0]+'-MECYGYR_F2.pdf', dpi=300, bbox_inches='tight')
    plt.show()
