In [1]:
#NOTE: use paimg1 env, the retccl one has package issue with torchvision
import sys
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')
import pandas as pd
import warnings
import torch
import torch.nn as nn

from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path

sys.path.insert(0, '../Utils/')
from Utils import create_dir_if_not_exists
from Utils import generate_deepzoom_tiles, extract_tile_start_end_coords, get_map_startend
from Utils import get_downsample_factor
from Utils import minmax_normalize
from Utils import log_message, set_seed
from Eval import compute_performance, plot_LOSS, compute_performance_each_label, get_attention_and_tileinfo
from train_utils import pull_tiles, get_feature_label_array_dynamic
from train_utils import ModelReadyData_diffdim, convert_to_dict, prediction
from Model import Mutation_MIL_MT
warnings.filterwarnings("ignore")
%matplotlib inline

In [2]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

# Get clustering data
def get_cluster_data(feature_list, label_list, id_list, selected_labels):
    feature_list = [pd.DataFrame(x) for x in feature_list]
    label_list = [y.squeeze() for y in label_list]
    
    for i,x in enumerate(feature_list):
        x['ID'] = id_list[i]
        for j,l in enumerate(selected_labels):
            x[l] = int(label_list[i][j])
    feature_df = pd.concat(feature_list)

    #Change feature tumor frac name
    feature_df.rename(columns = {2048: 'TUMOR_PIXEL_PERC'}, inplace = True)
    
    return feature_df


def get_cluster_label(feature_df, cluster_centers, cluster_features):
    r'''
    Get Cluster label by compute dist between test/valid pcs to the center of kmeans
    '''
    pcs = pca.fit_transform(feature_df[cluster_features])
    distances = np.linalg.norm(cluster_centers[:, np.newaxis] - pcs, axis=2)
    closest_indices = np.argmin(distances, axis=0)
    cluster_labels  = closest_indices

    return cluster_labels
    
def get_updated_feature(input_df, selected_ids, selected_feature):
    feature_list = []
    ct = 0 
    for pt in selected_ids:
        if ct % 10 == 0 : print(ct)

        cur_df = input_df.loc[input_df['ID'] == pt]

        #Extract feature, label and tumor info
        cur_feature = cur_df[selected_feature].values

        feature_list.append(cur_feature)
        ct += 1

    return feature_list

In [3]:
####################################
######      USERINPUT       ########
####################################
SELECTED_LABEL = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
SELECTED_FEATURE = [str(i) for i in range(0,2048)] + ['TUMOR_PIXEL_PERC']
TUMOR_FRAC_THRES = 0
TRAIN_SAMPLE_SIZE = "ALLTUMORTILES"
TRAIN_OVERLAP = 100
TEST_OVERLAP = 0
SELECTED_FOLD = 0

##################
###### DIR  ######
##################
proj_dir = '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/'
wsi_path = proj_dir + '/data/OPX/'
label_path = proj_dir + 'data/MutationCalls/'
ft_ids_path =  proj_dir + 'intermediate_data/cd_finetune/cancer_detection_training/' #the ID used for fine-tuning cancer detection model, needs to be excluded from mutation study
train_tile_path = proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL' + str(TRAIN_OVERLAP) + '/'
test_tile_path =  proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL' + str(TEST_OVERLAP) + '/'
feature_name = 'features_alltiles_retccl'

################################################
#Create output dir
################################################
outdir =   proj_dir + 'intermediate_data/model_ready_data/' +'MAX_SS'+ str(TRAIN_SAMPLE_SIZE)  + '_TrainOL' + str(TRAIN_OVERLAP) +  '_TestOL' + str(TEST_OVERLAP) + '/'
create_dir_if_not_exists(outdir)
outdir1 =  outdir + "/clusters/" 
create_dir_if_not_exists(outdir1)
outdir2 =  outdir + "/split_fold" + str(SELECTED_FOLD) + "/" 
create_dir_if_not_exists(outdir2)

