Skip to content

Commit

Permalink
Implement activation to avoid keeping files open.
Browse files Browse the repository at this point in the history
  • Loading branch information
lossyrob committed Oct 26, 2018
1 parent f5badc3 commit 35ac8b3
Show file tree
Hide file tree
Showing 20 changed files with 427 additions and 172 deletions.
26 changes: 12 additions & 14 deletions rastervision/core/raster_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@ def compute(self, raster_sources):

def chip_stream(channel):
for raster_source in raster_sources:
windows = raster_source.get_extent().get_windows(
chip_size, stride)
for window in windows:
chip = raster_source.get_raw_chip(window).astype(
np.float32)
chip = chip[:, :, channel].ravel()
# Ignore NODATA values.
chip[chip == 0.0] = np.nan
yield chip

# Sniff the number of channels.
window = raster_sources[0].get_extent().get_windows(chip_size,
stride)[0]
nb_channels = raster_sources[0].get_raw_chip(window).shape[2]
with raster_source.activate():
windows = raster_source.get_extent().get_windows(
chip_size, stride)
for window in windows:
chip = raster_source.get_raw_chip(window).astype(
np.float32)
chip = chip[:, :, channel].ravel()
# Ignore NODATA values.
chip[chip == 0.0] = np.nan
yield chip

nb_channels = len(raster_sources[0].channel_order)

self.means = []
self.stds = []
Expand Down
1 change: 1 addition & 0 deletions rastervision/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa

from rastervision.data.activate_mixin import *
from rastervision.data.raster_transformer import *
from rastervision.data.raster_source import *
from rastervision.data.crs_transformer import *
Expand Down
98 changes: 98 additions & 0 deletions rastervision/data/activate_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from abc import abstractmethod


class ActivationError(Exception):
pass


class ActivateMixin:
"""Defines a mixin for data that can activate and deactivate.
These methods can open and close files, download files, and do
whatever has to be done to make the entity usable, and cleanup
after the entity is not needed anymore.
"""

class ActivateContextManager:
def __init__(self, activate, deactivate):
self.activate = activate
self.deactivate = deactivate

def __enter__(self):
self.activate()
return self

def __exit__(self, type, value, traceback):
self.deactivate()

@classmethod
def dummy(cls):
def noop():
pass

return cls(noop, noop)

class CompositContextManager:
def __init__(self, *managers):
self.managers = managers

def __enter__(self):
for manager in self.managers:
manager.__enter__()

def __exit__(self, type, value, traceback):
for manager in self.managers:
manager.__exit__(type, value, traceback)

def activate(self):
if hasattr(self, '_mixin_activated'):
if self._mixin_activated:
raise ActivationError('This {} is already activated'.format(
type(self)))

def do_activate():
self._mixin_activated = True
self._activate()

def do_deactivate():
self._deactivate()
self._mixin_activated = False

a = ActivateMixin.ActivateContextManager(do_activate, do_deactivate)
subcomponents = self._subcomponents_to_activate()
if subcomponents:
return ActivateMixin.CompositContextManager(
a, ActivateMixin.compose(*subcomponents))
else:
return a

@abstractmethod
def _activate(self):
pass

@abstractmethod
def _deactivate(self):
pass

def _subcomponents_to_activate(self):
"""Subclasses override this if they have subcomponents
that may need to be activated when this class is activated
"""
return []

@staticmethod
def with_activation(obj):
"""Method will give activate an object if it mixes in the ActivateMixin and
return the context manager, or else return a dummy context manager.
"""
if obj is None or not isinstance(obj, ActivateMixin):
return ActivateMixin.dummy()
else:
return obj.activate()

@staticmethod
def compose(*objs):
managers = [
obj.activate() for obj in objs
if obj is not None and isinstance(obj, ActivateMixin)
]
return ActivateMixin.CompositContextManager(*managers)
24 changes: 17 additions & 7 deletions rastervision/data/crs_transformer/rasterio_crs_transformer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import pyproj

from rastervision.data.crs_transformer import CRSTransformer
from rasterio.transform import (rowcol, xy)

from rastervision.data.crs_transformer import (CRSTransformer,
IdentityCRSTransformer)


class RasterioCRSTransformer(CRSTransformer):
"""Transformer for a RasterioRasterSource."""

