In [1]:
#NOTE: use python env acmil in ACMIL folder
import sys
import os
import numpy as np
import openslide
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')
import pandas as pd
import warnings
import torch
import torch.nn as nn

from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader, Subset, ConcatDataset
import torch.optim as optim
from pathlib import Path
import PIL
from skimage import filters
import random

    
sys.path.insert(0, '../Utils/')
from Utils import create_dir_if_not_exists
from Utils import generate_deepzoom_tiles, extract_tile_start_end_coords, get_map_startend
from Utils import get_downsample_factor
from Utils import minmax_normalize, set_seed
from Utils import log_message
from Eval import compute_performance, plot_LOSS, compute_performance_each_label, get_attention_and_tileinfo
from Eval import get_performance
from train_utils import pull_tiles, FocalLoss, ModelReadyData_Instance_based, modify_to_instance_based
from train_utils import convert_to_dict, prediction_sepatt, BCE_Weighted_Reg, BCE_Weighted_Reg_focal, compute_loss_for_all_labels_sepatt
from Model import Mutation_MIL_MT_sepAtt #, Mutation_MIL_MT
from ACMIL import ACMIL_GA_MultiTask
warnings.filterwarnings("ignore")
%matplotlib inline


#FOR ACMIL
current_dir = os.getcwd()
grandparent_subfolder = os.path.join(current_dir, '..', '..', 'other_model_code','ACMIL-main')
grandparent_subfolder = os.path.normpath(grandparent_subfolder)
sys.path.insert(0, grandparent_subfolder)
from utils.utils import save_model, Struct, set_seed
import yaml
import sys
import os
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import yaml
from pprint import pprint

import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader

from utils.utils import save_model, Struct, set_seed
from datasets.datasets import build_HDF5_feat_dataset
from architecture.transformer import ACMIL_GA #ACMIL_GA
from architecture.transformer import ACMIL_MHA
import torch.nn.functional as F

from utils.utils import MetricLogger, SmoothedValue, adjust_learning_rate
from timm.utils import accuracy
import torchmetrics
import wandb

In [2]:
####################################
######      USERINPUT       ########
####################################
ALL_LABELS = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
TUMOR_FRAC_THRES = 0.9 
feature_extraction_method = 'uni2' #retccl, uni1
focal_gamma = 2
focal_alpha = 0.1
loss_method = '' #ATTLOSS

################################
#model Para
BATCH_SIZE  = 1
DROPOUT = 0
DIM_OUT = 128
SELECTED_MUTATION = "MT"
SELECTED_FOLD = 0

if feature_extraction_method == 'retccl':
    SELECTED_FEATURE = [str(i) for i in range(0,2048)] + ['TUMOR_PIXEL_PERC'] #If retccl 2048, if uni 1024
    N_FEATURE = 2048
elif feature_extraction_method == 'uni1': 
    SELECTED_FEATURE = [str(i) for i in range(0,1024)] + ['TUMOR_PIXEL_PERC'] #If retccl 2048, if uni 1024
    N_FEATURE = 1024
elif feature_extraction_method == 'uni2':
    SELECTED_FEATURE = [str(i) for i in range(0,1536)] + ['TUMOR_PIXEL_PERC'] #If retccl 2048, if uni 1024
    N_FEATURE = 1536

################################
# get config
config_dir = "myconf.yml"
with open(config_dir, "r") as ymlfile:
    c = yaml.load(ymlfile, Loader=yaml.FullLoader)
    #c.update(vars(args))
    conf = Struct(**c)

conf.train_epoch = 100
conf.D_feat = N_FEATURE
conf.D_inner = DIM_OUT
conf.wandb_mode = 'disabled'
conf.n_task = 7

##################
###### DIR  ######
##################
proj_dir = '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/'
folder_name_overlap = "IMSIZE250_OL100"
folder_name_nonoverlap = "IMSIZE250_OL0"
feature_path_opx_train =  os.path.join(proj_dir + 'intermediate_data/Old_5_model_ready_data', "OPX", folder_name_overlap, 'feature_' + feature_extraction_method, 'TFT' + str(TUMOR_FRAC_THRES))
feature_path_opx_test =  os.path.join(proj_dir + 'intermediate_data/Old_5_model_ready_data', "OPX", folder_name_nonoverlap, 'feature_' + feature_extraction_method, 'TFT' + str(TUMOR_FRAC_THRES))
feature_path_tma = os.path.join(proj_dir + 'intermediate_data/5_model_ready_data', "TAN_TMA_Cores",folder_name_nonoverlap, 'feature_' + feature_extraction_method, 'TFT' + str(TUMOR_FRAC_THRES))
folder_name_ids = 'uni1/TrainOL100_TestOL0_TFT' + str(TUMOR_FRAC_THRES)  + "/"
train_val_test_id_path =  os.path.join(proj_dir + 'intermediate_data/6_Train_TEST_IDS', folder_name_ids)

