### Notebook is used to visualize, minmax normalize and plot generated IMC from MVS model 

In [None]:
import os 
import numpy as np 
import cv2
import glob
import pandas as pd 
from pathlib import Path
import openslide 

from PIL import Image
import tifffile
from sklearn.cluster import KMeans

import sys 
import matplotlib.pyplot as plt
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)
from codebase.utils.constants import *
from codebase.utils.raw_utils import str2bool
from codebase.utils.eval_utils import *
from codebase.utils.metrics import *


In [None]:
# ---- Paths and settings ----
project_path = '/raid/sonali/project_mvs' # UPDATE as needed
results_path = os.path.join(project_path, 'results')
submission_id = 'pfggdx8b_dgm4h_5GP+ASP_selected_snr_dgm4h'

# -- settings to get predicted IMC wsi --
which_set = 'external_test' # 'valid
wsi_base = os.path.join(results_path, submission_id, which_set + '_wsis/step_350K')
level = '6' # which imc level to use for plotting: 6,4,2
wsi_paths_pred = glob.glob(os.path.join(wsi_base,  'level_' + level) + '/*')
print('Number of predicted IMC images: ', len(wsi_paths_pred), wsi_paths_pred[0])

# -- settings to get marker order -- 
job_args = json.load(open(Path(results_path).joinpath(submission_id, 'args.txt')))
protein_subset = get_protein_list(job_args['protein_set'])
channel_list = [protein2index[protein] for protein in protein_subset]
print(protein_subset)

# -- getting train quantiles -- 
cohort_quantiles_path =  Path(project_path).joinpath('data/tupro/imc_updated/agg_masked_data-raw_clip99_arc_otsu3_std_minmax_split3-r5-train_quantiles.tsv') # UPDATE as needed
df_cohort_quantiles = pd.read_csv(cohort_quantiles_path, sep='\t', index_col=[0])
df_cohort_quantiles = df_cohort_quantiles.loc[['q0.95'],:]
cohort_quantiles = df_cohort_quantiles.loc[:,protein_subset].values[0]
print('cohort_quantiles_protein_subset: ', cohort_quantiles)

# -- settings for input HE wsis -- 
HE_base = '/raid/sonali/project_mvs/downstream_tasks/immune_phenotyping/tupro/HE_new_wsi' # path where wsis are saved -- UPDATE as needed
# HE_base = '/home/sonali/github_code/Boqi/dryrun_tcga_inference/wsi'
level_he = 2 # 6 for tupro; 2 for TCGA 


In [None]:
# ---- Functions needed when have annotations for H&E ----
def map_colors(image, colors):
    kmeans = KMeans(n_clusters=len(pre_defined_colors), random_state=0)
    kmeans.fit(colors)
    kmeans.cluster_centers_ = colors
    
    pixels = image.reshape(-1, 3)
    cluster_assignment = kmeans.predict(pixels.astype('double'))    
    pixel_labels = list(map(lambda x: colors[x], cluster_assignment))
    mapped_image = np.array(pixel_labels).reshape(image.shape).astype(np.uint8)

    return mapped_image
def get_annotation_masks(annots_path, wsi_shape):
    # loading annotation masks: tumor and stroma within tumor compartment
    # reading using tifffile 
    img_annots = tifffile.imread(annots_path, key=0)

    # downsampling annotations as otherwise color mapping is slow  
    img_annots = cv2.resize(img_annots, (0,0), fx=0.25, fy=0.25) 

    # map colors correctly to annotations 
    img_annots = map_colors(img_annots, pre_defined_colors)
    print(img_annots.shape)

    colors, _ = np.unique(img_annots.reshape(-1, img_annots.shape[-1]), axis=0, return_counts=True)
    print('Unique colors and counts in annotated image: ', len(colors))#, colors, counts)
    assert len(colors)<=6, "more than 6 colors/annotations found in the image"

    # get tumor mask 
    mask_tumor = (np.all(img_annots == [255, 0, 0], axis=-1)).astype(np.uint8)

    # make sure shape matches of annotations and wsi_pred 
    mask_tumor = cv2.resize(mask_tumor, (wsi_shape[1], wsi_shape[0]))
    return mask_tumor

In [None]:
# ----getting annotations transferred from CD8 to HE ----
# NOTE: ignore this cell for tcga as don't have annotations yet

annotation_path = '/raid/sonali/project_mvs/downstream_tasks/immune_phenotyping/tupro/annotation_transfer/annotation_images_new'

color_code = {(255, 255, 0): "Whitespace", # yellow
(255, 0, 255): "Positive Lymphocytes", # pink   
(0, 0, 0): "Pigment", # black
(0, 128, 0): "Stroma", # green
(255, 0, 0): "Tumor", # red           
(192, 64, 0): "Blood and necrosis" # dirty red
         }
pre_defined_colors = [list(t) for t in list(color_code.keys())]
pre_defined_colors = np.array(pre_defined_colors).astype('double')

# samples with annotations on H&E WSI (tranferred from CD8)
samples_annotations = [x.split('.')[0] for x in os.listdir(annotation_path)]


In [None]:
# scaling using cohorts quantiles 
def minmax_scaling(x, max_cohort):
    return (x)/max_cohort

