Skip to content

Commit

Permalink
refactor DatasetSource more
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lat1le committed Nov 18, 2016
1 parent 866b809 commit 5329ba8
Showing 1 changed file with 123 additions and 69 deletions.
192 changes: 123 additions & 69 deletions datacube/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
}

assert str(rasterio.__version__) >= '0.34.0', "rasterio version 0.34.0 or higher is required"
GDAL_NETCDF_TIME = ('NETCDF_DIM_'
if str(rasterio.__gdal_version__) >= '1.10.0' else
'NETCDF_DIMENSION_') + 'time'
GDAL_NETCDF_DIM = ('NETCDF_DIM_'
if str(rasterio.__gdal_version__) >= '1.10.0' else
'NETCDF_DIMENSION_')


def _rasterio_resampling_method(resampling):
Expand Down Expand Up @@ -190,7 +190,7 @@ def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwar


class OverrideBandDataSource(object):
def __init__(self, source, nodata=None, crs=None, transform=None):
def __init__(self, source, nodata, crs, transform):
self.source = source
self.nodata = nodata
self.crs = crs
Expand Down Expand Up @@ -221,111 +221,165 @@ def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwar
**kwargs)


class DatasetSource(object):
"""model.Dataset is about metadata, this takes a dataset and knows how to get the actual data bytes."""
def __init__(self, dataset, measurement_id):
"""
class BaseRasterDataSource(object):
"""
Interface used by fuse_sources and reproject
"""
def __init__(self, filename, nodata):
self.filename = filename
self.nodata = nodata

:type dataset: datacube.model.Dataset
:param measurement_id:
"""
self._dataset = dataset
self._bandinfo = dataset.type.measurements[measurement_id]
self._descriptor = dataset.measurements[measurement_id]
self.dataset_crs = dataset.crs
self.format = dataset.format
self.time = dataset.center_time
self.local_uri = dataset.local_uri
def get_bandnumber(self, src):
raise NotImplementedError()

def get_transform(self, shape):
raise NotImplementedError()

def get_crs(self):
raise NotImplementedError()

@contextmanager
def open(self):
filename, bandnumber = self.wheres_my_data()
try:
_LOG.debug("opening %s, band %s", filename, bandnumber)
with rasterio.open(filename) as src:
_LOG.debug("opening %s", self.filename)
with rasterio.open(self.filename) as src:
override = False

if bandnumber is None:
if 'netcdf' in self.format.lower():
bandnumber = self.wheres_my_band(src, self.time)
else:
bandnumber = 1

transform = _rasterio_transform(src)
if transform.is_identity:
_LOG.warning('No GeoTransform in %s, band %s. Falling back to dataset GeoTransform.')
override = True
transform = self.whats_my_transform(src)
transform = self.get_transform(src.shape)

try:
crs = CRS(_rasterio_crs_wkt(src))
except ValueError:
_LOG.warning('No CRS in %s, band %s. Falling back to dataset CRS.')
override = True # HACK: mmmm... side effects! See reproject above
crs = self.dataset_crs
override = True
crs = self.get_crs()

bandnumber = self.get_bandnumber(src)
band = rasterio.band(src, bandnumber)
nodata = numpy.dtype(band.dtype).type(src.nodatavals[0] if src.nodatavals[0] is not None else
self._bandinfo.get('nodata'))
nodata = numpy.dtype(band.dtype).type(src.nodatavals[0] if src.nodatavals[0] is not None
else self.nodata)

if override:
yield OverrideBandDataSource(band, nodata=nodata, crs=crs, transform=transform)
else:
yield BandDataSource(band, nodata=nodata)

except Exception as e:
_LOG.error("Error opening source dataset: %s", filename)
_LOG.error("Error opening source dataset: %s", self.filename)
raise e

def wheres_my_data(self):
if self._descriptor['path']:
if is_url(self._descriptor['path']):
url_str = self._descriptor['path']
elif Path(self._descriptor['path']).is_absolute():
url_str = Path(self._descriptor['path']).as_uri()
else:
url_str = urljoin(self.local_uri, self._descriptor['path'])

class BasicRasterDataSource(BaseRasterDataSource):
def __init__(self, filename, bandnumber, nodata=None, crs=None, transform=None):
super(BasicRasterDataSource, self).__init__(filename, nodata)
self.bandnumber = bandnumber
self.crs = crs
self.transform = transform

def get_bandnumber(self, src):
return self.bandnumber

def get_transform(self, shape):
if self.transform is None:
raise RuntimeError('No transform in the data and no fallback')
return self.transform

