In [1]:
#NOTE: use paimg9 env
import sys
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')
import pandas as pd
import warnings
import torch
import torch.nn as nn

from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path
import PIL
from skimage import filters
import random

    
sys.path.insert(0, '../Utils/')
from Utils import create_dir_if_not_exists
from Utils import generate_deepzoom_tiles, extract_tile_start_end_coords, get_map_startend
from Utils import get_downsample_factor
from Utils import minmax_normalize, set_seed
from Utils import log_message
from Eval import compute_performance, plot_LOSS, compute_performance_each_label, get_attention_and_tileinfo
from train_utils import pull_tiles
from train_utils import ModelReadyData_diffdim, convert_to_dict, prediction, BCE_Weighted_Reg, compute_loss_for_all_labels
from Model import Mutation_Multihead
warnings.filterwarnings("ignore")
%matplotlib inline

In [2]:
####################################
######      USERINPUT       ########
####################################
model_name = "MIL" #Chose from Linear, LinearMT
SELECTED_LABEL = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
TRAIN_SAMPLE_SIZE = "ALLTUMORTILES"
TRAIN_OVERLAP = 100
TEST_OVERLAP = 0
SELECTED_FOLD = 0
TUMOR_FRAC_THRES = 0
feature_extraction_method = 'retccl'
INCLUDE_TF = False
INCLUDE_CLUSTER = False
N_CLUSTERS = 4


####
#model Para
LEARNING_RATE = 0.001  #0.00001 
BATCH_SIZE  = 1
ACCUM_SIZE = 16  # Number of steps to accumulate gradients
EPOCHS = 100
DROPOUT = 0.2
DIM_OUT = 128
NUM_HEADS = 8

if INCLUDE_TF == False and INCLUDE_CLUSTER == False:
    N_FEATURE = 2048
elif INCLUDE_TF == True and INCLUDE_CLUSTER == False:
    N_FEATURE = 2049
elif INCLUDE_TF == False and INCLUDE_CLUSTER == True:
    N_FEATURE = 2049
elif INCLUDE_TF == True and INCLUDE_CLUSTER == True:
    N_FEATURE = 2050
            

LOSS_FUNC_NAME = "BCELoss" #"BCE_Weighted_Reg"
REG_COEEF = 0.00001 #0.0000001
REG_TYPE = 'L1'
OPTMIZER = "ADAM"
ATT_REG_FLAG = False
SELECTED_MUTATION = "MSI_POS"

if SELECTED_MUTATION == "MT":
    N_LABELS = len(SELECTED_LABEL)
    LOSS_WEIGHTS_LIST = [[1, 10], [1, 100], [1, 50], [1, 100], [1, 100], [1, 100], [1, 100]]  #NEG, POS
else:
    N_LABELS = 1
    LOSS_WEIGHTS_LIST = [[1, 10], [1, 10], [1, 50], [1, 100], [1, 100], [1, 100], [1, 100]]  #NEG, POS

##################
###### DIR  ######
##################
proj_dir = '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/'
folder_name = feature_extraction_method + '/MAXSS'+ str(TRAIN_SAMPLE_SIZE)  + '_TrainOL' + str(TRAIN_OVERLAP) +  '_TestOL' + str(TEST_OVERLAP) + '_TFT' + str(TUMOR_FRAC_THRES) + "/split_fold" + str(SELECTED_FOLD) + "/" 
wsi_path = proj_dir + '/data/OPX/'
in_data_path = proj_dir + 'intermediate_data/model_ready_data/feature_' + folder_name + "model_input/"

if INCLUDE_TF == False and INCLUDE_CLUSTER == False:
    feature_type = "emb_only"
elif INCLUDE_TF == True and INCLUDE_CLUSTER == False:
    feature_type = "emb_and_tf"
elif INCLUDE_TF == False and INCLUDE_CLUSTER == True:
    feature_type = "emb_and_cluster" + str(N_CLUSTERS)
elif INCLUDE_TF == True and INCLUDE_CLUSTER == True:
    feature_type = "emb_and_tf_and_cluster" + str(N_CLUSTERS) 

model_data_path =  in_data_path + feature_type + "/"
    