In [3]:
######################
#Create output-dir
################################################
folder_name1 = feature_extraction_method + '/TrainOL100_TestOL0_TFT' + str(TUMOR_FRAC_THRES)  + "/"
outdir0 =  proj_dir + "intermediate_data/pred_out02262025——Instance_based" + "/" + folder_name1 + 'FOLD' + str(SELECTED_FOLD) + '/' + SELECTED_MUTATION + "/" 
outdir1 =  outdir0  + "/saved_model/"
outdir2 =  outdir0  + "/model_para/"
outdir3 =  outdir0  + "/logs/"
outdir4 =  outdir0  + "/predictions/"
outdir5 =  outdir0  + "/perf/"


create_dir_if_not_exists(outdir0)
create_dir_if_not_exists(outdir1)
create_dir_if_not_exists(outdir2)
create_dir_if_not_exists(outdir3)
create_dir_if_not_exists(outdir4)
create_dir_if_not_exists(outdir5)

##################
#Select GPU
##################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL100_TestOL0_TFT0.9/FOLD0/MT/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL100_TestOL0_TFT0.9/FOLD0/MT//saved_model/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL100_TestOL0_TFT0.9/FOLD0/MT//model_para/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL100_TestOL0_TFT0.9/FOLD0/MT//logs/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL100_TestOL0_TFT0.9/FOLD0/MT//predictions/' already exists.
Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/pred_out02262025——Instance_based/uni2/TrainOL

In [None]:
################################################
#     Model ready data 
################################################
opx_data_ol100 = torch.load(feature_path_opx_train + '/OPX_data.pth')
opx_ids_ol100 = torch.load(feature_path_opx_train + '/OPX_ids.pth')
opx_info_ol100  = torch.load(feature_path_opx_train + '/OPX_info.pth')

opx_data_ol0 = torch.load(feature_path_opx_test + '/OPX_data.pth')
opx_ids_ol0 = torch.load(feature_path_opx_test + '/OPX_ids.pth')
opx_info_ol0  = torch.load(feature_path_opx_test + '/OPX_info.pth')

tma_data = torch.load(feature_path_tma + '/tma_data.pth')
tma_ids = torch.load(feature_path_tma + '/tma_ids.pth')
tma_info  = torch.load(feature_path_tma + '/tma_info.pth')

In [None]:
########################################################
#Update tma
########################################################
haslabel_indexes = []
for i in range(len(tma_data)):
    if torch.isnan(tma_data[i][1]).all() == False:
        #print(f"Item {i} has the second element all NaNs.")
        haslabel_indexes.append(i)


tma_data = Subset(tma_data, haslabel_indexes)
tma_ids = list(Subset(tma_ids, haslabel_indexes))
tma_info = list(Subset(tma_info, haslabel_indexes))
len(tma_info) #355 if TF0.9, a lot of cores does not have enough cancer tiles > 0.9

In [None]:
################################################
#Get train, test IDs
#NOTE: this was in the old train: ['OPX_207','OPX_209','OPX_213','OPX_214','OPX_215']
################################################
train_test_val_id_df = pd.read_csv(train_val_test_id_path + "train_test_split.csv")
train_ids_all = list(train_test_val_id_df.loc[train_test_val_id_df['FOLD' + str(SELECTED_FOLD)] == 'TRAIN', 'SAMPLE_ID'])
test_ids_all = list(train_test_val_id_df.loc[train_test_val_id_df['FOLD' + str(SELECTED_FOLD)] == 'TEST', 'SAMPLE_ID'])
val_ids_all = list(train_test_val_id_df.loc[train_test_val_id_df['FOLD' + str(SELECTED_FOLD)] == 'VALID', 'SAMPLE_ID'])

In [None]:
################################################
#Get Train, test, val data
################################################
#Train:
inc_idx = [opx_ids_ol100.index(x) for x in train_ids_all]
train_data = Subset(opx_data_ol100, inc_idx)
train_ids =  list(Subset(opx_ids_ol100, inc_idx))
train_info = list(Subset(opx_info_ol100, inc_idx))

#Val:
inc_idx = [opx_ids_ol100.index(x) for x in val_ids_all]
val_data = Subset(opx_data_ol100, inc_idx)
val_ids =  list(Subset(opx_ids_ol100, inc_idx))
val_info = list(Subset(opx_info_ol100, inc_idx))

