In [None]:
import numpy as np
import pandas as pd
from skimage import draw
from pathlib import Path
import argparse
import os
import json
import cv2
import torch
import torchvision.transforms.functional as ttf
import matplotlib.pyplot as plt
from scipy.stats.stats import pearsonr, spearmanr

import sys 
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)

from codebase.utils.constants import *


In [None]:
PROJECT_PATH = Path('/raid/sonali/project_mvs/') #Path('/cluster/work/grlab/projects/projects2021-multivstain/')
# INPUT_PATH is only used to get coords
#INPUT_PATH = PROJECT_PATH.joinpath('data/tupro/binary_imc_rois_raw/')
# load from cv splits
cv = json.load(open(PROJECT_PATH.joinpath(CV_SPLIT_ROIS_PATH)))
cv_split = 'split3'
data_set = 'all' # {train, valid, test, all}
if data_set == 'all':
    sample_rois = cv[cv_split]['train']
    sample_rois.extend(cv[cv_split]['valid'])
    sample_rois.extend(cv[cv_split]['test'])
else:
    sample_rois = cv[cv_split][data_set]

In [None]:
he_coords = pd.read_csv(PROJECT_PATH.joinpath('meta/hovernet/hovernet_nuclei-coordinates_all-samples.csv'))
he_coords = he_coords.loc[he_coords.sample_roi.isin(sample_rois),:]
he_coords['X'] = he_coords['X']//4
he_coords['Y'] = he_coords['Y']//4
imc_coords = pd.read_csv(PROJECT_PATH.joinpath('data/tupro/imc_updated/coldata.tsv'), sep='\t')
imc_coords = imc_coords.loc[imc_coords.sample_roi.isin(sample_rois),:]

In [None]:
he_coords.head(2)

# Plot nuclei densities (and their overlap)

In [None]:
plot_density = True
alpha = 0.6
axmax = 1024 #1000
resolutions = [64]# [1,4,16,32,64,75,128,256] (if plotting across many ROIs then best to reduce number of resolutions)

for desired_resolution_px in resolutions:
    #print(desired_resolution_px)
    
    n_bins = axmax//desired_resolution_px
    x_bins = np.linspace(0, 1000, n_bins+1)
    y_bins = np.linspace(0, 1000, n_bins+1)

    for s_roi in sample_rois[:1]:
        # HE nuclei density
        df_roi = he_coords.loc[he_coords.sample_roi==s_roi,:]
        density_he, _, _ = np.histogram2d(df_roi['X'],df_roi['Y'], [x_bins, y_bins], density=True)
        # IMC nuclei density
        df_roi = imc_coords.loc[imc_coords['sample_roi']==s_roi,:]
        density_imc, _, _ = np.histogram2d(df_roi['X'],df_roi['Y'], [x_bins, y_bins], density=True)
        # plot density
        if plot_density:
            fig, axes = plt.subplots(1, 3, figsize=(8,2))
            axes[0].imshow(density_he, interpolation='spline36', cmap='Blues', alpha=1)
            axes[0].set_title('H&E \n resolution in px: '+str(desired_resolution_px)+'\n n_bins: '+str(n_bins**2))
            axes[1].imshow(density_imc, interpolation='spline36', cmap='Oranges', alpha=1)
            axes[1].set_title('IMC \n resolution in px: '+str(desired_resolution_px)+'\n n_bins: '+str(n_bins**2))
            axes[2].imshow(density_he, interpolation='spline36', cmap='Blues', alpha=alpha)
            axes[2].imshow(density_imc, interpolation='spline36', cmap='Oranges', alpha=alpha)
            axes[2].set_title('Blue: H&E, Orange: IMC \n resolution in px: '+str(desired_resolution_px)+'\n n_bins: '+str(n_bins**2))
            plt.show()
    


# Aggregate density comparisons