##################
#Select GPU
##################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
set_seed(0)

Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/MAX_SSALLTUMORTILES_TrainOL100_TestOL0/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/MAX_SSALLTUMORTILES_TrainOL100_TestOL0//clusters/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/MAX_SSALLTUMORTILES_TrainOL100_TestOL0//split_fold0/' already exists.
cuda:0


In [4]:
############################################################################################################
#Select IDS
############################################################################################################
#All available IDs
opx_ids = [x.replace('.tif','') for x in os.listdir(wsi_path)] #207
opx_ids.sort()

#Get IDs that are in FT train or already processed to exclude 
ft_ids_df = pd.read_csv(ft_ids_path + 'all_tumor_fraction_info.csv')
ft_train_ids = list(ft_ids_df.loc[ft_ids_df['Train_OR_Test'] == 'Train','sample_id'])

#OPX_182 â€“Exclude Possible Colon AdenoCa 
toexclude_ids = ft_train_ids + ['OPX_182']  #25


#Exclude ids in ft_train or processed
selected_ids = [x for x in opx_ids if x not in toexclude_ids] #199
print(len(selected_ids))

199


In [5]:
############################################################################################################
#Get Train and test IDs, 80% - 20%
############################################################################################################
# Number of folds
n_splits = 5


# Initialize KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

# Generate the folds
train_ids_folds = []
test_ids_folds = []
for fold, (train_index, test_index) in enumerate(kf.split(selected_ids)):
    train_ids_folds.append([selected_ids[i] for i in train_index])
    test_ids_folds.append([selected_ids[i] for i in test_index])

full_train_ids = train_ids_folds[SELECTED_FOLD]
test_ids = test_ids_folds[SELECTED_FOLD]

# Randomly select 5% of the train_ids for validation
train_ids, val_ids = train_test_split(full_train_ids, test_size=0.05, random_state=42)
print(len(train_ids))
print(len(val_ids))
print(len(test_ids))

151
8
40


In [6]:
############################################################################################################
#Get features and labels
#NOTE: OPX_005 has no tumor tiles, so excluded in this step
############################################################################################################
train_feature, train_label, train_info, train_tf_info, selected_train_ids = get_feature_label_array_dynamic(train_tile_path,feature_name, train_ids, SELECTED_LABEL,SELECTED_FEATURE,"Train" ,tumor_fraction_thres = TUMOR_FRAC_THRES,train_sample_size = TRAIN_SAMPLE_SIZE)
val_feature, val_label, val_info, val_tf_info, select_val_ids = get_feature_label_array_dynamic(train_tile_path,feature_name, val_ids, SELECTED_LABEL,SELECTED_FEATURE, "Train" ,tumor_fraction_thres = TUMOR_FRAC_THRES,train_sample_size = TRAIN_SAMPLE_SIZE)
test_feature, test_label, test_info, test_tf_info, select_test_ids = get_feature_label_array_dynamic(test_tile_path, feature_name, test_ids, SELECTED_LABEL, SELECTED_FEATURE, "Test", tumor_fraction_thres = TUMOR_FRAC_THRES)

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
0
0
10
20
30


In [None]:
torch.save(train_feature, outdir2 + 'train_feature.pth')
torch.save(test_feature,  outdir2 + 'test_feature.pth')
torch.save(val_feature,   outdir2 + 'val_feature.pth')

torch.save(train_label, outdir2 + 'train_label.pth')
torch.save(test_label,  outdir2 + 'test_label.pth')
torch.save(val_label,   outdir2 + 'val_label.pth')


torch.save(train_info,   outdir2 + 'train_info.pth')
torch.save(test_info,   outdir2 + 'test_info.pth')
torch.save(val_info,   outdir2 + 'val_info.pth')

torch.save(train_tf_info,   outdir2 + 'train_tf_info.pth')
torch.save(test_tf_info,   outdir2 + 'test_tf_info.pth')
torch.save(val_tf_info,   outdir2 + 'val_tf_info.pth')


