In [1]:
#NOTE: use paimg1 env, the retccl one has package issue with torchvision
import sys
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')
import pandas as pd
import warnings
import torch
import torch.nn as nn

from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path

sys.path.insert(0, '../Utils/')
from Utils import generate_deepzoom_tiles
from Utils import get_downsample_factor
from Utils import minmax_normalize
from Utils import log_message
from Eval import compute_performance, plot_LOSS
from Model import Mutation_MIL_MT
from train_utils import pull_tiles, get_feature_label_array, ModelReadyData_MT_V2, convert_to_dict
warnings.filterwarnings("ignore")


In [2]:
##################
###### DIR  ######
##################
proj_dir = '/fh/scratch/delete90/etzioni_r/lucas_l/michael_project/mutation_pred/'
wsi_path = proj_dir + '/data/OPX/'
label_path = proj_dir + 'data/MutationCalls/'
model_path = proj_dir + 'models/feature_extraction_models/'
tile_path = proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL0/'
ft_ids_path =  proj_dir + 'intermediate_data/cd_finetune/cancer_detection_training/' #the ID used for fine-tuning cancer detection model, needs to be excluded from mutation study
pretrain_model_name = 'retccl'

##################
#Select GPU
##################
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
################################################
#Create output dir
################################################
SELECTED_MUTATION = "MT"
model_name = "MIL" #Chose from Linear, LinearMT
outdir = proj_dir + "intermediate_data/pred_out/"  + SELECTED_MUTATION + "/saved_model/" + model_name + "/"
outdir2 = proj_dir + "intermediate_data/pred_out/" + SELECTED_MUTATION + "/model_para/"
outdir3 = proj_dir + "intermediate_data/pred_out/" + SELECTED_MUTATION + "/logs/"

if not os.path.exists(outdir):
    os.makedirs(outdir)
if not os.path.exists(outdir2):
    os.makedirs(outdir2)  
if not os.path.exists(outdir3):
    os.makedirs(outdir3)

In [4]:
############################################################################################################
#Select IDS
############################################################################################################

#All available IDs
opx_ids = [x.replace('.tif','') for x in os.listdir(wsi_path)] #207
opx_ids.sort()

#Get IDs that are in FT train or already processed to exclude 
ft_ids_df = pd.read_csv(ft_ids_path + 'all_tumor_fraction_info.csv')
ft_train_ids = list(ft_ids_df.loc[ft_ids_df['Train_OR_Test'] == 'Train','sample_id'])

#OPX_182 â€“Exclude Possible Colon AdenoCa 
toexclude_ids = ft_train_ids + ['OPX_182']  #25


#Exclude ids in ft_train or processed
selected_ids = [x for x in opx_ids if x not in toexclude_ids] #199

In [5]:
############################################################################################################
#Get Train and test IDs, 80% - 20%
############################################################################################################
# Number of folds
n_splits = 5

# Initialize KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

# Generate the folds
train_ids_folds = []
test_ids_folds = []
for fold, (train_index, test_index) in enumerate(kf.split(selected_ids)):
    train_ids_folds.append([selected_ids[i] for i in train_index])
    test_ids_folds.append([selected_ids[i] for i in test_index])

selected_fold = 0
full_train_ids = train_ids_folds[selected_fold]
test_ids = test_ids_folds[selected_fold]

# Randomly select 5% of the train_ids for validation
train_ids, val_ids = train_test_split(full_train_ids, test_size=0.05, random_state=42)
print(len(train_ids))
print(len(val_ids))
print(len(test_ids))

151
8
40


In [None]:
############################################################################################################
#Get features and labels
############################################################################################################
SELECTED_ID = train_ids
SELECTED_LABEL = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
SELECTED_FEATURE = [str(i) for i in range(0,2048)]