In [None]:
# Aggregated (mean or median) metric per resolution
axmax = '1024' #'1000'
agg_method = 'mean' #'median'
agg_df_all = pd.DataFrame()
for metric in ['pcorr', 'spcorr']:
    for data_set in ['all', 'split3_train', 'split3_valid', 'split3_test']:
        pcorr_df_all = pd.read_csv(PROJECT_PATH.joinpath('meta','nuclei_density', 'nuclei_density-he_imc-'+data_set+'-'+metric+'-max'+axmax+'.tsv'), sep='\t', index_col=[0])
        if agg_method=='mean':
            agg_df = pcorr_df_all.mean().to_frame(metric+'-'+data_set)
        elif agg_method=='median':
            agg_df = pcorr_df_all.median().to_frame(metric+'-'+data_set)
        agg_df.index = [x.split('_')[-1] for x in agg_df.index.to_list()]
        agg_df_all = pd.concat([agg_df_all, agg_df], axis=1)
agg_df_all

In [None]:
plt.figure(figsize=(4,3))
plt.plot(agg_df_all['pcorr-all'])
plt.scatter(agg_df_all['pcorr-all'].index, agg_df_all['pcorr-all'])
plt.ylim(-0.01,1)
plt.ylabel('Pearson correlation')
plt.xlabel('Resolution in pixels')

# Plot pointwise nuclei locations (and their overlap)

In [None]:
yaxlim = (250,506)
xaxlim = (250,506)

for s_roi in sample_rois:

    image_he_3ch = np.load(PROJECT_PATH.joinpath(DATA_DIR, 'binary_he_rois', s_roi+'.npy'))
    image_he_3ch = torch.from_numpy(image_he_3ch.copy().transpose(2,0,1))
    image_he_3ch = ttf.resize(image_he_3ch, image_he_3ch.shape[1]//4)
    image_he_3ch = np.asarray(image_he_3ch).transpose(1,2,0)
    
    #x_max, y_max, _ = np.load(INPUT_PATH.joinpath(s_roi+'.npy'), mmap_mode='r').shape
    x_max = 1000
    y_max = 1000
    # need to reset index to iterate through both he and imc
    imc = imc_coords.loc[imc_coords['sample_roi']==s_roi,:].reset_index(drop=True)
    he = he_coords.loc[he_coords['sample_roi']==s_roi,:].reset_index(drop=True)
    
    image_he_null = np.zeros((x_max, y_max, 1))
    image_imc_null = np.zeros((x_max, y_max, 1))
    image_joint_null = np.zeros((x_max, y_max, 1))
    fig, axes = plt.subplots(1, 4, figsize=(20,5))
    for idx, row in he.iterrows():
        image_he = cv2.circle(image_he_null, (row['X'], row['Y']), 5, (50, 141, 168), -1)
        image_joint = cv2.circle(image_joint_null, (row['X'], row['Y']), 5, (50, 141, 168), -1)
    for idx, row in imc.iterrows():
        image_imc = cv2.circle(image_imc_null, (row['X'], row['Y']), 5, (191, 127, 78), -1)
        image_joint = cv2.circle(image_joint_null, (row['X'], row['Y']), 5, (191, 127, 78), -1)
    
    
    axes[0].imshow(image_he_3ch, origin='lower')
    axes[1].imshow(image_he, origin='lower', cmap='Blues')
    axes[1].set_title('H&E')
    axes[2].imshow(image_imc, origin='lower', cmap='Greens')
    axes[2].set_title('IMC')
    axes[3].imshow(image_joint, origin='lower', cmap='ocean_r')
    axes[3].set_title('Joint')
    fig.set_facecolor("white")
    for j in range(len(axes)):
#         if j == 0:
#             axes[j].set_ylim((yaxlim[0]*4, yaxlim[1]*4))
#             axes[j].set_xlim((xaxlim[0]*4, xaxlim[1]*4))
#         else:
        axes[j].set_ylim(yaxlim)
        axes[j].set_xlim(xaxlim)
    fig.suptitle(s_roi)
    plt.show()