torch.save(selected_train_ids,   outdir2 + 'train_ids.pth')
torch.save(select_test_ids,   outdir2 + 'test_ids.pth')
torch.save(select_val_ids,   outdir2 + 'val_ids.pth')

In [None]:
# Count the number of 1s in each column
train_label_np = np.concatenate(train_label)
count_ones = np.sum(train_label_np == 1, axis=0)

print("Number of 1s in each column:", count_ones)
percentage_ones = np.round((count_ones / train_label_np.shape[0]) * 100,1)
print("% of 1s in each column:", percentage_ones)
print(["Mutation labels  :","AR","HR","PTEN","RB1","TP53","TMB","MSI_POS"])

# Count the number of 1s in each column
test_label_np = np.concatenate(test_label)
count_ones = np.sum(test_label_np == 1, axis=0)

print("--------TEST------")
print("Number of 1s in each column:", count_ones)
percentage_ones = np.round((count_ones / test_label_np.shape[0]) * 100,1)
print("% of 1s in each column:", percentage_ones)
print(["Mutation labels  :","AR","HR","PTEN","RB1","TP53","TMB","MSI_POS"])

In [None]:
################################################
#Clustering
################################################
feature_for_cluster = range(0,2048)
train_feature_df = get_cluster_data(train_feature, train_label, selected_train_ids, SELECTED_LABEL)
test_feature_df = get_cluster_data(test_feature, test_label, select_test_ids, SELECTED_LABEL)
valid_feature_df = get_cluster_data(val_feature, val_label, select_val_ids, SELECTED_LABEL)

# Perform PCA
pca = PCA(n_components=2)
principal_components = pca.fit_transform(train_feature_df[feature_for_cluster])

In [None]:
# Calculate WCSS for different values of k
wcss = []
for i in range(1, 11):
    print(i)
    kmeans = KMeans(n_clusters=i, random_state=42)
    kmeans.fit(principal_components)
    wcss.append(kmeans.inertia_)

# Plot the Elbow graph
plt.plot(range(1, 11), wcss)
plt.xlabel('Number of clusters (k)')
plt.ylabel('WCSS')
plt.title('Elbow Method')
plt.show()

In [None]:
# Perform K-means clustering
kmeans = KMeans(n_clusters=4)
kmeans.fit(principal_components)

# Get cluster centers and labels
centers = kmeans.cluster_centers_
cluster_labels_train = kmeans.labels_

# Plot the data points and cluster centers with cluster labels
plt.figure(figsize=(10, 6))
plt.scatter(principal_components[:, 0], principal_components[:, 1], c=cluster_labels_train, cmap='viridis',alpha=0.6)
for i, center in enumerate(centers):
    plt.scatter(center[0], center[1], c='red', marker=f'${i}$', s=200)  # Use cluster label as marker

plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('K-Means Clustering for Tile Embeddings on PCs')
plt.grid(True)
plt.savefig(outdir1  + 'original_cluster_scatter.png')
plt.close()

In [None]:
#Get Cluster labels
cluster_labels_test = get_cluster_label(test_feature_df, centers, feature_for_cluster)
cluster_labels_val = get_cluster_label(valid_feature_df, centers, feature_for_cluster)

#add cluster label to df
train_feature_df['Cluster'] = cluster_labels_train
test_feature_df['Cluster'] = cluster_labels_test
valid_feature_df['Cluster'] = cluster_labels_val


updated_feature_list = list(range(0,2048)) + ['TUMOR_PIXEL_PERC','Cluster']
updated_train_feature = get_updated_feature(train_feature_df, selected_train_ids, updated_feature_list)
updated_test_feature = get_updated_feature(test_feature_df, select_test_ids, updated_feature_list)
updated_val_feature = get_updated_feature(valid_feature_df, select_val_ids, updated_feature_list)

In [None]:
################################################
#     Model ready data 
################################################
train_data = ModelReadyData_diffdim(updated_train_feature,train_label,train_tf_info)
test_data = ModelReadyData_diffdim(updated_test_feature,test_label,test_tf_info)
val_data = ModelReadyData_diffdim(updated_val_feature,val_label,val_tf_info)

