diff --git a/torch_em/data/datasets/pannuke.py b/torch_em/data/datasets/pannuke.py index 2a2a9cb9..df46154f 100644 --- a/torch_em/data/datasets/pannuke.py +++ b/torch_em/data/datasets/pannuke.py @@ -4,6 +4,7 @@ import shutil import numpy as np from glob import glob +from typing import List import torch_em from torch_em.data.datasets import util @@ -144,12 +145,12 @@ def _channels_to_semantics(labels): def get_pannuke_dataset( path, patch_shape, - folds=("fold_1", "fold_2", "fold_3"), + folds: List[str] = ["fold_1", "fold_2", "fold_3"], rois={}, download=False, with_channels=True, with_label_channels=False, - custom_label_choice="instances", + custom_label_choice: str = "instances", **kwargs ): assert custom_label_choice in [ @@ -177,7 +178,7 @@ def get_pannuke_loader( path, patch_shape, batch_size, - folds=("fold_1", "fold_2", "fold_3"), + folds=["fold_1", "fold_2", "fold_3"], download=False, rois={}, custom_label_choice="instances",