In [1]:
#NOTE: use paimg1 env, the retccl one has package issue with torchvision
import sys
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt
import cv2
import matplotlib
matplotlib.use('Agg')
import pandas as pd
import warnings
import torch
import torch.nn as nn

from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path
from scipy.spatial.distance import cdist


sys.path.insert(0, '../Utils/')
from Utils import get_downsample_factor, get_image_at_target_mag
from Utils import do_mask_original,check_tissue,whitespace_check
from Utils import create_dir_if_not_exists
from Utils import generate_deepzoom_tiles, extract_tile_start_end_coords, get_map_startend
from Utils import get_downsample_factor
from Utils import minmax_normalize, count_label
from Utils import log_message, set_seed
from Eval import compute_performance, plot_LOSS, compute_performance_each_label, get_attention_and_tileinfo
from train_utils import pull_tiles, get_feature_label_array_dynamic
from train_utils import ModelReadyData_diffdim, convert_to_dict, prediction
from cluster_utils import plot_cluster_distribution, load_alltile_tumor_info
from Model import Mutation_MIL_MT
warnings.filterwarnings("ignore")
%matplotlib inline

In [2]:
####################################
######      USERINPUT       ########
####################################
SELECTED_LABEL = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
SELECTED_FEATURE = [str(i) for i in range(0,2048)] + ['TUMOR_PIXEL_PERC']
TUMOR_FRAC_THRES = 0
TRAIN_SAMPLE_SIZE = "ALLTUMORTILES"
TRAIN_OVERLAP = 100
TEST_OVERLAP = 0
SELECTED_FOLD = 0
CLUSTER_ALG = 'KMEAN'
N_CLUSTERS = 8
CLUSTER_DIST = 'L2'
feature_extraction_method = 'retccl'

##################
###### DIR  ######
##################
proj_dir = '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/'
wsi_path = proj_dir + '/data/OPX/'
data_dir = proj_dir + 'intermediate_data/model_ready_data/feature_' + feature_extraction_method + '/MAXSS'+ str(TRAIN_SAMPLE_SIZE)  + '_TrainOL' + str(TRAIN_OVERLAP) +  '_TestOL' + str(TEST_OVERLAP) + '_TFT' + str(TUMOR_FRAC_THRES) + "/split_fold" + str(SELECTED_FOLD) + "/"
train_tile_path = proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL' + str(TRAIN_OVERLAP) + '/'
test_tile_path =  proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL' + str(TEST_OVERLAP) + '/'
cluster_info_path =   os.path.join(data_dir, "clusters", CLUSTER_ALG, "ClusterInfo")
save_name = "_NCLUSTER_" + str(N_CLUSTERS) +  "_DISTMETRIC_" + CLUSTER_DIST

################################################
#Create output dir
################################################
outdir0 =   os.path.join(data_dir, "clusters",CLUSTER_ALG, "ClusterInfo", "Distribution_Plots_NCLUSTER_" + str(N_CLUSTERS) +  "_DISTMETRIC_" + CLUSTER_DIST)
create_dir_if_not_exists(outdir0)

outdir1 =   os.path.join(data_dir, "spatial_model_input", save_name, "heatmaps")
create_dir_if_not_exists(outdir1)

outdir2 =   os.path.join(data_dir, "spatial_model_input", save_name, "spatial_features/")
create_dir_if_not_exists(outdir2)

##################
#Select GPU
##################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
set_seed(0)

Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/feature_retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0/clusters/KMEAN/ClusterInfo/Distribution_Plots_NCLUSTER_8_DISTMETRIC_L2' created.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/feature_retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0/spatial_model_input/_NCLUSTER_8_DISTMETRIC_L2/heatmaps' created.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/feature_retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0/spatial_model_input/_NCLUSTER_8_DISTMETRIC_L2/spatial_features/' created.
cuda:0


In [3]:
############################################################################################################
#Get cluster info and label
############################################################################################################
train_info_df = pd.read_csv(cluster_info_path + '/train_cluster_info' + save_name + '.csv')
test_info_df = pd.read_csv(cluster_info_path + '/test_cluster_info' + save_name + '.csv')
val_info_df = pd.read_csv(cluster_info_path + '/valid_cluster_info' + save_name + '.csv')

In [4]:
############################################################################################################
#Plot distributions
############################################################################################################
plot_cluster_distribution(train_info_df, SELECTED_LABEL, outdir0, "train")

In [5]:
############################################################################################################
#Get spatial cluster feature and plot
############################################################################################################
save_image_size = 250
mag_extract = 20
limit_bounds = True
mag_target_prob = 2.5
smooth = False
mag_target_tiss = 1.25
clist = ['train','test','valid']

