Breast cancer stage prediction from pathological whole slide images with hierarchical image pyramid transformers.
Project developed under the "High Risk Breast Cancer Prediction Contest Phase 2" 
by Nightingale, Association for Health Learning & Inference (AHLI)
and Providence St. Joseph Health

Parts of code were took over and adapted from HIPT library.

https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/utils

Copyright (C) 2023 Zsolt Bedohazi, Andras Biricz, Istvan Csabai

In [None]:
import sys
import os
import random
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
from model_hierarchical_mil_stage3_vit_level1 import HIPT_LGP_FC_STAGE3ONLY, Attn_Net_Gated
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import auc as calc_auc
import matplotlib.pyplot as plt

### Functions and classes adapted from HIPT

In [None]:
def create_balanced_biopsy_subset(labels, minority_class_ratio=0.2, rnd_seed=38):
    # set random seed as given
    np.random.seed(rnd_seed)
    
    # collect selected biopsies that will be in the balanced subset
    test_local_idx = []
    
    # get current class occurences for biopsy
    class_occurence = np.array(list(dict( Counter(labels) ).values()))[ np.argsort(list(dict( Counter(labels) ).keys()))]
    #print(class_occurence)
    
    # calc class weights
    class_weights = ( class_occurence / class_occurence.sum() ).astype(np.float32)
    class_weights_dict = dict( zip( np.arange(class_weights.shape[0]), class_weights ))
    #print(class_weights_dict)
    
    # how many of biopsies to include in the balanced subset
    nr_class_test = int(labels.shape[0]*np.min(class_weights)*minority_class_ratio)

    # collect biopsy indices for the balanced subset
    for s in np.unique(labels): #loop over labelss
        s_idx = np.arange(labels.shape[0])[labels == s]
        rnd_idx = np.random.permutation(s_idx.shape[0])
        test_local_idx.append(s_idx[rnd_idx[:nr_class_test]])

    # aggregate all the balanced subset's indices
    test_idx = np.concatenate(test_local_idx)
    
    random.Random(23).shuffle(test_idx) # shuffle otherwise lables are ordered
    
    # other indices not in balanced set will be the rest
    train_idx = np.arange(labels.shape[0])[~np.in1d(np.arange(labels.shape[0]), test_local_idx)]
    
    return train_idx, test_idx#, label_remaining[]

In [None]:
def give_back_balanced_training_fold( X_current, y_current,
                                      minority_class_ratio=0.5, rnd_seed=12 ):
    
    _, test_idx, = create_balanced_biopsy_subset(y_current,
                                                 minority_class_ratio,
                                                 rnd_seed)
    X_train_balanced = X_current[test_idx]
    y_train_balanced = y_current[test_idx]
    #y_train_balanced_oh = lb.transform(y_train_balanced)
    #print( X_train_balanced.shape, y_train_balanced_oh.shape )
    
    return X_train_balanced, y_train_balanced

### Load all data

### Generate val set

In [None]:
class CollectionsDataset(Dataset):
    def __init__(self,
                 data,
                 labels,
                 num_classes, 
                 transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)

        return image, np.expand_dims(label,0)

In [None]:
# Testing the dataloader

