Breast cancer stage prediction from pathological whole slide images with hierarchical image pyramid transformers.
Project developed under the "High Risk Breast Cancer Prediction Contest Phase 2" 
by Nightingale, Association for Health Learning & Inference (AHLI)
and Providence St. Joseph Health

Copyright (C) 2023 Zsolt Bedohazi, Andras Biricz, Istvan Csabai

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Dataset, DataLoader
#from model_hierarchical_mil_stage3_vit_level1 import HIPT_LGP_FC_STAGE3ONLY, Attn_Net_Gated
from model_hierarchical_mil_stage3_resnet_level0 import HIPT_LGP_FC_STAGE3ONLY, Attn_Net_Gated
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import auc as calc_auc
import random
import glob

import ngsci

## Load models from CV folds

##### resnet

In [None]:
checkpoints_dir_cv_0 = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/runs/runs_before_april26/checkpoints_cv5_balanced_run5/cv_0/'
checkpoints_dir_cv_1 = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/runs/runs_before_april26/checkpoints_cv5_balanced_run5/cv_1/'
checkpoints_dir_cv_2 = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/runs/runs_before_april26/checkpoints_cv5_balanced_run5/cv_2/'
checkpoints_dir_cv_3 = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/runs/runs_before_april26/checkpoints_cv5_balanced_run5/cv_3/'
checkpoints_dir_cv_4 = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/runs/runs_before_april26/checkpoints_cv5_balanced_run5/cv_4/'

##### vit

##### resnet10 fold

scores:

    0.8054431085623077
    0.7990998543885838
    0.7602979518607795
    0.7814769456841988
    0.7797447111810377
    0.8201680607020354
    0.7763702620577697
    0.8226519769535041
    0.7691339792819564
    0.779663379238015

checkpoints_dir_cv_0 = 'project/nightingale_breast_working_development_directory/Preprocessing/runs/nightingale-nofinetuned_resnet50_embeddings_level0/checkpoints_cv5_balanced_run4/cv_0/'
checkpoints_dir_cv_1 = 'project/nightingale_breast_working_development_directory/Preprocessing/runs/nightingale-nofinetuned_resnet50_embeddings_level0/checkpoints_cv5_balanced_run4/cv_1/'
checkpoints_dir_cv_2 = 'project/nightingale_breast_working_development_directory/Preprocessing/runs/nightingale-nofinetuned_resnet50_embeddings_level0/checkpoints_cv5_balanced_run4/cv_2/'
checkpoints_dir_cv_3 = 'project/nightingale_breast_working_development_directory/Preprocessing/runs/nightingale-nofinetuned_resnet50_embeddings_level0/checkpoints_cv5_balanced_run4/cv_3/'
checkpoints_dir_cv_4 = 'project/nightingale_breast_working_development_directory/Preprocessing/runs/nightingale-nofinetuned_resnet50_embeddings_level0/checkpoints_cv5_balanced_run4/cv_4/'

In [None]:
file_names_all_cv = np.array([ np.array( sorted(  glob.glob( os.path.join(eval(f"checkpoints_dir_cv_{i}"), "*.pt"))   )) for i in range(5) ], dtype=object)
file_names_all_cv.shape

In [None]:
file_names_all_cv[0][0]

## Load biopsy bags -> input data

In [None]:
biopsy_bag_input_path = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/biopsy_embeddings/biopsy_bag_vit_xs_embeddings_nightingale-nofinetuned_resnet50_embeddings_level0/'
biopsy_bag_input_path_holdout = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/biopsy_embeddings/biopsy_bag_vit_xs_embeddings_nightingale-nofinetuned_resnet50_embeddings_level0_holdout/'

biopsy_df_local_test = pd.read_csv('cv_splits_stratified_with_test_set_10fold/test_split_stratified.csv')
biopsy_df_local_test.sort_values('biopsy_id', inplace=True)
biopsy_bag_input_files_local_test = np.array( sorted([ biopsy_bag_input_path+i+'.npz' for i in biopsy_df_local_test.biopsy_id.values ]) )
biopsy_bag_input_files_holdout = np.array(sorted([ biopsy_bag_input_path_holdout+k for k in os.listdir(biopsy_bag_input_path_holdout) if '.npz' in k ]))