for cohort in clist:
    if cohort == 'train' :
        pixel_overlap = 100
        tile_path = train_tile_path
        info_df = train_info_df
    elif cohort == 'test':
        pixel_overlap = 0
        tile_path = test_tile_path
        info_df = test_info_df
    elif cohort == 'valid':
        pixel_overlap = 100
        tile_path = train_tile_path
        info_df = val_info_df
    
    
    selected_ids = list(set(info_df['SAMPLE_ID']))
    selected_ids.sort()
    
    sp_feature_list = []
    ct = 0
    for pt in selected_ids:
        if ct % 10 == 0: print(ct)
        ct += 1
    
        #Load all tiles tumor info
        cur_info_df = load_alltile_tumor_info(tile_path, pt,SELECTED_LABEL,info_df)  
        #cur_info_df = info_df.loc[info_df['SAMPLE_ID'] == pt].copy()
        cur_info_df['Cluster'] =  cur_info_df['Cluster'] + 1 #shift the number up
    
        cur_site = cur_info_df['Anatomic site'].dropna().unique()
        if len(cur_site) == 0:
            cur_site = 'NA'
        else:
            cur_site = cur_site[0]
        
        #Load slides
        _file = wsi_path + pt + ".tif"
        oslide = openslide.OpenSlide(_file)
        save_name = str(Path(os.path.basename(_file)).with_suffix(''))
        
        #Generate tiles
        tiles, tile_lvls, physSize, base_mag = generate_deepzoom_tiles(oslide,save_image_size, pixel_overlap, limit_bounds)
        
        
        #get level 0 size in px
        l0_w = oslide.level_dimensions[0][0]
        l0_h = oslide.level_dimensions[0][1]
        
        #1.25x tissue detection for mask
        if 'OPX' in pt:
            rad_tissue = 5
        elif '(2017-0133)' in pt:
            rad_tissue = 2
        lvl_resize_tissue = get_downsample_factor(base_mag,target_magnification = mag_target_tiss) #downsample factor
        lvl_img = get_image_at_target_mag(oslide,l0_w, l0_h,lvl_resize_tissue)
        tissue, he_mask = do_mask_original(lvl_img, lvl_resize_tissue, rad = rad_tissue)
        
        #2.5x for probability maps
        lvl_resize = get_downsample_factor(base_mag,target_magnification = mag_target_prob) #downsample factor
        x_map = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
        x_count = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
        x_count_unselected = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
        for index, row in cur_info_df.iterrows():
            cur_xy = row['TILE_XY_INDEXES'].strip("()").split(", ")
            x ,y = int(cur_xy[0]) , int(cur_xy[1])
            
            #Extract tile for prediction
            lvl_in_deepzoom = tile_lvls.index(mag_extract)
            tile_starts, tile_ends, save_coords, tile_coords = extract_tile_start_end_coords(tiles, lvl_in_deepzoom, x, y) #get tile coords
            map_xstart, map_xend, map_ystart, map_yend = get_map_startend(tile_starts,tile_ends,lvl_resize) #Get current tile position in map
    
            #Get labels
            label_save_name = [l + str(int(row[l])) for l in SELECTED_LABEL]
            label_save_name = '_'.join(label_save_name)
    
            #Store predicted probabily in map and count
            if row['TUMOR_PIXEL_PERC'] > TUMOR_FRAC_THRES:
                try: 
                    x_count[map_xstart:map_xend,map_ystart:map_yend] += 1
                    x_map[map_xstart:map_xend,map_ystart:map_yend] += row['Cluster']
                except:
                    pass
            else:
                x_count_unselected[map_xstart:map_xend,map_ystart:map_yend] += 1
        
        #print('post-processing')
        x_count = np.where(x_count < 1, 1, x_count)
        x_map = x_map / x_count
    
        #ajudst for tissue map
        he_mask = cv2.resize(np.uint8(he_mask),(x_map.shape[1],x_map.shape[0])) #resize to output image size
        x_map[he_mask < 1] = -1 #If tissue map value < 1(not tissue), then it is backgound
        #x_map[x_count_unselected > 0] = -2 This can be used to identify the tiles are not selcted for cancer prediction (e.g, white space too large)
    
        if smooth == True:
            x_sm = filters.gaussian(x_map, sigma=2)
        if smooth == False:
            x_sm = x_map
        
    
    
        sp_feature_list.append(x_sm)
        
        plt.imshow(x_sm, cmap='Spectral_r')
        plt.colorbar()
        plt.savefig(os.path.join(outdir1, save_name + '_cluster_' + label_save_name + '_' + cur_site + '.png'), dpi=500,bbox_inches='tight')
        #plt.show()
        plt.close()
    
    #Output spatial cluster feature
    torch.save(sp_feature_list,   outdir2 +  cohort + '_sp_feature.pth')
    torch.save(selected_ids,   outdir2 +  cohort + '_ids.pth')
    

0
10
20
30
40
50
60
pred_union is invalid, buffering...
70
80
90
100
110
120
130
140
0
10
pred_union is invalid, buffering...
20
30
0