In [None]:
# not modified
def train_loop(cur, X_train_all, y_train_all, X_val_all, y_val_all, results_dir, num_epochs, model, n_classes, loss_fn=None, gc=32):  
            
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    
    print('\nInit optimizer ...', end=' ')
    
    lr = 8e-5 #9e-5
    weight_decay = 1e-6
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=0)
    #optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, momentum=0.9, weight_decay=1e-5)  # 1e-4, 1e-5
    print('Done!')
    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15,20,25], gamma=0.8)
    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6,9], gamma=0.36)
    
    train_loss_all_epoch = []
    val_loss_all_epoch = []
    val_auc_all_epoch = []
    
    # training loop with balanced folds
    for epoch in range(0, num_epochs):
        #my_lr = scheduler.get_last_lr()
        #print('Learning rate:', my_lr)
        acc_logger = Accuracy_Logger(n_classes=n_classes)
        
        # genearet a balanced train set
        X_train, y_train = give_back_balanced_training_fold(X_train_all, y_train_all, minority_class_ratio=0.5, rnd_seed=int(epoch*1.5+3*epoch))
        

        
        train_dataset = CollectionsDataset(data=X_train,
                                   labels=y_train,
                                   num_classes=5,
                                   transform=None)        


        # create the pytorch data loader
        train_dataset_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=2)
        
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        
        model.train()
        train_loss = 0.
        train_error = 0.

        print('\n')
        
        # Iterate over data.
        for bi, (data, label) in enumerate(train_dataset_loader):
            
            inputs = data
            label = label.squeeze(0)
            data = inputs.to(device, dtype=torch.float, non_blocking=True)
            label = label.to(device, dtype=torch.long, non_blocking=True)
        

            logits, Y_prob, Y_hat, _, _ = model(data)
            #logits, Y_prob, Y_hat, _, _ = model(x_path=data)
            acc_logger.log(Y_hat, label)


            loss = loss_fn(logits, label)
            loss_value = loss.item()

            train_loss += loss_value

            #if (bi + 1) % 20 == 0:
            #print('batch {}, loss: {:.4f}, label: {}, bag_size: {}'.format(bi, loss_value, label.item(), data.size(0)))

            error = calculate_error(Y_hat, label)
            train_error += error

            loss = loss / gc
            loss.backward()

            # step
            optimizer.step()
            optimizer.zero_grad()
            
        #scheduler.step()            

        # calculate loss and error for epoch
        train_loss /= len(train_dataset_loader)
        train_error /= len(train_dataset_loader)

        print('\nEpoch: {}, train_loss: {:.4f}, train_error: {:.4f}'.format(epoch, train_loss, train_error))
        for i in range(n_classes):
            acc, correct, count = acc_logger.get_summary(i)
            print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
            
            
        # VALIDATION
        val_dataset = CollectionsDataset(data=X_val_all,
                                       labels=y_val_all,
                                       num_classes=5,
                                       transform=None)
        
        val_dataset_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2)
        
        stop, val_loss_epoch, auc_epoch, auc_separated = validate(model, val_dataset_loader, n_classes, loss_fn, results_dir)
        
        os.makedirs(results_dir + f"cv_{cur}/", exist_ok=True)
        torch.save(model.state_dict(), os.path.join(results_dir, f"cv_{cur}", 
                                                    f"trainloss_{np.round(train_loss,3)}_valloss_{np.round(val_loss_epoch,3)}_auc_{np.round(auc_epoch,3)}_"\
                                                    +'_'.join(auc_separated)+"_checkpoint.pt"))
        
        train_loss_all_epoch.append(train_loss)
        val_loss_all_epoch.append(val_loss_epoch)
        val_auc_all_epoch.append(auc_epoch)
        
        
    # Save training parameters to disk    
    param_dict = {'num_epochs': num_epochs,
                  'lr': lr,
                  'weight_decay': weight_decay,
                  'train_loss_all_epoch': train_loss_all_epoch,
                  'val_loss_all_epoch': val_loss_all_epoch,
                  'val_auc_all_epoch': val_auc_all_epoch}
 
    return param_dict

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve

def plot_roc(y_true, y_pred):
    if y_pred.shape != y_true.shape:
        # try to one-hot encode y_true
        y_true = F.one_hot(torch.from_numpy(y_true).to(torch.int64), 5)

    plt.figure(figsize=(6, 6))

    for class_ind in range(y_pred.shape[-1]):
        fpr, tpr, _ = roc_curve(y_true[:, class_ind], y_pred[:, class_ind])
        auc = roc_auc_score(y_true[:, class_ind], y_pred[:, class_ind])
        plt.plot(fpr, tpr, '-', label='AUC : %.3f, label : %d' % (auc, class_ind))
    plt.legend()
    plt.show()