#Output
torch.save(train_data, outdir2 + 'train_data.pth')
torch.save(test_data,  outdir2 + 'test_data.pth')
torch.save(val_data,   outdir2 + 'val_data.pth')

torch.save(train_info,   outdir2 + 'train_info.pth')
torch.save(test_info,   outdir2 + 'test_info.pth')
torch.save(val_info,   outdir2 + 'val_info.pth')

torch.save(train_tf_info,   outdir2 + 'train_tf_info.pth')
torch.save(test_tf_info,   outdir2 + 'test_tf_info.pth')
torch.save(val_tf_info,   outdir2 + 'val_tf_info.pth')


torch.save(selected_train_ids,   outdir2 + 'train_ids.pth')
torch.save(select_test_ids,   outdir2 + 'test_ids.pth')
torch.save(select_val_ids,   outdir2 + 'val_ids.pth')

In [None]:
#Cluster anlaysis
#Assign cluster
train_feature_df['PC1'] = principal_components[:, 0]
train_feature_df['PC2'] = principal_components[:, 1]

In [None]:
#Plot scatter cluster plot for each outcome,
for plot_outcome in SELECTED_LABEL:
    plot_data = train_feature_df[['PC1','PC2', 'Cluster'] + [plot_outcome]]
        
    # Create scatter plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Separate dots by Cluster but color by Outcome
    for cluster in plot_data['Cluster'].unique():
        subset = plot_data[plot_data['Cluster'] == cluster]
        ax.scatter(subset['PC1'], subset['PC2'], 
                   s=np.where(subset[plot_outcome] == 1, 20, 0.001), 
                   c=['steelblue' if outcome == 0 else 'darkred' for outcome in subset[plot_outcome]], 
                   alpha=0.6,
                   linewidth=1.5, label=f'Cluster {cluster}',
                   zorder=3 if (subset[plot_outcome] == 1).any() else 2)
    
    
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_title(plot_outcome)
    plt.grid(True)
    plt.savefig(outdir1 +  'cluster_scatter_' + plot_outcome + '.png')
    plt.close()

In [None]:
# Plot distribution of outcome by cluster (stacked bar plot)   
for plot_outcome in SELECTED_LABEL:
    plot_data = train_feature_df[['PC1','PC2', 'Cluster'] + [plot_outcome]]
 
    # Create a crosstab to count the occurrences of each outcome per cluster
    crosstab = pd.crosstab(plot_data['Cluster'], plot_data[plot_outcome])
    
    # Calculate the percentage of each outcome per cluster
    percentage_crosstab = crosstab.div(crosstab.sum(axis=1), axis=0) * 100
    
    # Plot the stacked bar chart
    percentage_crosstab.plot(kind='bar', stacked=True, color=['steelblue', 'darkred'])
    plt.xlabel('Cluster')
    plt.ylabel('Percentage')
    plt.title('Bar Chart of ' + plot_outcome + ' per Cluster')
    plt.legend(title=plot_outcome, loc='center left', bbox_to_anchor=(1.0, 0.5))
    
    plt.tight_layout()
    plt.savefig(outdir1 + "outcome_distribution_" + plot_outcome + '.png')
    plt.close()

In [None]:
# Plot combined distribution of outcome by cluster 
for plot_outcome in SELECTED_LABEL:
    plot_data = train_feature_df[['PC1','PC2', 'Cluster'] + [plot_outcome]]

    # Create a crosstab to count the occurrences of each outcome per cluster
    crosstab = pd.crosstab(plot_data[plot_outcome],plot_data['Cluster'])

    # Calculate the percentage of each outcome per cluster
    percentage_crosstab = crosstab.div(crosstab.sum(axis=1), axis=0) * 100
    
    # Plot the stacked bar chart
    percentage_crosstab.plot(kind='bar', stacked=False, color=['#440154','#3b528b','#5ec962','#fde725'])
    plt.xlabel(plot_outcome)
    plt.ylabel('Percentage')
    plt.title('Bar Chart of Clusters Per Outcome')
    plt.legend(title='Cluster', loc='center left', bbox_to_anchor=(1.0, 0.5))

    plt.tight_layout()
    plt.savefig(outdir1  + 'cluster_distribution_' +  plot_outcome + '.png')
    plt.close()