biopsy_bag_input_files_local_test.shape, biopsy_bag_input_files_holdout.shape

## Look for best models based on AUC or VAL LOSS

In [None]:
best_models_on_val_auc = []

for i in range(file_names_all_cv.shape[0]):
    # only first best models
    #max_auc_index = np.argmax([float(os.path.basename(item).split('_')[3]) for item in file_names_all_cv[i]])
    
    # multiple best models
    select = 10
    max_auc_sort_index = np.argsort([float(os.path.basename(item).split('_')[3]) for item in file_names_all_cv[i]])[::-1]
    
    for m in range(select):
        max_auc_model = file_names_all_cv[i][max_auc_sort_index[m]]
        best_models_on_val_auc.append(max_auc_model)

best_models_on_val_auc = np.array(best_models_on_val_auc)

In [None]:
best_models_on_val_auc[::select]

In [None]:
# best 6
best_models_on_val_loss = []

for i in range(file_names_all_cv.shape[0]):
    # only first best models
    #max_auc_index = np.argmax([float(os.path.basename(item).split('_')[3]) for item in file_names_all_cv[i]])
    
    # multiple best models
    select = 6
    min_loss_sort_index = np.argsort([float(os.path.basename(item).split('_')[3]) for item in file_names_all_cv[i]])
    
    for m in range(select):
        min_loss_model = file_names_all_cv[i][min_loss_sort_index[m]]
        best_models_on_val_loss.append(min_loss_model)

best_models_on_val_loss = np.array(best_models_on_val_loss)

##### best models on val loss

##### best models on val auc

In [None]:
np.mean([float(e.split('_auc_')[1].split('_')[0]) for e in best_models_on_val_auc])

## Predict with HIPT stage 3 ViT with model ensemble

In [None]:
def plot_roc(y_true, y_pred):
    if y_pred.shape != y_true.shape:
        # try to one-hot encode y_true
        y_true = F.one_hot(torch.from_numpy(y_true).to(torch.int64), 5)

    plt.figure(figsize=(6, 6))
    auc_all = []
    for class_ind in range(y_pred.shape[-1]):
        fpr, tpr, _ = roc_curve(y_true[:, class_ind], y_pred[:, class_ind])
        auc = roc_auc_score(y_true[:, class_ind], y_pred[:, class_ind])
        auc_all.append(auc)
        plt.plot(fpr, tpr, '-', label='AUC : %.3f, label : %d' % (auc, class_ind))
    plt.legend()
    plt.show()
    return auc_all

In [None]:
def pred_with_one_model(model, biopsy_bag_input_files):
    
    preds_all = []
    labels_all = []

    for b in tqdm(range(biopsy_bag_input_files.shape[0])):

        with torch.no_grad():
            emb_npy = np.load( biopsy_bag_input_files[b] )['embedding']

            #if emb_npy.shape[0] > 15000:
            #    rand_idx = np.random.permutation(emb_npy.shape[0])
            #    emb_npy = emb_npy[rand_idx[:15000]]

            emb = torch.from_numpy(np.expand_dims(emb_npy, 0).astype(np.float32)).to('cuda:0')
            _, preds, label, _, _ = model(emb)


        preds_all.append(preds.cpu().detach().numpy())
        labels_all.append(label.cpu().numpy())
        
    preds_all = np.concatenate(preds_all)
    labels_all = np.concatenate(labels_all)
    
    return preds_all, labels_all

### Local test set

In [None]:
#best_models_on_val_auc = best_models_on_val_loss # VAL LOSS

nr_models = best_models_on_val_auc.shape[0]

preds_ensemble_local_test = np.zeros((best_models_on_val_auc.shape[0], biopsy_bag_input_files_local_test.shape[0], 5))
labels_ensmble_local_test = np.zeros((best_models_on_val_auc.shape[0], biopsy_bag_input_files_local_test.shape[0], 1))

