diff --git a/.gitattributes b/.gitattributes
index b04fc3fa..6aa148d8 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -2,5 +2,6 @@ tests/eva/assets/**/*.h5 filter=lfs diff=lfs merge=lfs -text
tests/eva/assets/**/*.png filter=lfs diff=lfs merge=lfs -text
tests/eva/assets/**/*.jpg filter=lfs diff=lfs merge=lfs -text
tests/eva/assets/**/*.tif filter=lfs diff=lfs merge=lfs -text
+tests/eva/assets/**/*.tiff filter=lfs diff=lfs merge=lfs -text
tests/eva/assets/**/*.csv filter=lfs diff=lfs merge=lfs -text
tests/eva/assets/**/*.pt filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 6c279c54..7b6cf4ed 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -34,6 +34,12 @@ jobs:
- "3.10"
runs-on: ${{ matrix.os }}
steps:
+ - name: Install OS dependencies
+ run: |
+ sudo apt update
+ sudo apt install -y software-properties-common
+ sudo add-apt-repository ppa:openslide/openslide
+ sudo apt install -y openslide-tools
- name: Checkout
uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4
with:
diff --git a/README.md b/README.md
index 4cbbdf5f..2e150690 100644
--- a/README.md
+++ b/README.md
@@ -104,29 +104,27 @@ and [tutorials](https://kaiko-ai.github.io/eva/dev/user-guide/advanced/replicate
In this section you will find model benchmarks which were generated with _`eva`_.
-### Table I: WSI patch-level benchmark
+### Table I: WSI classification tasks
-| Model | BACH | CRC | MHIST | PCam/val | PCam/test |
-|--------------------------------------------------|-------|-------|-------|----------|-----------|
-| ViT-S/16 _(random)_ [1] | 0.410 | 0.617 | 0.501 | 0.753 | 0.728 |
-| ViT-S/16 _(ImageNet)_ [1] | 0.695 | 0.935 | 0.831 | 0.864 | 0.849 |
-| ViT-B/8 _(ImageNet)_ [1] | 0.710 | 0.939 | 0.814 | 0.870 | 0.856 |
-| ViT-L/14 _(ImageNet)_ [1] | 0.707 | 0.916 | 0.832 | 0.873 | 0.888 |
-| DINO(p=16) [2] | 0.801 | 0.934 | 0.768 | 0.889 | 0.895 |
-| Phikon [3] | 0.725 | 0.935 | 0.777 | 0.912 | 0.915 |
-| UNI [4] | 0.814 | 0.950 | 0.837 | 0.936 | 0.938 |
-| ViT-S/16 _(kaiko.ai)_ [5] | 0.797 | 0.943 | 0.828 | 0.903 | 0.893 |
-| ViT-S/8 _(kaiko.ai)_ [5] | 0.834 | 0.946 | 0.832 | 0.897 | 0.887 |
-| ViT-B/16 _(kaiko.ai)_ [5] | 0.810 | 0.960 | 0.826 | 0.900 | 0.898 |
-| ViT-B/8 _(kaiko.ai)_ [5] | 0.865 | 0.956 | 0.809 | 0.913 | 0.921 |
-| ViT-L/14 _(kaiko.ai)_ [5] | 0.870 | 0.930 | 0.809 | 0.908 | 0.898 |
+| Model | BACH | CRC | MHIST | PCam | Camelyon16 | PANDA |
+|---------|-------|-------|-------|--------|------------|-------|
+| ViT-S/16 _(random)_ [1] | 0.411|0.613|0.5|0.752|0.551|0.347|
+| ViT-S/16 _(ImageNet)_ [1] | 0.675|0.936|0.827|0.861|0.751|0.676|
+| DINO(p=16) [2] | 0.77|0.936|0.751|0.905|0.869|0.737|
+| Phikon [3] | 0.715|0.942|0.766|0.925|0.879|0.784|
+| UNI [4] | 0.797|0.95|0.835|0.939|0.933|0.774|
+| ViT-S/16 _(kaiko.ai)_ [5] | 0.8|0.949|0.831|0.902|0.897|0.77|
+| ViT-S/8 _(kaiko.ai)_ [5] | 0.825|0.948|0.826|0.887|0.879|0.741|
+| ViT-B/16 _(kaiko.ai)_ [5] | 0.846|0.959|0.839|0.906|0.891|0.753|
+| ViT-B/8 _(kaiko.ai)_ [5] | 0.867|0.952|0.814|0.921|0.939|0.761|
+| ViT-L/14 _(kaiko.ai)_ [5] | 0.862|0.935|0.822|0.907|0.941|0.769|
_Table I: Linear probing evaluation of FMs on patch-level downstream datasets.
We report averaged balanced accuracy
-over 5 runs, with an average standard deviation of ±0.003._
+over 5 runs. Results are reported on the "test" split if available and otherwise on the "validation" split.
diff --git a/configs/vision/dino_vit/offline/camelyon16.yaml b/configs/vision/dino_vit/offline/camelyon16.yaml
new file mode 100644
index 00000000..19886da4
--- /dev/null
+++ b/configs/vision/dino_vit/offline/camelyon16.yaml
@@ -0,0 +1,134 @@
+---
+trainer:
+ class_path: eva.Trainer
+ init_args:
+ n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
+ default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/camelyon16}
+ max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100}
+ callbacks:
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
+ init_args:
+ logging_interval: epoch
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
+ init_args:
+ filename: best
+ save_last: true
+ save_top_k: 1
+ monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
+ mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
+ - class_path: lightning.pytorch.callbacks.EarlyStopping
+ init_args:
+ min_delta: 0
+ patience: ${oc.env:PATIENCE, 10}
+ monitor: *MONITOR_METRIC
+ mode: *MONITOR_METRIC_MODE
+ - class_path: eva.callbacks.ClassificationEmbeddingsWriter
+ init_args:
+ output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:DINO_BACKBONE, dino_vits16}/camelyon16}
+ save_every_n: 10_000
+ dataloader_idx_map:
+ 0: train
+ 1: val
+ 2: test
+ metadata_keys: ["wsi_id"]
+ backbone:
+ class_path: eva.models.ModelFromFunction
+ init_args:
+ path: torch.hub.load
+ arguments:
+ repo_or_dir: ${oc.env:REPO_OR_DIR, facebookresearch/dino:main}
+ model: ${oc.env:DINO_BACKBONE, dino_vits16}
+ pretrained: ${oc.env:PRETRAINED, true}
+ force_reload: ${oc.env:FORCE_RELOAD, false}
+ checkpoint_path: ${oc.env:CHECKPOINT_PATH, null}
+ logger:
+ - class_path: lightning.pytorch.loggers.TensorBoardLogger
+ init_args:
+ save_dir: *OUTPUT_ROOT
+ name: ""
+model:
+ class_path: eva.HeadModule
+ init_args:
+ head:
+ class_path: eva.vision.models.networks.ABMIL
+ init_args:
+ input_size: ${oc.env:IN_FEATURES, 384}
+ output_size: &NUM_CLASSES 1
+ projected_input_size: 128
+ criterion: torch.nn.BCEWithLogitsLoss
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: ${oc.env:LR_VALUE, 0.001}
+ betas: [0.9, 0.999]
+ lr_scheduler:
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
+ init_args:
+ T_max: *MAX_EPOCHS
+ eta_min: 0.0
+ metrics:
+ common:
+ - class_path: eva.metrics.AverageLoss
+ - class_path: eva.metrics.BinaryClassificationMetrics
+data:
+ class_path: eva.DataModule
+ init_args:
+ datasets:
+ train:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args: &DATASET_ARGS
+ root: *DATASET_EMBEDDINGS_ROOT
+ manifest_file: manifest.csv
+ split: train
+ embeddings_transforms:
+ class_path: eva.core.data.transforms.Pad2DTensor
+ init_args:
+ pad_size: 10_000
+ target_transforms:
+ class_path: eva.core.data.transforms.dtype.ArrayToFloatTensor
+ val:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: val
+ test:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: test
+ predict:
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args: &PREDICT_DATASET_ARGS
+ root: ${oc.env:DATA_ROOT, ./data/camelyon16}
+ sampler:
+ class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
+ init_args:
+ max_samples: 10_000
+ width: 224
+ height: 224
+ target_mpp: 0.25
+ split: train
+ image_transforms:
+ class_path: eva.vision.data.transforms.common.ResizeAndCrop
+ init_args:
+ size: ${oc.env:RESIZE_DIM, 224}
+ mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
+ std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: val
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: test
+ dataloaders:
+ train:
+ batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
+ shuffle: true
+ val:
+ batch_size: *BATCH_SIZE
+ test:
+ batch_size: *BATCH_SIZE
+ predict:
+ batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
diff --git a/configs/vision/dino_vit/offline/panda.yaml b/configs/vision/dino_vit/offline/panda.yaml
new file mode 100644
index 00000000..57f34696
--- /dev/null
+++ b/configs/vision/dino_vit/offline/panda.yaml
@@ -0,0 +1,133 @@
+---
+trainer:
+ class_path: eva.Trainer
+ init_args:
+ n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
+ default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/panda}
+ max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49}
+ callbacks:
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
+ init_args:
+ logging_interval: epoch
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
+ init_args:
+ filename: best
+ save_last: true
+ save_top_k: 1
+ monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy}
+ mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
+ - class_path: lightning.pytorch.callbacks.EarlyStopping
+ init_args:
+ min_delta: 0
+ patience: ${oc.env:PATIENCE, 8}
+ monitor: *MONITOR_METRIC
+ mode: *MONITOR_METRIC_MODE
+ - class_path: eva.callbacks.ClassificationEmbeddingsWriter
+ init_args:
+ output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:DINO_BACKBONE, dino_vits16}/panda}
+ dataloader_idx_map:
+ 0: train
+ 1: val
+ 2: test
+ metadata_keys: ["wsi_id"]
+ backbone:
+ class_path: eva.models.ModelFromFunction
+ init_args:
+ path: torch.hub.load
+ arguments:
+ repo_or_dir: ${oc.env:REPO_OR_DIR, facebookresearch/dino:main}
+ model: ${oc.env:DINO_BACKBONE, dino_vits16}
+ pretrained: ${oc.env:PRETRAINED, true}
+ force_reload: ${oc.env:FORCE_RELOAD, false}
+ checkpoint_path: ${oc.env:CHECKPOINT_PATH, null}
+ logger:
+ - class_path: lightning.pytorch.loggers.TensorBoardLogger
+ init_args:
+ save_dir: *OUTPUT_ROOT
+ name: ""
+model:
+ class_path: eva.HeadModule
+ init_args:
+ head:
+ class_path: eva.vision.models.networks.ABMIL
+ init_args:
+ input_size: ${oc.env:IN_FEATURES, 384}
+ output_size: &NUM_CLASSES 6
+ projected_input_size: 128
+ criterion: torch.nn.CrossEntropyLoss
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: ${oc.env:LR_VALUE, 0.001}
+ betas: [0.9, 0.999]
+ lr_scheduler:
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
+ init_args:
+ T_max: *MAX_EPOCHS
+ eta_min: 0.0
+ metrics:
+ common:
+ - class_path: eva.metrics.AverageLoss
+ - class_path: eva.metrics.MulticlassClassificationMetrics
+ init_args:
+ num_classes: *NUM_CLASSES
+data:
+ class_path: eva.DataModule
+ init_args:
+ datasets:
+ train:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args: &DATASET_ARGS
+ root: *DATASET_EMBEDDINGS_ROOT
+ manifest_file: manifest.csv
+ split: train
+ embeddings_transforms:
+ class_path: eva.core.data.transforms.Pad2DTensor
+ init_args:
+ pad_size: &N_PATCHES 1000
+ val:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: val
+ test:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: test
+ predict:
+ - class_path: eva.vision.datasets.PANDA
+ init_args: &PREDICT_DATASET_ARGS
+ root: ${oc.env:DATA_ROOT, ./data/panda/prostate-cancer-grade-assessment}
+ sampler:
+ class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
+ init_args:
+ max_samples: *N_PATCHES
+ width: 224
+ height: 224
+ target_mpp: 0.5
+ split: train
+ image_transforms:
+ class_path: eva.vision.data.transforms.common.ResizeAndCrop
+ init_args:
+ size: ${oc.env:RESIZE_DIM, 224}
+ mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
+ std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: val
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: test
+ dataloaders:
+ train:
+ batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
+ shuffle: true
+ val:
+ batch_size: *BATCH_SIZE
+ test:
+ batch_size: *BATCH_SIZE
+ predict:
+ batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
diff --git a/configs/vision/owkin/phikon/offline/camelyon16.yaml b/configs/vision/owkin/phikon/offline/camelyon16.yaml
new file mode 100644
index 00000000..b3bf1dca
--- /dev/null
+++ b/configs/vision/owkin/phikon/offline/camelyon16.yaml
@@ -0,0 +1,130 @@
+---
+trainer:
+ class_path: eva.Trainer
+ init_args:
+ n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
+ default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/owkin/phikon/offline/camelyon16}
+ max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100}
+ callbacks:
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
+ init_args:
+ logging_interval: epoch
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
+ init_args:
+ filename: best
+ save_last: true
+ save_top_k: 1
+ monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
+ mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
+ - class_path: lightning.pytorch.callbacks.EarlyStopping
+ init_args:
+ min_delta: 0
+ patience: ${oc.env:PATIENCE, 10}
+ monitor: *MONITOR_METRIC
+ mode: *MONITOR_METRIC_MODE
+ - class_path: eva.callbacks.ClassificationEmbeddingsWriter
+ init_args:
+ output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/owkin/phikon/camelyon16}
+ save_every_n: 10_000
+ dataloader_idx_map:
+ 0: train
+ 1: val
+ 2: test
+ metadata_keys: ["wsi_id"]
+ backbone:
+ class_path: eva.models.HuggingFaceModel
+ init_args:
+ model_name_or_path: owkin/phikon
+ tensor_transforms:
+ class_path: eva.core.models.networks.transforms.ExtractCLSFeatures
+ logger:
+ - class_path: lightning.pytorch.loggers.TensorBoardLogger
+ init_args:
+ save_dir: *OUTPUT_ROOT
+ name: ""
+model:
+ class_path: eva.HeadModule
+ init_args:
+ head:
+ class_path: eva.vision.models.networks.ABMIL
+ init_args:
+ input_size: ${oc.env:IN_FEATURES, 768}
+ output_size: &NUM_CLASSES 1
+ projected_input_size: 128
+ criterion: torch.nn.BCEWithLogitsLoss
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: ${oc.env:LR_VALUE, 0.001}
+ betas: [0.9, 0.999]
+ lr_scheduler:
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
+ init_args:
+ T_max: *MAX_EPOCHS
+ eta_min: 0.0
+ metrics:
+ common:
+ - class_path: eva.metrics.AverageLoss
+ - class_path: eva.metrics.BinaryClassificationMetrics
+data:
+ class_path: eva.DataModule
+ init_args:
+ datasets:
+ train:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args: &DATASET_ARGS
+ root: *DATASET_EMBEDDINGS_ROOT
+ manifest_file: manifest.csv
+ split: train
+ embeddings_transforms:
+ class_path: eva.core.data.transforms.Pad2DTensor
+ init_args:
+ pad_size: 10_000
+ target_transforms:
+ class_path: eva.core.data.transforms.dtype.ArrayToFloatTensor
+ val:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: val
+ test:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: test
+ predict:
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args: &PREDICT_DATASET_ARGS
+ root: ${oc.env:DATA_ROOT, ./data/camelyon16}
+ sampler:
+ class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
+ init_args:
+ max_samples: 10_000
+ width: 224
+ height: 224
+ target_mpp: 0.25
+ split: train
+ image_transforms:
+ class_path: eva.vision.data.transforms.common.ResizeAndCrop
+ init_args:
+ size: ${oc.env:RESIZE_DIM, 224}
+ mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
+ std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: val
+ - class_path: eva.vision.datasets.Camelyon16
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: test
+ dataloaders:
+ train:
+ batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
+ shuffle: true
+ val:
+ batch_size: *BATCH_SIZE
+ test:
+ batch_size: *BATCH_SIZE
+ predict:
+ batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
diff --git a/configs/vision/owkin/phikon/offline/panda.yaml b/configs/vision/owkin/phikon/offline/panda.yaml
new file mode 100644
index 00000000..462c2b53
--- /dev/null
+++ b/configs/vision/owkin/phikon/offline/panda.yaml
@@ -0,0 +1,128 @@
+---
+trainer:
+ class_path: eva.Trainer
+ init_args:
+ n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
+ default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/owkin/phikon/offline/panda}
+ max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49}
+ callbacks:
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
+ init_args:
+ logging_interval: epoch
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
+ init_args:
+ filename: best
+ save_last: true
+ save_top_k: 1
+ monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy}
+ mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
+ - class_path: lightning.pytorch.callbacks.EarlyStopping
+ init_args:
+ min_delta: 0
+ patience: ${oc.env:PATIENCE, 8}
+ monitor: *MONITOR_METRIC
+ mode: *MONITOR_METRIC_MODE
+ - class_path: eva.callbacks.ClassificationEmbeddingsWriter
+ init_args:
+ output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/owkin/phikon/panda}
+ dataloader_idx_map:
+ 0: train
+ 1: val
+ 2: test
+ metadata_keys: ["wsi_id"]
+ backbone:
+ class_path: eva.models.HuggingFaceModel
+ init_args:
+ model_name_or_path: owkin/phikon
+ tensor_transforms:
+ class_path: eva.core.models.networks.transforms.ExtractCLSFeatures
+ logger:
+ - class_path: lightning.pytorch.loggers.TensorBoardLogger
+ init_args:
+ save_dir: *OUTPUT_ROOT
+ name: ""
+model:
+ class_path: eva.HeadModule
+ init_args:
+ head:
+ class_path: eva.vision.models.networks.ABMIL
+ init_args:
+ input_size: ${oc.env:IN_FEATURES, 768}
+ output_size: &NUM_CLASSES 6
+ criterion: torch.nn.CrossEntropyLoss
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: ${oc.env:LR_VALUE, 0.001}
+ betas: [0.9, 0.999]
+ lr_scheduler:
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
+ init_args:
+ T_max: *MAX_EPOCHS
+ eta_min: 0.0
+ metrics:
+ common:
+ - class_path: eva.metrics.AverageLoss
+ - class_path: eva.metrics.MulticlassClassificationMetrics
+ init_args:
+ num_classes: *NUM_CLASSES
+data:
+ class_path: eva.DataModule
+ init_args:
+ datasets:
+ train:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args: &DATASET_ARGS
+ root: *DATASET_EMBEDDINGS_ROOT
+ manifest_file: manifest.csv
+ split: train
+ embeddings_transforms:
+ class_path: eva.core.data.transforms.Pad2DTensor
+ init_args:
+ pad_size: &N_PATCHES 1000
+ val:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: val
+ test:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: test
+ predict:
+ - class_path: eva.vision.datasets.PANDA
+ init_args: &PREDICT_DATASET_ARGS
+ root: ${oc.env:DATA_ROOT, ./data/panda/prostate-cancer-grade-assessment}
+ sampler:
+ class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
+ init_args:
+ max_samples: *N_PATCHES
+ width: 224
+ height: 224
+ target_mpp: 0.5
+ split: train
+ image_transforms:
+ class_path: eva.vision.data.transforms.common.ResizeAndCrop
+ init_args:
+ size: ${oc.env:RESIZE_DIM, 224}
+ mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
+ std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: val
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: test
+ dataloaders:
+ train:
+ batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
+ shuffle: true
+ val:
+ batch_size: *BATCH_SIZE
+ test:
+ batch_size: *BATCH_SIZE
+ predict:
+ batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
diff --git a/configs/vision/tests/offline/panda.yaml b/configs/vision/tests/offline/panda.yaml
new file mode 100644
index 00000000..28844dd1
--- /dev/null
+++ b/configs/vision/tests/offline/panda.yaml
@@ -0,0 +1,128 @@
+---
+trainer:
+ class_path: eva.Trainer
+ init_args:
+ default_root_dir: &LIGHTNING_ROOT ${oc.env:LIGHTNING_ROOT, logs/test/offline/panda}
+ max_epochs: &MAX_EPOCHS 1
+ limit_train_batches: 2
+ limit_val_batches: 2
+ callbacks:
+ - class_path: eva.callbacks.ClassificationEmbeddingsWriter
+ init_args:
+ output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT}/panda
+ dataloader_idx_map:
+ 0: train
+ 1: val
+ 2: test
+ metadata_keys: ["wsi_id"]
+ backbone:
+ class_path: eva.models.ModelFromFunction
+ init_args:
+ path: torch.hub.load
+ arguments:
+ repo_or_dir: facebookresearch/dino:main
+ model: dino_vits16
+ pretrained: false
+ checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null}
+model:
+ class_path: eva.HeadModule
+ init_args:
+ head:
+ class_path: eva.vision.models.networks.ABMIL
+ init_args:
+ input_size: ${oc.env:IN_FEATURES, 384}
+ output_size: &NUM_CLASSES 6
+ criterion: torch.nn.CrossEntropyLoss
+ optimizer:
+ class_path: torch.optim.SGD
+ init_args:
+ lr: &LR_VALUE ${oc.env:LR_VALUE, 0.00004}
+ momentum: 0.9
+ weight_decay: 0.0
+ lr_scheduler:
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
+ init_args:
+ T_max: *MAX_EPOCHS
+ eta_min: 0.0
+ metrics:
+ common:
+ - class_path: eva.metrics.AverageLoss
+ - class_path: eva.metrics.MulticlassClassificationMetrics
+ init_args:
+ num_classes: *NUM_CLASSES
+data:
+ class_path: eva.DataModule
+ init_args:
+ datasets:
+ train:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args: &DATASET_ARGS
+ root: *DATASET_EMBEDDINGS_ROOT
+ manifest_file: manifest.csv
+ split: train
+ embeddings_transforms:
+ class_path: eva.core.data.transforms.Pad2DTensor
+ init_args:
+ pad_size: &N_PATCHES 5
+ val:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: val
+ test:
+ class_path: eva.datasets.MultiEmbeddingsClassificationDataset
+ init_args:
+ <<: *DATASET_ARGS
+ split: test
+ predict:
+ - class_path: eva.vision.datasets.PANDA
+ init_args: &PREDICT_DATASET_ARGS
+ root: ${oc.env:TESTS_ROOT, tests/eva}/assets/vision/datasets/panda
+ sampler:
+ class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
+ init_args:
+ max_samples: *N_PATCHES
+ width: 2
+ height: 2
+ target_mpp: 0.5
+ split: train
+ image_transforms:
+ class_path: eva.vision.data.transforms.common.ResizeAndCrop
+ init_args:
+ size: ${oc.env:RESIZE_DIM, 224}
+ mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
+ std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: val
+ - class_path: eva.vision.datasets.PANDA
+ init_args:
+ <<: *PREDICT_DATASET_ARGS
+ split: test
+ dataloaders:
+ train:
+ batch_size: &BATCH_SIZE 2
+ shuffle: true
+ num_workers: 0
+ pin_memory: false
+ persistent_workers: false
+ prefetch_factor: null
+ val:
+ batch_size: *BATCH_SIZE
+ num_workers: 0
+ pin_memory: false
+ persistent_workers: false
+ prefetch_factor: null
+ test:
+ batch_size: *BATCH_SIZE
+ num_workers: 0
+ pin_memory: false
+ persistent_workers: false
+ prefetch_factor: null
+ predict:
+ batch_size: &PREDICT_BATCH_SIZE 2
+ num_workers: 0
+ pin_memory: false
+ persistent_workers: false
+ prefetch_factor: null
diff --git a/configs/vision/tests/offline/patch_camelyon.yaml b/configs/vision/tests/offline/patch_camelyon.yaml
index bf9722e4..16286058 100644
--- a/configs/vision/tests/offline/patch_camelyon.yaml
+++ b/configs/vision/tests/offline/patch_camelyon.yaml
@@ -16,7 +16,7 @@ trainer:
1: val
2: test
backbone:
- class_path: eva.core.models.networks.wrappers.ModelFromFunction
+ class_path: eva.models.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
diff --git a/docs/datasets/camelyon16.md b/docs/datasets/camelyon16.md
new file mode 100644
index 00000000..56bb69ff
--- /dev/null
+++ b/docs/datasets/camelyon16.md
@@ -0,0 +1,69 @@
+# Camelyon16
+
+The Camelyon16 dataset consists of 400 WSIs of lymph nodes for breast cancer metastasis classification. The dataset is a combination of two independent datasets, collected from two separate medical centers in the Netherlands (Radboud University Medical Center and University Medical Center Utrecht). The dataset contains the slides from which [PatchCamelyon](patch_camelyon.md)-patches were extracted.
+
+The dataset is divided in a train set (270 slides) and test set (130 slides), both containing images from both centers. Note that one test set slide was a duplicate has been removed (see [here](https://github.com/DIDSR/dldp?tab=readme-ov-file#04-data-description-important)).
+
+The task was part of [Grand Challenge](https://grand-challenge.org/) in 2016 and has later been replaced by Camelyon17.
+
+Source: https://camelyon16.grand-challenge.org
+
+## Raw data
+
+### Key stats
+
+| | |
+|---------------------------|----------------------------------------------------------|
+| **Modality** | Vision (WSI) |
+| **Task** | Binary classification |
+| **Cancer type** | Breast |
+| **Data size** | ~700 GB |
+| **Image dimension** | ~100-250k x ~100-250k x 3 |
+| **Magnification (μm/px)** | 40x (0.25) - Level 0 |
+| **Files format** | `.tif` |
+| **Number of images** | 399 (270 train, 129 test) |
+
+
+### Organization
+
+The data `CAMELYON16` (download links [here](https://camelyon17.grand-challenge.org/Data/)) is organized as follows:
+
+```
+CAMELYON16
+├── training
+│ ├── normal
+| │ ├── normal_001.tif
+| │ └── ...
+│ ├── tumor
+| │ ├── tumor_001.tif
+| │ └── ...
+│ └── lesion_annotations.zip
+├── testing
+│ ├── images
+| │ ├── test_001.tif
+| │ └── ...
+│ ├── evaluation # masks not in use
+│ ├── reference.csv # targets
+│ └── lesion_annotations.zip
+```
+
+## Download and preprocessing
+
+The `Camelyon16` dataset class doesn't download the data during runtime and must be downloaded manually from links provided [here](https://camelyon17.grand-challenge.org/Data/).
+
+The dataset is split into train / test. Additionally, we split the train set into train/val using the same splits as [PatchCamelyon](patch_camelyon.md) (see metadata CSV files on [Zenodo](https://zenodo.org/records/2546921)).
+
+| Splits | Train | Validation | Test |
+|----------|-------------|-------------|------------|
+| #Samples | 216 (54.1%) | 54 (13.5%) | 129 (32.3%)|
+
+
+## Relevant links
+
+* [Grand Challenge dataset description](https://camelyon16.grand-challenge.org/Data/)
+* [Download links](https://camelyon17.grand-challenge.org/Data/)
+* [GitHub with dataset description by DIDSR](https://github.com/DIDSR/dldp)
+
+
+## References
+1 : [A General-Purpose Self-Supervised Model for Computational Pathology](https://arxiv.org/abs/2308.15474)
\ No newline at end of file
diff --git a/docs/datasets/index.md b/docs/datasets/index.md
index 963de114..cc947e80 100644
--- a/docs/datasets/index.md
+++ b/docs/datasets/index.md
@@ -6,6 +6,7 @@
### Whole Slide (WSI) and microscopy image datasets
+#### Patch-level
| Dataset | #Patches | Patch Size | Magnification (μm/px) | Task | Cancer Type |
|------------------------------------|----------|------------|------------------------|----------------------------|------------------|
| [BACH](bach.md) | 400 | 2048x1536 | 20x (0.5) | Classification (4 classes) | Breast |
@@ -15,6 +16,13 @@
\* Downsampled from 40x (0.25 μm/px) to increase the field of view.
+#### Slide-level
+| Dataset | #Slides | Slide Size | Magnification (μm/px) | Task | Cancer Type |
+|------------------------------------|----------|---------------------------|------------------------|----------------------------|------------------|
+| [Camelyon16](camelyon16.md) | 400 | ~100-250k x ~100-250k x 3 | 40x (0.25) | Classification (2 classes) | Breast |
+| [PANDA](panda.md) | 10,616 | ~20k x 20k x 3 | 20x (0.5) | Classification (6 classes) | Prostate |
+
+
### Radiology datasets
| Dataset | #Images | Image Size | Task | Download provided
diff --git a/docs/datasets/panda.md b/docs/datasets/panda.md
new file mode 100644
index 00000000..cf29488e
--- /dev/null
+++ b/docs/datasets/panda.md
@@ -0,0 +1,68 @@
+# PANDA (Prostate cANcer graDe Assessment)
+
+The PANDA datasets consists of 10,616 whole-slide images of digitized H&E-stained prostate tissue biopsies originating from two medical centers. After the biopsy, the slides were classified into Gleason patterns (3, 4 or 5) based on the architectural growth patterns of the tumor, which are then converted into an ISUP grade on a 0-5 scale.
+
+The Gleason grading system is the most important prognostic marker for prostate cancer and the ISUP grade has a crucial role when deciding how a patient should be treated. However, the system suffers from significant inter-observer variability between pathologists, leading to imperfect and noisy labels.
+
+Source: https://www.kaggle.com/competitions/prostate-cancer-grade-assessment
+
+
+## Raw data
+
+### Key stats
+
+| | |
+|---------------------------|----------------------------------------------------------|
+| **Modality** | Vision (WSI) |
+| **Task** | Multiclass classification (6 classes) |
+| **Cancer type** | Prostate |
+| **Data size** | 347 GB |
+| **Image dimension** | ~20k x 20k x 3 |
+| **Magnification (μm/px)** | 20x (0.5) - Level 0 |
+| **Files format** | `.tiff` |
+| **Number of images** | 10,616 (9,555 after removing noisy labels) |
+
+
+### Organization
+
+The data `prostate-cancer-grade-assessment.zip` from [kaggle](https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/data) is organized as follows:
+
+```
+prostate-cancer-grade-assessment
+├── train_images
+│ ├── 0005f7aaab2800f6170c399693a96917.tiff
+│ └── ...
+├── train_label_masks (not used in eva)
+│ ├── 0005f7aaab2800f6170c399693a96917_mask.tiff
+│ └── ...
+├── train.csv (contains Gleason & ISUP labels)
+├── test.csv
+├── sample_submission.csv
+```
+
+## Download and preprocessing
+
+The `PANDA` dataset class doesn't download the data during runtime and must be downloaded manually from [kaggle](https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/data).
+
+As done in other studies1 we exclude ~10% of the samples with noisy labels according to kaggle's [6th place solution](https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/discussion/169230) resulting in a total dataset size of 9555 WSIs.
+
+We then generate random stratified train / validation and test splits using a 0.7 / 0.15 / 0.15 ratio:
+
+
+| Splits | Train | Validation | Test |
+|----------|-------------|-------------|------------|
+| #Samples | 6686 (70%) | 1430 (15%) | 1439 (15%) |
+
+
+## Relevant links
+
+* [Kaggle Challenge](https://www.kaggle.com/competitions/prostate-cancer-grade-assessment)
+* [Noisy Labels](https://github.com/analokmaus/kaggle-panda-challenge-public)
+
+
+## License
+
+[CC BY-SA-NC 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en)
+
+## References
+1 : [A General-Purpose Self-Supervised Model for Computational Pathology](https://arxiv.org/abs/2308.15474)
\ No newline at end of file
diff --git a/docs/images/starplot.png b/docs/images/starplot.png
new file mode 100644
index 00000000..50d20d97
Binary files /dev/null and b/docs/images/starplot.png differ
diff --git a/docs/index.md b/docs/index.md
index cee477de..3a0448ec 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -31,17 +31,17 @@ hide:
_Oncology FM Evaluation Framework by [kaiko.ai](https://www.kaiko.ai/)_
-With the first release, *eva* supports performance evaluation for vision Foundation Models ("FMs") and supervised machine learning models on WSI-patch-level image classification task. Support for radiology (CT-scans) segmentation tasks will be added soon.
+*eva* currently supports performance evaluation for vision Foundation Models ("FMs") and supervised machine learning models on WSI (patch- and slide-level) as well as radiology image classification tasks.
With *eva* we provide the open-source community with an easy-to-use framework that follows industry best practices to deliver a robust, reproducible and fair evaluation benchmark across FMs of different sizes and architectures.
-Support for additional modalities and tasks will be added in future releases.
+Support for additional modalities and tasks will be added soon.
## Use cases
### 1. Evaluate your own FMs on public benchmark datasets
-With a specified FM as input, you can run *eva* on several publicly available datasets & tasks. One evaluation run will download and preprocess the relevant data, compute embeddings, fit and evaluate a downstream head and report the mean and standard deviation of the relevant performance metrics.
+With a specified FM as input, you can run *eva* on several publicly available datasets & tasks. One evaluation run will download (if supported) and preprocess the relevant data, compute embeddings, fit and evaluate a downstream head and report the mean and standard deviation of the relevant performance metrics.
Supported datasets & tasks include:
@@ -52,6 +52,11 @@ Supported datasets & tasks include:
- **[CRC](datasets/crc.md)**: multiclass colorectal cancer classification
- **[MHIST](datasets/mhist.md)**: binary colorectal polyp cancer classification
+*WSI slide-level pathology datasets*
+
+- **[Camelyon16](datasets/camelyon16.md)**: binary breast cancer classification
+- **[PANDA](datasets/panda.md)**: multiclass prostate cancer classification
+
*Radiology datasets*
- **[TotalSegmentator](datasets/total_segmentator.md)**: radiology/CT-scan for segmentation of anatomical structures (*support coming soon*)
@@ -65,62 +70,7 @@ If you have your own labeled dataset, all that is needed is to implement a datas
## Evaluation results
-We evaluated the following FMs on the 4 supported WSI-patch-level image classification tasks. On the table below we report *Balanced Accuracy* for binary & multiclass tasks and show the average performance & standard deviation over 5 runs.
-
-
-
-
-| FM-backbone | pretraining | BACH | CRC | MHIST | PCam/val | PCam/test |
-|-----------------------------|-------------|------------------ |----------------- |----------------- |----------------- |-------------- |
-| DINO ViT-S16 | N/A | 0.410 (±0.009) | 0.617 (±0.008) | 0.501 (±0.004) | 0.753 (±0.002) | 0.728 (±0.003) |
-| DINO ViT-S16 | ImageNet | 0.695 (±0.004) | 0.935 (±0.003) | 0.831 (±0.002) | 0.864 (±0.007) | 0.849 (±0.007) |
-| DINO ViT-B8 | ImageNet | 0.710 (±0.007) | 0.939 (±0.001) | 0.814 (±0.003) | 0.870 (±0.003) | 0.856 (±0.004) |
-| DINOv2 ViT-L14 | ImageNet | 0.707 (±0.008) | 0.916 (±0.002) | 0.832 (±0.003) | 0.873 (±0.001) | 0.888 (±0.001) |
-| Lunit - ViT-S16 | TCGA | 0.801 (±0.005) | 0.934 (±0.001) | 0.768 (±0.004) | 0.889 (±0.002) | 0.895 (±0.006) |
-| Owkin - iBOT ViT-B16 | TCGA | 0.725 (±0.004) | 0.935 (±0.001) | 0.777 (±0.005) | 0.912 (±0.002) | 0.915 (±0.003) |
-| UNI - DINOv2 ViT-L16 | Mass-100k | 0.814 (±0.008) | 0.950 (±0.001) | **0.837 (±0.001)** | **0.936 (±0.001)** | **0.938 (±0.001)**|
-| kaiko.ai - DINO ViT-S16 | TCGA | 0.797 (±0.003) | 0.943 (±0.001) | 0.828 (±0.003) | 0.903 (±0.001) | 0.893 (±0.005) |
-| kaiko.ai - DINO ViT-S8 | TCGA | 0.834 (±0.012) | 0.946 (±0.002) | 0.832 (±0.006) | 0.897 (±0.001) | 0.887 (±0.002) |
-| kaiko.ai - DINO ViT-B16 | TCGA | 0.810 (±0.008) | **0.960 (±0.001)** | 0.826 (±0.003) | 0.900 (±0.002) | 0.898 (±0.003) |
-| kaiko.ai - DINO ViT-B8 | TCGA | 0.865 (±0.019) | 0.956 (±0.001) | 0.809 (±0.021) | 0.913 (±0.001) | 0.921 (±0.002) |
-| kaiko.ai - DINOv2 ViT-L14 | TCGA | **0.870 (±0.005)**| 0.930 (±0.001) | 0.809 (±0.001) | 0.908 (±0.001) | 0.898 (±0.002) |
-
-
-
-The runs use the default setup described in the section below.
-
-*eva* trains the decoder on the "train" split and uses the "validation" split for monitoring, early stopping and checkpoint selection. Evaluation results are reported on the "validation" split and, if available, on the "test" split.
-
-For more details on the FM-backbones and instructions to replicate the results, check out [Replicate evaluations](user-guide/advanced/replicate_evaluations.md).
-
-## Evaluation setup
-
-*Note that the current version of eva implements the task- & model-independent and fixed default set up following the standard evaluation protocol proposed by [1] and described in the table below. We selected this approach to prioritize reliable, robust and fair FM-evaluation while being in line with common literature. Additionally, with future versions we are planning to allow the use of cross-validation and hyper-parameter tuning to find the optimal setup to achieve best possible performance on the implemented downstream tasks.*
-
-With a provided FM, *eva* computes embeddings for all input images (WSI patches) which are then used to train a downstream head consisting of a single linear layer in a supervised setup for each of the benchmark datasets. We use early stopping with a patience of 5% of the maximal number of epochs.
-
-| | |
-|-------------------------|---------------------------|
-| **Backbone** | frozen |
-| **Hidden layers** | none |
-| **Dropout** | 0.0 |
-| **Activation function** | none |
-| **Number of steps** | 12,500 |
-| **Base Batch size** | 4,096 |
-| **Batch size** | dataset specific* |
-| **Base learning rate** | 0.01 |
-| **Learning Rate** | [Base learning rate] * [Batch size] / [Base batch size] |
-| **Max epochs** | [Number of samples] * [Number of steps] / [Batch size] |
-| **Early stopping** | 5% * [Max epochs] |
-| **Optimizer** | SGD |
-| **Momentum** | 0.9 |
-| **Weight Decay** | 0.0 |
-| **Nesterov momentum** | true |
-| **LR Schedule** | Cosine without warmup |
-
-\* For smaller datasets (e.g. BACH with 400 samples) we reduce the batch size to 256 and scale the learning rate accordingly.
-
-- [1]: [Virchow: A Million-Slide Digital Pathology Foundation Model, 2024](https://arxiv.org/pdf/2309.07778.pdf)
+Check out our [Leaderboards](leaderboards.md) to inspect evaluation results of publicly available FMs.
## License
diff --git a/docs/leaderboards.md b/docs/leaderboards.md
new file mode 100644
index 00000000..f55d2b55
--- /dev/null
+++ b/docs/leaderboards.md
@@ -0,0 +1,75 @@
+---
+hide:
+ - navigation
+---
+
+# Leaderboards
+
+We evaluated the following FMs on the 6 supported WSI-classification tasks. We report *Balanced Accuracy* for binary & multiclass tasks. The score shows the average performance over 5 runs.
+
+
+
+
+
+| Vision FM | pretraining | [BACH](datasets/bach.md) | [CRC](datasets/crc.md) | [MHIST](datasets/mhist.md) | [PCam](datasets/patch_camelyon.md) |[Camelyon16](datasets/camelyon16.md)| [PANDA](datasets/panda.md)|
+|---------|-------------|--------- |-----------|-----------|----------|----------|----------|
+| [DINO ViT-S16](https://arxiv.org/abs/2104.14294) | N/A | 0.411|0.613|0.5|0.752|0.551|0.347|
+| [DINO ViT-S16](https://arxiv.org/abs/2104.14294) | ImageNet | 0.675|0.936|0.827|0.861|0.751|0.676|
+| [Lunit - ViT-S16](https://github.com/lunit-io/benchmark-ssl-pathology/releases/) | TCGA | 0.77|0.936|0.751|0.905|0.869|0.737|
+| [Owkin (Phikon) - iBOT ViT-B16](https://huggingface.co/owkin/phikon) | TCGA | 0.715|0.942|0.766|0.925|0.879|0.784|
+| [UNI - DINOv2 ViT-L16](https://huggingface.co/MahmoodLab/UNI) | Mass-100k | 0.797|0.95|0.835|0.939|0.933|0.774|
+| [kaiko.ai - DINO ViT-S16](https://github.com/kaiko-ai/towards_large_pathology_fms) | TCGA | 0.8|0.949|0.831|0.902|0.897|0.77|
+| [kaiko.ai - DINO ViT-S8](https://github.com/kaiko-ai/towards_large_pathology_fms) | TCGA | 0.825|0.948|0.826|0.887|0.879|0.741|
+| [kaiko.ai - DINO ViT-B16](https://github.com/kaiko-ai/towards_large_pathology_fms) | TCGA | 0.846|0.959|0.839|0.906|0.891|0.753|
+| [kaiko.ai - DINO ViT-B8](https://github.com/kaiko-ai/towards_large_pathology_fms) | TCGA | 0.867|0.952|0.814|0.921|0.939|0.761|
+| [kaiko.ai - DINOv2 ViT-L14](https://github.com/kaiko-ai/towards_large_pathology_fms)| TCGA | 0.862|0.935|0.822|0.907|0.941|0.769|
+
+
+
+![Screenshot](images/starplot.png)
+
+
+
+
+
+The runs use the default setup described in the section below.
+
+*eva* trains the decoder on the "train" split and uses the "validation" split for monitoring, early stopping and checkpoint selection. Evaluation results are reported on the "test" split if available and otherwise on the "validation" split.
+
+For details on the FM-backbones and instructions to replicate the results, check out [Replicate evaluations](user-guide/advanced/replicate_evaluations.md). For information on the tasks, check out [Datasets](datasets/index.md).
+
+## Evaluation protocol
+
+*eva* uses a task- & model-independent and fixed default set up which closely follows the standard evaluation protocol proposed by [1] (with adjustments for slide-level tasks to ensure convergence and computational efficiency).
+
+We selected this approach to prioritize reliable, robust and fair FM-evaluation while being in line with common literature.
+
+| | WSI patch-level tasks | WSI slide-level tasks |
+|--------------------------------|---------------------------|---------------------------|
+| **Backbone** | frozen | frozen |
+| **Head** | single layer MLP | ABMIL |
+| **Dropout** | 0.0 | 0.0 |
+| **Hidden activation function** | n/a | ReLU |
+| **Output activation function** | none | none |
+| **Number of steps** | 12,500 | 12,500 (2) |
+| **Base batch size** | 4,096 (1) | 32 |
+| **Base learning rate** | 0.01 (1) | 0.001 |
+| **Early stopping** | 5% * [Max epochs] | 10% * [Max epochs] (3) |
+| **Optimizer** | SGD | AdamW |
+| **Momentum** | 0.9 | n/a |
+| **Weight Decay** | 0.0 | n/a |
+| **betas** | n/a | [0.9, 0.999] |
+| **LR Schedule** | Cosine without warmup | Cosine without warmup |
+| **number of patches per slide**| 1 | dataset specific (4) |
+
+
+(1) For smaller datasets (e.g. BACH with 400 samples) we reduce the batch size to 256 and scale the learning rate accordingly.
+
+(2) Upper cap at a maximum of 100 epochs.
+
+(3) Lower cap at a minimum of 8 epochs.
+
+(4) Number of patches per slide depends on task and slide size. For PANDA and Camelyon16 we use a max of 1,000 and 10,000 random patches per slide respectively.
+
+
+- [1]: [Virchow: A Million-Slide Digital Pathology Foundation Model, 2024](https://arxiv.org/pdf/2309.07778.pdf)
diff --git a/docs/user-guide/advanced/replicate_evaluations.md b/docs/user-guide/advanced/replicate_evaluations.md
index d3770586..964fa711 100644
--- a/docs/user-guide/advanced/replicate_evaluations.md
+++ b/docs/user-guide/advanced/replicate_evaluations.md
@@ -4,7 +4,7 @@ To produce the evaluation results presented [here](../../index.md#evaluation-res
Make sure to replace `` in the commands below with `bach`, `crc`, `mhist` or `patch_camelyon`.
-Note that to run the commands below you will need to first download the data. [BACH](../../datasets/bach.md), [CRC](../../datasets/crc.md) and [PatchCamelyon](../../datasets/patch_camelyon.md) provide automatic download by setting the argument `download: true` (either modify the config-files or set the environment variable `DOWNLOAD=true`). In the case of MHIST you will need to download the data manually by following the instructions provided [here](../../datasets/mhist.md#download-and-preprocessing).*
+*Note that to run the commands below you will need to first download the data. [BACH](../../datasets/bach.md), [CRC](../../datasets/crc.md) and [PatchCamelyon](../../datasets/patch_camelyon.md) provide automatic download by setting the argument `download: true` (either modify the config-files or set the environment variable `DOWNLOAD=true`). In the case of MHIST you will need to download the data manually by following the instructions provided [here](../../datasets/mhist.md#download-and-preprocessing).*
## DINO ViT-S16 (random weights)
@@ -25,29 +25,6 @@ EMBEDDINGS_ROOT="./data/embeddings/dino_vits16_imagenet" \
eva predict_fit --config configs/vision/dino_vit/offline/.yaml
```
-## DINO ViT-B8 (ImageNet)
-
-To evaluate performance on the larger ViT-B8 backbone pretrained on ImageNet, run:
-```
-EMBEDDINGS_ROOT="./data/embeddings/dino_vitb8_imagenet" \
-DINO_BACKBONE=dino_vitb8 \
-IN_FEATURES=768 \
-eva predict_fit --config configs/vision/dino_vit/offline/.yaml
-```
-
-## DINOv2 ViT-L14 (ImageNet)
-
-To evaluate performance on Dino v2 ViT-L14 backbone pretrained on ImageNet, run:
-```
-PRETRAINED=true \
-EMBEDDINGS_ROOT="./data/embeddings/dinov2_vitl14_kaiko" \
-REPO_OR_DIR=facebookresearch/dinov2:main \
-DINO_BACKBONE=dinov2_vitl14_reg \
-FORCE_RELOAD=true \
-IN_FEATURES=1024 \
-eva predict_fit --config configs/vision/dino_vit/offline/.yaml
-```
-
## Lunit - DINO ViT-S16 (TCGA)
[Lunit](https://www.lunit.io/en), released the weights for a DINO ViT-S16 backbone, pretrained on TCGA data
@@ -110,12 +87,13 @@ eva predict_fit --config path/to/.yaml
## kaiko.ai - DINO ViT-S16 (TCGA)
To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with DINO ViT-S16 backbone, pretrained on TCGA data
-on [GitHub](https://github.com/lunit-io/benchmark-ssl-pathology/releases/), run:
+and available on [GitHub](https://github.com/kaiko-ai/towards_large_pathology_fms), run:
```
PRETRAINED=false \
EMBEDDINGS_ROOT="./data/embeddings/dino_vits16_kaiko" \
-CHECKPOINT_PATH=[TBD*] \
+REPO_OR_DIR="kaiko-ai/towards_large_pathology_fms" \
+DINO_BACKBONE=="vits16" \
NORMALIZE_MEAN=[0.5,0.5,0.5] \
NORMALIZE_STD=[0.5,0.5,0.5] \
eva predict_fit --config configs/vision/dino_vit/offline/.yaml
@@ -126,13 +104,13 @@ eva predict_fit --config configs/vision/dino_vit/offline/.yaml
## kaiko.ai - DINO ViT-S8 (TCGA)
To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with DINO ViT-S8 backbone, pretrained on TCGA data
-on [GitHub](https://github.com/lunit-io/benchmark-ssl-pathology/releases/), run:
+and available on [GitHub](https://github.com/kaiko-ai/towards_large_pathology_fms), run:
```
PRETRAINED=false \
EMBEDDINGS_ROOT="./data/embeddings/dino_vits8_kaiko" \
-DINO_BACKBONE=dino_vits8 \
-CHECKPOINT_PATH=[TBD*] \
+REPO_OR_DIR="kaiko-ai/towards_large_pathology_fms" \
+DINO_BACKBONE=="vits8" \
NORMALIZE_MEAN=[0.5,0.5,0.5] \
NORMALIZE_STD=[0.5,0.5,0.5] \
eva predict_fit --config configs/vision/dino_vit/offline/.yaml
@@ -142,14 +120,14 @@ eva predict_fit --config configs/vision/dino_vit/offline/.yaml
## kaiko.ai - DINO ViT-B16 (TCGA)
-To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with the larger DINO ViT-B16 backbone, pretrained on TCGA data,
-run:
+To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with DINO ViT-B16 backbone, pretrained on TCGA data
+and available on [GitHub](https://github.com/kaiko-ai/towards_large_pathology_fms), run:
```
PRETRAINED=false \
EMBEDDINGS_ROOT="./data/embeddings/dino_vitb16_kaiko" \
-DINO_BACKBONE=dino_vitb16 \
-CHECKPOINT_PATH=[TBD*] \
+REPO_OR_DIR="kaiko-ai/towards_large_pathology_fms" \
+DINO_BACKBONE=="vitb16" \
IN_FEATURES=768 \
NORMALIZE_MEAN=[0.5,0.5,0.5] \
NORMALIZE_STD=[0.5,0.5,0.5] \
@@ -160,14 +138,14 @@ eva predict_fit --config configs/vision/dino_vit/offline/.yaml
## kaiko.ai - DINO ViT-B8 (TCGA)
-To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with the larger DINO ViT-B8 backbone, pretrained on TCGA data,
-run:
+To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with DINO ViT-B8 backbone, pretrained on TCGA data
+and available on [GitHub](https://github.com/kaiko-ai/towards_large_pathology_fms), run:
```
PRETRAINED=false \
EMBEDDINGS_ROOT="./data/embeddings/dino_vitb8_kaiko" \
-DINO_BACKBONE=dino_vitb8 \
-CHECKPOINT_PATH=[TBD*] \
+REPO_OR_DIR="kaiko-ai/towards_large_pathology_fms" \
+DINO_BACKBONE=="vitb8" \
IN_FEATURES=768 \
NORMALIZE_MEAN=[0.5,0.5,0.5] \
NORMALIZE_STD=[0.5,0.5,0.5] \
@@ -178,14 +156,14 @@ eva predict_fit --config configs/vision/dino_vit/offline/.yaml
## kaiko.ai - DINOv2 ViT-L14 (TCGA)
-To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with the larger DINOv2 ViT-L14 backbone, pretrained on TCGA data,
-run:
+To evaluate [kaiko.ai's](https://www.kaiko.ai/) FM with DINOv2 ViT-L14 backbone, pretrained on TCGA data
+and available on [GitHub](https://github.com/kaiko-ai/towards_large_pathology_fms), run:
```
PRETRAINED=false \
EMBEDDINGS_ROOT="./data/embeddings/dinov2_vitl14_kaiko" \
-REPO_OR_DIR=facebookresearch/dinov2:main \
-DINO_BACKBONE=dinov2_vitl14_reg \
+REPO_OR_DIR="kaiko-ai/towards_large_pathology_fms" \
+DINO_BACKBONE=="vitbl14" \
FORCE_RELOAD=true \
CHECKPOINT_PATH=[TBD*] \
IN_FEATURES=1024 \
diff --git a/mkdocs.yml b/mkdocs.yml
index a1642ff8..6584e7e6 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -65,6 +65,7 @@ markdown_extensions:
- pymdownx.superfences
nav:
- Introduction: index.md
+ - Leaderboards: leaderboards.md
- User Guide:
- user-guide/index.md
- Getting started:
@@ -78,11 +79,15 @@ nav:
- user-guide/advanced/model_wrappers.md
- Datasets:
- datasets/index.md
- - WSI-patches:
- - BACH: datasets/bach.md
- - CRC: datasets/crc.md
- - MHIST: datasets/mhist.md
- - PatchCamelyon: datasets/patch_camelyon.md
+ - WSI:
+ - Patch-level:
+ - BACH: datasets/bach.md
+ - CRC: datasets/crc.md
+ - MHIST: datasets/mhist.md
+ - PatchCamelyon: datasets/patch_camelyon.md
+ - Slide-level:
+ - Camelyon16: datasets/camelyon16.md
+ - PANDA: datasets/panda.md
- Radiology:
- TotalSegmentator: datasets/total_segmentator.md
- Reference API:
diff --git a/pdm.lock b/pdm.lock
index 59abd0d0..b881e343 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -1049,6 +1049,7 @@ dependencies = [
]
files = [
{file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"},
+ {file = "mkdocs_redirects-1.2.1-py3-none-any.whl", hash = "sha256:497089f9e0219e7389304cffefccdfa1cac5ff9509f2cb706f4c9b221726dffb"},
]
[[package]]
@@ -1537,7 +1538,7 @@ files = [
[[package]]
name = "opencv-python-headless"
-version = "4.9.0.80"
+version = "4.10.0.82"
requires_python = ">=3.6"
summary = "Wrapper package for OpenCV python bindings."
groups = ["all", "vision"]
@@ -1552,13 +1553,32 @@ dependencies = [
"numpy>=1.26.0; python_version >= \"3.12\"",
]
files = [
- {file = "opencv-python-headless-4.9.0.80.tar.gz", hash = "sha256:71a4cd8cf7c37122901d8e81295db7fb188730e33a0e40039a4e59c1030b0958"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:2ea8a2edc4db87841991b2fbab55fc07b97ecb602e0f47d5d485bd75cee17c1a"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e0ee54e27be493e8f7850847edae3128e18b540dac1d7b2e4001b8944e11e1c6"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57ce2865e8fec431c6f97a81e9faaf23fa5be61011d0a75ccf47a3c0d65fa73d"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:976656362d68d9f40a5c66f83901430538002465f7db59142784f3893918f3df"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-win32.whl", hash = "sha256:11e3849d83e6651d4e7699aadda9ec7ed7c38957cbbcb99db074f2a2d2de9670"},
- {file = "opencv_python_headless-4.9.0.80-cp37-abi3-win_amd64.whl", hash = "sha256:a8056c2cb37cd65dfcdf4153ca16f7362afcf3a50d600d6bb69c660fc61ee29c"},
+ {file = "opencv-python-headless-4.10.0.82.tar.gz", hash = "sha256:de9e742c1b9540816fbd115b0b03841d41ed0c65566b0d7a5371f98b131b7e6d"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a09ed50ba21cc5bf5d436cb0e784ad09c692d6b1d1454252772f6c8f2c7b4088"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:977a5fd21e1fe0d3d2134887db4441f8725abeae95150126302f31fcd9f548fa"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db4ec6755838b0be12510bfc9ffb014779c612418f11f4f7e6f505c36124a3aa"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10a37fa5276967ecf6eb297295b16b28b7a2eb3b568ca0ee469fb1a5954de298"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-win32.whl", hash = "sha256:94736e9b322d13db4768fd35588ad5e8995e78e207263076bfbee18aac835ad5"},
+ {file = "opencv_python_headless-4.10.0.82-cp37-abi3-win_amd64.whl", hash = "sha256:c1822fa23d1641c0249ed5eb906f4c385f7959ff1bd601a776d56b0c18914af4"},
+]
+
+[[package]]
+name = "openslide-python"
+version = "1.3.1"
+requires_python = ">=3.8"
+summary = "Python interface to OpenSlide"
+groups = ["all", "vision"]
+dependencies = [
+ "Pillow",
+]
+files = [
+ {file = "openslide-python-1.3.1.tar.gz", hash = "sha256:0909c6257cd8decfbbd0082e8c0cd94bbe3a89ad31e142cfa9accc8bb959294e"},
+ {file = "openslide_python-1.3.1-cp310-cp310-win32.whl", hash = "sha256:7a5c0c5bddb518f3e643d0ce2e8d5dfe6b3a374a966ca2c316ef56196dd3c602"},
+ {file = "openslide_python-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:d208e53d3da82313213303058b2ca9fc66c2d98365b9338e27ecc46ab8b07e9d"},
+ {file = "openslide_python-1.3.1-cp311-cp311-win32.whl", hash = "sha256:c4720598ba39e7b879e757eff31195f8b80d4638dcb0fbb297ca9823039724ae"},
+ {file = "openslide_python-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:53a468cd92bdd17cf5b56592242709519c0c7d7028b2f466d20d75264471cc6d"},
+ {file = "openslide_python-1.3.1-cp312-cp312-win32.whl", hash = "sha256:d10caf1a1c1e1f598d80e7a5e1a266979ed9bccf9ba8bf45aa34cf04639d9f9e"},
+ {file = "openslide_python-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:d834fbca0824b902da9d8541f7c34a3e62496823a42de5ac7bf6c35e4c799678"},
]
[[package]]
@@ -1722,18 +1742,18 @@ files = [
[[package]]
name = "protobuf"
-version = "5.26.0"
+version = "4.25.3"
requires_python = ">=3.8"
summary = ""
groups = ["default"]
files = [
- {file = "protobuf-5.26.0-cp310-abi3-win32.whl", hash = "sha256:f9ecc8eb6f18037e0cbf43256db0325d4723f429bca7ef5cd358b7c29d65f628"},
- {file = "protobuf-5.26.0-cp310-abi3-win_amd64.whl", hash = "sha256:dfd29f6eb34107dccf289a93d44fb6b131e68888d090b784b691775ac84e8213"},
- {file = "protobuf-5.26.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:7e47c57303466c867374a17b2b5e99c5a7c8b72a94118e2f28efb599f19b4069"},
- {file = "protobuf-5.26.0-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e184175276edc222e2d5e314a72521e10049938a9a4961fe4bea9b25d073c03f"},
- {file = "protobuf-5.26.0-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:6ee9d1aa02f951c5ce10bf8c6cfb7604133773038e33f913183c8b5201350600"},
- {file = "protobuf-5.26.0-py3-none-any.whl", hash = "sha256:a49b6c5359bf34fb7bf965bf21abfab4476e4527d822ab5289ee3bf73f291159"},
- {file = "protobuf-5.26.0.tar.gz", hash = "sha256:82f5870d74c99addfe4152777bdf8168244b9cf0ac65f8eccf045ddfa9d80d9b"},
+ {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
+ {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
+ {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
+ {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
+ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
]
[[package]]
@@ -2187,7 +2207,7 @@ files = [
[[package]]
name = "tensorboard"
-version = "2.16.2"
+version = "2.17.0"
requires_python = ">=3.9"
summary = "TensorBoard lets you watch Tensors Flow"
groups = ["default"]
@@ -2196,14 +2216,14 @@ dependencies = [
"grpcio>=1.48.2",
"markdown>=2.6.8",
"numpy>=1.12.0",
- "protobuf!=4.24.0,>=3.19.6",
+ "protobuf!=4.24.0,<5.0.0,>=3.19.6",
"setuptools>=41.0.0",
"six>1.9",
"tensorboard-data-server<0.8.0,>=0.7.0",
"werkzeug>=1.0.1",
]
files = [
- {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"},
+ {file = "tensorboard-2.17.0-py3-none-any.whl", hash = "sha256:859a499a9b1fb68a058858964486627100b71fcb21646861c61d31846a6478fb"},
]
[[package]]
diff --git a/pyproject.toml b/pyproject.toml
index f3fa10ed..97c2dfb3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -62,6 +62,7 @@ vision = [
"opencv-python-headless>=4.9.0.80",
"timm>=0.9.12",
"torchvision>=0.17.0",
+ "openslide-python>=1.3.1",
]
all = [
"h5py>=3.10.0",
@@ -69,6 +70,7 @@ all = [
"opencv-python-headless>=4.9.0.80",
"timm>=0.9.12",
"torchvision>=0.17.0",
+ "openslide-python>=1.3.1",
]
[project.scripts]
diff --git a/src/eva/core/data/datasets/classification/multi_embeddings.py b/src/eva/core/data/datasets/classification/multi_embeddings.py
index 2d42df76..130c17b7 100644
--- a/src/eva/core/data/datasets/classification/multi_embeddings.py
+++ b/src/eva/core/data/datasets/classification/multi_embeddings.py
@@ -42,6 +42,8 @@ def __init__(
the `root` argument.
split: The dataset split to use. The `split` column of the manifest
file will be splitted based on this value.
+ n_embeddings: Expected number of embeddings per sample. If less, the embeddings
+ will be padded with zeros.
column_mapping: Defines the map between the variables and the manifest
columns. It will overwrite the `default_column_mapping` with
the provided values, so that `column_mapping` can contain only the
diff --git a/src/eva/core/data/datasets/embeddings.py b/src/eva/core/data/datasets/embeddings.py
index 1ad9edb6..81b22ad1 100644
--- a/src/eva/core/data/datasets/embeddings.py
+++ b/src/eva/core/data/datasets/embeddings.py
@@ -19,7 +19,7 @@
"path": "embeddings",
"target": "target",
"split": "split",
- "multi_id": "slide_id",
+ "multi_id": "wsi_id",
}
"""The default column mapping of the variables to the manifest columns."""
diff --git a/src/eva/core/data/splitting/__init__.py b/src/eva/core/data/splitting/__init__.py
new file mode 100644
index 00000000..5faeccd6
--- /dev/null
+++ b/src/eva/core/data/splitting/__init__.py
@@ -0,0 +1,5 @@
+"""Dataset splitting API."""
+
+from eva.core.data.splitting.stratified import stratified_split
+
+__all__ = ["stratified_split"]
diff --git a/src/eva/core/data/splitting/stratified.py b/src/eva/core/data/splitting/stratified.py
new file mode 100644
index 00000000..ad9377a7
--- /dev/null
+++ b/src/eva/core/data/splitting/stratified.py
@@ -0,0 +1,56 @@
+"""Functions for stratified splitting."""
+
+from typing import Any, List, Sequence, Tuple
+
+import numpy as np
+
+
+def stratified_split(
+ samples: Sequence[Any],
+ targets: Sequence[Any],
+ train_ratio: float,
+ val_ratio: float,
+ test_ratio: float = 0.0,
+ seed: int = 42,
+) -> Tuple[List[int], List[int], List[int] | None]:
+ """Splits the samples into stratified train, validation, and test (optional) sets.
+
+ Args:
+ samples: The samples to split.
+ targets: The corresponding targets used for stratification.
+ train_ratio: The ratio of the training set.
+ val_ratio: The ratio of the validation set.
+ test_ratio: The ratio of the test set (optional).
+ seed: The seed for reproducibility.
+
+ Returns:
+ The indices of the train, validation, and test sets.
+ """
+ if len(samples) != len(targets):
+ raise ValueError("The number of samples and targets must be equal.")
+ if train_ratio + val_ratio + (test_ratio or 0) != 1:
+ raise ValueError("The sum of the ratios must be equal to 1.")
+
+ np.random.seed(seed)
+ unique_classes, y_indices = np.unique(targets, return_inverse=True)
+ n_classes = unique_classes.shape[0]
+
+ train_indices, val_indices, test_indices = [], [], []
+
+ for c in range(n_classes):
+ class_indices = np.where(y_indices == c)[0]
+ np.random.shuffle(class_indices)
+
+ n_train = int(np.floor(train_ratio * len(class_indices))) or 1
+ n_val = (
+ len(class_indices) - n_train
+ if test_ratio == 0.0
+ else int(np.floor(val_ratio * len(class_indices))) or 1
+ )
+
+ train_indices.extend(class_indices[:n_train])
+ val_indices.extend(class_indices[n_train : n_train + n_val])
+ if test_ratio > 0.0:
+ test_indices.extend(class_indices[n_train + n_val :])
+
+ return train_indices, val_indices, test_indices or None
diff --git a/src/eva/core/models/networks/mlp.py b/src/eva/core/models/networks/mlp.py
index 4decad2a..c8403dbe 100644
--- a/src/eva/core/models/networks/mlp.py
+++ b/src/eva/core/models/networks/mlp.py
@@ -1,6 +1,6 @@
"""Multi-layer Perceptron (MLP) implemented in PyTorch."""
-from typing import Type
+from typing import Tuple, Type
import torch
import torch.nn as nn
@@ -13,7 +13,7 @@ def __init__(
self,
input_size: int,
output_size: int,
- hidden_layer_sizes: tuple[int, ...] | None = None,
+ hidden_layer_sizes: Tuple[int, ...] | None = None,
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
output_activation_fn: Type[torch.nn.Module] | None = None,
dropout: float = 0.0,
diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py
index ff9ad73e..7d05c16e 100644
--- a/src/eva/vision/data/datasets/__init__.py
+++ b/src/eva/vision/data/datasets/__init__.py
@@ -1,8 +1,17 @@
"""Vision Datasets API."""
-from eva.vision.data.datasets.classification import BACH, CRC, MHIST, PatchCamelyon
+from eva.vision.data.datasets.classification import (
+ BACH,
+ CRC,
+ MHIST,
+ PANDA,
+ Camelyon16,
+ PatchCamelyon,
+ WsiClassificationDataset,
+)
from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
from eva.vision.data.datasets.vision import VisionDataset
+from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
__all__ = [
"BACH",
@@ -10,6 +19,11 @@
"MHIST",
"ImageSegmentation",
"PatchCamelyon",
+ "PANDA",
+ "Camelyon16",
"TotalSegmentator2D",
"VisionDataset",
+ "WsiDataset",
+ "MultiWsiDataset",
+ "WsiClassificationDataset",
]
diff --git a/src/eva/vision/data/datasets/_validators.py b/src/eva/vision/data/datasets/_validators.py
index 9989bc45..ef6407e4 100644
--- a/src/eva/vision/data/datasets/_validators.py
+++ b/src/eva/vision/data/datasets/_validators.py
@@ -13,7 +13,7 @@
def check_dataset_integrity(
dataset: vision.VisionDataset,
*,
- length: int,
+ length: int | None,
n_classes: int,
first_and_last_labels: Tuple[str, str],
) -> None:
@@ -23,7 +23,7 @@ def check_dataset_integrity(
ValueError: If the input dataset's values do not
match the expected ones.
"""
- if len(dataset) != length:
+ if length and len(dataset) != length:
raise ValueError(
f"Dataset's '{dataset.__class__.__qualname__}' length "
f"({len(dataset)}) does not match the expected one ({length}). "
@@ -57,3 +57,16 @@ def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
if download_available:
error_message += " You can set `download=True` to download the dataset automatically."
raise FileNotFoundError(error_message)
+
+
+def check_number_of_files(file_paths: List[str], expected_length: int, split: str | None) -> None:
+ """Verifies the number of files in the dataset.
+
+ Raise:
+ ValueError: If the number of files in the dataset does not match the expected one.
+ """
+ if len(file_paths) != expected_length:
+ raise ValueError(
+ f"Expected {expected_length} files, for split '{split}' found {len(file_paths)}. "
+ f"{_SUFFIX_ERROR_MESSAGE}"
+ )
diff --git a/src/eva/vision/data/datasets/classification/__init__.py b/src/eva/vision/data/datasets/classification/__init__.py
index a300cfc2..c9daabbe 100644
--- a/src/eva/vision/data/datasets/classification/__init__.py
+++ b/src/eva/vision/data/datasets/classification/__init__.py
@@ -1,8 +1,19 @@
"""Image classification datasets API."""
from eva.vision.data.datasets.classification.bach import BACH
+from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
from eva.vision.data.datasets.classification.crc import CRC
from eva.vision.data.datasets.classification.mhist import MHIST
+from eva.vision.data.datasets.classification.panda import PANDA
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
+from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
-__all__ = ["BACH", "CRC", "MHIST", "PatchCamelyon"]
+__all__ = [
+ "BACH",
+ "CRC",
+ "MHIST",
+ "PatchCamelyon",
+ "WsiClassificationDataset",
+ "PANDA",
+ "Camelyon16",
+]
diff --git a/src/eva/vision/data/datasets/classification/base.py b/src/eva/vision/data/datasets/classification/base.py
index fb358cb7..1127f6db 100644
--- a/src/eva/vision/data/datasets/classification/base.py
+++ b/src/eva/vision/data/datasets/classification/base.py
@@ -35,12 +35,11 @@ def classes(self) -> List[str] | None:
def class_to_idx(self) -> Dict[str, int] | None:
"""Returns a mapping of the class name to its target index."""
- def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
"""Returns the dataset metadata.
Args:
index: The index of the data sample to return the metadata of.
- If `None`, it will return the metadata of the current dataset.
Returns:
The sample metadata.
@@ -74,10 +73,11 @@ def __len__(self) -> int:
raise NotImplementedError
@override
- def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor]:
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
image = self.load_image(index)
target = self.load_target(index)
- return self._apply_transforms(image, target)
+ image, target = self._apply_transforms(image, target)
+ return image, target, self.load_metadata(index) or {}
def _apply_transforms(
self, image: tv_tensors.Image, target: torch.Tensor
diff --git a/src/eva/vision/data/datasets/classification/camelyon16.py b/src/eva/vision/data/datasets/classification/camelyon16.py
new file mode 100644
index 00000000..10846440
--- /dev/null
+++ b/src/eva/vision/data/datasets/classification/camelyon16.py
@@ -0,0 +1,247 @@
+"""Camelyon16 dataset class."""
+
+import functools
+import glob
+import os
+from typing import Any, Callable, Dict, List, Literal, Tuple
+
+import pandas as pd
+import torch
+from torchvision import tv_tensors
+from torchvision.transforms.v2 import functional
+from typing_extensions import override
+
+from eva.vision.data.datasets import _validators, wsi
+from eva.vision.data.datasets.classification import base
+from eva.vision.data.wsi.patching import samplers
+
+
+class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
+ """Dataset class for Camelyon16 images and corresponding targets."""
+
+ _val_slides = [
+ "normal_010",
+ "normal_013",
+ "normal_016",
+ "normal_017",
+ "normal_019",
+ "normal_020",
+ "normal_025",
+ "normal_030",
+ "normal_031",
+ "normal_032",
+ "normal_052",
+ "normal_056",
+ "normal_057",
+ "normal_067",
+ "normal_076",
+ "normal_079",
+ "normal_085",
+ "normal_095",
+ "normal_098",
+ "normal_099",
+ "normal_101",
+ "normal_102",
+ "normal_105",
+ "normal_106",
+ "normal_109",
+ "normal_129",
+ "normal_132",
+ "normal_137",
+ "normal_142",
+ "normal_143",
+ "normal_148",
+ "normal_152",
+ "tumor_001",
+ "tumor_005",
+ "tumor_011",
+ "tumor_012",
+ "tumor_013",
+ "tumor_019",
+ "tumor_031",
+ "tumor_037",
+ "tumor_043",
+ "tumor_046",
+ "tumor_057",
+ "tumor_065",
+ "tumor_069",
+ "tumor_071",
+ "tumor_073",
+ "tumor_079",
+ "tumor_080",
+ "tumor_081",
+ "tumor_082",
+ "tumor_085",
+ "tumor_097",
+ "tumor_109",
+ ]
+ """Validation slide names, same as the ones in patch camelyon."""
+
+ def __init__(
+ self,
+ root: str,
+ sampler: samplers.Sampler,
+ split: Literal["train", "val", "test"] | None = None,
+ width: int = 224,
+ height: int = 224,
+ target_mpp: float = 0.5,
+ backend: str = "openslide",
+ image_transforms: Callable | None = None,
+ seed: int = 42,
+ ) -> None:
+ """Initializes the dataset.
+
+ Args:
+ root: Root directory of the dataset.
+ sampler: The sampler to use for sampling patch coordinates.
+ split: Dataset split to use. If `None`, the entire dataset is used.
+ width: Width of the patches to be extracted, in pixels.
+ height: Height of the patches to be extracted, in pixels.
+ target_mpp: Target microns per pixel (mpp) for the patches.
+ backend: The backend to use for reading the whole-slide images.
+ image_transforms: Transforms to apply to the extracted image patches.
+ seed: Random seed for reproducibility.
+ """
+ self._split = split
+ self._root = root
+ self._width = width
+ self._height = height
+ self._target_mpp = target_mpp
+ self._seed = seed
+
+ wsi.MultiWsiDataset.__init__(
+ self,
+ root=root,
+ file_paths=self._load_file_paths(split),
+ width=width,
+ height=height,
+ sampler=sampler,
+ target_mpp=target_mpp,
+ backend=backend,
+ image_transforms=image_transforms,
+ )
+
+ @property
+ @override
+ def classes(self) -> List[str]:
+ return ["normal", "tumor"]
+
+ @property
+ @override
+ def class_to_idx(self) -> Dict[str, int]:
+ return {"normal": 0, "tumor": 1}
+
+ @functools.cached_property
+ def annotations_test_set(self) -> Dict[str, str]:
+ """Loads the dataset labels."""
+ path = os.path.join(self._root, "testing/reference.csv")
+ reference_df = pd.read_csv(path, header=None)
+ return {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
+
+ @functools.cached_property
+ def annotations(self) -> Dict[str, str]:
+ """Loads the dataset labels."""
+ annotations = {}
+ if self._split in ["test", None]:
+ path = os.path.join(self._root, "testing/reference.csv")
+ reference_df = pd.read_csv(path, header=None)
+ annotations.update(
+ {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
+ )
+
+ if self._split in ["train", "val", None]:
+ annotations.update(
+ {
+ self._get_id_from_path(file_path): self._get_class_from_path(file_path)
+ for file_path in self._file_paths
+ if "test" not in file_path
+ }
+ )
+ return annotations
+
+ @override
+ def prepare_data(self) -> None:
+ _validators.check_dataset_exists(self._root, True)
+
+ expected_directories = ["training/normal", "training/tumor", "testing/images"]
+ for resource in expected_directories:
+ if not os.path.isdir(os.path.join(self._root, resource)):
+ raise FileNotFoundError(f"'{resource}' not found in the root folder.")
+
+ if not os.path.isfile(os.path.join(self._root, "testing/reference.csv")):
+ raise FileNotFoundError("'reference.csv' file not found in the testing folder.")
+
+ @override
+ def validate(self) -> None:
+
+ expected_n_files = {
+ "train": 216,
+ "val": 54,
+ "test": 129,
+ None: 399,
+ }
+ length = expected_n_files[self._split]
+ _validators.check_number_of_files(self._file_paths, length, self._split)
+ _validators.check_dataset_integrity(
+ self,
+ length=None,
+ n_classes=2,
+ first_and_last_labels=("normal", "tumor"),
+ )
+
+ @override
+ def filename(self, index: int) -> str:
+ return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
+
+ @override
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
+ return base.ImageClassification.__getitem__(self, index)
+
+ @override
+ def load_image(self, index: int) -> tv_tensors.Image:
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
+ return functional.to_image(image_array)
+
+ @override
+ def load_target(self, index: int) -> torch.Tensor:
+ file_path = self._file_paths[self._get_dataset_idx(index)]
+ class_name = self.annotations[self._get_id_from_path(file_path)]
+ return torch.tensor(self.class_to_idx[class_name], dtype=torch.int64)
+
+ @override
+ def load_metadata(self, index: int) -> Dict[str, Any]:
+ return {"wsi_id": self.filename(index).split(".")[0]}
+
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
+ """Loads the file paths of the corresponding dataset split."""
+ train_paths, val_paths = [], []
+ for path in glob.glob(os.path.join(self._root, "training/**/*.tif")):
+ if self._get_id_from_path(path) in self._val_slides:
+ val_paths.append(path)
+ else:
+ train_paths.append(path)
+ test_paths = glob.glob(os.path.join(self._root, "testing/images", "*.tif"))
+
+ match split:
+ case "train":
+ paths = train_paths
+ case "val":
+ paths = val_paths
+ case "test":
+ paths = test_paths
+ case None:
+ paths = train_paths + val_paths + test_paths
+ case _:
+ raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
+ return sorted([os.path.relpath(path, self._root) for path in paths])
+
+ def _get_id_from_path(self, file_path: str) -> str:
+ """Extracts the slide ID from the file path."""
+ return os.path.basename(file_path).replace(".tif", "")
+
+ def _get_class_from_path(self, file_path: str) -> str:
+ """Extracts the class name from the file path."""
+ class_name = self._get_id_from_path(file_path).split("_")[0]
+ if class_name not in self.classes:
+ raise ValueError(f"Invalid class name '{class_name}' in file path '{file_path}'.")
+ return class_name
diff --git a/src/eva/vision/data/datasets/classification/panda.py b/src/eva/vision/data/datasets/classification/panda.py
new file mode 100644
index 00000000..b8d2f49c
--- /dev/null
+++ b/src/eva/vision/data/datasets/classification/panda.py
@@ -0,0 +1,188 @@
+"""PANDA dataset class."""
+
+import functools
+import glob
+import os
+from typing import Any, Callable, Dict, List, Literal, Tuple
+
+import pandas as pd
+import torch
+from torchvision import tv_tensors
+from torchvision.datasets import utils
+from torchvision.transforms.v2 import functional
+from typing_extensions import override
+
+from eva.core.data import splitting
+from eva.vision.data.datasets import _validators, structs, wsi
+from eva.vision.data.datasets.classification import base
+from eva.vision.data.wsi.patching import samplers
+
+
+class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
+ """Dataset class for PANDA images and corresponding targets."""
+
+ _train_split_ratio: float = 0.7
+ """Train split ratio."""
+
+ _val_split_ratio: float = 0.15
+ """Validation split ratio."""
+
+ _test_split_ratio: float = 0.15
+ """Test split ratio."""
+
+ _resources: List[structs.DownloadResource] = [
+ structs.DownloadResource(
+ filename="train_with_noisy_labels.csv",
+ url="https://raw.githubusercontent.com/analokmaus/kaggle-panda-challenge-public/master/train.csv",
+ md5="5e4bfc78bda9603d2e2faf3ed4b21dfa",
+ )
+ ]
+ """Download resources."""
+
+ def __init__(
+ self,
+ root: str,
+ sampler: samplers.Sampler,
+ split: Literal["train", "val", "test"] | None = None,
+ width: int = 224,
+ height: int = 224,
+ target_mpp: float = 0.5,
+ backend: str = "openslide",
+ image_transforms: Callable | None = None,
+ seed: int = 42,
+ ) -> None:
+ """Initializes the dataset.
+
+ Args:
+ root: Root directory of the dataset.
+ sampler: The sampler to use for sampling patch coordinates.
+ split: Dataset split to use. If `None`, the entire dataset is used.
+ width: Width of the patches to be extracted, in pixels.
+ height: Height of the patches to be extracted, in pixels.
+ target_mpp: Target microns per pixel (mpp) for the patches.
+ backend: The backend to use for reading the whole-slide images.
+ image_transforms: Transforms to apply to the extracted image patches.
+ seed: Random seed for reproducibility.
+ """
+ self._split = split
+ self._root = root
+ self._seed = seed
+
+ self._download_resources()
+
+ wsi.MultiWsiDataset.__init__(
+ self,
+ root=root,
+ file_paths=self._load_file_paths(split),
+ width=width,
+ height=height,
+ sampler=sampler,
+ target_mpp=target_mpp,
+ backend=backend,
+ image_transforms=image_transforms,
+ )
+
+ @property
+ @override
+ def classes(self) -> List[str]:
+ return ["0", "1", "2", "3", "4", "5"]
+
+ @functools.cached_property
+ def annotations(self) -> pd.DataFrame:
+ """Loads the dataset labels."""
+ path = os.path.join(self._root, "train_with_noisy_labels.csv")
+ return pd.read_csv(path, index_col="image_id")
+
+ @override
+ def prepare_data(self) -> None:
+ _validators.check_dataset_exists(self._root, False)
+
+ if not os.path.isdir(os.path.join(self._root, "train_images")):
+ raise FileNotFoundError("'train_images' directory not found in the root folder.")
+ if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
+ raise FileNotFoundError("'train.csv' file not found in the root folder.")
+
+ def _download_resources(self) -> None:
+ """Downloads the dataset resources."""
+ for resource in self._resources:
+ utils.download_url(resource.url, self._root, resource.filename, resource.md5)
+
+ @override
+ def validate(self) -> None:
+ _validators.check_dataset_integrity(
+ self,
+ length=None,
+ n_classes=6,
+ first_and_last_labels=("0", "5"),
+ )
+
+ @override
+ def filename(self, index: int) -> str:
+ return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
+
+ @override
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
+ return base.ImageClassification.__getitem__(self, index)
+
+ @override
+ def load_image(self, index: int) -> tv_tensors.Image:
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
+ return functional.to_image(image_array)
+
+ @override
+ def load_target(self, index: int) -> torch.Tensor:
+ file_path = self._file_paths[self._get_dataset_idx(index)]
+ return torch.tensor(self._get_target_from_path(file_path), dtype=torch.int64)
+
+ @override
+ def load_metadata(self, index: int) -> Dict[str, Any]:
+ return {"wsi_id": self.filename(index).split(".")[0]}
+
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
+ """Loads the file paths of the corresponding dataset split."""
+ image_dir = os.path.join(self._root, "train_images")
+ file_paths = sorted(glob.glob(os.path.join(image_dir, "*.tiff")))
+ file_paths = [os.path.relpath(path, self._root) for path in file_paths]
+ if len(file_paths) != len(self.annotations):
+ raise ValueError(
+ f"Expected {len(self.annotations)} images, found {len(file_paths)} in {image_dir}."
+ )
+ file_paths = self._filter_noisy_labels(file_paths)
+ targets = [self._get_target_from_path(file_path) for file_path in file_paths]
+
+ train_indices, val_indices, test_indices = splitting.stratified_split(
+ samples=file_paths,
+ targets=targets,
+ train_ratio=self._train_split_ratio,
+ val_ratio=self._val_split_ratio,
+ test_ratio=self._test_split_ratio,
+ seed=self._seed,
+ )
+
+ match split:
+ case "train":
+ return [file_paths[i] for i in train_indices]
+ case "val":
+ return [file_paths[i] for i in val_indices]
+ case "test":
+ return [file_paths[i] for i in test_indices or []]
+ case None:
+ return file_paths
+ case _:
+ raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
+
+ def _filter_noisy_labels(self, file_paths: List[str]):
+ is_noisy_filter = self.annotations["noise_ratio_10"] == 0
+ non_noisy_image_ids = set(self.annotations.loc[~is_noisy_filter].index)
+ filtered_file_paths = [
+ file_path
+ for file_path in file_paths
+ if self._get_id_from_path(file_path) in non_noisy_image_ids
+ ]
+ return filtered_file_paths
+
+ def _get_target_from_path(self, file_path: str) -> int:
+ return self.annotations.loc[self._get_id_from_path(file_path), "isup_grade"]
+
+ def _get_id_from_path(self, file_path: str) -> str:
+ return os.path.basename(file_path).replace(".tiff", "")
diff --git a/src/eva/vision/data/datasets/classification/wsi.py b/src/eva/vision/data/datasets/classification/wsi.py
new file mode 100644
index 00000000..3889be1e
--- /dev/null
+++ b/src/eva/vision/data/datasets/classification/wsi.py
@@ -0,0 +1,105 @@
+"""WSI classification dataset."""
+
+import os
+from typing import Any, Callable, Dict, Literal, Tuple
+
+import numpy as np
+import pandas as pd
+import torch
+from torchvision import tv_tensors
+from typing_extensions import override
+
+from eva.vision.data.datasets import wsi
+from eva.vision.data.datasets.classification import base
+from eva.vision.data.wsi.patching import samplers
+
+
+class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
+ """A general dataset class for whole-slide image classification using manifest files."""
+
+ default_column_mapping: Dict[str, str] = {
+ "path": "path",
+ "target": "target",
+ "split": "split",
+ }
+
+ def __init__(
+ self,
+ root: str,
+ manifest_file: str,
+ width: int,
+ height: int,
+ target_mpp: float,
+ sampler: samplers.Sampler,
+ backend: str = "openslide",
+ split: Literal["train", "val", "test"] | None = None,
+ image_transforms: Callable | None = None,
+ column_mapping: Dict[str, str] = default_column_mapping,
+ ):
+ """Initializes the dataset.
+
+ Args:
+ root: Root directory of the dataset.
+ manifest_file: The path to the manifest file, relative to
+ the `root` argument. The `path` column is expected to contain
+ relative paths to the whole-slide images.
+ width: Width of the patches to be extracted, in pixels.
+ height: Height of the patches to be extracted, in pixels.
+ target_mpp: Target microns per pixel (mpp) for the patches.
+ sampler: The sampler to use for sampling patch coordinates.
+ backend: The backend to use for reading the whole-slide images.
+ split: The split of the dataset to load.
+ image_transforms: Transforms to apply to the extracted image patches.
+ column_mapping: Mapping of the columns in the manifest file.
+ """
+ self._split = split
+ self._column_mapping = self.default_column_mapping | column_mapping
+ self._manifest = self._load_manifest(os.path.join(root, manifest_file))
+
+ wsi.MultiWsiDataset.__init__(
+ self,
+ root=root,
+ file_paths=self._manifest[self._column_mapping["path"]].tolist(),
+ width=width,
+ height=height,
+ sampler=sampler,
+ target_mpp=target_mpp,
+ backend=backend,
+ image_transforms=image_transforms,
+ )
+
+ @override
+ def filename(self, index: int) -> str:
+ path = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["path"]]
+ return os.path.basename(path) if os.path.isabs(path) else path
+
+ @override
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
+ return base.ImageClassification.__getitem__(self, index)
+
+ @override
+ def load_image(self, index: int) -> np.ndarray:
+ return wsi.MultiWsiDataset.__getitem__(self, index)
+
+ @override
+ def load_target(self, index: int) -> np.ndarray:
+ target = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["target"]]
+ return np.asarray(target)
+
+ @override
+ def load_metadata(self, index: int) -> Dict[str, Any]:
+ return {"wsi_id": self.filename(index).split(".")[0]}
+
+ def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
+ df = pd.read_csv(manifest_path)
+
+ missing_columns = set(self._column_mapping.values()) - set(df.columns)
+ if self._split is None:
+ missing_columns = missing_columns - {self._column_mapping["split"]}
+ if missing_columns:
+ raise ValueError(f"Missing columns in the manifest file: {missing_columns}")
+
+ if self._split is not None:
+ df = df.loc[df[self._column_mapping["split"]] == self._split]
+
+ return df.reset_index(drop=True)
diff --git a/src/eva/vision/data/datasets/wsi.py b/src/eva/vision/data/datasets/wsi.py
new file mode 100644
index 00000000..3557bfc5
--- /dev/null
+++ b/src/eva/vision/data/datasets/wsi.py
@@ -0,0 +1,150 @@
+"""Dataset classes for whole-slide images."""
+
+import bisect
+import os
+from typing import Callable, List
+
+from loguru import logger
+from torch.utils.data import dataset as torch_datasets
+from torchvision import tv_tensors
+from torchvision.transforms.v2 import functional
+from typing_extensions import override
+
+from eva.core.data.datasets import base
+from eva.vision.data import wsi
+from eva.vision.data.datasets import vision
+from eva.vision.data.wsi.patching import samplers
+
+
+class WsiDataset(vision.VisionDataset):
+ """Dataset class for reading patches from whole-slide images."""
+
+ def __init__(
+ self,
+ file_path: str,
+ width: int,
+ height: int,
+ target_mpp: float,
+ sampler: samplers.Sampler,
+ backend: str = "openslide",
+ image_transforms: Callable | None = None,
+ ):
+ """Initializes a new dataset instance.
+
+ Args:
+ file_path: Path to the whole-slide image file.
+ width: Width of the patches to be extracted, in pixels.
+ height: Height of the patches to be extracted, in pixels.
+ target_mpp: Target microns per pixel (mpp) for the patches.
+ sampler: The sampler to use for sampling patch coordinates.
+ backend: The backend to use for reading the whole-slide images.
+ image_transforms: Transforms to apply to the extracted image patches.
+ """
+ self._file_path = file_path
+ self._width = width
+ self._height = height
+ self._target_mpp = target_mpp
+ self._sampler = sampler
+ self._backend = backend
+ self._image_transforms = image_transforms
+
+ @override
+ def __len__(self):
+ return len(self._coords.x_y)
+
+ @override
+ def filename(self, index: int) -> str:
+ return f"{self._file_path}_{index}"
+
+ @property
+ def _wsi(self) -> wsi.Wsi:
+ return wsi.get_cached_wsi(self._file_path, self._backend)
+
+ @property
+ def _coords(self) -> wsi.PatchCoordinates:
+ return wsi.get_cached_coords(
+ file_path=self._file_path,
+ width=self._width,
+ height=self._height,
+ target_mpp=self._target_mpp,
+ sampler=self._sampler,
+ backend=self._backend,
+ )
+
+ @override
+ def __getitem__(self, index: int) -> tv_tensors.Image:
+ x, y = self._coords.x_y[index]
+ width, height, level_idx = self._coords.width, self._coords.height, self._coords.level_idx
+ patch = self._wsi.read_region((x, y), level_idx, (width, height))
+ patch = functional.to_image(patch)
+ patch = self._apply_transforms(patch)
+ return patch
+
+ def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
+ if self._image_transforms is not None:
+ image = self._image_transforms(image)
+ return image
+
+
+class MultiWsiDataset(torch_datasets.ConcatDataset, base.Dataset):
+ """Dataset class for reading patches from multiple whole-slide images."""
+
+ def __init__(
+ self,
+ root: str,
+ file_paths: List[str],
+ width: int,
+ height: int,
+ target_mpp: float,
+ sampler: samplers.Sampler,
+ backend: str = "openslide",
+ image_transforms: Callable | None = None,
+ ):
+ """Initializes a new dataset instance.
+
+ Args:
+ root: Root directory of the dataset.
+ file_paths: List of paths to the whole-slide image files, relative to the root.
+ width: Width of the patches to be extracted, in pixels.
+ height: Height of the patches to be extracted, in pixels.
+ target_mpp: Target microns per pixel (mpp) for the patches.
+ sampler: The sampler to use for sampling patch coordinates.
+ backend: The backend to use for reading the whole-slide images.
+ image_transforms: Transforms to apply to the extracted image patches.
+ """
+ self._root = root
+ self._file_paths = file_paths
+ self._width = width
+ self._height = height
+ self._target_mpp = target_mpp
+ self._sampler = sampler
+ self._backend = backend
+ self._image_transforms = image_transforms
+
+ @override
+ def setup(self):
+ super().__init__(self._load_datasets())
+
+ def _load_datasets(self) -> list[WsiDataset]:
+ logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
+ wsi_datasets = []
+ for file_path in self._file_paths:
+ file_path = os.path.join(self._root, file_path) if self._root else file_path
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"File not found: {file_path}")
+
+ wsi_datasets.append(
+ WsiDataset(
+ file_path=file_path,
+ width=self._width,
+ height=self._height,
+ target_mpp=self._target_mpp,
+ sampler=self._sampler,
+ backend=self._backend,
+ image_transforms=self._image_transforms,
+ )
+ )
+ return wsi_datasets
+
+ def _get_dataset_idx(self, index: int) -> int:
+ return bisect.bisect_right(self.cumulative_sizes, index)
diff --git a/src/eva/vision/data/wsi/__init__.py b/src/eva/vision/data/wsi/__init__.py
new file mode 100644
index 00000000..116fec74
--- /dev/null
+++ b/src/eva/vision/data/wsi/__init__.py
@@ -0,0 +1,16 @@
+"""WSI API."""
+
+from eva.vision.data.wsi.backends import Wsi, get_cached_wsi, wsi_backend
+from eva.vision.data.wsi.patching.coordinates import PatchCoordinates, get_cached_coords
+from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
+
+__all__ = [
+ "Wsi",
+ "PatchCoordinates",
+ "Mask",
+ "get_cached_coords",
+ "wsi_backend",
+ "get_cached_wsi",
+ "get_mask",
+ "get_mask_level",
+]
diff --git a/src/eva/vision/data/wsi/backends/__init__.py b/src/eva/vision/data/wsi/backends/__init__.py
new file mode 100644
index 00000000..273707bf
--- /dev/null
+++ b/src/eva/vision/data/wsi/backends/__init__.py
@@ -0,0 +1,65 @@
+"""WSI Backends API."""
+
+import functools
+import importlib.util
+from typing import Callable
+
+from eva.vision.data.wsi.backends.base import Wsi
+
+LRU_CACHE_SIZE = 32
+
+
+def _is_openslide_available() -> bool:
+ """Whether the OpenSlide library is available."""
+ return importlib.util.find_spec("openslide") is not None
+
+
+def _is_tiffslide_available() -> bool:
+ """Whether the TiffSlide library is available."""
+ return importlib.util.find_spec("tiffslide") is not None
+
+
+def is_backend_available(backend: str) -> bool:
+ """Whether the specified backend is available."""
+ match backend:
+ case "openslide":
+ return _is_openslide_available()
+ case "tiffslide":
+ return _is_tiffslide_available()
+ return False
+
+
+def wsi_backend(backend: str = "openslide") -> Callable[..., Wsi]:
+ """Returns the backend to use for reading the whole-slide images."""
+ match backend:
+ case "openslide":
+ if _is_openslide_available():
+ from eva.vision.data.wsi.backends.openslide import WsiOpenslide
+
+ return WsiOpenslide
+ else:
+ raise ValueError(
+ "Missing optional dependency: openslide.\n"
+ "Please install using `pip install openslide-python`."
+ )
+ case "tiffslide":
+ if _is_tiffslide_available():
+ from eva.vision.data.wsi.backends.tiffslide import WsiTiffslide
+
+ return WsiTiffslide
+ else:
+ raise ValueError(
+ "Missing optional dependency: tiffslide.\n"
+ "Please install using `pip install tiffslide`."
+ )
+ case _:
+ raise ValueError(f"Unknown WSI backend selected: {backend}")
+
+
+@functools.lru_cache(LRU_CACHE_SIZE)
+def get_cached_wsi(file_path: str, backend: str) -> Wsi:
+ """Returns a cached instance of the whole-slide image backend reader."""
+ return wsi_backend(backend)(file_path)
+
+
+__all__ = ["Wsi", "wsi_backend", "get_cached_wsi", "_is_openslide_available"]
diff --git a/src/eva/vision/data/wsi/backends/base.py b/src/eva/vision/data/wsi/backends/base.py
new file mode 100644
index 00000000..b8cc4e06
--- /dev/null
+++ b/src/eva/vision/data/wsi/backends/base.py
@@ -0,0 +1,113 @@
+"""Base Module for loading data from WSI files."""
+
+import abc
+from typing import Any, Sequence, Tuple
+
+import numpy as np
+
+
+class Wsi(abc.ABC):
+ """Base class for loading data from Whole Slide Image (WSI) files."""
+
+ def __init__(self, file_path: str):
+ """Initializes a Wsi object.
+
+ Args:
+ file_path: The path to the WSI file.
+ """
+ self._wsi = self.open_file(file_path)
+
+ @abc.abstractmethod
+ def open_file(self, file_path: str) -> Any:
+ """Opens the WSI file.
+
+ Args:
+ file_path: The path to the WSI file.
+ """
+
+ @property
+ @abc.abstractmethod
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
+ """A list of (width, height) tuples for each level, from highest to lowest resolution."""
+
+ @property
+ @abc.abstractmethod
+ def level_downsamples(self) -> Sequence[float]:
+ """A list of downsampling factors for each level, relative to the highest resolution."""
+
+ @property
+ @abc.abstractmethod
+ def mpp(self) -> float:
+ """Microns per pixel at the highest resolution (level 0)."""
+
+ @abc.abstractmethod
+ def _read_region(
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
+ ) -> np.ndarray:
+ """Abstract method to read a region at a specified zoom level."""
+
+ def read_region(
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
+ ) -> np.ndarray:
+ """Reads and returns image data for a specified region and zoom level.
+
+ Args:
+ location: Top-left corner (x, y) to start reading at level 0.
+ level: WSI level to read from.
+ size: Region size as (width, height) in pixels at the selected read level.
+ Remember to scale the size correctly.
+ """
+ self._verify_location(location, size)
+ data = self._read_region(location, level, size)
+ return self._read_postprocess(data)
+
+ def get_closest_level(self, target_mpp: float) -> int:
+ """Calculate the slide level that is closest to the target mpp.
+
+ Args:
+ slide: The whole-slide image object.
+ target_mpp: The target microns per pixel (mpp) value.
+ """
+ # Calculate the mpp for each level
+ level_mpps = self.mpp * np.array(self.level_downsamples)
+
+ # Ignore levels with higher mpp
+ level_mpps_filtered = level_mpps.copy()
+ level_mpps_filtered[level_mpps_filtered > target_mpp] = 0
+
+ if level_mpps_filtered.max() == 0:
+ # When all levels have higher mpp than target_mpp return the level with lowest mpp
+ level_idx = np.argmin(level_mpps)
+ else:
+ level_idx = np.argmax(level_mpps_filtered)
+
+ return int(level_idx)
+
+ def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None:
+ """Verifies that the requested region is within the slide dimensions.
+
+ Args:
+ location: Top-left corner (x, y) to start reading at level 0.
+ size: Region size as (width, height) in pixels at the selected read level.
+ """
+ x_max, y_max = self.level_dimensions[0]
+ x_scale = x_max / self.level_dimensions[0][0]
+ y_scale = y_max / self.level_dimensions[0][1]
+
+ if (
+ int(location[0] + x_scale * size[0]) > x_max
+ or int(location[1] + y_scale * size[1]) > y_max
+ ):
+ raise ValueError(f"Out of bounds region: {location}, {size}")
+
+ def _read_postprocess(self, data: np.ndarray) -> np.ndarray:
+ """Post-processes the read region data.
+
+ Args:
+ data: The read region data as a numpy array of shape (height, width, channels).
+ """
+ # Change color to white where the alpha channel is 0
+ if data.shape[2] == 4:
+ data[data[:, :, 3] == 0] = 255
+
+ return data[:, :, :3]
diff --git a/src/eva/vision/data/wsi/backends/openslide.py b/src/eva/vision/data/wsi/backends/openslide.py
new file mode 100644
index 00000000..10c7f8a9
--- /dev/null
+++ b/src/eva/vision/data/wsi/backends/openslide.py
@@ -0,0 +1,73 @@
+"""Module for loading data from WSI files using the OpenSlide library."""
+
+from typing import Sequence, Tuple
+
+import numpy as np
+import openslide
+from typing_extensions import override
+
+from eva.vision.data.wsi.backends import base
+
+
+class WsiOpenslide(base.Wsi):
+ """Class for loading data from WSI files using the OpenSlide library."""
+
+ _wsi: openslide.OpenSlide
+
+ @override
+ def open_file(self, file_path: str) -> openslide.OpenSlide:
+ return openslide.OpenSlide(file_path)
+
+ @property
+ @override
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
+ return self._wsi.level_dimensions
+
+ @property
+ @override
+ def level_downsamples(self) -> Sequence[float]:
+ return self._wsi.level_downsamples
+
+ @property
+ @override
+ def mpp(self) -> float:
+ # TODO: add overwrite_mpp class attribute to allow setting a default value
+ if self._wsi.properties.get(openslide.PROPERTY_NAME_MPP_X) and self._wsi.properties.get(
+ openslide.PROPERTY_NAME_MPP_Y
+ ):
+ x_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_X])
+ y_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_Y])
+ elif (
+ self._wsi.properties.get("tiff.XResolution")
+ and self._wsi.properties.get("tiff.YResolution")
+ and self._wsi.properties.get("tiff.ResolutionUnit")
+ ):
+ unit = self._wsi.properties.get("tiff.ResolutionUnit")
+ if unit not in _conversion_factor_to_micrometer:
+ raise ValueError(f"Unit {unit} not supported.")
+
+ conversion_factor = float(_conversion_factor_to_micrometer.get(unit)) # type: ignore
+ x_mpp = conversion_factor / float(self._wsi.properties["tiff.XResolution"])
+ y_mpp = conversion_factor / float(self._wsi.properties["tiff.YResolution"])
+ else:
+ raise ValueError("`mpp` cannot be obtained for this slide.")
+
+ return (x_mpp + y_mpp) / 2.0
+
+ @override
+ def _read_region(
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
+ ) -> np.ndarray:
+ return np.array(self._wsi.read_region(location, level, size))
+
+
+_conversion_factor_to_micrometer = {
+ "meter": 10**6,
+ "decimeter": 10**5,
+ "centimeter": 10**4,
+ "millimeter": 10**3,
+ "micrometer": 1,
+ "nanometer": 10**-3,
+ "picometer": 10**-6,
+ "femtometer": 10**-9,
+}
diff --git a/src/eva/vision/data/wsi/backends/tiffslide.py b/src/eva/vision/data/wsi/backends/tiffslide.py
new file mode 100644
index 00000000..7577e19e
--- /dev/null
+++ b/src/eva/vision/data/wsi/backends/tiffslide.py
@@ -0,0 +1,42 @@
+"""Module for loading data from WSI files using the OpenSlide library."""
+
+from typing import Sequence, Tuple
+
+import numpy as np
+import tiffslide # type: ignore
+from typing_extensions import override
+
+from eva.vision.data.wsi.backends import base
+
+
+class WsiTiffslide(base.Wsi):
+ """Class for loading data from WSI files using the TiffSlide library."""
+
+ _wsi: tiffslide.TiffSlide
+
+ @override
+ def open_file(self, file_path: str) -> tiffslide.TiffSlide:
+ return tiffslide.TiffSlide(file_path)
+
+ @property
+ @override
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
+ return self._wsi.level_dimensions
+
+ @property
+ @override
+ def level_downsamples(self) -> Sequence[float]:
+ return self._wsi.level_downsamples
+
+ @property
+ @override
+ def mpp(self) -> float:
+ x_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_X])
+ y_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_Y])
+ return (x_mpp + y_mpp) / 2.0
+
+ @override
+ def _read_region(
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
+ ) -> np.ndarray:
+ return np.array(self._wsi.read_region(location, level, size))
diff --git a/src/eva/vision/data/wsi/patching/__init__.py b/src/eva/vision/data/wsi/patching/__init__.py
new file mode 100644
index 00000000..f14b1b51
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/__init__.py
@@ -0,0 +1,5 @@
+"""WSI Patching API."""
+
+from eva.vision.data.wsi.patching.coordinates import PatchCoordinates
+
+__all__ = ["PatchCoordinates"]
diff --git a/src/eva/vision/data/wsi/patching/coordinates.py b/src/eva/vision/data/wsi/patching/coordinates.py
new file mode 100644
index 00000000..0600db98
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/coordinates.py
@@ -0,0 +1,94 @@
+"""A module for handling coordinates of patches from a whole-slide image."""
+
+import dataclasses
+import functools
+from typing import List, Tuple
+
+from eva.vision.data.wsi import backends
+from eva.vision.data.wsi.patching import samplers
+from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
+
+LRU_CACHE_SIZE = 32
+
+
+@dataclasses.dataclass
+class PatchCoordinates:
+ """A class to store coordinates of patches from a whole-slide image.
+
+ Args:
+ x_y: A list of (x, y) coordinates of the patches (refer to level 0).
+ width: The width of the patches, in pixels (refers to level_idx).
+ height: The height of the patches, in pixels (refers to level_idx).
+ level_idx: The level index at which to extract the patches.
+ mask: The foreground mask of the wsi.
+ """
+
+ x_y: List[Tuple[int, int]]
+ width: int
+ height: int
+ level_idx: int
+ mask: Mask | None = None
+
+ @classmethod
+ def from_file(
+ cls,
+ wsi_path: str,
+ width: int,
+ height: int,
+ target_mpp: float,
+ sampler: samplers.Sampler,
+ backend: str = "openslide",
+ ) -> "PatchCoordinates":
+ """Create a new instance of PatchCoordinates from a whole-slide image file.
+
+ Patches will be read from the level that is closest to the specified target_mpp.
+
+ Args:
+ wsi_path: The path to the whole-slide image file.
+ width: The width of the patches to be extracted, in pixels.
+ height: The height of the patches to be extracted, in pixels.
+ target_mpp: The target microns per pixel (mpp) for the patches.
+ sampler: The sampler to use for sampling patch coordinates.
+ backend: The backend to use for reading the whole-slide images.
+ """
+ wsi = backends.wsi_backend(backend)(wsi_path)
+
+ # Sample patch coordinates at level 0
+ mpp_ratio_0 = target_mpp / wsi.mpp
+ sample_args = {
+ "width": int(mpp_ratio_0 * width),
+ "height": int(mpp_ratio_0 * height),
+ "layer_shape": wsi.level_dimensions[0],
+ }
+ if isinstance(sampler, samplers.ForegroundSampler):
+ mask_level_idx = get_mask_level(wsi, width, height, target_mpp)
+ sample_args["mask"] = get_mask(wsi, mask_level_idx)
+
+ x_y = list(sampler.sample(**sample_args))
+
+ # Scale dimensions to level that is closest to the target_mpp
+ level_idx = wsi.get_closest_level(target_mpp)
+ mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx])
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
+
+ return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
+
+
+@functools.lru_cache(LRU_CACHE_SIZE)
+def get_cached_coords(
+ file_path: str,
+ width: int,
+ height: int,
+ target_mpp: float,
+ sampler: samplers.Sampler,
+ backend: str,
+) -> PatchCoordinates:
+ """Get a cached instance of PatchCoordinates for the specified parameters."""
+ return PatchCoordinates.from_file(
+ wsi_path=file_path,
+ width=width,
+ height=height,
+ target_mpp=target_mpp,
+ backend=backend,
+ sampler=sampler,
+ )
diff --git a/src/eva/vision/data/wsi/patching/mask.py b/src/eva/vision/data/wsi/patching/mask.py
new file mode 100644
index 00000000..2a69425b
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/mask.py
@@ -0,0 +1,123 @@
+"""Functions for extracting foreground masks."""
+
+import dataclasses
+from typing import Tuple
+
+import cv2
+import numpy as np
+
+from eva.vision.data.wsi.backends.base import Wsi
+
+
+@dataclasses.dataclass
+class Mask:
+ """A class to store the mask of a whole-slide image."""
+
+ mask_array: np.ndarray
+ """Binary mask array where 1s represent the foreground and 0s represent the background."""
+
+ mask_level_idx: int
+ """WSI level index at which the mask_array was extracted."""
+
+ scale_factors: Tuple[float, float]
+ """Factors to scale x/y coordinates from mask_level_idx to level 0."""
+
+
+def get_mask(
+ wsi: Wsi,
+ mask_level_idx: int,
+ saturation_threshold: int = 20,
+ median_blur_kernel_size: int | None = None,
+ fill_holes: bool = False,
+ holes_kernel_size: Tuple[int, int] = (7, 7),
+ use_otsu: bool = False,
+) -> Mask:
+ """Generates a binary foreground mask for a given WSI.
+
+ The is a simplified version of the algorithm proposed in [1] (CLAM):
+ 1. Convert the image to the HSV color space (easier to seperate specific colors with RGB).
+ 2. (optional) Apply a median blur to the saturation channel to reduce noise
+ & closing small gaps in the mask. While this yields cleaner masks, this step is the most
+ computationally expensive and thus disabled by default (CLAM uses a value of 7).
+ 3. Calculate binary mask by thresholding accross the saturation channel.
+
+ [1] Lu, Ming Y., et al. "Data-efficient and weakly supervised computational
+ pathology on whole-slide images." Nature biomedical engineering 5.6 (2021): 555-570.
+ https://github.com/mahmoodlab/CLAM
+
+ Args:
+ wsi: The WSI object.
+ mask_level_idx: The level index of the WSI at which we want to extract the mask.
+ saturation_threshold: The threshold value for the saturation channel.
+ median_blur_kernel_size: Kernel size for the median blur operation.
+ holes_kernel_size: The size of the kernel for morphological operations to fill holes.
+ fill_holes: Whether to fill holes in the mask.
+ use_otsu: Whether to use Otsu's method for the thresholding operation. If False,
+ a fixed threshold value is used.
+
+ Returns: A Mask object instance.
+ """
+ image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx])
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+ image = (
+ cv2.medianBlur(image[:, :, 1], median_blur_kernel_size)
+ if median_blur_kernel_size
+ else image[:, :, 1]
+ )
+
+ threshold_type = cv2.THRESH_BINARY + cv2.THRESH_OTSU if use_otsu else cv2.THRESH_BINARY
+ _, mask_array = cv2.threshold(image, saturation_threshold, 1, threshold_type)
+
+ if fill_holes:
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, holes_kernel_size)
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
+ contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
+ for cnt in contour:
+ cv2.drawContours(mask_array, [cnt], 0, (1,), -1)
+
+ mask_array = mask_array.astype(np.uint8)
+ scale_factors = (
+ wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0],
+ wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1],
+ )
+
+ return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors)
+
+
+def get_mask_level(
+ wsi: Wsi,
+ width: int,
+ height: int,
+ target_mpp: float,
+ min_mask_patch_pixels: int = 3 * 3,
+) -> int:
+ """For performance reasons, we generate the mask at the lowest resolution level possible.
+
+ However, if minimum resolution level has too few pixels, the patches scaled to that level will
+ be too small or even collapse to a single pixel. This function allows to find the lowest
+ resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels.
+
+ Args:
+ wsi: The WSI object.
+ width: The width of the patches to be extracted, in pixels (at target_mpp).
+ height: The height of the patches to be extracted, in pixels.
+ target_mpp: The target microns per pixel (mpp) for the patches.
+ min_mask_patch_pixels: The minimum number of pixels required for the mask patches.
+ Mask patch refers to width / height at target_mpp scaled down to the WSI level
+ at which the mask is generated.
+ """
+ level_mpps = wsi.mpp * np.array(wsi.level_downsamples)
+ mask_level_idx = None
+
+ for level_idx, level_mpp in reversed(list(enumerate(level_mpps))):
+ mpp_ratio = target_mpp / level_mpp
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
+
+ if scaled_width * scaled_height >= min_mask_patch_pixels:
+ mask_level_idx = level_idx
+ break
+
+ if mask_level_idx is None:
+ raise ValueError("No level with the specified minimum number of patch pixels available.")
+
+ return mask_level_idx
diff --git a/src/eva/vision/data/wsi/patching/samplers/__init__.py b/src/eva/vision/data/wsi/patching/samplers/__init__.py
new file mode 100644
index 00000000..49860968
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/__init__.py
@@ -0,0 +1,8 @@
+"""Patch Sampler API."""
+
+from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
+from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
+from eva.vision.data.wsi.patching.samplers.grid import GridSampler
+from eva.vision.data.wsi.patching.samplers.random import RandomSampler
+
+__all__ = ["Sampler", "ForegroundSampler", "RandomSampler", "GridSampler", "ForegroundGridSampler"]
diff --git a/src/eva/vision/data/wsi/patching/samplers/_utils.py b/src/eva/vision/data/wsi/patching/samplers/_utils.py
new file mode 100644
index 00000000..af8418df
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/_utils.py
@@ -0,0 +1,50 @@
+import random
+from typing import Tuple
+
+import numpy as np
+
+
+def set_seed(seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+
+
+def get_grid_coords_and_indices(
+ layer_shape: Tuple[int, int],
+ width: int,
+ height: int,
+ overlap: Tuple[int, int],
+ shuffle: bool = True,
+ seed: int = 42,
+):
+ """Get grid coordinates and indices.
+
+ Args:
+ layer_shape: The shape of the layer.
+ width: The width of the patches.
+ height: The height of the patches.
+ overlap: The overlap between patches in the grid.
+ shuffle: Whether to shuffle the indices.
+ seed: The random seed.
+ """
+ x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
+ y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
+ x_y = [(x, y) for x in x_range for y in y_range]
+
+ indices = list(range(len(x_y)))
+ if shuffle:
+ set_seed(seed)
+ np.random.shuffle(indices)
+ return x_y, indices
+
+
+def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
+ """Checks if the width / height is bigger than the layer shape.
+
+ Args:
+ width: The width of the patches.
+ height: The height of the patches.
+ layer_shape: The shape of the layer.
+ """
+ if width > layer_shape[0] or height > layer_shape[1]:
+ raise ValueError("The width / height cannot be bigger than the layer shape.")
diff --git a/src/eva/vision/data/wsi/patching/samplers/base.py b/src/eva/vision/data/wsi/patching/samplers/base.py
new file mode 100644
index 00000000..fa9a24ac
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/base.py
@@ -0,0 +1,48 @@
+"""Base classes for samplers."""
+
+import abc
+from typing import Generator, Tuple
+
+from eva.vision.data.wsi.patching.mask import Mask
+
+
+class Sampler(abc.ABC):
+ """Base class for samplers."""
+
+ @abc.abstractmethod
+ def sample(
+ self,
+ width: int,
+ height: int,
+ layer_shape: Tuple[int, int],
+ mask: Mask | None = None,
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Sample patche coordinates.
+
+ Args:
+ width: The width of the patches.
+ height: The height of the patches.
+ layer_shape: The shape of the layer.
+ mask: Tuple containing the mask array and the scaling factor with respect to the
+ provided layer_shape. Optional, only required for samplers with foreground
+ filtering.
+
+ Returns:
+ A generator producing sampled patch coordinates.
+ """
+
+
+class ForegroundSampler(Sampler):
+ """Base class for samplers with foreground filtering capabilities."""
+
+ @abc.abstractmethod
+ def is_foreground(
+ self,
+ mask: Mask,
+ x: int,
+ y: int,
+ width: int,
+ height: int,
+ min_foreground_ratio: float,
+ ) -> bool:
+ """Check if a patch contains sufficient foreground."""
diff --git a/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py
new file mode 100644
index 00000000..e062caf5
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/foreground_grid.py
@@ -0,0 +1,87 @@
+"""Foreground grid sampler."""
+
+from typing import Tuple
+
+from eva.vision.data.wsi.patching.mask import Mask
+from eva.vision.data.wsi.patching.samplers import _utils, base
+
+
+class ForegroundGridSampler(base.ForegroundSampler):
+ """Sample patches based on a grid, only returning patches containing foreground.
+
+ Args:
+ max_samples: The maximum number of samples to return.
+ overlap: The overlap between patches in the grid.
+ min_foreground_ratio: The minimum amount of foreground within a sampled patch.
+ seed: The random seed.
+ """
+
+ def __init__(
+ self,
+ max_samples: int = 20,
+ overlap: Tuple[int, int] = (0, 0),
+ min_foreground_ratio: float = 0.35,
+ seed: int = 42,
+ ):
+ """Initializes the sampler."""
+ self.max_samples = max_samples
+ self.overlap = overlap
+ self.min_foreground_ratio = min_foreground_ratio
+ self.seed = seed
+
+ def sample(
+ self,
+ width: int,
+ height: int,
+ layer_shape: Tuple[int, int],
+ mask: Mask,
+ ):
+ """Sample patches from a grid containing foreground.
+
+ Args:
+ width: The width of the patches.
+ height: The height of the patches.
+ layer_shape: The shape of the layer.
+ mask: The mask of the image.
+ """
+ _utils.validate_dimensions(width, height, layer_shape)
+ x_y, indices = _utils.get_grid_coords_and_indices(
+ layer_shape, width, height, self.overlap, seed=self.seed
+ )
+
+ count = 0
+ for i in indices:
+ if count >= self.max_samples:
+ break
+ if self.is_foreground(
+ mask, x_y[i][0], x_y[i][1], width, height, self.min_foreground_ratio
+ ):
+ count += 1
+ yield x_y[i]
+
+ def is_foreground(
+ self,
+ mask: Mask,
+ x: int,
+ y: int,
+ width: int,
+ height: int,
+ min_foreground_ratio: float,
+ ) -> bool:
+ """Check if a patch contains sufficient foreground.
+
+ Args:
+ mask: The mask of the image.
+ x: The x-coordinate of the patch.
+ y: The y-coordinate of the patch.
+ width: The width of the patch.
+ height: The height of the patch.
+ min_foreground_ratio: The minimum amount of foreground in the patch.
+ """
+ x_, y_ = self._scale_coords(x, y, mask.scale_factors)
+ width_, height_ = self._scale_coords(width, height, mask.scale_factors)
+ patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
+ return patch_mask.sum() / patch_mask.size >= min_foreground_ratio
+
+ def _scale_coords(self, x: int, y: int, scale_factors: Tuple[float, float]) -> Tuple[int, int]:
+ return int(x / scale_factors[0]), int(y / scale_factors[1])
diff --git a/src/eva/vision/data/wsi/patching/samplers/grid.py b/src/eva/vision/data/wsi/patching/samplers/grid.py
new file mode 100644
index 00000000..3f2b0081
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/grid.py
@@ -0,0 +1,47 @@
+"""Grid sampler."""
+
+from typing import Generator, Tuple
+
+from eva.vision.data.wsi.patching.samplers import _utils, base
+
+
+class GridSampler(base.Sampler):
+ """Sample patches based on a grid.
+
+ Args:
+ max_samples: The maximum number of samples to return.
+ overlap: The overlap between patches in the grid.
+ seed: The random seed.
+ """
+
+ def __init__(
+ self,
+ max_samples: int | None = None,
+ overlap: Tuple[int, int] = (0, 0),
+ seed: int = 42,
+ ):
+ """Initializes the sampler."""
+ self.max_samples = max_samples
+ self.overlap = overlap
+ self.seed = seed
+
+ def sample(
+ self,
+ width: int,
+ height: int,
+ layer_shape: Tuple[int, int],
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Sample patches from a grid.
+
+ Args:
+ width: The width of the patches.
+ height: The height of the patches.
+ layer_shape: The shape of the layer.
+ """
+ _utils.validate_dimensions(width, height, layer_shape)
+ x_y, indices = _utils.get_grid_coords_and_indices(
+ layer_shape, width, height, self.overlap, seed=self.seed
+ )
+ max_samples = len(indices) if self.max_samples is None else self.max_samples
+ for i in indices[:max_samples]:
+ yield x_y[i]
diff --git a/src/eva/vision/data/wsi/patching/samplers/random.py b/src/eva/vision/data/wsi/patching/samplers/random.py
new file mode 100644
index 00000000..09ae5729
--- /dev/null
+++ b/src/eva/vision/data/wsi/patching/samplers/random.py
@@ -0,0 +1,41 @@
+"""Random sampler."""
+
+import random
+from typing import Generator, Tuple
+
+from eva.vision.data.wsi.patching.samplers import _utils, base
+
+
+class RandomSampler(base.Sampler):
+ """Sample patch coordinates randomly.
+
+ Args:
+ n_samples: The number of samples to return.
+ seed: The random seed.
+ """
+
+ def __init__(self, n_samples: int = 1, seed: int = 42):
+ """Initializes the sampler."""
+ self.seed = seed
+ self.n_samples = n_samples
+
+ def sample(
+ self,
+ width: int,
+ height: int,
+ layer_shape: Tuple[int, int],
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Sample random patches.
+
+ Args:
+ width: The width of the patches.
+ height: The height of the patches.
+ layer_shape: The shape of the layer.
+ """
+ _utils.validate_dimensions(width, height, layer_shape)
+ _utils.set_seed(self.seed)
+
+ x_max, y_max = layer_shape[0], layer_shape[1]
+ for _ in range(self.n_samples):
+ x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
+ yield x, y
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8.pt
deleted file mode 100644
index 4356f915..00000000
--- a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:998be8c190a910135c2ea2722543c2750ddc070280427b7ab211db3da59ee9b8
-size 1225
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8_list.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8_list.pt
new file mode 100644
index 00000000..abe16557
--- /dev/null
+++ b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_2_shape_8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ea1a50d4c8e714ea4d13cbe41f7ab3f455f8fb8963be1ba1787a7b5ab9a4545
+size 1250
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8.pt
deleted file mode 100644
index 4356f915..00000000
--- a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:998be8c190a910135c2ea2722543c2750ddc070280427b7ab211db3da59ee9b8
-size 1225
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8_list.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8_list.pt
new file mode 100644
index 00000000..f42c1040
--- /dev/null
+++ b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_3_shape_8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf79fa09a95cc06d3b07a0badde28badc86a5412a50c43836e86c2e1215aeddf
+size 1250
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8.pt
deleted file mode 100644
index fdab67c6..00000000
--- a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:85c2a3fe8ded76eda86274f0066d8f34445a88d1e90c23e2b598194d2bd6542c
-size 1235
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8_list.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8_list.pt
new file mode 100644
index 00000000..35b400a5
--- /dev/null
+++ b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_6_shape_1x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eabea1dec7432f00698f3a0aa2ed9d838a4dc126450979e7d807cc9c90feb6bc
+size 1324
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8.pt
deleted file mode 100644
index fdab67c6..00000000
--- a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:85c2a3fe8ded76eda86274f0066d8f34445a88d1e90c23e2b598194d2bd6542c
-size 1235
diff --git a/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8_list.pt b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8_list.pt
new file mode 100644
index 00000000..d32d60f9
--- /dev/null
+++ b/tests/eva/assets/core/datasets/embeddings/embeddings/tensor_7_shape_1x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa9320e36efabc9dc6fa9f70e3c8232a6a906a8e622b92d81318aa855e6d044e
+size 1324
diff --git a/tests/eva/assets/core/datasets/embeddings/manifest.csv b/tests/eva/assets/core/datasets/embeddings/manifest.csv
index c0a54e26..ea0224b9 100644
--- a/tests/eva/assets/core/datasets/embeddings/manifest.csv
+++ b/tests/eva/assets/core/datasets/embeddings/manifest.csv
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:0fc8758e7b95dfb048d9cfba4c64667f553c9980dde874066ed795382980b2d0
-size 337
+oid sha256:5798f6f5031188227f211531f20d79e6df1916e620eaf653777dc4417840b65d
+size 357
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8.pt
deleted file mode 100644
index 417ac877..00000000
--- a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e7a7f391ab7f206a92cca08d630538a430c0d5cadf3eaadb3d3f845724d76692
-size 1235
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8_list.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8_list.pt
new file mode 100644
index 00000000..79d7ed13
--- /dev/null
+++ b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_6_shape_1x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:356b1db144d75d6a6cc2bb78e553bfeb0248cb78ce5272c1614fdf0c1c11a9ea
+size 1324
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8.pt
deleted file mode 100644
index 4fe6454c..00000000
--- a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6ea1a6a0588de18e5ca1e744a6d7e2cd933773cdae280a92985e78eca06a7c62
-size 1427
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8_list.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8_list.pt
new file mode 100644
index 00000000..f6213ac3
--- /dev/null
+++ b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_7_shape_6x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b4b1cc0cfc033ee5643c1c7b1d1f911641ba115f4dfd7d74c4fe488fc07c51e
+size 1772
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8.pt
deleted file mode 100644
index 0c37da5d..00000000
--- a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d730238ef61f0dc4b98719e72ecff2fcee0f3c69b5c634554429828792f6a251
-size 1299
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8_list.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8_list.pt
new file mode 100644
index 00000000..043d17e8
--- /dev/null
+++ b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_8_shape_2x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3323f1ad0aa408b081c0a6eac1c9b3c46d0d1d61e5fbfe084ea39794c2c3cee2
+size 1452
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8.pt
deleted file mode 100644
index fa63fd0f..00000000
--- a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6ad0657aaa9a8d0521ce9f51ae2fcc748eaf4c94a7e27845b32d9587542b98de
-size 1363
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8_list.pt b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8_list.pt
new file mode 100644
index 00000000..e5d062b7
--- /dev/null
+++ b/tests/eva/assets/core/datasets/multi-embeddings/embeddings/tensor_9_shape_5x8_list.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c1de7941fe339d557a99fcedeae0a8ab7f1f52302f44a400573ce7aa2f8e66a
+size 1644
diff --git a/tests/eva/assets/core/datasets/multi-embeddings/manifest.csv b/tests/eva/assets/core/datasets/multi-embeddings/manifest.csv
index 1eb25968..71e152bf 100644
--- a/tests/eva/assets/core/datasets/multi-embeddings/manifest.csv
+++ b/tests/eva/assets/core/datasets/multi-embeddings/manifest.csv
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:5e884b93aa81257148dbb73564d734045ffe110463f60a6064814bf95aa82044
-size 514
+oid sha256:9c414260a55131f60431d23ca34baf5da47f3eb18d614cd8e71d5e51da402cef
+size 532
diff --git a/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_001.tif b/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_001.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_001.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_002.tif b/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_002.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/testing/images/test_002.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/camelyon16/testing/reference.csv b/tests/eva/assets/vision/datasets/camelyon16/testing/reference.csv
new file mode 100644
index 00000000..5b36aa91
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/testing/reference.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe2d8f0df36ba2b44f1ff875300019aa7df443fe1f428d7142dcc2f4ddc1a908
+size 50
diff --git a/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_001.tif b/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_001.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_001.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_002.tif b/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_002.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/training/normal/normal_002.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_001.tif b/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_001.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_001.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_002.tif b/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_002.tif
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/camelyon16/training/tumor/tumor_002.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/0214df71ae527e2144021178c453d204.tiff b/tests/eva/assets/vision/datasets/panda/train_images/0214df71ae527e2144021178c453d204.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/0214df71ae527e2144021178c453d204.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/02d302a8d723fa00331f373091b29135.tiff b/tests/eva/assets/vision/datasets/panda/train_images/02d302a8d723fa00331f373091b29135.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/02d302a8d723fa00331f373091b29135.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/157565e23ba28d5a42f63f34f3dd4425.tiff b/tests/eva/assets/vision/datasets/panda/train_images/157565e23ba28d5a42f63f34f3dd4425.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/157565e23ba28d5a42f63f34f3dd4425.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/682a1fd346b6fff340afbdb80c2f7caf.tiff b/tests/eva/assets/vision/datasets/panda/train_images/682a1fd346b6fff340afbdb80c2f7caf.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/682a1fd346b6fff340afbdb80c2f7caf.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/8582b59b41635fa38401d1bddad66707.tiff b/tests/eva/assets/vision/datasets/panda/train_images/8582b59b41635fa38401d1bddad66707.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/8582b59b41635fa38401d1bddad66707.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/8c357871e57c5c60277230412f2d9028.tiff b/tests/eva/assets/vision/datasets/panda/train_images/8c357871e57c5c60277230412f2d9028.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/8c357871e57c5c60277230412f2d9028.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/979cf5a2fa4079eaf74343d6ff5e1b51.tiff b/tests/eva/assets/vision/datasets/panda/train_images/979cf5a2fa4079eaf74343d6ff5e1b51.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/979cf5a2fa4079eaf74343d6ff5e1b51.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/9dd40c0127d217bc4917e4db40e06e94.tiff b/tests/eva/assets/vision/datasets/panda/train_images/9dd40c0127d217bc4917e4db40e06e94.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/9dd40c0127d217bc4917e4db40e06e94.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/9ed8ec7bf90653bc4ca86b3ca53cbb96.tiff b/tests/eva/assets/vision/datasets/panda/train_images/9ed8ec7bf90653bc4ca86b3ca53cbb96.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/9ed8ec7bf90653bc4ca86b3ca53cbb96.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/a04310d441e8d2c7a5066627baeec9b6.tiff b/tests/eva/assets/vision/datasets/panda/train_images/a04310d441e8d2c7a5066627baeec9b6.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/a04310d441e8d2c7a5066627baeec9b6.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_images/fb8886059879eaac70139336cb525838.tiff b/tests/eva/assets/vision/datasets/panda/train_images/fb8886059879eaac70139336cb525838.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_images/fb8886059879eaac70139336cb525838.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/panda/train_with_noisy_labels.csv b/tests/eva/assets/vision/datasets/panda/train_with_noisy_labels.csv
new file mode 100644
index 00000000..db3d8230
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/panda/train_with_noisy_labels.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd9cec7cd6b94b2cb845ab5093659dd127d1e31ad2b94a8f97effd9c0184bfff
+size 465
diff --git a/tests/eva/assets/vision/datasets/wsi/0/a.tiff b/tests/eva/assets/vision/datasets/wsi/0/a.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/wsi/0/a.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/wsi/0/b.tiff b/tests/eva/assets/vision/datasets/wsi/0/b.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/wsi/0/b.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/wsi/1/a.tiff b/tests/eva/assets/vision/datasets/wsi/1/a.tiff
new file mode 100644
index 00000000..64bc6f24
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/wsi/1/a.tiff
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a0877005a9da2360e67107b25c4657696516a54504a5f903b895ebdfad5062
+size 246784
diff --git a/tests/eva/assets/vision/datasets/wsi/manifest.csv b/tests/eva/assets/vision/datasets/wsi/manifest.csv
new file mode 100644
index 00000000..d9e7d867
--- /dev/null
+++ b/tests/eva/assets/vision/datasets/wsi/manifest.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac6feb39305e51bc126f0599bbb097af2525bfcbfd2e028d71bfebb7a29fdcab
+size 65
diff --git a/tests/eva/core/callbacks/writers/embeddings/test_classification.py b/tests/eva/core/callbacks/writers/embeddings/test_classification.py
index 59685b23..0b7ab822 100644
--- a/tests/eva/core/callbacks/writers/embeddings/test_classification.py
+++ b/tests/eva/core/callbacks/writers/embeddings/test_classification.py
@@ -1,14 +1,17 @@
"""Tests the embeddings writer."""
+import functools
import os
import random
import tempfile
from pathlib import Path
-from typing import List, Literal
+from typing import List, Literal, Set
import lightning.pytorch as pl
import pandas as pd
import pytest
+import torch
+from lightning.pytorch import callbacks
from lightning.pytorch.demos import boring_classes
from torch import nn
from typing_extensions import override
@@ -21,57 +24,129 @@
@pytest.mark.parametrize(
- "batch_size, n_samples",
+ "batch_size, n_samples, metadata_keys, filenames",
[
- (5, 7),
- (8, 16),
+ (5, 7, None, None),
+ (5, 7, ["wsi_id"], None),
+ (8, 16, None, None),
+ (8, 32, ["wsi_id"], ["slide_1", "slide_2"]),
],
)
def test_embeddings_writer(datamodule: datamodules.DataModule, model: modules.HeadModule) -> None:
- """Tests the embeddings writer callback."""
+ """Tests the embeddings writer callback.
+
+ This test executes a lightning trainer predict operation and checks if the expected
+ embedding tensors & manifest files are correctly written to disk.
+ """
with tempfile.TemporaryDirectory() as output_dir:
- trainer = pl.Trainer(
- logger=False,
- callbacks=writers.ClassificationEmbeddingsWriter(
- output_dir=output_dir,
- dataloader_idx_map={0: "train", 1: "val", 2: "test"},
- backbone=nn.Flatten(),
- ),
- )
- all_predictions = trainer.predict(
- model=model, datamodule=datamodule, return_predictions=True
+ metadata_keys = datamodule.datasets.predict[0]._metadata_keys # type: ignore
+ expected_filenames = datamodule.datasets.predict[0]._filenames # type: ignore
+ grouping_enabled = expected_filenames is not None
+ callback = writers.ClassificationEmbeddingsWriter(
+ output_dir=output_dir,
+ dataloader_idx_map={0: "train", 1: "val", 2: "test"},
+ backbone=nn.Flatten(),
+ metadata_keys=metadata_keys,
)
- files = Path(output_dir).glob("*.pt")
- files = [f.relative_to(output_dir).as_posix() for f in files]
+ trainer = _init_and_run_trainer([callback], model, datamodule)
assert isinstance(trainer.predict_dataloaders, list)
assert len(trainer.predict_dataloaders) == 3
- assert isinstance(all_predictions, list)
- assert len(all_predictions) == 3
- total_n_predictions = 0
+
+ unique_filenames = set()
+ tot_n_samples = 0
for dataloader_idx in range(len(trainer.predict_dataloaders)):
+ _check_embedding_dimensions(output_dir, grouping_enabled)
dataset = trainer.predict_dataloaders[dataloader_idx].dataset
+ filenames = _check_if_embedding_files_exist(output_dir, dataset, expected_filenames)
+ unique_filenames.update(filenames)
+ tot_n_samples += len(dataset)
+
+ expected_file_count = len(unique_filenames) if expected_filenames else tot_n_samples
+ _check_expected_n_files(output_dir, expected_file_count)
+ _check_manifest(output_dir, len(unique_filenames), metadata_keys)
+
+
+def _init_and_run_trainer(
+ callbacks: List[callbacks.Callback],
+ model: pl.LightningModule,
+ datamodule: datamodules.DataModule,
+):
+ """Initializes and runs the trainer with the given callbacks."""
+ trainer = pl.Trainer(
+ logger=False,
+ accelerator="cpu",
+ callbacks=callbacks,
+ )
+ trainer.predict(model=model, datamodule=datamodule, return_predictions=True)
+
+ return trainer
+
+
+def _check_if_embedding_files_exist(
+ output_dir: str, dataset: datasets.Dataset, expected_filenames: List[str] | None
+) -> Set[str]:
+ """Checks if the expected embedding files exist in the output directory."""
+ output_files = _get_output_filenames(output_dir)
+
+ dataset_filenames = set()
+ for idx in range(len(dataset)): # type: ignore
+ filename = f"{dataset.filename(idx)}.pt" # type: ignore
+ assert filename in output_files
+ dataset_filenames.add(filename)
+
+ if expected_filenames:
+ assert len(set(expected_filenames) - {Path(x).stem for x in output_files}) == 0
+
+ return dataset_filenames
- # Check if the number of predictions is correct
- predictions = all_predictions[dataloader_idx]
- assert isinstance(predictions, list)
- n_predictions = sum(len(p) for p in predictions)
- assert len(dataset) == n_predictions
- # Check if the expected files are present
- for idx in range(len(dataset)):
- filename = dataset.filename(idx)
- assert f"{filename}.pt" in files
+def _check_embedding_dimensions(output_dir: str, grouping_enabled: bool):
+ """Checks if the produced embeddings have the expected dimensions."""
+ embedding_paths = Path(output_dir).glob("*.pt")
- total_n_predictions += n_predictions
+ for path in embedding_paths:
+ tensor_list = torch.load(path)
+ assert isinstance(tensor_list, list)
+ for t in tensor_list:
+ assert isinstance(t, torch.Tensor)
+ assert t.ndim == 1
- # Check if the manifest file is in the expected format
- df_manifest = pd.read_csv(os.path.join(output_dir, "manifest.csv"))
- assert "origin" in df_manifest.columns
- assert "embeddings" in df_manifest.columns
- assert "target" in df_manifest.columns
- assert "split" in df_manifest.columns
- assert len(df_manifest) == total_n_predictions
+ if grouping_enabled:
+ assert len(tensor_list) > 1
+ else:
+ assert len(tensor_list) == 1
+
+
+def _check_expected_n_files(output_dir: str, expected_file_count: int):
+ """Checks if the number of produced output files matches the expected count."""
+ output_files = _get_output_filenames(output_dir)
+ assert len(output_files) == expected_file_count
+
+
+def _check_manifest(
+ output_dir: str, expected_n_entries: int, metadata_keys: List[str] | None = None
+):
+ """Checks if the manifest file contains the expected number of entries and columns."""
+ manifest_path = os.path.join(output_dir, "manifest.csv")
+ assert os.path.isfile(manifest_path)
+ df_manifest = pd.read_csv(manifest_path)
+
+ expected_columns = ["origin", "embeddings", "target", "split"] + (metadata_keys or [])
+ for column in expected_columns:
+ assert column in df_manifest.columns
+
+ assert len(df_manifest) == expected_n_entries
+
+ if metadata_keys:
+ assert all(key in df_manifest.columns for key in metadata_keys)
+
+
+def _get_output_filenames(output_dir: str) -> List[str]:
+ """Returns the list of output embedding filenames in the output directory."""
+ output_files = Path(output_dir).glob("*.pt")
+ output_files = [f.relative_to(output_dir).as_posix() for f in output_files]
+ return output_files
@pytest.fixture(scope="function")
@@ -87,11 +162,20 @@ def model(n_classes: int = 4) -> modules.HeadModule:
@pytest.fixture(scope="function")
def dataset(
n_samples: int,
+ metadata_keys: List[str] | None,
+ filenames: List[str] | None,
) -> List[datasets.Dataset]:
"""Fake dataset fixture."""
- train_dataset = FakeDataset(split="train", length=n_samples, size=SAMPLE_SHAPE)
- val_dataset = FakeDataset(split="val", length=n_samples, size=SAMPLE_SHAPE)
- test_dataset = FakeDataset(split="test", length=n_samples, size=SAMPLE_SHAPE)
+ Dataset = functools.partial(
+ FakeDataset,
+ length=n_samples,
+ size=SAMPLE_SHAPE,
+ metadata_keys=metadata_keys,
+ filenames=filenames,
+ )
+ train_dataset = Dataset(split="train")
+ val_dataset = Dataset(split="val")
+ test_dataset = Dataset(split="test")
return [train_dataset, val_dataset, test_dataset]
@@ -99,17 +183,35 @@ def dataset(
class FakeDataset(boring_classes.RandomDataset, datasets.Dataset):
"""Fake prediction dataset."""
- def __init__(self, split: Literal["train", "val", "test"], size: int = 32, length: int = 10):
+ def __init__(
+ self,
+ split: Literal["train", "val", "test"],
+ size: int = 32,
+ length: int = 10,
+ metadata_keys: List[str] | None = None,
+ filenames: List[str] | None = None,
+ ):
"""Initializes the dataset."""
super().__init__(size=size, length=length)
self._split = split
+ self._metadata_keys = metadata_keys
+ self._filenames = filenames
def filename(self, index: int) -> str:
"""Returns the filename for the given index."""
- return f"{self._split}-{index}"
+ if self._filenames:
+ # This simulates the case where where multiple items can correspond to the same file.
+ # e.g. in WSI classification, multiple patches can belong to the same slide.
+ return random.choice(self._filenames)
+ else:
+ return f"{self._split}-{index}"
@override
def __getitem__(self, index: int):
data = boring_classes.RandomDataset.__getitem__(self, index)
target = random.choice([0, 1])
- return data, target
+ if self._metadata_keys:
+ metadata = {key: random.choice([0, 1, 2]) for key in self._metadata_keys}
+ return data, target, metadata
+ else:
+ return data, target
diff --git a/tests/eva/core/data/splitting/__init__.py b/tests/eva/core/data/splitting/__init__.py
new file mode 100644
index 00000000..18a90221
--- /dev/null
+++ b/tests/eva/core/data/splitting/__init__.py
@@ -0,0 +1 @@
+"""Tests core splitting module."""
diff --git a/tests/eva/core/data/splitting/test_stratified.py b/tests/eva/core/data/splitting/test_stratified.py
new file mode 100644
index 00000000..2b65ccd8
--- /dev/null
+++ b/tests/eva/core/data/splitting/test_stratified.py
@@ -0,0 +1,70 @@
+"""Tests for the stratified split function."""
+
+import pytest
+
+from eva.core.data import splitting
+
+
+@pytest.mark.parametrize(
+ "targets, train_ratio, val_ratio, test_ratio",
+ [
+ ([0] * 50 + [1] * 50, 0.8, 0.2, 0.0),
+ ([0] * 50 + [1] * 50, 0.7, 0.15, 0.15),
+ ([0] * 30 + [1] * 70, 0.8, 0.2, 0.0),
+ ([0] * 30 + [1] * 70, 0.7, 0.15, 0.15),
+ ],
+)
+def test_stratification(
+ targets: list[int], train_ratio: float, val_ratio: float, test_ratio: float
+):
+ """Tests if the stratified split maintains the class proportions."""
+ samples = list(range(len(targets)))
+ train_indices, val_indices, test_indices = splitting.stratified_split(
+ samples, targets, train_ratio, val_ratio, test_ratio
+ )
+ train_classes = [targets[i] for i in train_indices]
+ val_classes = [targets[i] for i in val_indices]
+
+ for c in set(targets):
+ expected_train_proportion = train_ratio * targets.count(c)
+ expected_val_proportion = val_ratio * targets.count(c)
+ assert train_classes.count(c) == pytest.approx(expected_train_proportion, abs=1)
+ assert val_classes.count(c) == pytest.approx(expected_val_proportion, abs=1)
+
+ assert len(train_indices) + len(val_indices) + len(test_indices or []) == len(samples)
+
+
+@pytest.mark.parametrize("train_ratio, val_ratio, test_ratio", [(0.6, 0.3, 0.0), (0.6, 0.4, 0.3)])
+def test_invalid_ratio_sums(train_ratio: float, val_ratio: float, test_ratio: float):
+ """Tests if the function raises an error when the ratios do not sum to 1."""
+ samples = list(range(100))
+ targets = [0] * 50 + [1] * 50
+ expected_error = "The sum of the ratios must be equal to 1."
+ with pytest.raises(ValueError, match=expected_error):
+ splitting.stratified_split(samples, targets, train_ratio, val_ratio, test_ratio)
+
+
+@pytest.mark.parametrize("seed1, seed2", [(42, 43), (123, 124), (999, 1000)])
+def test_different_seeds_produce_different_outputs(seed1, seed2):
+ """Tests if different seeds produce different train, validation, and test indices."""
+ samples = list(range(100))
+ targets = [0] * 50 + [1] * 50
+ train1, val1, test1 = splitting.stratified_split(samples, targets, 0.6, 0.2, 0.2, seed=seed1)
+ train2, val2, test2 = splitting.stratified_split(samples, targets, 0.6, 0.2, 0.2, seed=seed2)
+
+ assert train1 != train2, "Different seeds should produce different train indices"
+ assert val1 != val2, "Different seeds should produce different validation indices"
+ assert test1 != test2, "Different seeds should produce different test indices"
+
+
+@pytest.mark.parametrize("seed", [42, 123, 999])
+def test_same_seed_produces_same_outputs(seed):
+ """Tests if the same seed produces the same train, validation, and test indices."""
+ samples = list(range(100))
+ targets = [0] * 50 + [1] * 50
+ train1, val1, test1 = splitting.stratified_split(samples, targets, 0.6, 0.2, 0.2, seed=seed)
+ train2, val2, test2 = splitting.stratified_split(samples, targets, 0.6, 0.2, 0.2, seed=seed)
+
+ assert train1 == train2, "Same seed should produce the same train indices"
+ assert val1 == val2, "Same seed should produce the same validation indices"
+ assert test1 == test2, "Same seed should produce the same test indices"
diff --git a/tests/eva/vision/data/datasets/classification/test_bach.py b/tests/eva/vision/data/datasets/classification/test_bach.py
index 41ccd88d..8c3ad47f 100644
--- a/tests/eva/vision/data/datasets/classification/test_bach.py
+++ b/tests/eva/vision/data/datasets/classification/test_bach.py
@@ -33,9 +33,9 @@ def test_sample(bach_dataset: datasets.BACH, index: int) -> None:
# assert data sample is a tuple
sample = bach_dataset[index]
assert isinstance(sample, tuple)
- assert len(sample) == 2
+ assert len(sample) == 3
# assert the format of the `image` and `target`
- image, target = sample
+ image, target, _ = sample
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 16, 16)
assert isinstance(target, torch.Tensor)
diff --git a/tests/eva/vision/data/datasets/classification/test_camelyon16.py b/tests/eva/vision/data/datasets/classification/test_camelyon16.py
new file mode 100644
index 00000000..e198dc87
--- /dev/null
+++ b/tests/eva/vision/data/datasets/classification/test_camelyon16.py
@@ -0,0 +1,81 @@
+"""Camelyon16 dataset tests."""
+
+import os
+from typing import Any, Literal
+
+import pytest
+import torch
+import torchvision.transforms.v2 as torch_transforms
+from torchvision import tv_tensors
+
+from eva.vision.data import datasets
+from eva.vision.data import transforms as eva_transforms
+from eva.vision.data.wsi.patching import samplers
+
+TARGET_SIZE = 224
+DEFAULT_ARGS = {
+ "width": 16,
+ "height": 16,
+ "target_mpp": 0.5,
+ "sampler": samplers.GridSampler(),
+ "backend": "openslide",
+ "image_transforms": torch_transforms.Compose([eva_transforms.ResizeAndCrop(size=TARGET_SIZE)]),
+}
+
+
+def test_split_and_expected_shapes(root: str):
+ """Test loading the dataset with different splits."""
+ train_dataset = datasets.Camelyon16(root=root, split="train", **DEFAULT_ARGS)
+ val_dataset = datasets.Camelyon16(root=root, split="val", **DEFAULT_ARGS)
+ test_dataset = datasets.Camelyon16(root=root, split="test", **DEFAULT_ARGS)
+
+ _setup_datasets(train_dataset, val_dataset, test_dataset)
+
+ assert len(train_dataset.datasets) == 3
+ assert len(val_dataset.datasets) == 1
+ assert len(test_dataset.datasets) == 2
+
+ assert len(train_dataset) == 192
+ assert len(val_dataset) == 64
+ assert len(test_dataset) == 128
+
+ _check_batch_shape(train_dataset[0])
+ _check_batch_shape(val_dataset[0])
+ _check_batch_shape(test_dataset[0])
+
+
+@pytest.mark.parametrize("split", ["train", "val", "test", None])
+def test_filenames(root: str, split: Literal["train", "val", "test"]):
+ """Tests that the number of filenames matches the dataset size."""
+ dataset = datasets.Camelyon16(root=root, split=split, **DEFAULT_ARGS)
+ _setup_datasets(dataset)
+
+ filenames = set()
+ for i in range(len(dataset)):
+ filenames.add(dataset.filename(i))
+
+ assert len(filenames) == len(dataset.datasets)
+
+
+def _check_batch_shape(batch: Any):
+ assert isinstance(batch, tuple)
+ assert len(batch) == 3
+
+ image, target, metadata = batch
+ assert isinstance(image, tv_tensors.Image)
+ assert image.shape == (3, TARGET_SIZE, TARGET_SIZE)
+
+ assert isinstance(target, torch.Tensor)
+ assert isinstance(metadata, dict)
+ assert "wsi_id" in metadata
+
+
+@pytest.fixture
+def root(assets_path: str) -> str:
+ """Fixture returning the root directory of the dataset."""
+ return os.path.join(assets_path, "vision/datasets/camelyon16")
+
+
+def _setup_datasets(*datasets: datasets.Camelyon16):
+ for dataset in datasets:
+ dataset.setup()
diff --git a/tests/eva/vision/data/datasets/classification/test_crc.py b/tests/eva/vision/data/datasets/classification/test_crc.py
index a199c0b2..c3f5ba09 100644
--- a/tests/eva/vision/data/datasets/classification/test_crc.py
+++ b/tests/eva/vision/data/datasets/classification/test_crc.py
@@ -21,12 +21,12 @@
)
def test_sample(crc_dataset: datasets.CRC, index: int) -> None:
"""Tests the format of a dataset sample."""
- # assert data sample is a tuple
sample = crc_dataset[index]
+ # assert data sample is a tuple
assert isinstance(sample, tuple)
- assert len(sample) == 2
+ assert len(sample) == 3
# assert the format of the `image` and `target`
- image, target = sample
+ image, target, _ = sample
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 16, 16)
assert isinstance(target, torch.Tensor)
diff --git a/tests/eva/vision/data/datasets/classification/test_mhist.py b/tests/eva/vision/data/datasets/classification/test_mhist.py
index f93d8294..5249e52e 100644
--- a/tests/eva/vision/data/datasets/classification/test_mhist.py
+++ b/tests/eva/vision/data/datasets/classification/test_mhist.py
@@ -30,12 +30,12 @@ def test_length(mhist_dataset: datasets.BACH, expected_length: int) -> None:
)
def test_sample(mhist_dataset: datasets.MHIST, index: int) -> None:
"""Tests the format of a dataset sample."""
- # assert data sample is a tuple
sample = mhist_dataset[index]
+ # assert data sample is a tuple
assert isinstance(sample, tuple)
- assert len(sample) == 2
+ assert len(sample) == 3
# assert the format of the `image` and `target`
- image, target = sample
+ image, target, _ = sample
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 224, 224)
assert isinstance(target, torch.Tensor)
diff --git a/tests/eva/vision/data/datasets/classification/test_panda.py b/tests/eva/vision/data/datasets/classification/test_panda.py
new file mode 100644
index 00000000..6b901344
--- /dev/null
+++ b/tests/eva/vision/data/datasets/classification/test_panda.py
@@ -0,0 +1,112 @@
+"""PANDA dataset tests."""
+
+import os
+from typing import Any, Literal
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+import torch
+import torchvision.transforms.v2 as torch_transforms
+from torchvision import tv_tensors
+
+from eva.vision.data import datasets
+from eva.vision.data import transforms as eva_transforms
+from eva.vision.data.wsi.patching import samplers
+
+TARGET_SIZE = 224
+DEFAULT_ARGS = {
+ "width": 16,
+ "height": 16,
+ "target_mpp": 0.5,
+ "sampler": samplers.GridSampler(),
+ "backend": "openslide",
+ "image_transforms": torch_transforms.Compose([eva_transforms.ResizeAndCrop(size=TARGET_SIZE)]),
+}
+
+
+def test_split_and_expected_shapes(root: str):
+ """Test loading the dataset with different splits."""
+ train_dataset = datasets.PANDA(root=root, split="train", **DEFAULT_ARGS)
+ val_dataset = datasets.PANDA(root=root, split="val", **DEFAULT_ARGS)
+ test_dataset = datasets.PANDA(root=root, split="test", **DEFAULT_ARGS)
+ _setup_datasets(train_dataset, val_dataset, test_dataset)
+
+ assert len(train_dataset.datasets) == 6
+ assert len(val_dataset.datasets) == 2
+ assert len(test_dataset.datasets) == 2
+
+ assert len(train_dataset) == 384
+ assert len(val_dataset) == 128
+ assert len(test_dataset) == 128
+
+ _check_batch_shape(train_dataset[0])
+ _check_batch_shape(val_dataset[0])
+ _check_batch_shape(test_dataset[0])
+
+
+@pytest.mark.parametrize("split", ["train", "val", "test", None])
+def test_filenames(root: str, split: Literal["train", "val", "test"]):
+ """Tests that the number of filenames matches the dataset size."""
+ dataset = datasets.PANDA(root=root, split=split, **DEFAULT_ARGS)
+ _setup_datasets(dataset)
+
+ filenames = set()
+ for i in range(len(dataset)):
+ filenames.add(dataset.filename(i))
+
+ assert len(filenames) == len(dataset.datasets)
+
+
+def test_same_split_same_seed(root: str):
+ """Test that the generated split is deterministic when using the same seed."""
+ dataset1 = datasets.PANDA(root=root, split="train", seed=42, **DEFAULT_ARGS)
+ dataset2 = datasets.PANDA(root=root, split="train", seed=42, **DEFAULT_ARGS)
+ _setup_datasets(dataset1, dataset2)
+
+ assert len(dataset1) == len(dataset2)
+ assert dataset1._file_paths == dataset2._file_paths
+
+ for i in range(len(dataset1)):
+ assert np.allclose(dataset1[i][1], dataset2[i][1])
+
+
+def test_different_seed_different_split(root: str):
+ """Test that the generated split is different when using a different seed."""
+ dataset1 = datasets.PANDA(root=root, split="train", seed=42, **DEFAULT_ARGS)
+ dataset2 = datasets.PANDA(root=root, split="train", seed=43, **DEFAULT_ARGS)
+ _setup_datasets(dataset1, dataset2)
+
+ assert len(dataset1) == len(dataset2)
+ assert dataset1._file_paths != dataset2._file_paths
+
+
+def _check_batch_shape(batch: Any):
+ assert isinstance(batch, tuple)
+ assert len(batch) == 3
+
+ image, target, metadata = batch
+ assert isinstance(image, tv_tensors.Image)
+ assert image.shape == (3, TARGET_SIZE, TARGET_SIZE)
+
+ assert isinstance(target, torch.Tensor)
+ assert isinstance(metadata, dict)
+ assert "wsi_id" in metadata
+
+
+@pytest.fixture
+def root(assets_path: str) -> str:
+ """Fixture returning the root directory of the dataset."""
+ return os.path.join(assets_path, "vision/datasets/panda")
+
+
+@pytest.fixture(autouse=True)
+def mock_download():
+ """Mocks the download function to avoid downloading resources when running tests."""
+ with patch.object(datasets.PANDA, "_download_resources", return_value=None):
+ yield
+
+
+def _setup_datasets(*datasets: datasets.PANDA):
+ for dataset in datasets:
+ dataset.setup()
diff --git a/tests/eva/vision/data/datasets/classification/test_patch_camelyon.py b/tests/eva/vision/data/datasets/classification/test_patch_camelyon.py
index 3eaaa596..30ecb73a 100644
--- a/tests/eva/vision/data/datasets/classification/test_patch_camelyon.py
+++ b/tests/eva/vision/data/datasets/classification/test_patch_camelyon.py
@@ -25,12 +25,12 @@ def test_length(patch_camelyon_dataset: datasets.PatchCamelyon, expected_length:
)
def test_sample(patch_camelyon_dataset: datasets.PatchCamelyon) -> None:
"""Tests the format of a dataset sample."""
- # assert data sample is a tuple
sample = patch_camelyon_dataset[0]
+ # assert data sample is a tuple
assert isinstance(sample, tuple)
- assert len(sample) == 2
+ assert len(sample) == 3
# assert the format of the `image` and `target`
- image, target = sample
+ image, target, _ = sample
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 96, 96)
assert isinstance(target, torch.Tensor)
diff --git a/tests/eva/vision/data/datasets/classification/test_wsi.py b/tests/eva/vision/data/datasets/classification/test_wsi.py
new file mode 100644
index 00000000..d14573d8
--- /dev/null
+++ b/tests/eva/vision/data/datasets/classification/test_wsi.py
@@ -0,0 +1,95 @@
+"""WsiClassificationDataset tests."""
+
+import os
+import pickle
+import re
+from typing import Any
+
+import numpy as np
+import pytest
+import torch
+import torchvision.transforms.v2 as torch_transforms
+
+from eva.vision.data import datasets
+from eva.vision.data import transforms as eva_transforms
+from eva.vision.data.wsi.patching import samplers
+
+TARGET_SIZE = 224
+DEFAULT_ARGS = {
+ "manifest_file": "manifest.csv",
+ "width": 32,
+ "height": 32,
+ "target_mpp": 0.25,
+ "sampler": samplers.GridSampler(),
+ "backend": "openslide",
+ "image_transforms": torch_transforms.Compose([eva_transforms.ResizeAndCrop(size=TARGET_SIZE)]),
+}
+
+
+def test_pickleable(dataset: datasets.WsiClassificationDataset):
+ """Tests if the dataset is pickleable (required for multi-worker torch data loaders)."""
+ pickled = pickle.dumps(dataset)
+
+ # Check if it works after unpickling
+ unpickled_dataset = pickle.loads(pickled)
+ for batch in unpickled_dataset:
+ _check_batch_shape(batch)
+
+
+def test_split(root: str):
+ """Test loading the dataset with different splits."""
+ dataset = datasets.WsiClassificationDataset(root=root, split=None, **DEFAULT_ARGS)
+ dataset.setup()
+ assert len(dataset) == 192
+ _check_batch_shape(dataset[0])
+
+ train_dataset = datasets.WsiClassificationDataset(root=root, split="train", **DEFAULT_ARGS)
+ train_dataset.setup()
+ assert len(train_dataset) == 64
+ _check_batch_shape(train_dataset[0])
+
+
+def test_filename(dataset: datasets.WsiClassificationDataset):
+ """Tests the filename method."""
+ pattern = r"^\d+/[a-z]\.tiff$"
+ for i in range(len(dataset)):
+ assert bool(re.match(pattern, dataset.filename(i)))
+
+
+def test_missing_columns(root: str):
+ """Test if error is raised if columns are missing in the manifest file."""
+ with pytest.raises(ValueError, match="Missing columns in the manifest file"):
+ datasets.WsiClassificationDataset(
+ root=root,
+ column_mapping={"target": "label"},
+ **DEFAULT_ARGS,
+ )
+
+
+def _check_batch_shape(batch: Any):
+ assert isinstance(batch, tuple)
+ assert len(batch) == 3
+
+ image, target, metadata = batch
+ assert isinstance(image, torch.Tensor)
+ assert image.shape == (3, TARGET_SIZE, TARGET_SIZE)
+
+ assert isinstance(target, np.ndarray)
+ assert target.size == 1
+
+ assert isinstance(metadata, dict)
+ assert "wsi_id" in metadata
+
+
+@pytest.fixture
+def dataset(root: str) -> datasets.WsiClassificationDataset:
+ """Fixture returning a dataset instance."""
+ dataset = datasets.WsiClassificationDataset(root=root, **DEFAULT_ARGS)
+ dataset.setup()
+ return dataset
+
+
+@pytest.fixture
+def root(assets_path: str) -> str:
+ """Fixture returning the root path to the test dataset assets."""
+ return os.path.join(assets_path, "vision/datasets/wsi")
diff --git a/tests/eva/vision/data/datasets/test_wsi.py b/tests/eva/vision/data/datasets/test_wsi.py
new file mode 100644
index 00000000..87959a60
--- /dev/null
+++ b/tests/eva/vision/data/datasets/test_wsi.py
@@ -0,0 +1,108 @@
+"""WsiDataset & MultiWsiDataset tests."""
+
+import os
+from typing import Tuple
+
+import pytest
+
+from eva.vision.data import datasets
+from eva.vision.data.wsi.backends import is_backend_available
+from eva.vision.data.wsi.patching import samplers
+
+
+@pytest.mark.parametrize(
+ "width, height, overlap, backend",
+ [
+ (4, 4, (0, 0), "openslide"),
+ (4, 4, (2, 2), "openslide"),
+ (33, 33, (0, 0), "openslide"),
+ (224, 224, (0, 0), "openslide"),
+ (4, 4, (0, 0), "tiffslide"),
+ (4, 4, (2, 2), "tiffslide"),
+ (33, 33, (0, 0), "tiffslide"),
+ (224, 224, (0, 0), "tiffslide"),
+ ],
+)
+def test_len(width: int, height: int, root: str, overlap: Tuple[int, int], backend: str):
+ """Test the length of the dataset using different patch dimensions."""
+ if not is_backend_available(backend):
+ pytest.skip(f"{backend} backend is not available.")
+ dataset = datasets.WsiDataset(
+ file_path=os.path.join(root, "0/a.tiff"),
+ width=width,
+ height=height,
+ target_mpp=0.25,
+ sampler=samplers.GridSampler(max_samples=None, overlap=overlap),
+ backend=backend,
+ )
+
+ layer_shape = dataset._wsi.level_dimensions[0]
+ assert len(dataset) == _expected_n_patches(layer_shape, width, height, overlap)
+
+
+@pytest.mark.parametrize(
+ "width, height, target_mpp, backend",
+ [
+ (4, 4, 0.25, "openslide"),
+ (4, 4, 1.3, "openslide"),
+ (4, 4, 0.25, "tiffslide"),
+ (4, 4, 1.3, "tiffslide"),
+ ],
+)
+def test_patch_shape(width: int, height: int, target_mpp: float, root: str, backend: str):
+ """Test the shape of the extracted patches."""
+ if not is_backend_available(backend):
+ pytest.skip(f"{backend} backend is not available.")
+ dataset = datasets.WsiDataset(
+ file_path=os.path.join(root, "0/a.tiff"),
+ width=width,
+ height=height,
+ target_mpp=target_mpp,
+ sampler=samplers.GridSampler(max_samples=None),
+ backend=backend,
+ )
+
+ mpp_ratio = target_mpp / (
+ dataset._wsi.mpp * dataset._wsi.level_downsamples[dataset._coords.level_idx]
+ )
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
+ assert dataset[0].shape == (3, scaled_width, scaled_height)
+
+
+def test_multi_dataset(root: str):
+ """Test MultiWsiDataset with multiple whole-slide image paths."""
+ file_paths = [
+ os.path.join(root, "0/a.tiff"),
+ os.path.join(root, "0/b.tiff"),
+ os.path.join(root, "1/a.tiff"),
+ ]
+
+ width, height = 32, 32
+ dataset = datasets.MultiWsiDataset(
+ root=root,
+ file_paths=file_paths,
+ width=width,
+ height=height,
+ target_mpp=0.25,
+ sampler=samplers.GridSampler(max_samples=None),
+ backend="openslide",
+ )
+ dataset.setup()
+
+ assert isinstance(dataset.datasets[0], datasets.WsiDataset)
+ layer_shape = dataset.datasets[0]._wsi.level_dimensions[0]
+ assert len(dataset) == _expected_n_patches(layer_shape, width, height, (0, 0)) * len(file_paths)
+ assert dataset.cumulative_sizes == [64, 128, 192]
+
+
+def _expected_n_patches(layer_shape, width, height, overlap):
+ """Calculate the expected number of patches."""
+ n_patches_x = (layer_shape[0] - width) // (width - overlap[0]) + 1
+ n_patches_y = (layer_shape[1] - height) // (height - overlap[1]) + 1
+ return n_patches_x * n_patches_y
+
+
+@pytest.fixture
+def root(assets_path: str) -> str:
+ """Fixture returning the root path to the test dataset assets."""
+ return os.path.join(assets_path, "vision/datasets/wsi")
diff --git a/tests/eva/vision/data/wsi/__init__.py b/tests/eva/vision/data/wsi/__init__.py
new file mode 100644
index 00000000..c3adfdd3
--- /dev/null
+++ b/tests/eva/vision/data/wsi/__init__.py
@@ -0,0 +1 @@
+"""WSI module tests."""
diff --git a/tests/eva/vision/data/wsi/patching/__init__.py b/tests/eva/vision/data/wsi/patching/__init__.py
new file mode 100644
index 00000000..686c6e8d
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/__init__.py
@@ -0,0 +1 @@
+"""WSI patch extraction tests."""
diff --git a/tests/eva/vision/data/wsi/patching/samplers/__init__.py b/tests/eva/vision/data/wsi/patching/samplers/__init__.py
new file mode 100644
index 00000000..e7064022
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/samplers/__init__.py
@@ -0,0 +1 @@
+"""WSI patch samplers tests."""
diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py b/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py
new file mode 100644
index 00000000..9a5510ac
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/samplers/test_foreground_grid.py
@@ -0,0 +1,93 @@
+"""ForegroundGridSampler tests."""
+
+import numpy as np
+import pytest
+
+from eva.vision.data.wsi.patching import mask, samplers
+
+TEST_MASK = mask.Mask(
+ mask_array=np.array(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0],
+ ]
+ ),
+ mask_level_idx=3,
+ scale_factors=(6.0, 6.0),
+)
+
+TEST_ARGS = {"width": 12, "height": 12, "layer_shape": (36, 36), "mask": TEST_MASK}
+
+
+@pytest.mark.parametrize(
+ "min_foreground_ratio, max_samples, expected_n_samples",
+ [(0.0, 3, 3), (0.0, 100, 9), (0.5, 100, 5), (0.9, 100, 1)],
+)
+def test_length(min_foreground_ratio: float, max_samples: int, expected_n_samples: int) -> None:
+ """Tests if the sampler returns the correct number of samples."""
+ sampler = samplers.ForegroundGridSampler(
+ max_samples=max_samples, min_foreground_ratio=min_foreground_ratio
+ )
+
+ x_y = list(sampler.sample(**TEST_ARGS))
+
+ assert len(x_y) == expected_n_samples
+
+
+@pytest.mark.parametrize("n_samples, seed", [(10, 8), (22, 42)])
+def test_same_seed(n_samples: int, seed: int) -> None:
+ """Tests if the sampler returns the same samples for the same seed."""
+ sampler = samplers.ForegroundGridSampler(
+ max_samples=n_samples, seed=seed, min_foreground_ratio=0.5
+ )
+
+ x_y_1 = list(sampler.sample(**TEST_ARGS))
+ x_y_2 = list(sampler.sample(**TEST_ARGS))
+
+ assert x_y_1 == x_y_2
+
+
+@pytest.mark.parametrize("n_samples, seed_1, seed_2", [(3, 1, 2), (5, 3, 4)])
+def test_different_seed(n_samples: int, seed_1: int, seed_2: int) -> None:
+ """Tests if the sampler returns different samples for different seeds."""
+ sampler_1 = samplers.ForegroundGridSampler(max_samples=n_samples, seed=seed_1)
+ sampler_2 = samplers.ForegroundGridSampler(max_samples=n_samples, seed=seed_2)
+
+ x_y_1 = list(sampler_1.sample(**TEST_ARGS))
+ x_y_2 = list(sampler_2.sample(**TEST_ARGS))
+
+ assert x_y_1 != x_y_2
+
+
+def test_invalid_width_height() -> None:
+ """Tests if the sampler raises an error when width / height is bigger than layer_shape."""
+ sampler = samplers.ForegroundGridSampler(max_samples=10, seed=42)
+
+ with pytest.raises(ValueError):
+ list(sampler.sample(width=200, height=200, layer_shape=(100, 100), mask=TEST_MASK))
+
+
+@pytest.mark.parametrize("min_foreground_ratio", [0.0, 0.5, 0.9])
+def test_min_foreground_ratio(min_foreground_ratio: float) -> None:
+ """Tests if sampled coordinates respect the min_foreground_ratio."""
+ sampler = samplers.ForegroundGridSampler(
+ max_samples=100, min_foreground_ratio=min_foreground_ratio
+ )
+
+ x_y = list(sampler.sample(**TEST_ARGS))
+
+ mask = TEST_MASK
+ width, height = TEST_ARGS["width"], TEST_ARGS["height"]
+
+ for x, y in x_y:
+ x_, y_ = sampler._scale_coords(x, y, mask.scale_factors)
+ width_, height_ = sampler._scale_coords(width, height, mask.scale_factors)
+
+ patch_mask = mask.mask_array[x_ : x_ + width_, y_ : y_ + height_]
+ foreground_ratio = patch_mask.sum() / patch_mask.size
+
+ assert foreground_ratio >= min_foreground_ratio
diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_grid.py b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py
new file mode 100644
index 00000000..efeecf54
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/samplers/test_grid.py
@@ -0,0 +1,69 @@
+"""GridSampler tests."""
+
+from typing import Tuple
+
+import pytest
+
+from eva.vision.data.wsi.patching import samplers
+
+TEST_ARGS = {"width": 10, "height": 10, "layer_shape": (100, 100)}
+
+
+@pytest.mark.parametrize("max_samples, expected_n_samples", [(3, 3), (10, 10), (200, 100)])
+def test_length(max_samples: int, expected_n_samples: int) -> None:
+ """Tests if the sampler returns the correct number of samples."""
+ sampler = samplers.GridSampler(max_samples=max_samples)
+
+ x_y = list(sampler.sample(**TEST_ARGS))
+
+ assert len(x_y) == expected_n_samples
+
+
+@pytest.mark.parametrize("max_samples, seed", [(10, 8), (22, 42)])
+def test_same_seed(max_samples: int, seed: int) -> None:
+ """Tests if the sampler returns the same samples for the same seed."""
+ sampler = samplers.GridSampler(max_samples=max_samples, seed=seed)
+
+ x_y_1 = list(sampler.sample(**TEST_ARGS))
+ x_y_2 = list(sampler.sample(**TEST_ARGS))
+
+ assert x_y_1 == x_y_2
+
+
+@pytest.mark.parametrize("max_samples, seed_1, seed_2", [(3, 1, 2), (5, 3, 4)])
+def test_different_seed(max_samples: int, seed_1: int, seed_2: int) -> None:
+ """Tests if the sampler returns different samples for different seeds."""
+ sampler_1 = samplers.GridSampler(max_samples=max_samples, seed=seed_1)
+ sampler_2 = samplers.GridSampler(max_samples=max_samples, seed=seed_2)
+
+ x_y_1 = list(sampler_1.sample(**TEST_ARGS))
+ x_y_2 = list(sampler_2.sample(**TEST_ARGS))
+
+ assert x_y_1 != x_y_2
+
+
+def test_invalid_width_height() -> None:
+ """Tests if the sampler raises an error when width / height is bigger than layer_shape."""
+ sampler = samplers.GridSampler(max_samples=10, seed=42)
+
+ with pytest.raises(ValueError):
+ list(sampler.sample(width=200, height=200, layer_shape=(100, 100)))
+
+
+@pytest.mark.parametrize(
+ "width, height, layer_shape",
+ [
+ (5, 5, (25, 25)),
+ (5, 5, (100, 100)),
+ (224, 224, (1000, 1000)),
+ ],
+)
+def test_expected_n_patches(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
+ """Tests if the sampler respects the max_samples limit."""
+ sampler = samplers.GridSampler(max_samples=None)
+
+ expected_max_samples = (layer_shape[0] // width) * (layer_shape[1] // height)
+
+ x_y = list(sampler.sample(width=width, height=height, layer_shape=layer_shape))
+
+ assert len(x_y) == expected_max_samples
diff --git a/tests/eva/vision/data/wsi/patching/samplers/test_random.py b/tests/eva/vision/data/wsi/patching/samplers/test_random.py
new file mode 100644
index 00000000..85110a6c
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/samplers/test_random.py
@@ -0,0 +1,48 @@
+"""RandomSampler tests."""
+
+import pytest
+
+from eva.vision.data.wsi.patching import samplers
+
+TEST_ARGS = {"width": 10, "height": 10, "layer_shape": (100, 100)}
+
+
+@pytest.mark.parametrize("n_samples", [3, 10, 22])
+def test_length(n_samples: int) -> None:
+ """Tests if the sampler returns the correct number of samples."""
+ sampler = samplers.RandomSampler(n_samples=n_samples)
+
+ x_y = list(sampler.sample(**TEST_ARGS))
+
+ assert len(x_y) == n_samples
+
+
+@pytest.mark.parametrize("n_samples, seed", [(10, 8), (22, 42)])
+def test_same_seed(n_samples: int, seed: int) -> None:
+ """Tests if the sampler returns the same samples for the same seed."""
+ sampler = samplers.RandomSampler(n_samples=n_samples, seed=seed)
+
+ x_y_1 = list(sampler.sample(**TEST_ARGS))
+ x_y_2 = list(sampler.sample(**TEST_ARGS))
+
+ assert x_y_1 == x_y_2
+
+
+@pytest.mark.parametrize("n_samples, seed_1, seed_2", [(10, 1, 2), (22, 3, 4)])
+def test_different_seed(n_samples: int, seed_1: int, seed_2: int) -> None:
+ """Tests if the sampler returns different samples for different seeds."""
+ sampler_1 = samplers.RandomSampler(n_samples=n_samples, seed=seed_1)
+ sampler_2 = samplers.RandomSampler(n_samples=n_samples, seed=seed_2)
+
+ x_y_1 = list(sampler_1.sample(**TEST_ARGS))
+ x_y_2 = list(sampler_2.sample(**TEST_ARGS))
+
+ assert x_y_1 != x_y_2
+
+
+def test_invalid_width_height() -> None:
+ """Tests if the sampler raises an error when width / height is bigger than layer_shape."""
+ sampler = samplers.RandomSampler(n_samples=10, seed=42)
+
+ with pytest.raises(ValueError):
+ list(sampler.sample(width=200, height=200, layer_shape=(100, 100)))
diff --git a/tests/eva/vision/data/wsi/patching/test_mask.py b/tests/eva/vision/data/wsi/patching/test_mask.py
new file mode 100644
index 00000000..e63375ea
--- /dev/null
+++ b/tests/eva/vision/data/wsi/patching/test_mask.py
@@ -0,0 +1,91 @@
+"""WSI foreground mask tests."""
+
+import os
+
+import numpy as np
+import pytest
+
+from eva.vision.data import wsi as eva_wsi
+
+DEFAULT_ARGS = {
+ "saturation_threshold": 20,
+ "median_blur_kernel_size": 7,
+ "fill_holes": False,
+ "use_otsu": False,
+ "holes_kernel_size": (7, 7),
+}
+
+
+@pytest.mark.parametrize(
+ "mask_level_idx, mask_args",
+ [
+ (0, DEFAULT_ARGS),
+ (1, DEFAULT_ARGS),
+ (0, DEFAULT_ARGS | {"median_blur_kernel_size": None}),
+ (0, DEFAULT_ARGS | {"fill_holes": True}),
+ (0, DEFAULT_ARGS | {"use_otsu": True}),
+ (0, DEFAULT_ARGS | {"fill_holes": True, "use_otsu": True}),
+ ],
+)
+def test_get_mask(wsi: eva_wsi.Wsi, mask_level_idx: int, mask_args: dict):
+ """Tests the foreground mask generation with different configurations."""
+ mask = eva_wsi.get_mask(wsi, mask_level_idx=0, **mask_args)
+
+ assert isinstance(mask, eva_wsi.Mask)
+ assert isinstance(mask.mask_array, np.ndarray)
+ assert mask.mask_array.shape == wsi.level_dimensions[mask.mask_level_idx]
+ assert np.all(np.isin(mask.mask_array, [0, 1]))
+
+ if mask.mask_level_idx == 0:
+ assert mask.scale_factors == (1.0, 1.0)
+ elif mask_level_idx == 1:
+ assert mask.scale_factors == (0.5, 0.5)
+
+
+@pytest.mark.parametrize(
+ "width, height, target_mpp, expected_level",
+ [
+ (4, 4, 0.25, 0),
+ (16, 16, 0.05, 0),
+ (4, 4, 0.5, 1),
+ ],
+)
+def test_get_mask_level(
+ wsi: eva_wsi.Wsi, width: int, height: int, target_mpp: float, expected_level: int
+):
+ """Tests the selection of the mask level based on the patch dimensions."""
+ level = eva_wsi.get_mask_level(wsi, width, height, target_mpp)
+ assert level == expected_level
+
+
+@pytest.mark.parametrize(
+ "width, height, target_mpp",
+ [
+ (4, 4, 0.1),
+ (16, 16, 0.01),
+ (2, 2, 0.25),
+ ],
+)
+def test_no_suitable_level_available(wsi: eva_wsi.Wsi, width: int, height: int, target_mpp: float):
+ """Tests the case where no suitable mask level is available.
+
+ This can happen for instance when the patch dimensions scaled to the selected mask level
+ are too small or even collapse to zero pixels.
+ """
+ with pytest.raises(
+ ValueError, match="No level with the specified minimum number of patch pixels available."
+ ):
+ eva_wsi.get_mask_level(wsi, width, height, target_mpp)
+
+
+@pytest.fixture
+def wsi(assets_path: str) -> eva_wsi.Wsi:
+ """Fixture for loading a WSI object.
+
+ The test WSI slide has the following specs:
+ - level_dimensions: ((256, 256), (128, 128))
+ - level_downsamples: (1.0, 2.0)
+ - mpp (level 0): 0.25
+ """
+ path = os.path.join(assets_path, "vision/datasets/wsi/0/a.tiff")
+ return eva_wsi.wsi_backend("openslide")(path)
diff --git a/tests/eva/vision/models/networks/test_abmil.py b/tests/eva/vision/models/networks/test_abmil.py
index 7ca80a02..b89fb01d 100644
--- a/tests/eva/vision/models/networks/test_abmil.py
+++ b/tests/eva/vision/models/networks/test_abmil.py
@@ -1,6 +1,7 @@
"""ABMIL network tests."""
import itertools
+from typing import Tuple
import pytest
import torch
@@ -15,7 +16,7 @@
def test_masked_abmil(
input_size: int,
output_size: int,
- hidden_sizes_mlp: tuple[int],
+ hidden_sizes_mlp: Tuple[int],
batch_size: int,
n_instances: int,
masked_fraction: float,
diff --git a/tests/eva/vision/test_vision_cli.py b/tests/eva/vision/test_vision_cli.py
index a1c69963..174f0c32 100644
--- a/tests/eva/vision/test_vision_cli.py
+++ b/tests/eva/vision/test_vision_cli.py
@@ -3,6 +3,7 @@
import os
import tempfile
from unittest import mock
+from unittest.mock import patch
import pytest
@@ -21,10 +22,14 @@
"configs/vision/dino_vit/offline/crc.yaml",
"configs/vision/dino_vit/offline/mhist.yaml",
"configs/vision/dino_vit/offline/patch_camelyon.yaml",
+ "configs/vision/dino_vit/offline/panda.yaml",
+ "configs/vision/dino_vit/offline/camelyon16.yaml",
"configs/vision/owkin/phikon/offline/bach.yaml",
"configs/vision/owkin/phikon/offline/crc.yaml",
"configs/vision/owkin/phikon/offline/mhist.yaml",
"configs/vision/owkin/phikon/offline/patch_camelyon.yaml",
+ "configs/vision/owkin/phikon/offline/panda.yaml",
+ "configs/vision/owkin/phikon/offline/camelyon16.yaml",
],
)
def test_configuration_initialization(configuration_file: str, lib_path: str) -> None:
@@ -61,6 +66,7 @@ def test_fit_from_configuration(configuration_file: str, lib_path: str) -> None:
"configuration_file",
[
"configs/vision/tests/offline/patch_camelyon.yaml",
+ "configs/vision/tests/offline/panda.yaml",
],
)
def test_predict_fit_from_configuration(configuration_file: str, lib_path: str) -> None:
@@ -80,3 +86,10 @@ def test_predict_fit_from_configuration(configuration_file: str, lib_path: str)
def _skip_dataset_validation() -> None:
"""Mocks the validation step of the datasets."""
datasets.PatchCamelyon.validate = mock.MagicMock(return_value=None)
+
+
+@pytest.fixture(autouse=True)
+def mock_download():
+ """Mocks the download functions to avoid downloading resources when running tests."""
+ with patch.object(datasets.PANDA, "_download_resources", return_value=None):
+ yield