save_plots = False
# iterating through samples 
for wsi_path_pred in wsi_paths_pred: 

    # get corresponding HE image from level 
    sample = wsi_path_pred.split('/')[-1].split('.')[0]
    wsi = openslide.open_slide(glob.glob(HE_base + '/' + sample + '*')[0])
    level_dims = wsi.level_dimensions[level_he]
    wsi_he = np.array(wsi.read_region((0, 0), level_he, level_dims))[:,:,0:3]
    print(wsi_he.shape)

    # get annotation mask if possible -- ignore for tcga for now 
    if sample in samples_annotations: 
        print(sample)
        annots_path = os.path.join(annotation_path, sample + '.tif')
        mask_tumor = get_annotation_masks(annots_path, wsi_he.shape)
        print('mask_tumor: ', mask_tumor.shape, np.amax(mask_tumor), np.amin(mask_tumor))

        # overlay tumor mask on HE
        mask_tumor = cv2.cvtColor((mask_tumor*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
        print('mask_tumor: ', mask_tumor.shape)

        # Use cv2.addWeighted() to overlay the mask on the image
        wsi_he_overlay = cv2.addWeighted(wsi_he, 0.6, mask_tumor, 0.4, 0)
        print('wsi_he_overlay: ', wsi_he_overlay.shape)

    # load wsi imc pred 
    wsi_pred = np.load(wsi_path_pred)
    print(wsi_pred.shape)
    # scaling 
    wsi_pred = np.apply_along_axis(minmax_scaling, 2, wsi_pred, max_cohort=cohort_quantiles)
    print(wsi_pred.shape)

    # --- combining markers for plot ----
    pred_Tcells = ['CD8a', 'CD3']
    pred_Bcells = ['CD20']
    pred_tumor = ['S100', 'MelanA', 'gp100']
    # -- 1. Combining tumor, B cells and T cells -- 
    # tumor
    wsi_pred_tumor = np.zeros(wsi_pred.shape[0:2], dtype=float)
    for protein in pred_tumor:
        wsi_pred_tumor = np.maximum(wsi_pred_tumor, wsi_pred[:,:,protein_subset.index(protein)])
    # B cells 
    wsi_pred_Bcells = np.zeros(wsi_pred.shape[0:2], dtype=float)
    for protein in pred_Bcells:
        wsi_pred_Bcells = np.maximum(wsi_pred_Bcells, wsi_pred[:,:,protein_subset.index(protein)])
    # T cells 
    wsi_pred_Tcells = np.zeros(wsi_pred.shape[0:2], dtype=float)
    for protein in pred_Tcells:
        wsi_pred_Tcells = np.maximum(wsi_pred_Tcells, wsi_pred[:,:,protein_subset.index(protein)])
    # tumor, Bcells, Tcells
    wsi_pred_viz1 = np.dstack((wsi_pred_tumor, wsi_pred_Bcells, wsi_pred_Tcells))
    print('wsi_pred_viz1: ', wsi_pred_viz1.shape)

    # -- 2. Combining tumor,HLA-DR and HLA-ABC -- 
    # tumor
    wsi_pred_tumor = np.zeros(wsi_pred.shape[0:2], dtype=float)
    for protein in pred_tumor:
        wsi_pred_tumor = np.maximum(wsi_pred_tumor, wsi_pred[:,:,protein_subset.index(protein)])
    # combining tumor, HLA-DR, HLA-ABC
    wsi_pred_viz2 = np.dstack((wsi_pred_tumor, wsi_pred[:,:,protein_subset.index('HLA-DR')], wsi_pred[:,:,protein_subset.index('HLA-ABC')]))
    print('wsi_pred_viz2: ', wsi_pred_viz2.shape)

    # -- 3. Combining tumor, CD16, CD31 -- 
    # tumor
    wsi_pred_tumor = np.zeros(wsi_pred.shape[0:2], dtype=float)
    for protein in pred_tumor:
        wsi_pred_tumor = np.maximum(wsi_pred_tumor, wsi_pred[:,:,protein_subset.index(protein)])
    # combining tumor, HLA-DR, HLA-ABC
    wsi_pred_viz3 = np.dstack((wsi_pred_tumor, wsi_pred[:,:,protein_subset.index('CD16')], wsi_pred[:,:,protein_subset.index('CD31')]))
    print('wsi_pred_viz3: ', wsi_pred_viz3.shape)

    # Plotting 
    fig = plt.figure(figsize=(40, 20))

    plt.subplot(1, 4, 1) 
    plt.imshow(wsi_he)#(wsi_he_overlay)
    plt.axis('off')

    plt.subplot(1, 4, 2)  
    plt.imshow(wsi_pred_viz1)
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.imshow(wsi_pred_viz2)
    plt.axis('off')

    plt.subplot(1, 4, 4) 
    plt.imshow(wsi_pred_viz3)
    plt.axis('off')

    plt.subplots_adjust(wspace=0.01, hspace=0)

    # save 
    if save_plots: 
        save_path = os.path.join(os.getcwd(), 'plots') # change as needed
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(os.path.join(save_path, sample+'-wsi_dgm4h.pdf'), bbox_inches='tight', dpi=300,  pad_inches = 0)
        
    plt.show() 

    