# Datasets

When implementing a MIL model, the first step is to define a `MILDataset` to apply the model. In this tutorial, we will show how MIL datasets are implemented in <tt>torchmil</tt>.

## Required data and hard drive organization

In a <tt>torchmil</tt> dataset, each bag is composed of four elements:

- `features`: with shape `(bag_size, feature_dim)` with the features of the bag. Usually, the `features` will be the result of applying a feature extractor to the patches of the slide.
- `label`: with shape `()`, containing the bag label.
- `patch_labels`: with shape `(bag_size, )`, containing the label of the instance.
- `coords`: with shape `(bag_size, coords_dim)`, containing the coordinates of each of the patches in the bag. 

!!! note "Not everything is required"
    Even though the implementation manages `features, labels, patch_labels` and `coords`, they are not strictly required to build a dataset. For instance, if the `patch_labels`, the dataset can still be used with the rest of the elements.
    

The first step is to organize the files in our system in a way that is compatible with the implementation of the `ProcessedMILDataset`, which is the base class used to implement new datasets. One folder is required for each of the required elements. Assuming that the patches have the size `patch_size` and that the features were extracted using `feature_extractor`, the following folder structure is required:
```
{dataset_name}
├── patches_{patch_size}
│   ├── features
│   │   ├── features_{feature_extractor}
│   │   │   ├── bag1.npy
│   │   │   ├── bag2.npy
│   │   │   └── ...
│   ├── labels
│   │   ├── bag1.npy
│   │   ├── bag2.npy
│   │   └── ...
│   ├── patch_labels
│   │   ├── bag1.npy
│   │   ├── bag2.npy
│   │   └── ...
│   ├── coords
│   │   ├── bag1.npy
│   │   ├── bag2.npy
│   │   └── ...
```

!!! important
    As the folder structure shows, there must exist one folder for each required type of element (`features, labels, patch_labels, coords`), and one `.npy` file per element. Each `.npy` file must be named as the bag that it represents. 

## The `ProcessedMILDataset` class

The `ProcessedMILDataset` class manages the bag loading and adjacency matrix building (if `coords` are available). To initialize this class only 2 parameters are required: `features_path` and `labels_path`, indicating the path to the respective folders. However, using also `inst_labels_path`, `coords_path` (both indicating paths), and `bag_names` (containing a list of all the bag names in the dataset) is recommended. Another important parameter is `bag_keys`, which indicates which elements the `__get_item__` function will return, defaulting to all the options: `["X", "Y", "y_inst", "adj", "coords"]`.

!!! note "First time loading the data in an execution"
    The `ProcessedMILDataset` class has an extra parameter `load_at_init` which defaults to `True`. With this option, the data is loaded to cache when the dataset is instantiated. Changing this parameter to `load_at_init=False` implies that the data is not read until it is required by the `__get_item__` function.

The adjacency matrix `adj` of each bag is created when each bag is built (i.e., when it is loaded for the first time). Some options can be passed to obtain different versions of `adj`. Feel free to see the options in the documentation of [<tt><b>torchmil.datasets.processed_mil_dataset</b></tt>](../api/datasets/processed_mil_dataset.md).

### Extending `ProcessedMILDataset` funcionalities.

As an additional feature, <tt>torchmil</tt> implements `BinaryClassificationDataset`, a subclass of `ProcessedMILDataset`. This subclass assumes that

$$
\begin{gather}
        Y \in \left\{ 0, 1 \right\}, \quad y_n \in \left\{ 0, 1 \right\}, \quad \forall n \in \left\{ 1, \ldots, N \right\},\\
        Y = \max \left\{ y_1, \ldots, y_N \right\}.
\end{gather}
$$

The functionality of this class extends `ProcessedMILDataset` by adding explicit comprobations to ensure that the conditions in equations are fullfilled. If they are not, a warning is shown on the output stream. Check all the information about this class in [<tt><b>torchmil.datasets.binary_classification_dataset</b></tt>](../api/datasets/binary_classification_dataset.md).


Also, the class `WSIDataset` assumes that the bags are Whole Slide Images (WSIs), and it gives the coordinates of the patches (`coords`) a special treatment, normalizing their values. Find more information about this class in [<tt><b>torchmil.datasets.wsi_dataset</b></tt>](../api/datasets/wsi_dataset.md).


## Creating your own dataset