################################################
#Create output-dir
################################################
outdir0 =  proj_dir + "intermediate_data/pred_out1212_mutation/" + folder_name + "/DL_" + feature_type + "/" + SELECTED_MUTATION + "/"
outdir1 =  outdir0  + "/saved_model/"
outdir2 =  outdir0  + "/model_para/"
outdir3 =  outdir0  + "/logs/"
outdir4 =  outdir0  + "/predictions/"
outdir5 =  outdir0  + "/perf/"

create_dir_if_not_exists(outdir0)
create_dir_if_not_exists(outdir1)
create_dir_if_not_exists(outdir2)
create_dir_if_not_exists(outdir3)
create_dir_if_not_exists(outdir4)
create_dir_if_not_exists(outdir5)

##################
#Select GPU
##################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out1212_mutation/retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0//DL_emb_only/MSI_POS/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out1212_mutation/retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0//DL_emb_only/MSI_POS//saved_model/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out1212_mutation/retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0//DL_emb_only/MSI_POS//model_para/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out1212_mutation/retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0//DL_emb_only/MSI_POS//logs/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out1212_mutation/retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0/split_fold0//DL_emb_only/

In [3]:
################################################
#     Model ready data 
################################################
train_data = torch.load(model_data_path + 'train_data.pth')
test_data = torch.load(model_data_path + 'test_data.pth')
val_data = torch.load(model_data_path + 'val_data.pth')

train_ids = torch.load(model_data_path + 'train_ids.pth')
test_ids = torch.load(model_data_path + 'test_ids.pth')
test_info  = torch.load(model_data_path + 'test_info.pth')

In [4]:
####################################################
#            Train 
####################################################
set_seed(0)

#Dataloader for training
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)
val_loader = DataLoader(dataset=val_data, batch_size=BATCH_SIZE, shuffle=False)


#Construct model
# model = Mutation_MIL_MT(in_features = N_FEATURE, 
#                         act_func = 'tanh', 
#                         drop_out = DROPOUT,
#                         n_outcomes = N_LABELS,
#                         dim_out = DIM_OUT)

# model.to(device)

model = Mutation_Multihead(in_features = N_FEATURE, num_heads = NUM_HEADS, 
                            embed_dim = 2048, dim_feedforward = 2048, 
                            act_func = 'tanh', drop_out = DROPOUT, n_outcomes = N_LABELS, dim_out = 128)
model.to(device)

#Optimizer
if OPTMIZER == "ADAM":
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
elif OPTMIZER == "SGD":
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

#Loss
if LOSS_FUNC_NAME == "BCE_Weighted_Reg":
    loss_func = BCE_Weighted_Reg(REG_COEEF, REG_TYPE, model, reduction = 'mean', att_reg_flag = ATT_REG_FLAG)
elif LOSS_FUNC_NAME == "BCELoss":
    loss_func = torch.nn.BCELoss()
    

#Model para
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")
#print(model)


#OUTPUT MODEL hyper-para
hyper_df = pd.DataFrame({"Target_Mutation": SELECTED_MUTATION,
                         "TRAIN_OVERLAP": TRAIN_OVERLAP,
                         "TEST_OVERLAP": TEST_OVERLAP,
                         "TRAIN_SAMPLE_SIZE": TRAIN_SAMPLE_SIZE,
                         "TUMOR_FRAC_THRES": TUMOR_FRAC_THRES,
                         "N_FEATURE": N_FEATURE,
                         "N_LABELS": N_LABELS,
                         "BATCH_SIZE": BATCH_SIZE,
                         "ACCUM_SIZE": ACCUM_SIZE,
                         "N_EPOCH": EPOCHS,
                         "OPTMIZER": OPTMIZER,
                         "LEARNING_RATE": LEARNING_RATE,
                         "DROPOUT": DROPOUT,
                         "DIM_OUT": DIM_OUT,
                         "REG_TYPE": REG_TYPE,
                         "REG_COEEF": REG_COEEF,
                         "LOSS_FUNC_NAME": LOSS_FUNC_NAME,
                         "LOSS_WEIGHTS_LIST": str(LOSS_WEIGHTS_LIST),
                         "ATT_REG_FLAG": ATT_REG_FLAG,
                         "NUM_MODEL_PARA": total_params}, index = [0])
hyper_df.to_csv(outdir2 + "hyperpara_df.csv")

Number of parameters: 25186562