In [None]:
def validate(model, loader, n_classes, loss_fn = None, results_dir=None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    # loader.dataset.update_mode(True)
    val_loss = 0.
    val_error = 0.
    
    prob = np.zeros((len(loader), n_classes))
    labels = np.zeros(len(loader))

    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):

            data, label = batch
            label = label.squeeze(0)
            data = data.to(device, dtype=torch.float, non_blocking=True)
            label =  label.to(device, dtype=torch.long, non_blocking=True)
            
            logits, Y_prob, Y_hat, _, _ = model(data)
            #logits, Y_prob, Y_hat, _, _ = model(x_path=data)
            acc_logger.log(Y_hat, label)
            
            loss = loss_fn(logits, label)

            prob[batch_idx] = Y_prob.cpu().numpy()
            labels[batch_idx] = label.item()
            
            val_loss += loss.item()
            error = calculate_error(Y_hat, label)
            val_error += error
            

    val_error /= len(loader)
    val_loss /= len(loader)

    if n_classes == 2:
        auc = roc_auc_score(labels, prob[:, 1])

    else:
        auc = roc_auc_score(labels, prob, multi_class='ovr')
        
        auc_separated = []
        labels_oh = F.one_hot(torch.from_numpy(labels).to(torch.int64), 5)
        for class_ind in range(prob.shape[-1]):
            fpr, tpr, _ = roc_curve(labels_oh[:, class_ind], prob[:, class_ind])
            auc_current = np.round( roc_auc_score(labels_oh[:, class_ind], prob[:, class_ind]), 3 )
            auc_separated.append(str(auc_current))
        

    print('\nVal Set, val_loss: {:.4f}, val_error: {:.4f}, auc: {:.4f}'.format(val_loss, val_error, auc))
    
    for i in range(n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
        
    #print('Max probs: ', np.max(prob, axis=0))
    #print('Min probs: ', np.min(prob, axis=0))
        
    # print roc vurve
    print(labels.shape, prob.shape)
    plot_roc(labels, prob)


    return False, val_loss, auc, auc_separated

In [None]:
class Accuracy_Logger(object):
    """Accuracy logger"""
    def __init__(self, n_classes):
        super(Accuracy_Logger, self).__init__()
        self.n_classes = n_classes
        self.initialize()

    def initialize(self):
        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
    
    def log(self, Y_hat, Y):
        Y_hat = int(Y_hat)
        Y = int(Y)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)
    
    def log_batch(self, Y_hat, Y):
        Y_hat = np.array(Y_hat).astype(int)
        Y = np.array(Y).astype(int)
        for label_class in np.unique(Y):
            cls_mask = Y == label_class
            self.data[label_class]["count"] += cls_mask.sum()
            self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
    
    def get_summary(self, c):
        count = self.data[c]["count"] 
        correct = self.data[c]["correct"]
        
        if count == 0: 
            acc = None
        else:
            acc = float(correct) / count
        
        return acc, correct, count

In [None]:
def calculate_error(Y_hat, Y):
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()

    return error

### Training - test with CV

In [None]:
def seed_torch(seed=7):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
#biopsy_embeddings_folder = '/home/ngsci/project/nightingale_breast/Preprocessing/hipt_stage3_biopsy_bag_inputs/' #(files with array of shape Mx192)
biopsy_embeddings_folder = '/home/ngsci/project/nightingale_breast_working_development_directory/Preprocessing/biopsy_embeddings/biopsy_bag_vit_xs_embeddings_two-nightingale-finetuned-vits_level1/'

file_paths = np.array(sorted( [biopsy_embeddings_folder + i for i in os.listdir(biopsy_embeddings_folder) if '.npz' in i ]))