#Test:
inc_idx = [opx_ids_ol0.index(x) for x in test_ids_all]
test_data = Subset(opx_data_ol0, inc_idx)
test_ids =  list(Subset(opx_ids_ol0, inc_idx))
test_info = list(Subset(opx_info_ol0, inc_idx))

In [None]:
#count labels in train
train_label_counts = [dt[1] for dt in train_data]
train_label_counts = torch.concat(train_label_counts)
count_ones = (train_label_counts == 1).sum(dim=0)
print(count_ones)
perc_ones = count_ones/train_label_counts.shape[0] * 100
formatted_numbers = [f"{x.item():.1f}" for x in perc_ones]
print(formatted_numbers)

#count labels in test
test_label_counts = [dt[1] for dt in test_data]
test_label_counts = torch.concat(test_label_counts)
count_ones = (test_label_counts == 1).sum(dim=0)
print(count_ones)
perc_ones = count_ones/test_label_counts.shape[0] * 100
formatted_numbers = [f"{x.item():.1f}" for x in perc_ones]
print(formatted_numbers)

#count labels in tma
tma_label_counts = [dt[1] for dt in tma_data] 
tma_label_counts = torch.concat(tma_label_counts)
count_ones = (tma_label_counts == 1).sum(dim=0)
print(count_ones)
perc_ones = count_ones/tma_label_counts.shape[0] * 100
formatted_numbers = [f"{x.item():.1f}" for x in perc_ones]
print(formatted_numbers) #["AR","PTEN","RB1","TP53"

In [None]:
print(len(train_data))
print(len(val_data))
print(len(test_data))

In [None]:
# Modify data
train_data = modify_to_instance_based(train_data)
#test_data = modify_to_instance_based(test_data)
#val_data = modify_to_instance_based(val_data)
#tma_data = modify_to_instance_based(tma_data)

In [None]:
train_data[0][1].shape

In [None]:
####################################################
#            Train 
####################################################
set_seed(0)
BATCH_SIZE = 32
#Dataloader for training
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)
val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False)
tma_loader = DataLoader(dataset=tma_data, batch_size=1, shuffle=False)

## TDOD
## For TEST, one batch one patient, prediction if there is one tile predicted postive, then postive

In [None]:
class Instance_Based(nn.Module):
    def __init__(self, in_features = 2048, act_func = 'tanh', drop_out = 0, n_outcomes = 7, dim_out = 128):
        super().__init__()
        self.in_features = in_features  
        self.n_outs = n_outcomes     # number of outcomes
        self.d_out = dim_out          # dim of output layers
        self.drop_out = drop_out

        if act_func == 'leakyrelu':
            self.act_func = nn.LeakyReLU()
        if act_func == 'tanh':
            self.act_func = nn.Tanh()
        elif act_func == 'relu':
            self.act_func = nn.ReLU()
 
        self.embedding_layer = nn.Sequential(
            nn.Linear(self.in_features, 1024), #linear layer
            self.act_func,
            nn.Linear(1024, 512), #linear layer
            self.act_func,
            nn.Linear(512, 256), #linear layer
            self.act_func,
            nn.Linear(256, self.d_out), #linear layer
        )

        #Outcome layers
        self.hidden_layers =  nn.ModuleList([nn.Linear(self.d_out, 1) for _ in range(self.n_outs)])        
        
        self.dropout = nn.Dropout(p=drop_out)

    def forward(self, x):
        r'''
        x size: [1, N_TILE ,N_FEATURE]
        '''
        #out = x
        
        #Linear
        x = self.embedding_layer(x) #[1, N_TILE ,d_out]

        out = []
        for i in range(len(self.hidden_layers)):
            cur_out = self.hidden_layers[i](x) #[BS, 1]
            out.append(cur_out)

        #Drop out
        if self.drop_out > 0:
            for i in range(len(self.hidden_layers)):
                out[i] = self.dropout(out[i])
        
        # # predict 
        # for i in range(len(self.hidden_layers)):
        #     out[i] = torch.sigmoid(out[i])
        
        return out

In [None]:
# define network
model = Instance_Based(in_features = N_FEATURE, act_func = 'tanh', drop_out = 0, n_outcomes = 7, dim_out = DIM_OUT)
model.to(device)

# Example usage:
criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma, reduction='mean')
#criterion = nn.CrossEntropyLoss()