In [5]:
log_message("Start Training", outdir3 + "training_log.txt")
####################################################################################
#Training
####################################################################################
train_loss = []
valid_loss = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    ct = 0
    optmizer_loss = 0
    for x,y,tf in train_loader:
        ct += 1
        #optimizer.zero_grad() #zero the grad
        if x.size(1) > 2000:
            indices = torch.randperm(x.size(1))[:2000]
            x = x[:, indices, :]
            #print(x.shape)
        yhat_list, train_att_list = model(x.to(device)) #Forward
        #yhat_list = model(x.to(device)) #Forward

        #compute loss
        loss = compute_loss_for_all_labels(yhat_list, y, LOSS_WEIGHTS_LIST, LOSS_FUNC_NAME, loss_func, device, tf , train_att_list, SELECTED_MUTATION, SELECTED_LABEL)
        #loss = compute_loss_for_all_labels(yhat_list, y, LOSS_WEIGHTS_LIST, LOSS_FUNC_NAME, loss_func, device, tf , None, SELECTED_MUTATION, SELECTED_LABEL)

        running_loss += loss.detach().item() #acuumalated batch loss
        optmizer_loss += loss #accumalted loss for optimizer
       
        #Optimize
        if ct % ACCUM_SIZE == 0:
            optmizer_loss = optmizer_loss/ACCUM_SIZE
            optmizer_loss.backward() 
            optimizer.step()  # Optimize
            optmizer_loss = 0
            optimizer.zero_grad() #gradient reset

    #Training loss 
    epoch_loss = running_loss/len(train_loader) #accumulated loss/total # batches (averaged loss over batches)
    train_loss.append(epoch_loss)

    #Validation
    model.eval()
    with torch.no_grad():
        val_running_loss = 0
        for x_val,y_val,tf_val in val_loader:
            val_yhat_list, val_att_list = model(x_val.to(device))
            #val_yhat_list = model(x_val.to(device))
            val_loss = compute_loss_for_all_labels(val_yhat_list, y_val, LOSS_WEIGHTS_LIST, LOSS_FUNC_NAME, loss_func, device, tf_val, val_att_list, SELECTED_MUTATION, SELECTED_LABEL)
            #val_loss = compute_loss_for_all_labels(val_yhat_list, y_val, LOSS_WEIGHTS_LIST, LOSS_FUNC_NAME, loss_func, device, tf_val, None, SELECTED_MUTATION, SELECTED_LABEL)
            val_running_loss += val_loss.detach().item() 
        val_epoch_loss = val_running_loss/len(val_loader) 
        valid_loss.append(val_epoch_loss)

    if epoch % 10 == 0:
        print("Epoch"+ str(epoch) + ":",
              "Train-LOSS:" + "{:.5f}".format(train_loss[epoch]) + ", " +
              "Valid-LOSS:" +  "{:.5f}".format(valid_loss[epoch]))
    
    #Save model parameters
    torch.save(model.state_dict(), outdir1 + "model" + str(epoch))


#Plot LOSS
plot_LOSS(train_loss,valid_loss, outdir1)
log_message("End Training", outdir3 + "training_log.txt")

#SAVE VALIDATION LOSS
valid_loss_df  = pd.DataFrame({"VALID_LOSS": valid_loss})
valid_loss_df.to_csv(outdir1 + "Valid_LOSS.csv")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x2048 and 128x1)

In [None]:
## #Testing
####################################################################################
#Load model
valid_loss_df = pd.read_csv(outdir1 + "Valid_LOSS.csv")
min_index = valid_loss_df['VALID_LOSS'].idxmin()
print(min_index)
#min_index = 99
model = Mutation_Multihead(in_features = N_FEATURE, num_heads = 2, 
                            embed_dim = 128, dim_feedforward = 2048, 
                            act_func = 'tanh', drop_out = DROPOUT, n_outcomes = N_LABELS, dim_out = 128)
state_dict = torch.load(outdir1 + "model" + str(min_index))

#model2 = Mutation_MIL_MT(in_features = 2048, act_func = 'tanh', drop_out = DROPOUT)
#model_dir = "/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/z_old_mutation_prediction_results/mutation_pred_out_11272024/MAX_SS0_NFEATURES2048/MT/saved_model/MIL/"
#state_dict = torch.load(model_dir + "model" + str(499))
model.load_state_dict(state_dict)
model.to(device)