In [None]:
def read_cv_data(data_df):
    embeddings_all = []
    labels_all = []
        
    file_paths = [ biopsy_embeddings_folder + filename + '.npz' for filename in data_df.biopsy_id.values]
    
    for file_path in file_paths:
        data = np.load(file_path)
        embeddings = data['embedding']

        
        if embeddings.shape[0] >= 15000:
            np.random.seed(23)
            rand_idx = np.random.permutation(embeddings.shape[0])
            
            print(f"Embedding>15k, subsampling...: {embeddings.shape[0]}, label: {data['label']}")
            embeddings = embeddings[rand_idx[:15000]]
        

        # skip empty files:
        if embeddings.size == 0:
            print(f"Skipping empty file: {files[p]}")
            
        else:

            labels = data['label']

            embeddings_all.append(embeddings)
            labels_all.append(labels)
            
    return np.array(embeddings_all, dtype=object), np.array(labels_all)

In [None]:
folds_nr = 5

In [None]:
for i in range(folds_nr):
    #val_loss_all_epoch = []
    #val_auc_all_epoch = []

    print(f'\n ############################ CV-Fold {i} - Balanced training ############################')
    seed_torch()
    
    print('\nInit loss function...', end=' ')
    loss_fn = nn.CrossEntropyLoss()
    print('Done!')
    
    print('\nInit Model...', end=' ')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HIPT_LGP_FC_STAGE3ONLY().to(device)
    print('Done!')

    train_df = pd.read_csv(f'cv_splits_multi_stratified/train_split_multi_stratified_{i}.csv')
    val_df = pd.read_csv(f'cv_splits_multi_stratified/val_split_multi_stratified_{i}.csv')
    
    X_train_all, y_train_all = read_cv_data(train_df)
    X_val_all, y_val_all = read_cv_data(val_df)
    
    
    n_classes=5
    results_dir='./runs/two-nightingale-finetuned-vits_level1/checkpoints_multi_strat_with_test_set_run7/'
    os.makedirs(results_dir, exist_ok=True)
    
    num_epochs = 70        
        
    param_dict = train_loop(i, X_train_all, y_train_all, X_val_all, y_val_all, results_dir, num_epochs, model, n_classes, loss_fn, gc=32)
    
    #val_loss_all_epoch.append(param_dict['val_loss_all_epoch'])
    #val_auc_all_epoch.append(param_dict['val_auc_all_epoch'])
    

    json_data = {'num_epochs': param_dict['num_epochs'],
                 'lr': param_dict['lr'],
                 'weight_decay': param_dict['weight_decay'],
                 'train_loss_all_epoch_cv': param_dict['train_loss_all_epoch'],
                 'val_loss_all_epoch_all_cv': param_dict['val_loss_all_epoch'],
                 'val_auc_all_epoch_all_cv': param_dict['val_auc_all_epoch'],
                 'min_val_loss': np.min(param_dict['val_loss_all_epoch']),
                 'max_val_auc': np.max(param_dict['val_auc_all_epoch'])}

    # Save training parameters to disk    
    with open(results_dir + f"cv_{i}/" 'test_params.json', 'w') as file:
        json.dump(json_data, file)
        
    """    
    # plot curves
    plt.figure(figsize=(6,6))
    plt.plot(json_data['train_loss_all_epoch_cv'], "-b", label="train_loss")
    plt.plot(json_data['val_loss_all_epoch_all_cv'], "-r", label="val_loss")
    plt.legend(loc="upper right")

    plt.figure(figsize=(6,6))
    plt.plot(json_data['val_auc_all_epoch_all_cv'], "-g", label="val_auc")
    plt.legend(loc="upper right")
    """
    
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    ax1.plot(json_data['train_loss_all_epoch_cv'], "-b", label="train_loss")
    ax1.plot(json_data['val_loss_all_epoch_all_cv'], "-r", label="val_loss")
    ax1.set_xlabel('epoch #')
    ax1.legend()
    
    ax2.plot(json_data['val_auc_all_epoch_all_cv'], "-g", label="val_auc")
    ax2.set_xlabel('epoch #')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    #break

## SAVE CODES WITH CHECKPOINTS

In [None]:
!zip -r ./runs/two-nightingale-finetuned-vits_level1/checkpoints_multi_strat_with_test_set_run7_codes.zip ./hipt_stage3_training_balanced_folds_cross_val_vit_level1.ipynb  ./model_hierarchical_mil_stage3_vit_level1.py