train_feature_np, train_label_np = get_feature_label_array(tile_path,pretrain_model_name, train_ids, SELECTED_LABEL,SELECTED_FEATURE)
test_feature_np, test_label_np = get_feature_label_array(tile_path,pretrain_model_name, test_ids, SELECTED_LABEL,SELECTED_FEATURE)
val_feature_np, val_label_np = get_feature_label_array(tile_path,pretrain_model_name, val_ids, SELECTED_LABEL,SELECTED_FEATURE)

In [None]:
# Count the number of 1s in each column
count_ones = np.sum(train_label_np == 1, axis=0)

print("Number of 1s in each column:", count_ones)
percentage_ones = np.round((count_ones / train_label_np.shape[0]) * 100,2)
print("% of 1s in each column:", percentage_ones)
print(["AR","HR","PTEN","RB1","TP53","TMB","MSI_POS"])

# Count the number of 1s in each column
count_ones = np.sum(test_label_np == 1, axis=0)

print("--------TEST------")
print("Number of 1s in each column:", count_ones)
percentage_ones = np.round((count_ones / test_label_np.shape[0]) * 100,2)
print("% of 1s in each column:", percentage_ones)
print(["AR","HR","PTEN","RB1","TP53","TMB","MSI_POS"])

In [None]:
################################################
#     Model ready data 
################################################
train_data = ModelReadyData_MT_V2(train_feature_np,train_label_np)
test_data = ModelReadyData_MT_V2(test_feature_np,test_label_np)
val_data = ModelReadyData_MT_V2(val_feature_np,val_label_np)

####################################################
#            Train 
####################################################
LEARNING_RATE = 0.1
BATCH_SIZE  = 1
EPOCHS = 200
                 
#Dataloader for training
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
#Construct model
model = Mutation_MIL_MT()
model.to(device)

#Optimizer
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

#Loss
loss_func = torch.nn.BCELoss()

#Model para
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")
#print(model)


#OUTPUT MODEL hyper-para
hyper_df = pd.DataFrame({"Target_Mutation": SELECTED_MUTATION,
                        #"N_Train_Patches": train_df.shape[0],
                        #"N_Train_Features": train_df.shape[1]-1,
                        #"N_Validation_Patches": val_df.shape[0],
                        "BATCH_SIZE": BATCH_SIZE,
                        "N_EPOCH": EPOCHS,
                        "Learning_Rate": LEARNING_RATE,
                        "NUM_MODEL_PARA": total_params}, index = [0])
hyper_df.to_csv(outdir2 + "hyperpara_df.csv")


log_message("Start Training", outdir3 + "training_log.txt")

In [None]:
####################################################################################
#Training
####################################################################################
train_loss = []
valid_loss = []

for epoch in range(EPOCHS):
    running_loss = 0
    ct = 0
    for x,y in train_loader:
        ct += 1
        optimizer.zero_grad() #zero the grad
        yhat_list, _ = model(x.to(device)) #Forward

        loss_list = []
        for i in range(0,7):
            #cur_l = BCE_WithRegularization(yhat_list[i].to('cpu'), y[i].to('cpu'),0.01, 'None',  model, [1,8]) #loss with regularization
            #loss_list.append(cur_l)
            loss_list.append(loss_func(yhat_list[i].squeeze(),y[:,i].to(device))) #compute loss

        for i in range(0,7):
            if i != 6:
                loss_list[i].backward(retain_graph=True)   #backward  
            else:
                loss_list[i].backward() 
        
        #Sum loss
        loss = sum(loss_list)
        optimizer.step() #Optimize
        running_loss += loss.detach().item() #acuumalated average batch loss

    #Training loss 
    epoch_loss = running_loss/len(train_loader) #accumulated loss/total # batches (averaged loss over batches)
    train_loss.append(epoch_loss)

    #Validation
    with torch.no_grad():
        val_running_loss = 0
        for x_val,y_val in val_loader:
            val_yhat_list, _ = model(x_val.to(device))
            
            val_loss_list = []
            for i in range(0,7):
                val_loss_list.append(loss_func(val_yhat_list[i].squeeze(),y_val[:,i].to(device))) #compute loss

            val_loss = sum(val_loss_list)
            val_running_loss += val_loss.detach().item() 
        val_epoch_loss = val_running_loss/len(val_loader) 
        valid_loss.append(val_epoch_loss)
    
    print("Epoch"+ str(epoch) + ":",
          "Train-LOSS:" + "{:.5f}".format(train_loss[epoch]) + ", " +
          "Valid-LOSS:" +  "{:.5f}".format(valid_loss[epoch]))
    
    #Save model parameters
    torch.save(model.state_dict(), outdir + "model" + str(epoch))