for m in range(nr_models):
    
    model = HIPT_LGP_FC_STAGE3ONLY() # define model
    best_model_path = best_models_on_val_auc[m] # path of the m th best model
    print(os.path.basename(best_model_path))
    model_state_dict = torch.load(best_model_path, map_location=torch.device('cuda:0')) # load

    model.load_state_dict(model_state_dict) # load weights
    model.eval() # set to eval mode ! 
    model.to('cuda:0')
    
    preds_all, labels_all = pred_with_one_model(model, biopsy_bag_input_files_local_test)
    
    preds_ensemble_local_test[m] = preds_all
    labels_ensmble_local_test[m] = labels_all

In [None]:
preds_ensemble_local_test.shape, labels_ensmble_local_test.shape


#### Simple mean

In [None]:
final_pred_ensemble_local_test = np.mean(preds_ensemble_local_test, axis=0)
final_label_ensemble_local_test = np.argmax(preds_ensemble_local_test, axis=-1)
final_pred_ensemble_local_test.shape, final_label_ensemble_local_test.shape

In [None]:
auc = plot_roc( biopsy_df_local_test.stage.values, final_pred_ensemble_local_test  )
#auc = plot_roc( biopsy_df_local_test.stage.values, preds_ensemble_local_test[3]  )
print( np.mean(auc) )

In [None]:
final_pred_ensemble_local_test[:,4].max()

#### Filtering then mean

In [None]:
def filt_one_sample_colwise( current_sample ):
    mean_sample = np.mean( current_sample, axis=0 )
    std_sample = np.std( current_sample, axis=0 )

    filted_sample = np.zeros(5)
    for s in range(5):
        filt_one_class = np.abs( current_sample[:,s] - mean_sample[s] ) < 1.5*std_sample[s]
        filted_sample[s] = np.mean( current_sample[:, s][filt_one_class])
    
    # THIS CANNOT BE DONE -> so low probs for class 4 ! -> would filter those out
    #noise_filt = filted_sample < 0.01
    #filted_sample[noise_filt] = 0.0

    filted_sample = filted_sample / np.sum(filted_sample)
    
    return filted_sample

In [None]:
preds_ensemble_local_test_corr = np.array( [ filt_one_sample_colwise( preds_ensemble_local_test[:,q] ) for q in range(preds_ensemble_local_test.shape[1]) ] )
final_pred_ensemble_local_test_corr = preds_ensemble_local_test_corr

In [None]:
auc = plot_roc( biopsy_df_local_test.stage.values, final_pred_ensemble_local_test_corr  )
print( np.mean(auc) )

In [None]:
def filt_one_sample_all( current_sample ):
    mean_sample = np.mean( current_sample, axis=0 )
    dist_from_mean = np.sqrt( np.sum( (current_sample - mean_sample )**2, 1) )
    #plt.hist(dist_from_mean) ## for testing
    filt = dist_from_mean < np.percentile( dist_from_mean, 10 )
    filted_sample = np.mean( current_sample[filt], axis=0 )

    filted_sample = filted_sample / np.sum(filted_sample)
    
    return filted_sample

In [None]:
preds_ensemble_local_test_corr = np.array( [ filt_one_sample_all( preds_ensemble_local_test[:,q] ) for q in range(preds_ensemble_local_test.shape[1]) ] )
final_pred_ensemble_local_test_corr = preds_ensemble_local_test_corr

In [None]:
auc = plot_roc( biopsy_df_local_test.stage.values, final_pred_ensemble_local_test_corr  )
print( np.mean(auc) )

