Skip to content

Commit

Permalink
Merge pull request #922 from azavea/lf/train-limit
Browse files Browse the repository at this point in the history
Support groups of chips and limits on training size
  • Loading branch information
lewfish committed May 27, 2020
2 parents c617c49 + 61e8f29 commit ae3d68a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 19 deletions.
4 changes: 2 additions & 2 deletions rastervision2/pytorch_learner/classification_learner.py
Expand Up @@ -29,12 +29,12 @@ def build_model(self):
model.fc = nn.Linear(in_features, num_labels)
return model

def get_datasets(self):
def _get_datasets(self, uri):
cfg = self.cfg
class_names = cfg.data.class_names

if cfg.data.data_format == ClassificationDataFormat.image_folder:
data_dirs = self.unzip_data()
data_dirs = self.unzip_data(uri)

transform, aug_transform = self.get_data_transforms()

Expand Down
66 changes: 55 additions & 11 deletions rastervision2/pytorch_learner/learner.py
Expand Up @@ -12,7 +12,9 @@
from subprocess import Popen
import numbers
import zipfile
from typing import Optional, List, Tuple, Dict
from typing import Optional, List, Tuple, Dict, Union
import random
import uuid

import click
import matplotlib
Expand All @@ -25,7 +27,7 @@
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR, MultiStepLR, _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset, Dataset
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
from albumentations.augmentations.transforms import (
Blur, RandomRotate90, HorizontalFlip, VerticalFlip, GaussianBlur,
GaussNoise, RGBShift, ToGray, Resize)
Expand Down Expand Up @@ -195,23 +197,27 @@ def build_model(self) -> nn.Module:
"""Build a PyTorch model."""
pass

def unzip_data(self) -> List[str]:
def unzip_data(self, uri: Union[str, List[str]]) -> List[str]:
"""Unzip dataset zip files.
Args:
uri: a list of URIs of zip files or the URI of a directory containing
zip files
Returns:
paths to directories that each contain contents of one zip file
"""
cfg = self.cfg
data_dirs = []

if isinstance(cfg.data.uri, list):
zip_uris = cfg.data.uri
if isinstance(uri, list):
zip_uris = uri
else:
if cfg.data.uri.startswith('s3://') or cfg.data.uri.startswith(
'/'):
data_uri = cfg.data.uri
# TODO generalize this to work with any file system
if uri.startswith('s3://') or uri.startswith('/'):
data_uri = uri
else:
data_uri = join(cfg.base_uri, cfg.data.uri)
data_uri = join(cfg.base_uri, uri)
zip_uris = ([data_uri]
if data_uri.endswith('.zip') else list_paths(
data_uri, 'zip'))
Expand All @@ -221,7 +227,7 @@ def unzip_data(self) -> List[str]:
if not isfile(zip_path):
zip_path = download_if_needed(zip_uri, self.data_cache_dir)
with zipfile.ZipFile(zip_path, 'r') as zipf:
data_dir = join(self.tmp_dir, 'data', str(zip_ind))
data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()), str(zip_ind))
data_dirs.append(data_dir)
zipf.extractall(data_dir)