def __init__(self, image_dataset, map_crs='epsg:4326'):
def __init__(self, transform, image_crs, map_crs='epsg:4326'):
"""Construct transformer.
Args:
image_dataset: Rasterio DatasetReader
map_crs: CRS code
"""
self.image_dataset = image_dataset
self.transform = transform
self.map_proj = pyproj.Proj(init=map_crs)
image_crs = image_dataset.crs['init']
self.image_proj = pyproj.Proj(init=image_crs)

super().__init__(image_crs, map_crs)
Expand All @@ -31,7 +33,7 @@ def map_to_pixel(self, map_point):
"""
image_point = pyproj.transform(self.map_proj, self.image_proj,
map_point[0], map_point[1])
pixel_point = self.image_dataset.index(image_point[0], image_point[1])
pixel_point = rowcol(self.transform, image_point[0], image_point[1])
pixel_point = (pixel_point[1], pixel_point[0])
return pixel_point

Expand All @@ -44,8 +46,16 @@ def pixel_to_map(self, pixel_point):
Returns:
(x, y) tuple in map coordinates
"""
image_point = self.image_dataset.xy(
int(pixel_point[1]), int(pixel_point[0]))
image_point = xy(self.transform, int(pixel_point[1]),
int(pixel_point[0]))
map_point = pyproj.transform(self.image_proj, self.map_proj,
image_point[0], image_point[1])
return map_point

@classmethod
def from_dataset(cls, dataset, map_crs='epsg:4326'):
if dataset.crs is None:
return IdentityCRSTransformer()
transform = dataset.transform
image_crs = dataset.crs['init']
return cls(transform, image_crs, map_crs)
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@

from rastervision.core.box import Box
from rastervision.core.class_map import ClassMap
from rastervision.data import ActivateMixin
from rastervision.data.label import SemanticSegmentationLabels
from rastervision.data.label_source import LabelSource, SegmentationClassTransformer
from rastervision.data.raster_source import RasterSource


class SemanticSegmentationRasterSource(LabelSource):
class SemanticSegmentationRasterSource(ActivateMixin, LabelSource):
"""A read-only label source for segmentation raster files.
"""

def __init__(self, source: RasterSource, rgb_class_map: ClassMap = None):
"""Constructor.
Args:
source: (RasterSource) assumed to have RGB values that are mapped to
class_ids using the rgb_class_map
source: (RasterSource) A raster source that returns a single channel
raster with class_ids as values, or a 3 channel raster with
RGB values that are mapped to class_ids using the rgb_class_map
rgb_class_map: (ClassMap) with color values filled in. Optional and used to
transform RGB values to class ids.
transform RGB values to class ids. Only use if the raster source
is RGB.
"""
self.source = source
self.class_transformer = None
Expand Down Expand Up @@ -77,3 +80,12 @@ def get_labels(self, window: Union[Box, None] = None) -> np.ndarray:
labels = np.squeeze(raw_labels)

return SemanticSegmentationLabels.from_array(labels)

def _subcomponents_to_activate(self):
return [self.source]

def _activate(self):
pass

