This tutorial focuses on how to make MSI predictions at both slide-level and tile-level with the trained MSI-SEER models. We provide the model weights trained using Yonsei-1 for colorectal cancer slides and using the combined data (TCGA-STAD and Yonsei-Classic) for gastric cancer. The Python script file "script_msi_seer_train_inference.py" contains the codes that train the MSI-SEER model using the image features computed by the pre-trained deep learning models. We have included two example slides that were used to generate Figure 4 (the tile-level prediction heatmaps) in this repository so that readers can understand how our model works on whole slide images and check the reproducibility of our model. Please note that MSI-SEER is based on Monte Carlo sampling to implement both training and testing procedures, so there may be some randomness in prediction results. However, prediction results should not be too different from those reported in the paper.

In [49]:
#- load the packages 
import numpy as np
import os

from models.wsl_binary_classifier_fea_vi_agg_ensemble import wsl_classifier as dgp_rf_agg_ens
from models.wsl_binary_classifier_fea_vi_agg_ensemble import cal_unc_quntities

from models.utils import get_MSI_data_with_tileinfo, data_feature_mat, seed_everything

import copy

#- fix the random seed
seed_everything(11111)

N_MSIDETECT_models = 9    

#- options for MSI-SEER
from types import SimpleNamespace   
setting = {'batch_size': 8,
    'max_epoch': 100,
    'minimum_epoch': 75,
    'sub_Ni': 300,
    'sel_instances': 'sampling',
    'nMCsamples': 10 ,   
    'alpha': 0.5,
    'n_layers': 6,
    'n_RFs': 100,
    'ker_type': 'arccosin',
    'exp_ensemble': True,
    'n_runs': 10,
    'flag_mean_function': True, 
    'iter_print': True,
    'save_directory': './model_weights',
    'w_pos': 1.0,
    'model_ref_path': './model_weights/prior_means_fromMSIDETECT/'
}
setting = SimpleNamespace(**setting)

In [None]:
#- Codes for making both tile-level and tile-level predictions on the selected examples
PATH_FEATURES = './sample_wsi_features'

str_save_path = './results/predictions/examples'
os.makedirs(str_save_path, exist_ok=True)

str_trn_data = 'GC_Combined'

Yest_list = [] 
for ith_run in range(N_MSIDETECT_models):
    str_model = 'MODEL_' + str(ith_run)  

    str_save_each_model_path =  os.path.join(str_save_path, str_model)
    if not os.path.exists(str_save_each_model_path):
        os.makedirs(str_save_each_model_path)   

    # data (image features) loading
    img_names_tst, img_tilenames_tst, X_tst, _, _ \
            = get_MSI_data_with_tileinfo(os.path.join(PATH_FEATURES, str_model))                                    
        
    if ith_run == 0:
        img_names_tst_ref = copy.deepcopy(img_names_tst)
    else:   
        idx_matched = [img_names_tst.index(elm) for elm in img_names_tst_ref]
        X_tst = [X_tst[idx] for idx in idx_matched]
        img_tilenames_tst = [img_tilenames_tst[idx] for idx in idx_matched]
    
    N_total = len(img_names_tst_ref)
         
    X_tst_cs = data_feature_mat \
    (sample_ids=img_names_tst_ref, tile_names=img_tilenames_tst, data_mat=X_tst)

    # create the MSI model   
    str_trn_model_name = str_trn_data + '/MODEL_' + str(ith_run)
          
    model_sepbest = dgp_rf_agg_ens(X_tst_cs, None, setting, str_trndata=str_trn_model_name)
    
    #- predictions 
    Ytst_probs_V1, _ = model_sepbest.predict\
                    (np.arange(len(img_names_tst)), data_set_=X_tst_cs, save_pred_path=str_save_each_model_path)
                    
    Yest_list.append(Ytst_probs_V1)
    
# inference in ensemble learning
Yest_all = []
for mn_sub in range(N_total):
    Yest_cur = []
    for ith_run in range(N_MSIDETECT_models):  # N_MSIDETECT_models
        Yest_cur.append(Yest_list[ith_run][mn_sub])

    Yest_all.append(np.hstack(Yest_cur))

# calculate the aggregated MSI prediction probabilities and Bayesian confidence scores 
[Ytst_mean, unc_aleat, unc_epist] = cal_unc_quntities(Yest_all)
BCS = 1 - 2*np.sqrt(unc_aleat + unc_epist)               

In [None]:
import numpy as np
import os
from PIL import Image 

N_MSIDETECT_models = 9  

import matplotlib.pyplot as plt
import cblind as cb

from models.utils import fill_heatmap_grid_upscale
from scipy.ndimage.filters import gaussian_filter

img_list = ['Item_1', 'Item_2']

str_save_path = './results/predictions/examples/'

sigmoid_np = lambda x: 1.0 / (1.0 + np.exp(-x))

for img in img_list: 
    #- calculate the MSI-prediction probabilities at tile-level               
    for ith_run in range(N_MSIDETECT_models):
        resolved_mat_path = os.path.join(str_save_path, f'MODEL_{ith_run}',  img + '_patch_maps.npz')

        patch_mat = np.load(resolved_mat_path, allow_pickle=True)

        if ith_run ==0:
            patch_ids = np.reshape(patch_mat['arr_0'], [-1, 1])
            patch_Fout = patch_mat['arr_1']
        else:
            idx_matched = np.reshape(np.array([np.where(patch_mat['arr_0']==elm)[0] for elm in patch_ids]), [-1])
            assert np.sum(idx_matched == np.array(range(len(idx_matched)))) == len(idx_matched)

            patch_Fout = np.concatenate((patch_Fout,  patch_mat['arr_1']), axis=1)
    
    patch_preds = np.mean(sigmoid_np(patch_Fout), axis=1, keepdims=True) 
    
    # Get WSI details
    wsi_width, wsi_height = np.load(f'./sample_wsi_features/{img}_wsiinfo.npy')

    # Get thumbnail
    thumbnail = Image.open(f'./sample_wsi_features/{img}_thumb.png')
   
    #- generate the heatmap  
    n = int(patch_ids[0][0].split("_dy")[-1].split(patch_ids[0][0][-4:])[0])            
                    
    width, height = thumbnail.size[0], thumbnail.size[1]             

    grid_dims = (np.array((wsi_width, wsi_height)) / n).astype(np.int32)

    heatmap_image = fill_heatmap_grid_upscale(\
        [wsi_width, wsi_height], n, [width, height], patch_ids, patch_preds)                                           

    heatmap_image_np = np.array(heatmap_image)                                             

    #- tile-level prediction heatmap
    plt.figure()

    plt.imshow(thumbnail)
            
    scale_factor_w = np.sqrt(width/grid_dims[1])
    filtered_arr = gaussian_filter(heatmap_image_np, sigma=scale_factor_w)
            
    plt.imshow(np.ma.array(filtered_arr, mask=heatmap_image_np==0), \
        interpolation='hanning', alpha=0.4, cmap=cb.cbmap("cb.solstice", nbin=20))
            
    cbar = plt.colorbar()
    cbar.set_label("MSI-H predictive probability")

    plt.tight_layout(w_pad=0)
    plt.axis('off')

    plt.show()

    #- histogram of the MSI prediction probabilities in the slide
    fig, ax = plt.subplots()

    ax.hist(patch_preds, density=False, bins=50, alpha=0.75)

    plt.xlabel("MSI-H predictive probability")
    plt.show()
    