In [None]:
def train_one_epoch_instance_based(model, criterion, data_loader, optimizer0, device, epoch, conf, loss_method = 'none'):
    """
    Trains the given network for one epoch according to given criterions (loss functions)
    """

    # Set the network to training mode
    model.train()

    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 5000


    for data_it, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # for data_it, data in enumerate(data_loader, start=epoch * len(data_loader)):
        # Move input batch onto GPU if eager execution is enabled (default), else leave it on CPU
        # Data is a dict with keys `input` (patches) and `{task_name}` (labels for given task)
        image_patches = data[0].to(device, dtype=torch.float32)
        label_lists = data[1]

        # # Calculate and set new learning rate
        adjust_learning_rate(optimizer0, epoch + data_it/len(data_loader), conf)

        # Compute loss        
        out = model(image_patches) #torch.Size([BS, 1])
        
        #Compute loss for each task, then sum
        loss = 0
        for k in range(conf.n_task):
            slide_preds = out[k]
            labels = label_lists[:,k].to(device, dtype = torch.float32).to(device)
            loss += criterion(slide_preds, labels.unsqueeze(1))
            pred = torch.sigmoid(slide_preds)
            
        optimizer0.zero_grad()
        # Backpropagate error and update parameters
        loss.backward()
        optimizer0.step()


        metric_logger.update(lr=optimizer0.param_groups[0]['lr'])
        metric_logger.update(total_loss=loss.item())


In [None]:
# Disable gradient calculation during evaluation
@torch.no_grad()
def evaluate_instance_based(net, criterion, data_loader, device, conf, header):

    # Set the network to evaluation mode
    net.eval()

    y_pred = []
    y_true = []

    metric_logger = MetricLogger(delimiter="  ")

    for data in metric_logger.log_every(data_loader, 100, header):
        image_patches = data[0].to(device, dtype=torch.float32)
        label_lists = data[1]

        out = model(image_patches) #torch.Size([BS, 1])
        
        #Compute loss for each task, then sum
        loss = 0
        pred_list = []
        acc1_list = []
        for k in range(conf.n_task):
            slide_preds = out[k]
            labels = label_lists[:,k].to(device, dtype = torch.int64).to(device)
            loss += criterion(slide_preds, labels)
            pred = torch.sigmoid(slide_preds)
            acc1 = accuracy(pred, labels, topk=(1,))[0]
            pred_list.append(pred)
            acc1_list.append(acc1)
            
        avg_acc = sum(acc1_list)/conf.n_task

        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(avg_acc.item(), n=labels.shape[0])

        y_pred.append(pred_list)
        y_true.append(label_lists)

    #Get prediction for each task
    y_pred_tasks = []
    y_true_tasks = []
    for k in range(conf.n_task):
        y_pred_tasks.append([p[k] for p in y_pred])
        y_true_tasks.append([t[:,k].to(device, dtype = torch.int64) for t in y_true])
    
    #get performance for each calss
    auroc_each = 0
    f1_score_each = 0
    for k in range(conf.n_task):
        y_pred_each = torch.cat(y_pred_tasks[k], dim=0)
        y_true_each = torch.cat(y_true_tasks[k], dim=0)
    
        AUROC_metric = torchmetrics.AUROC(num_classes = conf.n_class, task='multiclass').to(device)
        AUROC_metric(y_pred_each, y_true_each)
        auroc_each += AUROC_metric.compute().item()
    
        F1_metric = torchmetrics.F1Score(num_classes = conf.n_class, task='multiclass').to(device)
        F1_metric(y_pred_each, y_true_each)
        f1_score_each += F1_metric.compute().item()
        print("AUROC",str(k),":",AUROC_metric.compute().item())
    auroc = auroc_each/conf.n_task
    f1_score = f1_score_each/conf.n_task

    # print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f} auroc {AUROC:.3f} f1_score {F1:.3f}'
    #       .format(top1=metric_logger.acc1, losses=metric_logger.loss, AUROC=auroc, F1=f1_score))

    # return auroc, metric_logger.acc1.global_avg, f1_score, metric_logger.loss.global_avg

In [None]:
print(slide_preds.shape)
print(labels.unsqueeze(0).shape)

In [None]:
# Set the network to evaluation mode
model.eval()

y_pred = []
y_true = []

metric_logger = MetricLogger(delimiter="  ")

