From 61e8f29021d771afd1d9f72c3da764b5e522dc7a Mon Sep 17 00:00:00 2001 From: Lewis Fishgold Date: Wed, 27 May 2020 15:52:49 -0400 Subject: [PATCH] Support groups of chips and limits on training size --- .../pytorch_learner/classification_learner.py | 4 +- rastervision2/pytorch_learner/learner.py | 66 +++++++++++++++---- .../pytorch_learner/learner_config.py | 12 ++++ .../object_detection_learner.py | 4 +- .../pytorch_learner/regression_learner.py | 4 +- .../semantic_segmentation_learner.py | 4 +- 6 files changed, 75 insertions(+), 19 deletions(-) diff --git a/rastervision2/pytorch_learner/classification_learner.py b/rastervision2/pytorch_learner/classification_learner.py index c997ba224..0c35e65d4 100644 --- a/rastervision2/pytorch_learner/classification_learner.py +++ b/rastervision2/pytorch_learner/classification_learner.py @@ -29,12 +29,12 @@ def build_model(self): model.fc = nn.Linear(in_features, num_labels) return model - def get_datasets(self): + def _get_datasets(self, uri): cfg = self.cfg class_names = cfg.data.class_names if cfg.data.data_format == ClassificationDataFormat.image_folder: - data_dirs = self.unzip_data() + data_dirs = self.unzip_data(uri) transform, aug_transform = self.get_data_transforms() diff --git a/rastervision2/pytorch_learner/learner.py b/rastervision2/pytorch_learner/learner.py index f0d6d310a..cc2f534d9 100644 --- a/rastervision2/pytorch_learner/learner.py +++ b/rastervision2/pytorch_learner/learner.py @@ -12,7 +12,9 @@ from subprocess import Popen import numbers import zipfile -from typing import Optional, List, Tuple, Dict +from typing import Optional, List, Tuple, Dict, Union +import random +import uuid import click import matplotlib @@ -25,7 +27,7 @@ import torch.optim as optim from torch.optim.lr_scheduler import CyclicLR, MultiStepLR, _LRScheduler from torch.utils.tensorboard import SummaryWriter -from torch.utils.data import DataLoader, Subset, Dataset +from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset from albumentations.augmentations.transforms import ( Blur, RandomRotate90, HorizontalFlip, VerticalFlip, GaussianBlur, GaussNoise, RGBShift, ToGray, Resize) @@ -195,23 +197,27 @@ def build_model(self) -> nn.Module: """Build a PyTorch model.""" pass - def unzip_data(self) -> List[str]: + def unzip_data(self, uri: Union[str, List[str]]) -> List[str]: """Unzip dataset zip files. + Args: + uri: a list of URIs of zip files or the URI of a directory containing + zip files + Returns: paths to directories that each contain contents of one zip file """ cfg = self.cfg data_dirs = [] - if isinstance(cfg.data.uri, list): - zip_uris = cfg.data.uri + if isinstance(uri, list): + zip_uris = uri else: - if cfg.data.uri.startswith('s3://') or cfg.data.uri.startswith( - '/'): - data_uri = cfg.data.uri + # TODO generalize this to work with any file system + if uri.startswith('s3://') or uri.startswith('/'): + data_uri = uri else: - data_uri = join(cfg.base_uri, cfg.data.uri) + data_uri = join(cfg.base_uri, uri) zip_uris = ([data_uri] if data_uri.endswith('.zip') else list_paths( data_uri, 'zip')) @@ -221,7 +227,7 @@ def unzip_data(self) -> List[str]: if not isfile(zip_path): zip_path = download_if_needed(zip_uri, self.data_cache_dir) with zipfile.ZipFile(zip_path, 'r') as zipf: - data_dir = join(self.tmp_dir, 'data', str(zip_ind)) + data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()), str(zip_ind)) data_dirs.append(data_dir) zipf.extractall(data_dir) @@ -277,9 +283,41 @@ def get_collate_fn(self) -> Optional[callable]: """ return None + def _get_datasets(self, uri: Union[str, List[str]]) -> Tuple[Dataset, Dataset, Dataset]: # noqa + """Gets Datasets for a single group of chips. + + This should be overridden for each Learner subclass. + + Args: + uri: a list of URIs of zip files or the URI of a directory containing + zip files + + Returns: + train, validation, and test DataSets.""" + raise NotImplementedError() + def get_datasets(self) -> Tuple[Dataset, Dataset, Dataset]: """Returns train, validation, and test DataSets.""" - raise NotImplementedError() + if self.cfg.data.group_uris: + train_ds_lst, valid_ds_lst, test_ds_lst = [], [], [] + for group_uri in self.cfg.data.group_uris: + train_ds, valid_ds, test_ds = self._get_datasets(group_uri) + group_train_sz = self.cfg.data.group_train_sz + if group_train_sz is not None: + train_inds = list(range(len(train_ds))) + random.shuffle(train_inds) + train_inds = train_inds[0:group_train_sz] + train_ds = Subset(train_ds, train_inds) + train_ds_lst.append(train_ds) + valid_ds_lst.append(valid_ds) + test_ds_lst.append(test_ds) + + train_ds, valid_ds, test_ds = ( + ConcatDataset(train_ds_lst), ConcatDataset(valid_ds_lst), + ConcatDataset(test_ds_lst)) + return train_ds, valid_ds, test_ds + else: + return self._get_datasets(self.cfg.data.uri) def setup_data(self): """Set the the DataSet and DataLoaders for train, validation, and test sets.""" @@ -307,6 +345,12 @@ def setup_data(self): valid_ds = Subset(valid_ds, range(batch_sz)) test_ds = Subset(test_ds, range(batch_sz)) + if cfg.data.train_sz is not None: + train_inds = list(range(len(train_ds))) + random.shuffle(train_inds) + train_inds = train_inds[0:cfg.data.train_sz] + train_ds = Subset(train_ds, train_inds) + collate_fn = self.get_collate_fn() train_dl = DataLoader( train_ds, diff --git a/rastervision2/pytorch_learner/learner_config.py b/rastervision2/pytorch_learner/learner_config.py index 2554e00d6..abac68c98 100644 --- a/rastervision2/pytorch_learner/learner_config.py +++ b/rastervision2/pytorch_learner/learner_config.py @@ -135,6 +135,18 @@ class DataConfig(Config): description= ('URI of the dataset. This can be a zip file, a list of zip files, or a ' 'directory which contains a set of zip files.')) + train_sz: Optional[int] = Field( + None, description=( + 'If set, the number of training images to use. If fewer images exist, ' + 'then an exception will be raised.')) + group_uris: Union[None, List[Union[str, List[str]]]] = Field(None, description=( + 'This can be set instead of uri in order to specify groups of chips. Each ' + 'element in the list is expected to be an object of the same form accepted by ' + 'the uri field. The purpose of separating chips into groups is to be able to ' + 'use the group_train_sz field.')) + group_train_sz: Optional[int] = Field(None, description=( + 'If group_uris is set, this can be used to specify the number of chips to use ' + 'per group.')) data_format: Optional[str] = Field( None, description='Name of dataset format.') class_names: List[str] = Field([], description='Names of classes.') diff --git a/rastervision2/pytorch_learner/object_detection_learner.py b/rastervision2/pytorch_learner/object_detection_learner.py index c5b96b071..95ee28936 100644 --- a/rastervision2/pytorch_learner/object_detection_learner.py +++ b/rastervision2/pytorch_learner/object_detection_learner.py @@ -40,11 +40,11 @@ def get_bbox_params(self): def get_collate_fn(self): return collate_fn - def get_datasets(self): + def _get_datasets(self, uri): cfg = self.cfg if cfg.data.data_format == ObjectDetectionDataFormat.default: - data_dirs = self.unzip_data() + data_dirs = self.unzip_data(uri) transform, aug_transform = self.get_data_transforms() diff --git a/rastervision2/pytorch_learner/regression_learner.py b/rastervision2/pytorch_learner/regression_learner.py index afae1ff10..94cee9ae9 100644 --- a/rastervision2/pytorch_learner/regression_learner.py +++ b/rastervision2/pytorch_learner/regression_learner.py @@ -83,9 +83,9 @@ def build_model(self): pos_out_inds=pos_out_inds) return model - def get_datasets(self): + def _get_datasets(self, uri): cfg = self.cfg - data_dirs = self.unzip_data() + data_dirs = self.unzip_data(uri) transform, aug_transform = self.get_data_transforms() train_ds, valid_ds, test_ds = [], [], [] diff --git a/rastervision2/pytorch_learner/semantic_segmentation_learner.py b/rastervision2/pytorch_learner/semantic_segmentation_learner.py index 59a454977..3ea395c8c 100644 --- a/rastervision2/pytorch_learner/semantic_segmentation_learner.py +++ b/rastervision2/pytorch_learner/semantic_segmentation_learner.py @@ -60,10 +60,10 @@ def build_model(self): pretrained_backbone=pretrained) return model - def get_datasets(self): + def _get_datasets(self, uri): cfg = self.cfg - data_dirs = self.unzip_data() + data_dirs = self.unzip_data(uri) transform, aug_transform = self.get_data_transforms() train_ds, valid_ds, test_ds = [], [], []