In [None]:
def filt_one_sample_all_upgraded( current_sample ):
    mean_sample = np.mean( current_sample, axis=0 )
    dist_from_mean = np.sqrt( np.sum( (current_sample - mean_sample )**2, 1) )
    #plt.hist(dist_from_mean, bins=20) ## for testing
    #filt = dist_from_mean < np.percentile( dist_from_mean, 20 )
    #idx = np.argmin( np.diff( [ np.percentile( dist_from_mean, p ) for p in range(0,100,1)  ] ) )
    filt = np.logical_and( dist_from_mean > np.percentile( dist_from_mean, 20 ), dist_from_mean < np.percentile( dist_from_mean, 80 ) )
    #filt = np.logical_and( dist_from_mean > np.percentile( dist_from_mean, idx ), dist_from_mean < np.percentile( dist_from_mean, idx ) )
    filted_sample = np.mean( current_sample[filt], axis=0 )

    filted_sample = filted_sample / np.sum(filted_sample)
    
    return filted_sample

In [None]:
preds_ensemble_local_test_corr = np.array( [ filt_one_sample_all_upgraded( preds_ensemble_local_test[:,q] ) for q in range(preds_ensemble_local_test.shape[1]) ] )
final_pred_ensemble_local_test_corr = preds_ensemble_local_test_corr

In [None]:
auc = plot_roc( biopsy_df_local_test.stage.values, final_pred_ensemble_local_test_corr  )
print( np.mean(auc) )

### HOLDOUT set for leaderboard

In [None]:
nr_models = best_models_on_val_auc.shape[0]

preds_ensemble_holdout = np.zeros((best_models_on_val_auc.shape[0], biopsy_bag_input_files_holdout.shape[0], 5))
labels_ensmble_holdout = np.zeros((best_models_on_val_auc.shape[0], biopsy_bag_input_files_holdout.shape[0], 1))

for m in range(nr_models):
    
    model = HIPT_LGP_FC_STAGE3ONLY() # define model
    best_model_path = best_models_on_val_auc[m] # path of the m th best model
    model_state_dict = torch.load(best_model_path, map_location=torch.device('cuda:0')) # load

    model.load_state_dict(model_state_dict) # load weights
    model.eval() # set to eval mode ! 
    model.to('cuda:0')
    
    preds_all, labels_all = pred_with_one_model(model, biopsy_bag_input_files_holdout)
    
    preds_ensemble_holdout[m] = preds_all
    labels_ensmble_holdout[m] = labels_all

In [None]:
preds_ensemble_holdout.shape

### Simple mean

In [None]:
final_pred_ensemble = np.mean(preds_ensemble_holdout, axis=0)
final_label_ensemble = np.argmax(final_pred_ensemble, axis=-1)
final_pred_ensemble.shape, final_label_ensemble.shape

In [None]:
final_pred_ensemble[:5]

### Corrigate with models around the mean of 100

In [None]:
preds_ensemble_holdout_corr = np.array( [ filt_one_sample_all( preds_ensemble_holdout[:,q] ) for q in range(preds_ensemble_holdout.shape[1]) ] )
final_pred_ensemble_holdout_corr = preds_ensemble_holdout_corr
final_pred_ensemble_holdout_corr_labels = np.argmax(final_pred_ensemble_holdout_corr, axis=1).reshape(-1,1)

In [None]:
final_pred_ensemble_holdout_corr[:10], final_pred_ensemble_holdout_corr_labels[:10]

In [None]:
preds_ensemble_holdout_corr = np.array( [ filt_one_sample_all_upgraded( preds_ensemble_holdout[:,q] ) for q in range(preds_ensemble_holdout.shape[1]) ] )
final_pred_ensemble_holdout_corr = preds_ensemble_holdout_corr
final_pred_ensemble_holdout_corr_labels = np.argmax(final_pred_ensemble_holdout_corr, axis=1).reshape(-1,1)

In [None]:
final_pred_ensemble_holdout_corr[:10], final_pred_ensemble_holdout_corr_labels[:10]

In [None]:
pred_csv = pd.DataFrame(np.concatenate((np.array([os.path.basename(f).split('.npz')[0] for f in biopsy_bag_input_files_holdout]).reshape(-1,1), final_pred_ensemble_holdout_corr, final_pred_ensemble_holdout_corr_labels), axis=1), columns=None)
pred_csv.columns = ['' for i in range(pred_csv.shape[1])]

In [None]:
pred_csv.head()