diff --git a/torchio/data/dataset.py b/torchio/data/dataset.py index 3b31cf68e..37c74e338 100644 --- a/torchio/data/dataset.py +++ b/torchio/data/dataset.py @@ -1,6 +1,5 @@ import copy -import collections -from typing import Sequence, Optional, Callable +from typing import Sequence, Optional, Callable, Iterable from torch.utils.data import Dataset @@ -17,8 +16,7 @@ class SubjectsDataset(Dataset): and an optional transform applied to the volumes after loading. Args: - subjects: List of instances of - :class:`~torchio.Subject`. + subjects: List of instances of :class:`~torchio.Subject`. transform: An instance of :class:`~torchio.transforms.Transform` that will be applied to each subject. load_getitem: Load all subject images before returning it in @@ -96,7 +94,7 @@ def set_transform(self, transform: Optional[Callable]) -> None: self._transform = transform @staticmethod - def _parse_subjects_list(subjects_list: Sequence[Subject]) -> None: + def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None: # Check that it's an iterable try: iter(subjects_list)