#Loss function
loss_func = torch.nn.BCELoss()
THRES = 0.5

#predicts
test_pred_prob, test_true_label, test_att, test_loss = prediction(test_loader, model, N_LABELS, loss_func, device, SELECTED_MUTATION, SELECTED_LABEL, attention = True)
print("Test-Loss TOTAL: " + "{:.5f}".format(test_loss))


#Prediction df
pred_df_list = []
for i in range(0,N_LABELS):
    if N_LABELS > 1:
        cur_pred_df = pd.DataFrame({"SAMPLE_IDs":  test_ids, 
                                              "Y_True": [l[i] for l in test_true_label], 
                                              "Pred_Prob" :  [l[i] for l in test_pred_prob],
                                              #"Pred_Prob" :  test_pred_prob,
                                              "OUTCOME": SELECTED_LABEL[i]})
    else:
        cur_pred_df = pd.DataFrame({"SAMPLE_IDs":  test_ids, 
                                    "Y_True": [l[i] for l in test_true_label], 
                                    "Pred_Prob" :  test_pred_prob,
                                    "OUTCOME": SELECTED_MUTATION})
        
    pred_df_list.append(cur_pred_df)
pred_df = pd.concat(pred_df_list)

#Add Predict class
pred_df['Pred_Class'] = 0
pred_df.loc[pred_df['Pred_Prob'] > THRES,'Pred_Class'] = 1
pred_df.to_csv(outdir4 + "/pred_df.csv",index = False)


#Compute performance
if SELECTED_MUTATION == "MT":
    perf_df = compute_performance_each_label(SELECTED_LABEL, pred_df, "SAMPLE_LEVEL")
else:
    perf_df = compute_performance_each_label([SELECTED_MUTATION], pred_df, "SAMPLE_LEVEL")
perf_df.to_csv(outdir5 + "/perf.csv",index = True)

print(perf_df.iloc[:,[0,5,6,7,8,9]])
print("AVG AUC:", round(perf_df['AUC'].mean(),2))
print("AVG PRAUC:", round(perf_df['PR_AUC'].mean(),2))

In [None]:
pred_df.loc[pred_df['OUTCOME'] == 'MSI_POS']

In [None]:
####################################################################################
#Atention scores
####################################################################################
save_image_size = 250
pixel_overlap = 0
mag_extract = 20
limit_bounds = True
TOP_K = 5
pretrain_model_name = "retccl"
mag_target_prob = 2.5
smooth = False
mag_target_tiss = 1.25

In [None]:
i = test_ids.index('OPX_011')
pt = test_ids[i]
print(pt)

save_location =  outdir4  + pt + "/"
create_dir_if_not_exists(save_location)

_file = wsi_path + pt + ".tif"
oslide = openslide.OpenSlide(_file)
save_name = str(Path(os.path.basename(_file)).with_suffix(''))


#Get a Attention, and corresponding tiles
cur_pt_att = test_att[i]
cur_pt_info = test_info[i]
cur_att_df = get_attention_and_tileinfo(cur_pt_info, cur_pt_att)

In [None]:
#Generate tiles
tiles, tile_lvls, physSize, base_mag = generate_deepzoom_tiles(oslide,save_image_size, pixel_overlap, limit_bounds)

#get level 0 size in px
l0_w = oslide.level_dimensions[0][0]
l0_h = oslide.level_dimensions[0][1]

#1.25x tissue detection for mask
from Utils import get_downsample_factor, get_image_at_target_mag
from Utils import do_mask_original,check_tissue,whitespace_check
import cv2
if 'OPX' in pt:
    rad_tissue = 5
elif '(2017-0133)' in pt:
    rad_tissue = 2
lvl_resize_tissue = get_downsample_factor(base_mag,target_magnification = mag_target_tiss) #downsample factor
lvl_img = get_image_at_target_mag(oslide,l0_w, l0_h,lvl_resize_tissue)
tissue, he_mask = do_mask_original(lvl_img, lvl_resize_tissue, rad = rad_tissue)

In [None]:
#2.5x for probability maps
lvl_resize = get_downsample_factor(base_mag,target_magnification = mag_target_prob) #downsample factor
x_map = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)
x_count = np.zeros((int(np.ceil(l0_h/lvl_resize)),int(np.ceil(l0_w/lvl_resize))), float)

