# Training a MIL model for cancer detection in Whole Slide Images

One of the applications of Multiple Instance Learning is the detection cancerous tissue in Whole Slide Images (WSIs). A WSis is obtained by "digitally  converting a tissue on a glass slide into a high-resolution virtual slide" ([NiH](https://pmc.ncbi.nlm.nih.gov/articles/PMC7522141/)). This slides are the size of gigapixels, which makes their computational treatment unfeasible. A possible solution for this problem is to divide the slide in <tt>patches</tt> and treat the patches individually. Thus, we may see a WSI as a bag, where each patch is an instance, which sets us up in the Multiple Instance Learning scenario.  

In the following, we explain how to train a Multiple Instance Learning (MIL) model to detect cancerous WSIs and localize the cancerous tissue using the <tt>torchmil</tt> library.

In [None]:
import sys
sys.path.append('../../torchmil/')
import torch


## The dataset

!!! example "MIL binary classification"
    In this case, the bags have the form $\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}$, where each $\mathbf{x}_n \in \mathbb{R}^D$ is an instance. 
    The labels of the instances are $\mathbf{y} = \left[ y_1, \ldots, y_N \right]^\top \in \{0, 1\}^N$, but we do not have access to them at training time (they may be accessible at test time). The label of the bag is $Y \in \{0, 1\}$, and the relation between the instance labels and the bag label is as follows:

    $$ Y = \max \left\{ y_1, \ldots, y_N \right\} $$

    This example is the most common in MIL, but there are many other possibilities. 

For this tutorial, we will use the Camelyon16 dataset, a public dataset for the detection of breast cancer metastasis. The [original version](https://camelyon17.grand-challenge.org/Data/) of this dataset has been processed to be used for MIL binary classification problems. It can be downloaded from [here](https://huggingface.co/datasets/Franblueee/Camelyon16_MIL/). 


In practice, training a MIL model directly on the slices is computationally intractable. Due to this limitation, MIL models usually operate on pre-computed features extracted from each of the instances. Although <tt>torchmil</tt> allows to define models that receive the original slices as input, in this tutorial we will use the pre-computed features since it is the most common form of alleviating this computational barrier. We have processed the CAMELYON16 dataset to be used for MIL binary classification problems. It can be downloaded from [here](https://huggingface.co/datasets/Franblueee/Camelyon16_MIL/). To facilitate your MIL tasks, we have extracted the features of all the patches using <tt>Resnet50</tt> trained with self-supervised learning ([Barlow Twins method](https://arxiv.org/abs/2103.03230)) and the **foundation model** <tt>UNI</tt> (check this very cool model [here](https://huggingface.co/MahmoodLab/UNI)).

We now make use of `torchmil.datasets.CAMELYON16MILDataset` to create an object that serves as a `torch.utils.data.Dataset` dataset and contains RSNA. You only need to provide the `root` to the processed dataset, the `patch_size`, the desired feature extractor (`features`) and the `partition`. Also, when using <tt>WSI</tt> datasets, you may be interesed in the relative positioning of the patches. While loading the dataset, <tt>torchmil</tt> loads the adjacency matrix and offers two additional parameters that may be used to refine that matrix: `adj_with_dist`, which builds the adjacency matrix is built using the Euclidean distance between the patches features, and `norm_adj` which normalizes the adjacency matrix. For this example we will omit the use of the adjacency matrix.

See how simple is to instance the train dataset, using the <tt>UNI</tt> features and patches of size $512 \times 512$:

In [3]:
from torchmil.datasets import CAMELYON16MILDataset
from sklearn.model_selection import train_test_split

root = '/data/datasets/CAMELYON16/'
features = 'UNI'
patch_size = 512

dataset        = CAMELYON16MILDataset(  root        = root,
                                        features    = features,
                                        patch_size  = patch_size,
                                        partition   = 'train')
            
# Split the dataset into train and validation sets
bag_labels = dataset.get_bag_labels()
idx = list(range(len(bag_labels)))
val_prop = 0.2
idx_train, idx_val = train_test_split(idx, test_size=val_prop, random_state=1234, stratify=bag_labels)
train_dataset = dataset.subset(idx_train)
val_dataset = dataset.subset(idx_val)

test_dataset   = CAMELYON16MILDataset( root        = root,
                                       features    = features,
                                       partition   = 'test')

In <tt>torchmil</tt>, each bag is a `TensorDict`. The different keys correspond to different elements of the bag. In this case, each bag has a feature matrix `X`, the bag label `Y`, and the instance labels `y_inst`. A patch is considered positive if at least $50\%$ of its pixels are cancerous. Recall that the instance labels cannot be used during training, they are available only for evaluation purposes.

## Mini-batching of bags

Tipically, the bags in a MIL dataset have different size. This can be a problem when creating mini-batches. To solve this, we use the function `collate_fn` from the [<tt><b>torchmil.data</b></tt>](../api/data/index.md) module. This function creates a mini-batch of bags by padding the bags with zeros to the size of the largest bag in the batch. The function also returns a mask tensor that indicates which instances are real and which are padding.

!!! question "Why not use [`torch.nested`](https://pytorch.org/docs/stable/nested.html)?"
    `torch.nested` offer a more flexible method for handling bags of varying sizes. However, since the PyTorch API for nested tensors is still in the prototype stage, <tt><b>torchmil</b></tt> currently relies on the padding approach.

Let's create the dataloaders and visualize the shape of a mini-batch. When using a patch size of $512 \times 512$, some of the bags in CAMELYON16 produce more than $20.000$ instances. Because of this, we need to use a small `batch_size` to be able to fit it in standard <tt>GPUs</tt>.

In [4]:
from torchmil.data import collate_fn

batch_size = 4

# Create dataloaders
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


it = iter(train_dataloader)
batch = next(it)
data_shape = (batch['X'].shape[-1], )
print("Batch X shape: ", batch['X'].shape)
print("Batch Y shape: ", batch['Y'].shape)
print("Batch y_inst shape: ", batch['y_inst'].shape)
print("Batch mask shape: ", batch['mask'].shape)


Batch X shape:  torch.Size([4, 29946, 1024])
Batch Y shape:  torch.Size([4])
Batch y_inst shape:  torch.Size([4, 29946])
Batch mask shape:  torch.Size([4, 29946])


Each batch is again a `TensorDict` with an additional key `mask` that indicates which instances are real and which are padding. As we can see, the bags are padded to the maximum size of the bags in the batch with zeros. The mask tensor indicates which elements are real instances and which are padding. The function `collate_fn` also pads other tensors, such as the adjacency matrix or the instance coordinates. 

## Training a model in CAMELYON16

We have shown how to load the CAMELYON16 dataset for the binary classification task. Now, let us train a MIL model in this dataset! For this example, we will use <tt>torchmil</tt> implementation of a [Transformer ABMIL](../api/models/transformer_abmil.md), a version of [ABMIL](../api/models/abmil.md) where a [Transformer Encoder](../api/nn/transformers/conventional_transformer.md) is applied to refine the instances before the [Attention Pool](../api/nn/attention/attention_pool.md). To highlight how simple is to instance a model in <tt>torchmil</tt>, we will leave all the parameters by default except for the `in_shape`, which reflects the data shape. Feel free to check the [documentation of Transformer ABMIL](../api/models/transformer_abmil.md) to observe the different parameters that this model can be passed.

In [5]:
from torchmil.models import TransformerABMIL
model = TransformerABMIL( in_shape = data_shape)

See? It can not be easier! Now, let's train the model. <tt>torchmil</tt> offers an easy-to-use trainer class located in `torchmil.utils.trainer.Trainer` that provides a generic training for any MIL model. Also, it will show the evolution of the losses and the desired metrics during the epochs.

!!! note
    This `Trainer` gives the flexibility to log the results using any wrapped `logger`, use annealing for the loss functions via the `annealing_scheduler_dict` dictionary, or to set a learning rate scheduler using the parameter `lr_scheduler`. Also, you can follow multiple metrics during the training thanks to the parameter `metrics_dict` and the integration with the <tt>torchmetrics</tt> package.

For now, let us just keep it simple and perform a simple training using the `torch.optim.Adam` optimizer. When using the features from the <tt>UNI</tt> model, <tt>torchmil</tt> models obtain very good results in just a few epochs. We will train the model for only 2 epochs. First, we instance the trainer and then we train the model. We hide the output of the training using `verbose = False` and we disable progress bars using `disable_pbar = True`.

!!! note
    Transformers are computationally expensive models, so training this model for two epochs takes approximately `20` minutes in our <tt>CPUs</tt>.

In [6]:
from torchmil.utils.trainer import Trainer
import torchmetrics
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
device = 'cpu'

trainer = Trainer(  model        = model,
                    optimizer    = optimizer,
                    metrics_dict = {'auroc' : torchmetrics.AUROC(task='binary').to(device), 'acc': torchmetrics.Accuracy(task='binary').to(device)},
                    obj_metric   = 'acc',
                    device       = device,
                    disable_pbar = True,
                    verbose      = False, )


In [7]:
EPOCHS = 2
trainer.train( max_epochs       = EPOCHS,
               train_dataloader = train_dataloader,
               val_dataloader   = val_dataloader,
               test_dataloader  = test_dataloader)



!!! question "Why is the first epoch much slower?"
    If you set `disable_pbar = False`, you may observe that the first epoch is much slower than the rest of them. The first time the dataloader indexes a bag, that bag is loaded from a <tt>.npy</tt> file. Thus, if the data is not used before the training, during the first epoch the bags are loaded from the hard drive to the computer's memory, causing a delay in the training.


## Evaluating the model

Let's evaluate the model. We are going to compute the accuracy and f1-score on the test set. The accuracy is the proportion of correctly classified bags, while the f1-score is the harmonic mean of precision and recall. The f1-score is a good metric for imbalanced datasets.
Typically, in MIL datasets, there are many more negative instances than positive instances. In this case, the f1-score will be very useful.

First, we define some auxiliary functions.

In [8]:
def accuracy(pred, y):
    return (pred == y).sum().item() / len(y)

def f1_score(pred, y):
    tp = ((pred == 1) & (y == 1)).sum().item()
    fp = ((pred == 1) & (y == 0)).sum().item()
    fn = ((pred == 0) & (y == 1)).sum().item()
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    return f1

Now, we compute we can easily obtain the model's performance in the test set:

In [9]:
inst_pred_list = []
y_inst_list = []
Y_pred_list = []
Y_list = []

model.eval()

for batch in test_dataloader:
    batch = batch.to(device)
    
    # predict bag label using our model
    out = model(batch['X'], batch['mask'])
    Y_pred = (out > 0).float()

    Y_pred_list.append(Y_pred)
    Y_list.append(batch['Y'])

Y_pred = torch.cat(Y_pred_list)
Y = torch.cat(Y_list)

print(f"test/bag/acc: {accuracy(Y_pred, Y)}")
print(f"test/bag/f1: {f1_score(Y_pred, Y)}")

test/bag/acc: 0.9612403100775194
test/bag/f1: 0.9462360402361285


Excellent! Our model has reached a very high accuracy and f1-score in only 2 epochs! This shows how simple is to obtain very good results in one of the most famous <tt>WSI</tt> classification dataset, such as <tt>Camelyon16</tt>, thanks to <tt>torchmil</tt>!