diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/utils/utils.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/utils/utils.py index 369f9d1729..d29cada314 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/utils/utils.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/utils/utils.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Optional, Tuple from os import PathLike -from os.path import splitext +from os.path import join, splitext from pathlib import Path from itertools import chain @@ -64,10 +64,20 @@ def make_image_folder_dataset(data_dir: str, return DatasetFolder( data_dir, loader=load_image, extensions=IMG_EXTENSIONS) + from rastervision.pipeline.file_system.utils import (file_exists, + list_paths) + + class_dirs = [join(data_dir, c) for c in classes] + classes_present = [ + c for c, dir in zip(classes, class_dirs) + if file_exists(dir, include_dir=True) and len(list_paths(dir)) > 0 + ] + class_to_id = {c: classes.index(c) for c in classes_present} + class ImageFolder(DatasetFolder): def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: """Override to force mapping from class name to class index.""" - return classes, {c: i for (i, c) in enumerate(classes)} + return classes_present, class_to_id return ImageFolder(data_dir, loader=load_image, extensions=IMG_EXTENSIONS)