for data in metric_logger.log_every(test_loader, 100000, ""):
    image_patches = data[0].to(device, dtype=torch.float32)
    label_lists = data[1]

    out = model(image_patches) #torch.Size([BS, 1])
    
    #Compute loss for each task, then sum
    loss = 0
    pred_list = []
    acc1_list = []
    for k in range(conf.n_task):
        slide_preds = out[k]
        labels = label_lists[:,k].to(device, dtype = torch.float32).to(device)
        loss += criterion(slide_preds, labels.unsqueeze(0))
        pred = torch.sigmoid(slide_preds)
        acc1 = accuracy(pred, labels, topk=(1,))[0]
        pred_list.append(pred)
        acc1_list.append(acc1)
        
    avg_acc = sum(acc1_list)/conf.n_task

    metric_logger.update(loss=loss.item())
    metric_logger.meters['acc1'].update(avg_acc.item(), n=labels.shape[0])
    
    y_pred.append(pred_list)
    y_true.append(label_lists)

In [None]:
label_lists[:,k]

In [None]:
#Get prediction for each task
y_pred_tasks = []
y_true_tasks = []
for k in range(conf.n_task):
    y_pred_tasks.append([p[k] for p in y_pred])
    y_true_tasks.append([t[:,k].to(device, dtype = torch.int64) for t in y_true])

#get performance for each calss
auroc_each = 0
f1_score_each = 0
for k in range(conf.n_task):
    y_pred_each = torch.cat(y_pred_tasks[k], dim=0)
    y_true_each = torch.cat(y_true_tasks[k], dim=0)

    AUROC_metric = torchmetrics.AUROC(num_classes = conf.n_class, task='multiclass').to(device)
    AUROC_metric(y_pred_each, y_true_each)
    auroc_each += AUROC_metric.compute().item()

    F1_metric = torchmetrics.F1Score(num_classes = conf.n_class, task='multiclass').to(device)
    F1_metric(y_pred_each, y_true_each)
    f1_score_each += F1_metric.compute().item()
    print("AUROC",str(k),":",AUROC_metric.compute().item())
auroc = auroc_each/conf.n_task
f1_score = f1_score_each/conf.n_task

In [None]:
evaluate_instance_based(model, criterion, val_loader, device, conf, 'Val')

In [None]:
ckpt_dir = outdir1 + SELECTED_MUTATION + "/"
create_dir_if_not_exists(ckpt_dir)

# define optimizer, lr not important at this point
optimizer0 = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, weight_decay=conf.wd)


best_state = {'epoch':-1, 'val_acc':0, 'val_auc':0, 'val_f1':0, 'test_acc':0, 'test_auc':0, 'test_f1':0}
train_epoch = conf.train_epoch
for epoch in range(train_epoch):
    train_one_epoch_instance_based(model, criterion, train_loader, optimizer0, device, epoch, conf, loss_method)

    #val_auc, val_acc, val_f1, val_loss = evaluate_instance_based(model, criterion, val_loader, device, conf, 'Val')
    #test_auc, test_acc, test_f1, test_loss = evaluate_instance_based(model, criterion, test_loader, device, conf, 'Test')
    #tma_auc, tma_acc, tma_f1, tma_loss = evaluate_multitask(model, criterion, tma_loader, device, conf, 'TMA')

    save_model(conf=conf, model=model, optimizer=optimizer0, epoch=epoch,
        save_path=os.path.join(ckpt_dir + 'checkpoint_' + 'epoch' + str(epoch) + '.pth'))
print("Results on best epoch:")
print(best_state)
wandb.finish()

In [None]:
# define network
model2 = Instance_Based(in_features = N_FEATURE, act_func = 'tanh', drop_out = 0, n_outcomes = 7, dim_out = DIM_OUT)
model2.to(device)

# Load the checkpoint
#checkpoint = torch.load(ckpt_dir + 'checkpoint-best.pth')
checkpoint = torch.load(ckpt_dir + 'checkpoint_epoch99.pth')

# Load the state_dict into the model
model2.load_state_dict(checkpoint['model'])
    
y_pred_tasks_test, y_predprob_task_test, y_true_task_test = predict(model2, criterion, test_loader, device, conf, 'Test')
pred_df_list = []
perf_df_list = []
for i in range(conf.n_task):
    pred_df, perf_df = get_performance(y_predprob_task_test[i], y_true_task_test[i], test_ids, ALL_LABELS[i],THRES = round(np.quantile(y_predprob_task_test[i],0.8),2))
    pred_df_list.append(pred_df)
    perf_df_list.append(perf_df)

all_perd_df = pd.concat(pred_df_list)
all_perf_df = pd.concat(perf_df_list)
print(all_perf_df)

all_perd_df.to_csv(outdir4 + "/n_token" + str(conf.n_token) + "_TEST_pred_df.csv",index = False)
all_perf_df.to_csv(outdir5 + "/n_token" + str(conf.n_token) + "_TEST_perf.csv",index = True)

In [None]:
print(round(all_perf_df['AUC'].mean(),2))