Expand Down Expand Up @@ -277,9 +283,41 @@ def get_collate_fn(self) -> Optional[callable]:
"""
return None

def _get_datasets(self, uri: Union[str, List[str]]) -> Tuple[Dataset, Dataset, Dataset]: # noqa
"""Gets Datasets for a single group of chips.
This should be overridden for each Learner subclass.
Args:
uri: a list of URIs of zip files or the URI of a directory containing
zip files
Returns:
train, validation, and test DataSets."""
raise NotImplementedError()

def get_datasets(self) -> Tuple[Dataset, Dataset, Dataset]:
"""Returns train, validation, and test DataSets."""
raise NotImplementedError()
if self.cfg.data.group_uris:
train_ds_lst, valid_ds_lst, test_ds_lst = [], [], []
for group_uri in self.cfg.data.group_uris:
train_ds, valid_ds, test_ds = self._get_datasets(group_uri)
group_train_sz = self.cfg.data.group_train_sz
if group_train_sz is not None:
train_inds = list(range(len(train_ds)))
random.shuffle(train_inds)
train_inds = train_inds[0:group_train_sz]
train_ds = Subset(train_ds, train_inds)
train_ds_lst.append(train_ds)
valid_ds_lst.append(valid_ds)
test_ds_lst.append(test_ds)

train_ds, valid_ds, test_ds = (
ConcatDataset(train_ds_lst), ConcatDataset(valid_ds_lst),
ConcatDataset(test_ds_lst))
return train_ds, valid_ds, test_ds
else:
return self._get_datasets(self.cfg.data.uri)

def setup_data(self):
"""Set the the DataSet and DataLoaders for train, validation, and test sets."""
Expand Down Expand Up @@ -307,6 +345,12 @@ def setup_data(self):
valid_ds = Subset(valid_ds, range(batch_sz))
test_ds = Subset(test_ds, range(batch_sz))

if cfg.data.train_sz is not None:
train_inds = list(range(len(train_ds)))
random.shuffle(train_inds)
train_inds = train_inds[0:cfg.data.train_sz]
train_ds = Subset(train_ds, train_inds)

collate_fn = self.get_collate_fn()
train_dl = DataLoader(
train_ds,
Expand Down
12 changes: 12 additions & 0 deletions rastervision2/pytorch_learner/learner_config.py
Expand Up @@ -135,6 +135,18 @@ class DataConfig(Config):
description=
('URI of the dataset. This can be a zip file, a list of zip files, or a '
'directory which contains a set of zip files.'))
train_sz: Optional[int] = Field(
None, description=(
'If set, the number of training images to use. If fewer images exist, '
'then an exception will be raised.'))
group_uris: Union[None, List[Union[str, List[str]]]] = Field(None, description=(
'This can be set instead of uri in order to specify groups of chips. Each '
'element in the list is expected to be an object of the same form accepted by '
'the uri field. The purpose of separating chips into groups is to be able to '
'use the group_train_sz field.'))
group_train_sz: Optional[int] = Field(None, description=(
'If group_uris is set, this can be used to specify the number of chips to use '
'per group.'))
data_format: Optional[str] = Field(
None, description='Name of dataset format.')
class_names: List[str] = Field([], description='Names of classes.')
Expand Down
4 changes: 2 additions & 2 deletions rastervision2/pytorch_learner/object_detection_learner.py
Expand Up @@ -40,11 +40,11 @@ def get_bbox_params(self):
def get_collate_fn(self):
return collate_fn

def get_datasets(self):
def _get_datasets(self, uri):
cfg = self.cfg

if cfg.data.data_format == ObjectDetectionDataFormat.default:
data_dirs = self.unzip_data()
data_dirs = self.unzip_data(uri)

transform, aug_transform = self.get_data_transforms()

Expand Down
4 changes: 2 additions & 2 deletions rastervision2/pytorch_learner/regression_learner.py
Expand Up @@ -83,9 +83,9 @@ def build_model(self):
pos_out_inds=pos_out_inds)
return model

def get_datasets(self):
def _get_datasets(self, uri):
cfg = self.cfg
data_dirs = self.unzip_data()
data_dirs = self.unzip_data(uri)
transform, aug_transform = self.get_data_transforms()

train_ds, valid_ds, test_ds = [], [], []
Expand Down
Expand Up @@ -60,10 +60,10 @@ def build_model(self):
pretrained_backbone=pretrained)
return model

def get_datasets(self):
def _get_datasets(self, uri):
cfg = self.cfg

data_dirs = self.unzip_data()
data_dirs = self.unzip_data(uri)
transform, aug_transform = self.get_data_transforms()

train_ds, valid_ds, test_ds = [], [], []
Expand Down

0 comments on commit ae3d68a

Please sign in to comment.