def get_crs(self):
if self.crs is None:
raise RuntimeError('No CRS in the data and no fallback')
return self.crs


def _resolve_url(base_url, path):
"""
If path is a URL or an absolute path return URL
If path is a relative path return base_url joined with path

>>> _resolve_url('file:///foo/abc', 'bar')
'file:///foo/bar'
>>> _resolve_url('file:///foo/abc', 'file:///bar')
'file:///bar'
>>> _resolve_url('file:///foo/abc', None)
'file:///foo/abc'
>>> _resolve_url('file:///foo/abc', '/bar')
'file:///bar'
"""
if path:
if is_url(path):
url_str = path
elif Path(path).is_absolute():
url_str = Path(path).as_uri()
else:
url_str = self.local_uri
url = urlparse(url_str)
assert url.scheme, "Expecting URL with scheme here"

# if format is NETCDF of HDF need to pass NETCDF:path:band as filename to rasterio/GDAL
for nasty_format in ('netcdf', 'hdf'):
if nasty_format in self.format.lower():
if url.scheme != 'file':
raise RuntimeError("Can't access %s over %s" % (self.format, url.scheme))
filename = '%s:%s:%s' % (self.format, uri_to_local_path(url_str), self._descriptor['layer'])
return filename, None

if url.scheme and url.scheme != 'file':
return url_str, self._descriptor.get('layer', 1)

# if local path strip scheme and other gunk
return str(uri_to_local_path(url_str)), self._descriptor.get('layer', 1)

def wheres_my_band(self, src, time):
if GDAL_NETCDF_TIME not in src.tags(1):
_LOG.warning("NetCDF dataset has no time dimension") # HACK: should support time-less datasets
url_str = urljoin(base_url, path)
else:
url_str = base_url
return url_str


def _url2rasterio(url_str, fmt, layer):
"""
turn URL into a string that could be passed to raterio.open
"""
url = urlparse(url_str)
assert url.scheme, "Expecting URL with scheme here"

# if format is NETCDF of HDF need to pass NETCDF:path:band as filename to rasterio/GDAL
for nasty_format in ('netcdf', 'hdf'):
if nasty_format in fmt.lower():
if url.scheme != 'file':
raise RuntimeError("Can't access %s over %s" % (fmt, url.scheme))
filename = '%s:%s:%s' % (fmt, uri_to_local_path(url_str), layer)
return filename

if url.scheme and url.scheme != 'file':
return url_str

# if local path strip scheme and other gunk
return str(uri_to_local_path(url_str))


class DatasetSource(BaseRasterDataSource):
def __init__(self, dataset, measurement_id):
self._dataset = dataset
self._measurement = dataset.measurements[measurement_id]
url = _resolve_url(dataset.local_uri, self._measurement['path'])
filename = _url2rasterio(url, dataset.format, self._measurement.get('layer'))
nodata = dataset.type.measurements[measurement_id].get('nodata')
super(DatasetSource, self).__init__(filename, nodata=nodata)

def get_bandnumber(self, src):
if 'netcdf' not in self._dataset.format.lower():
return self._measurement.get('layer', 1)

tag_name = GDAL_NETCDF_DIM + 'time'
if tag_name not in src.tags(1): # TODO: support time-less datasets properly
return 1

time = self._dataset.center_time
sec_since_1970 = datetime_to_seconds_since_1970(time)

idx = 0
dist = float('+inf')
for i in range(1, src.count + 1):
v = float(src.tags(i)[GDAL_NETCDF_TIME])
v = float(src.tags(i)[tag_name])
if abs(sec_since_1970 - v) < dist:
idx = i
dist = abs(sec_since_1970 - v)
return idx

def whats_my_transform(self, src):
bounds = self._dataset.metadata.grid_spatial['geo_ref_points']
width = bounds['lr']['x'] - bounds['ul']['x']
height = bounds['lr']['y'] - bounds['ul']['y']
return (Affine.translation(bounds['ul']['x'], bounds['ul']['y']) *
Affine.scale(width / src.shape[1], height / src.shape[0]))
def get_transform(self, shape):
bounds = self._dataset.bounds
width = bounds.right - bounds.left
height = bounds.top - bounds.bottom
return (Affine.translation(bounds.left, bounds.bottom) *
Affine.scale(width / shape[1], height / shape[0]))

def get_crs(self):
return self._dataset.crs


def create_netcdf_storage_unit(filename,
Expand Down

0 comments on commit 5329ba8

Please sign in to comment.