cur_att_df['pred_map_location'] = pd.NA
for index, row in cur_att_df.iterrows():
    cur_xy = row['TILE_XY_INDEXES'].strip("()").split(", ")
    x ,y = int(cur_xy[0]) , int(cur_xy[1])
    
    #Extract tile for prediction
    lvl_in_deepzoom = tile_lvls.index(mag_extract)
    tile_starts, tile_ends, save_coords, tile_coords = extract_tile_start_end_coords(tiles, lvl_in_deepzoom, x, y) #get tile coords
    map_xstart, map_xend, map_ystart, map_yend = get_map_startend(tile_starts,tile_ends,lvl_resize) #Get current tile position in map
    cur_att_df.loc[index,'pred_map_location'] = str(tuple([map_xstart, map_xend, map_ystart, map_yend]))

    #Store predicted probabily in map and count
    try: 
        x_count[map_xstart:map_xend,map_ystart:map_yend] += 1
        x_map[map_xstart:map_xend,map_ystart:map_yend] += row['ATT']
    except:
        pass

print('post-processing')
x_count = np.where(x_count < 1, 1, x_count)
x_map = x_map / x_count
x_map[x_map>1]=1

if smooth == True:
    x_sm = filters.gaussian(x_map, sigma=2)
if smooth == False:
    x_sm = x_map

he_mask = cv2.resize(np.uint8(he_mask),(x_sm.shape[1],x_sm.shape[0])) #resize to output image size
#TODO:
#get cancer_mask:
# cancer_mask == 
# x_sm[(he_mask == 1) & (x_sm == 0)] = 0.1 #If tissue map value > 1, then x_sm = 1
x_sm[he_mask < 1] = 0.001 

plt.imshow(x_sm, cmap='Spectral_r')
plt.colorbar()
plt.savefig(os.path.join(save_location, save_name + '_attention.png'), dpi=500,bbox_inches='tight')
plt.show()
plt.close()


#Top attented tiles
save_location2 = save_location + "top_tiles/"
create_dir_if_not_exists(save_location2)

#Get a Attention, and corresponding tiles
cur_att_df= cur_att_df.sort_values(by = ['ATT'], ascending = False) 
cur_pulled_img_obj = pull_tiles(cur_att_df.iloc[0:TOP_K], tiles, tile_lvls)

for i in range(TOP_K):
    cur_pulled_img = cur_pulled_img_obj[i][0] #image
    cur_pulled_att = cur_pulled_img_obj[i][1] #attentiom
    cur_pulled_coord = cur_pulled_img_obj[i][2].strip("()").split(", ")  #att tile map coordiates
    coord_save_name = '[xs' + cur_pulled_coord[0] + '_xe' + cur_pulled_coord[1] + '_ys' + cur_pulled_coord[2] + '_ye' + cur_pulled_coord[3] + "]"
    tile_save_name = "ATT" + str(round(cur_pulled_att,2)) + "_MAPCOORD" +  coord_save_name +  ".png"
    cur_pulled_img.save(os.path.join(save_location2, tile_save_name))

#Bot attented tiles
save_location2 = save_location + "bot_tiles/"
create_dir_if_not_exists(save_location2)

#Get a Attention, and corresponding tiles
cur_att_df= cur_att_df.sort_values(by = ['ATT'], ascending = True) 
cur_pulled_img_obj = pull_tiles(cur_att_df.iloc[0:TOP_K], tiles, tile_lvls)

for i in range(TOP_K):
    cur_pulled_img = cur_pulled_img_obj[i][0] #image
    cur_pulled_att = cur_pulled_img_obj[i][1] #attentiom
    cur_pulled_coord = cur_pulled_img_obj[i][2].strip("()").split(", ")  #att tile map coordiates
    coord_save_name = '[xs' + cur_pulled_coord[0] + '_xe' + cur_pulled_coord[1] + '_ys' + cur_pulled_coord[2] + '_ye' + cur_pulled_coord[3] + "]"
    tile_save_name = "ATT" + str(round(cur_pulled_att,2)) + "_MAPCOORD" +  coord_save_name +  ".png"
    cur_pulled_img.save(os.path.join(save_location2, tile_save_name))

In [None]:
#TODO LIST
#1.Attention score
#2. zero gradiant place