In [None]:
save_image_size = 250
pixel_overlap = 100
mag_extract = 20
limit_bounds = True
TOP_K = 5
pretrain_model_name = "retccl"
mag_target_prob = 2.5
smooth = False
mag_target_tiss = 1.25

In [None]:
i = 6
pt = selected_train_ids[i]
cur_df = train_feature_df.loc[train_feature_df['ID'] == pt,['ID','Cluster']]
cur_info_df = train_info[i]

print(["Mutation labels  :","AR","HR","PTEN","RB1","TP53","TMB","MSI_POS"])
print(train_label[i])

_file = wsi_path + pt + ".tif"
oslide = openslide.OpenSlide(_file)
save_name = str(Path(os.path.basename(_file)).with_suffix(''))


#Generate tiles
tiles, tile_lvls, physSize, base_mag = generate_deepzoom_tiles(oslide,save_image_size, pixel_overlap, limit_bounds)

#get level 0 size in px
l0_w = oslide.level_dimensions[0][0]
l0_h = oslide.level_dimensions[0][1]

#1.25x tissue detection for mask
from Utils import get_downsample_factor, get_image_at_target_mag
from Utils import do_mask_original,check_tissue,whitespace_check
import cv2
if 'OPX' in pt:
    rad_tissue = 5
elif '(2017-0133)' in pt:
    rad_tissue = 2
lvl_resize_tissue = get_downsample_factor(base_mag,target_magnification = mag_target_tiss) #downsample factor
lvl_img = get_image_at_target_mag(oslide,l0_w, l0_h,lvl_resize_tissue)
tissue, he_mask = do_mask_original(lvl_img, lvl_resize_tissue, rad = rad_tissue)

cur_comb_df = pd.concat([cur_info_df, cur_df], axis = 1)

#2.5x for probability maps
lvl_resize = get_downsample_factor(base_mag,target_magnification = mag_target_prob) #downsample factor
x_map = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
x_count = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)

for index, row in cur_comb_df.iterrows():
    cur_xy = row['TILE_XY_INDEXES'].strip("()").split(", ")
    x ,y = int(cur_xy[0]) , int(cur_xy[1])
    
    #Extract tile for prediction
    lvl_in_deepzoom = tile_lvls.index(mag_extract)
    tile_starts, tile_ends, save_coords, tile_coords = extract_tile_start_end_coords(tiles, lvl_in_deepzoom, x, y) #get tile coords
    map_xstart, map_xend, map_ystart, map_yend = get_map_startend(tile_starts,tile_ends,lvl_resize) #Get current tile position in map

    #Store predicted probabily in map and count
    try: 
        x_count[map_xstart:map_xend,map_ystart:map_yend] += 1
        x_map[map_xstart:map_xend,map_ystart:map_yend] += row['Cluster']
    except:
        pass

print('post-processing')
x_count = np.where(x_count < 1, 1, x_count)
x_map = x_map / x_count

x_sm = x_map

he_mask = cv2.resize(np.uint8(he_mask),(x_sm.shape[1],x_sm.shape[0])) #resize to output image size
#TODO:
#get cancer_mask:
# cancer_mask == 
# x_sm[(he_mask == 1) & (x_sm == 0)] = 0.1 #If tissue map value > 1, then x_sm = 1
x_sm[he_mask < 1] = -1 

plt.imshow(x_sm, cmap='Spectral_r')
plt.colorbar()
plt.savefig(os.path.join(outdir1, save_name + '_cluster.png'), dpi=500,bbox_inches='tight')
plt.show()
plt.close()