#Plot LOSS
plot_LOSS(train_loss,valid_loss, outdir)
log_message("End Training", outdir3 + "training_log.txt")

In [None]:
#Loss
loss_func = torch.nn.BCELoss()
THRES = 0.1
#Prediction
with torch.no_grad():
    pred_prob_list, test_attention = model(test_data.x.to(device))

    #Get y_True list
    y_true_list = test_data.y
    
    #order matches with SELECTED_MUTATION_COLS
    test_loss_list = []
    for i in range(0,7):
        test_loss_list.append(loss_func(pred_prob_list[i].squeeze(),y_true_list[:,i].to(device)).detach().cpu().numpy())
        print("Test-Loss " + SELECTED_LABEL[i] + ":" + "{:.5f}".format(test_loss_list[i]))
    #Total loss
    test_loss = sum(test_loss_list)
    print("Test-Loss TOTAL: " + "{:.5f}".format(test_loss))

    #prediction list per outcome
    test_pred_list = []
    test_pred_class_list = []
    for i in range(0,7):
         test_pred_list.append(torch.flatten(pred_prob_list[i].squeeze().detach().cpu()))
         #test_pred_class_list.append(torch.flatten(torch.round(test_pred_list[i])))
         test_pred_class_list.append([(t > THRES).float().numpy() for t in test_pred_list[i]])

#Prediction df
tile_pred_df_list = []
for i in range(0,7):
   tile_pred_df_list.append(pd.DataFrame({"SAMPLE_IDs":  test_ids, 
                            #"Test_IDs":  test_IDs_df['TEST_IDS'], 
                            "Y_True": torch.flatten(y_true_list[:,i]).tolist(), 
                            "Pred_Prob" :  test_pred_list[i].tolist(),
                            "Pred_Class": [float(value) for value in test_pred_class_list[i]],
                            "OUTCOME": SELECTED_LABEL[i]}))
tile_pred_df = pd.concat(tile_pred_df_list)
tile_pred_df.to_csv(proj_dir + "intermediate_data/pred_out/" + SELECTED_MUTATION + "/tile_pred_df.csv",index = False)

comb_perf_list = []
for mut in SELECTED_LABEL:
    cur_tile_pred_df = tile_pred_df.loc[tile_pred_df['OUTCOME'] == mut]
    cur_tile_level_perf = compute_performance(cur_tile_pred_df['Y_True'],cur_tile_pred_df['Pred_Prob'],cur_tile_pred_df['Pred_Class'],'TILE_LEVEL')
    cur_tile_level_perf['OUTCOME'] = mut
    comb_perf_list.append(cur_tile_level_perf)
comb_perf = pd.concat(comb_perf_list)

################################################
#Create output dir
################################################
indir = proj_dir + "intermediate_data/pred_out/"  + SELECTED_MUTATION + "/"
outdir =  proj_dir + "intermediate_data/pred_out/" + SELECTED_MUTATION + "/perf_table/"
if not os.path.exists(outdir):
    os.makedirs(outdir)

comb_perf.to_csv(outdir + "perf.csv",index = True)
print(comb_perf)
print(comb_perf['AUC'].mean())

In [None]:
pred_prob_list, test_attention = model(test_data.x[0][0:100,].unsqueeze(0).to(device))
print(test_data.x[0][0:100,].shape)