def _deactivate(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def get_labels(self):
.with_uri(self.uri) \
.build() \
.create_source(self.tmp_dir)
raw_labels = source.get_raw_image_array()
if self.class_trans:
labels = self.class_trans.rgb_to_class(raw_labels)
else:
labels = np.squeeze(raw_labels)
return SemanticSegmentationLabels.from_array(labels)
with source.activate():
raw_labels = source.get_raw_image_array()
if self.class_trans:
labels = self.class_trans.rgb_to_class(raw_labels)
else:
labels = np.squeeze(raw_labels)
return SemanticSegmentationLabels.from_array(labels)

def save(self, labels):
"""Save.
Expand All @@ -60,7 +61,7 @@ def save(self, labels):

# TODO: this only works if crs_transformer is RasterioCRSTransformer.
# Need more general way of computing transform for the more general case.
transform = self.crs_transformer.image_dataset.transform
transform = self.crs_transformer.transform
crs = self.crs_transformer.get_image_crs()
clipped_labels = labels.get_clipped_labels(self.extent)

Expand Down
26 changes: 19 additions & 7 deletions rastervision/data/raster_source/geojson_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import shapely

from rastervision.data import (ActivateMixin, ActivationError)
from rastervision.data.raster_source import RasterSource
from rastervision.utils.files import file_to_str
from rastervision.data.utils import geojson_to_shapes
Expand Down Expand Up @@ -32,7 +33,7 @@ def geojson_to_raster(geojson, rasterizer_options, extent, crs_transformer):
return raster


class GeoJSONSource(RasterSource):
class GeoJSONSource(ActivateMixin, RasterSource):
"""A RasterSource based on the rasterization of a GeoJSON file."""

def __init__(self, uri, rasterizer_options, extent, crs_transformer):
Expand All @@ -49,12 +50,9 @@ def __init__(self, uri, rasterizer_options, extent, crs_transformer):
self.rasterizer_options = rasterizer_options
self.extent = extent
self.crs_transformer = crs_transformer
geojson = json.loads(file_to_str(self.uri))
self.raster = geojson_to_raster(geojson, rasterizer_options, extent,
crs_transformer)
# Add third singleton dim since rasters must have >=1 channel.
self.raster = np.expand_dims(self.raster, 2)
super().__init__()
self.activated = False

super().__init__(channel_order=[0])

def get_extent(self):
"""Return the extent of the RasterSource.
Expand All @@ -81,4 +79,18 @@ def _get_chip(self, window):
Returns:
[height, width, channels] numpy array
"""
if not self.activated:
raise ActivationError('GeoJSONSource must be activated before use')
return self.raster[window.ymin:window.ymax, window.xmin:window.xmax, :]

def _activate(self):
geojson = json.loads(file_to_str(self.uri))
self.raster = geojson_to_raster(geojson, self.rasterizer_options,
self.extent, self.crs_transformer)
# Add third singleton dim since rasters must have >=1 channel.
self.raster = np.expand_dims(self.raster, 2)
self.activated = True

def _deactivate(self):
self.raster = None
self.activated = False
16 changes: 7 additions & 9 deletions rastervision/data/raster_source/geotiff_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import subprocess
import os
import rasterio
import logging

from rastervision.data.raster_source.rasterio_source \
Expand All @@ -19,7 +18,7 @@ def build_vrt(vrt_path, image_paths):


def download_and_build_vrt(image_uris, temp_dir):
log.info('Downloading and building VRT...')
log.info('Building VRT...')
image_paths = [download_if_needed(uri, temp_dir) for uri in image_uris]
image_path = os.path.join(temp_dir, 'index.vrt')
build_vrt(image_path, image_paths)
Expand All @@ -32,13 +31,12 @@ def __init__(self, uris, raster_transformers, temp_dir,
self.uris = uris
super().__init__(raster_transformers, temp_dir, channel_order)

def build_image_dataset(self, temp_dir):
log.info('Loading GeoTiff files...')
def _download_data(self, temp_dir):
if len(self.uris) == 1:
imagery_path = download_if_needed(self.uris[0], temp_dir)
return download_if_needed(self.uris[0], temp_dir)
else:
imagery_path = download_and_build_vrt(self.uris, temp_dir)
return rasterio.open(imagery_path)
return download_and_build_vrt(self.uris, temp_dir)

def get_crs_transformer(self):
return RasterioCRSTransformer(self.image_dataset)
def _set_crs_transformer(self):
self.crs_transformer = RasterioCRSTransformer.from_dataset(
self.image_dataset)
11 changes: 4 additions & 7 deletions rastervision/data/raster_source/image_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import rasterio

from rastervision.data.raster_source.rasterio_source import (
RasterioRasterSource)
from rastervision.data.crs_transformer.identity_crs_transformer import (
Expand All @@ -12,9 +10,8 @@ def __init__(self, uri, raster_transformers, temp_dir, channel_order=None):
self.uri = uri
super().__init__(raster_transformers, temp_dir, channel_order)

def build_image_dataset(self, temp_dir):
imagery_path = download_if_needed(self.uri, self.temp_dir)
return rasterio.open(imagery_path)
def _download_data(self, temp_dir):
return download_if_needed(self.uri, self.temp_dir)

def get_crs_transformer(self):
return IdentityCRSTransformer()
def _set_crs_transformer(self):
self.crs_transformer = IdentityCRSTransformer()
4 changes: 1 addition & 3 deletions rastervision/data/raster_source/raster_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ class RasterSource(ABC):
a set of files, an API, a TMS URI schema, etc.
"""

def __init__(self, raster_transformers=[], channel_order=None):
def __init__(self, channel_order, raster_transformers=[]):
"""Construct a new RasterSource.
Args:
raster_transformers: RasterTransformers used to transform chips
whenever they are retrieved.
channel_order: numpy array of length n where n is the number of
channels to use and the values are channel indices.
Default: None, which will take all the raster's bands as is.
"""
self.raster_transformers = raster_transformers
self.channel_order = channel_order
Expand Down

0 comments on commit 35ac8b3

Please sign in to comment.