With the provided explanation, we are now ready to define a custom class to use <tt>torchmil</tt> in your own dataset. We will implement a WSI dataset using slides from the [Genotype-Tissue Expression (GTEx) Project](https://www.gtexportal.org/home/), which can be downloaded for free. Particularly, we will use slides of <tt>UrinaryBladder</tt> tissue. 

To create the dataset, we must first extract the `coords` of the patches from the original <tt>.tiff</tt> files and then extract `features` from those patches. To achieve that, a tool like [CLAM](https://github.com/mahmoodlab/CLAM) can be used. We will assume that no masks are provided, so we will not have access to `labels` or `inst_labels`. We have extracted the features using the foundation model [UNI](https://huggingface.co/MahmoodLab/UNI).

Then, creating the dataset is as simple as defining a new class that extends `WSIDataset` and properly 


In [1]:
import numpy as np
from torchmil.datasets import WSIDataset
from torchmil.utils.common import read_csv, keep_only_existing_files


class GTExUrinaryBladderDataset(WSIDataset):
    r"""

    GTEx Urinary Bladder dataset.
    This dataset is an example of dataset with no labels.


    """
    def __init__(
        self,
        root : str,
        features : str = 'UNI',
        partition : str = 'train',
        bag_keys: list = ["X", "Y", "y_inst", "adj", "coords"],
        patch_size: int = 512,
        adj_with_dist: bool = False,
        norm_adj: bool = True,
        load_at_init: bool = True
    ) -> None:
        """
        Arguments:
            root: Path to the root directory of the dataset.
            features: Type of features to use. Must be one of ['UNI'].
            partition: Partition of the dataset. Must be one of ['train', 'test'].
            bag_keys: List of keys to use for the bags. Must be in ['X', 'Y', 'y_inst', 'coords'].
            patch_size: Size of the patches. Currently, only 512 is supported.
            adj_with_dist: If True, the adjacency matrix is built using the Euclidean distance between the patches features. If False, the adjacency matrix is binary.
            norm_adj: If True, normalize the adjacency matrix.
            load_at_init: If True, load the bags at initialization. If False, load the bags on demand.
        """

        features_path = f'{root}/patches_{patch_size}/features/features_{features}/'
        labels_path = f'{root}/patches_{patch_size}/labels/'
        patch_labels_path = f'{root}/patches_{patch_size}/inst_labels/'
        coords_path = f'{root}/patches_{patch_size}/coords/'

        # This csv is generated by CLAM, with slide_id containing "bag_name.format"
        bag_names_file = f'{root}/patches_{patch_size}/process_list_autogen.csv'
        dict_list = read_csv(bag_names_file)
        wsi_names = list(set([ row['slide_id'].split('.')[0] for row in dict_list]))
        wsi_names = keep_only_existing_files(features_path, wsi_names)

        WSIDataset.__init__(
            self,
            features_path=features_path,
            labels_path=labels_path,
            patch_labels_path=patch_labels_path,
            coords_path=coords_path,
            wsi_names=wsi_names,
            bag_keys=bag_keys,
            patch_size=patch_size,
            adj_with_dist=adj_with_dist,
            norm_adj=norm_adj,
            load_at_init=load_at_init
        )


We have now defined our new `GTExUrinaryBladderDataset` class. We can now instantiate it, using as `bag_keys` only the features `X` and the adjacency matrix `adj`.  We only have to specify the root path! We will use `load_at_init = False` so that the features of the slides are only loaded when needed.

In [2]:
root = '/data/data_fjaviersaezm/GTExTorchmil/UrinaryBladder/'
dataset = GTExUrinaryBladderDataset(root = root, partition='train', features='UNI', bag_keys=['X', 'y_inst', 'adj'], patch_size=512, adj_with_dist=False, norm_adj=True, load_at_init=False)
print(dataset.bag_names)

['GTEX-N7MS-2125', 'GTEX-N7MT-1825', 'GTEX-NFK9-2125', 'GTEX-NL3G-1925', 'GTEX-NL3H-2125', 'GTEX-NL4W-1125', 'GTEX-NPJ7-2625', 'GTEX-NPJ8-1125', 'GTEX-O5YT-1925', 'GTEX-O5YU-2125', 'GTEX-O5YV-2125', 'GTEX-OHPJ-1925', 'GTEX-OHPK-1925', 'GTEX-OHPL-1925', 'GTEX-OHPM-1925', 'GTEX-OHPN-1825', 'GTEX-OIZF-1925', 'GTEX-OIZG-1825', 'GTEX-OIZH-1925', 'GTEX-OIZI-1925', 'GTEX-OOBJ-1925', 'GTEX-OOBK-1925', 'GTEX-OXRK-1925', 'GTEX-OXRL-1925', 'GTEX-OXRN-2225', 'GTEX-OXRO-1125', 'GTEX-OXRP-1425', 'GTEX-P44G-2025', 'GTEX-P44H-2225', 'GTEX-P4PP-1925', 'GTEX-P4PQ-1925', 'GTEX-P4QR-2325', 'GTEX-P4QS-1925', 'GTEX-P4QT-1925', 'GTEX-P78B-2125', 'GTEX-PLZ4-2025', 'GTEX-PLZ5-1325', 'GTEX-PLZ6-1025', 'GTEX-POMQ-0925', 'GTEX-POYW-2225', 'GTEX-PSDG-2025', 'GTEX-PVOW-2425', 'GTEX-PW2O-1025', 'GTEX-PWCY-1125', 'GTEX-PWN1-1925', 'GTEX-PWO3-2725', 'GTEX-PWOO-1425', 'GTEX-PX3G-1925', 'GTEX-Q2AG-2125', 'GTEX-Q2AH-1325', 'GTEX-Q2AI-1025', 'GTEX-Q734-1225', 'GTEX-QCQG-0825', 'GTEX-QDT8-2825', 'GTEX-QDVJ-1525', 'GTEX-QDV

Great! All the bags have been recognized. Now we can display a bag, which is returned as a `dict`.

In [3]:
el = dataset[0]
print(el.keys())
print(el['X'].shape)
print(el['adj'].shape)



TypeError: expected np.ndarray (got NoneType)

As you can observe, the dataset object printed a warning since it could not find the instance labels, but properly loaded the `features` of the bag and, using the `coords`, it built the adjacency matrix correctly. 

Building a new dataset for <tt>torchmil</tt> was super easy!