In [None]:
check_model = nn.Linear(2048, 128)
sp1  = test_data.x[0][0:100,]
sp2  = test_data.x[0][0:8,]
print(sp1.shape)
print(check_model(sp1).shape)
check_model(sp2).shape

In [None]:
#TODO: Get attention score
all_att = []
for i in range(0,40):
    cur_id = test_ids[i]
    cur_attent = pd.DataFrame(minmax_normalize(test_attention[i]).cpu().numpy())
    cur_attent.rename(columns = {0: 'ATT'}, inplace = True)
    input_dir = tile_path + cur_id + '/' + 'features/' + 'train_features_' + pretrain_model_name + '.h5'
    cur_label_df = pd.read_hdf(input_dir, key='tile_info')
    cur_att_df = pd.concat([cur_label_df,cur_attent], axis = 1)
    all_att.append(cur_att_df)
all_att_df = pd.concat(all_att)

In [None]:
#Load slide
i = 6
pt = test_ids[i]
print(pt)

save_image_size = 250
pixel_overlap = 100
limit_bounds = True

_file = wsi_path + pt + ".tif"
oslide = openslide.OpenSlide(_file)
save_name = str(Path(os.path.basename(_file)).with_suffix(''))

#Generate tiles
tiles, tile_lvls, physSize, base_mag = generate_deepzoom_tiles(oslide,save_image_size, pixel_overlap, limit_bounds)

#get level 0 size in px
l0_w = oslide.level_dimensions[0][0]
l0_h = oslide.level_dimensions[0][1]

mag_target_prob = 2.5
lvl_resize = get_downsample_factor(base_mag,target_magnification = mag_target_prob) #downsample factor
heatmap = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
#Attention
cur_attent = pd.DataFrame(minmax_normalize(test_attention[i]).cpu().numpy())
cur_attent.rename(columns = {0: 'ATT'}, inplace = True)
input_dir = tile_path + pt + '/' + 'features/' + 'train_features_' + pretrain_model_name + '.h5'
cur_label_df = pd.read_hdf(input_dir, key='tile_info')
cur_att_df = pd.concat([cur_label_df,cur_attent], axis = 1)
cur_att_df = cur_att_df.sort_values(by = ['ATT'], ascending = False) 

#Pull tiles
att_img = pull_tiles(cur_att_df, tiles, tile_lvls)


outdir =  proj_dir + "intermediate_data/pred_out/MT/top_tiles/" + pt + "/"
if not os.path.exists(outdir):
    os.makedirs(outdir)

#Grab tiles and plot
for i in range(0,5): #top5
    cur_row = cur_att_df.iloc[i]
    cur_img = att_img[i]
    #Save tile
    cur_att = cur_row['ATT']
    tile_save_name = "TF" + str(cur_att) + ".png"
    cur_img.save(os.path.join(outdir, tile_save_name))
    
cur_att_df = cur_att_df.sort_values(by = ['ATT'], ascending = True) 

#Pull tiles
att_img = pull_tiles(cur_att_df, tiles, tile_lvls)

outdir =  proj_dir + "intermediate_data/pred_out/MT/bot_tiles/" + pt + "/"
if not os.path.exists(outdir):
    os.makedirs(outdir)

#Grab tiles and plot
for i in range(0,5): #top5
    cur_row = cur_att_df.iloc[i]
    cur_img = att_img[i]
    #Save tile
    cur_att = cur_row['ATT']
    tile_save_name = "TF" + str(cur_att) + ".png"
    cur_img.save(os.path.join(outdir, tile_save_name))

In [None]:
# Apply the function to each row
att_dict = cur_att_df.apply(convert_to_dict, axis=1)

# Map probabilities to the heatmap
for att_coor in att_dict:
    startx, endx, starty, endy  = att_coor['coords']
    prob = att_coor['att']
    
    heatmap[starty:endy+1, startx:endx+1] = prob

# Plot the heatmap
plt.imshow(heatmap, cmap='coolwarm', interpolation='nearest')
plt.colorbar(label='Prediction Probability')
plt.title('Attention Scores')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.show()