Skip to content

Commit

Permalink
Redo of PR #1427
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean C. Gillies committed Sep 13, 2018
1 parent 41bbc36 commit 8821bc3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 175 deletions.
7 changes: 7 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changes
=======

Next
----

Refactoring:

- Use of InMemoryRaster eliminates redundant code in the _warp module (#1427,
#816).

1.0.3 (2018-08-01)
------------------
Expand Down
2 changes: 1 addition & 1 deletion rasterio/_io.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef class WarpedVRTReaderBase(DatasetReaderBase):
cdef class InMemoryRaster:
cdef GDALDatasetH _hds
cdef double gdal_transform[6]
cdef int band_ids[1]
cdef int* band_ids
cdef np.ndarray _image
cdef object crs
cdef object transform # this is an Affine object.
Expand Down
98 changes: 60 additions & 38 deletions rasterio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,6 @@ cdef class InMemoryRaster:
This class is only intended for internal use within rasterio to support
IO with GDAL. Other memory based operations should use numpy arrays.
"""

def __cinit__(self, image=None, dtype='uint8', count=1, width=None,
height=None, transform=None, gcps=None, crs=None):
"""
Expand All @@ -1634,19 +1633,16 @@ cdef class InMemoryRaster:
(see rasterio.dtypes.dtype_rev)
:param transform: Affine transform object
"""

self._image = image

cdef int i = 0 # avoids Cython warning in for loop below
cdef const char *srcwkt = NULL
cdef OGRSpatialReferenceH osr = NULL
cdef GDALDriverH mdriver = NULL
cdef GDAL_GCP *gcplist = NULL

if image is not None:
if len(image.shape) == 3:
if image.ndim == 3:
count, height, width = image.shape
elif len(image.shape) == 2:
elif image.ndim == 2:
count = 1
height, width = image.shape
dtype = image.dtype.name
Expand All @@ -1657,9 +1653,18 @@ cdef class InMemoryRaster:
if width is None or width == 0:
raise ValueError("width must be > 0")

self.band_ids[0] = 1
self.band_ids = <int *>CPLMalloc(count*sizeof(int))
for i in range(1, count + 1):
self.band_ids[i-1] = i

try:
memdriver = exc_wrap_pointer(GDALGetDriverByName("MEM"))
except Exception:
raise DriverRegistrationError(
"'MEM' driver not found. Check that this call is contained "
"in a `with rasterio.Env()` or `with rasterio.open()` "
"block.")

memdriver = exc_wrap_pointer(GDALGetDriverByName("MEM"))
datasetname = str(uuid.uuid4()).encode('utf-8')
self._hds = exc_wrap_pointer(
GDALCreate(memdriver, <const char *>datasetname, width, height,
Expand All @@ -1670,46 +1675,52 @@ cdef class InMemoryRaster:
gdal_transform = transform.to_gdal()
for i in range(6):
self.gdal_transform[i] = gdal_transform[i]
err = GDALSetGeoTransform(self._hds, self.gdal_transform)
if err:
raise ValueError("transform not set: %s" % transform)

exc_wrap_int(GDALSetGeoTransform(self._hds, self.gdal_transform))
if crs:
osr = _osr_from_crs(crs)
OSRExportToWkt(osr, <char**>&srcwkt)
GDALSetProjection(self._hds, srcwkt)
log.debug("Set CRS on temp source dataset: %s", srcwkt)
CPLFree(<void *>srcwkt)
_safe_osr_release(osr)
try:
OSRExportToWkt(osr, &srcwkt)
exc_wrap_int(GDALSetProjection(self._hds, srcwkt))
log.debug("Set CRS on temp dataset: %s", srcwkt)
finally:
CPLFree(srcwkt)
_safe_osr_release(osr)

elif gcps and crs:
gcplist = <GDAL_GCP *>CPLMalloc(len(gcps) * sizeof(GDAL_GCP))
for i, obj in enumerate(gcps):
ident = str(i).encode('utf-8')
info = "".encode('utf-8')
gcplist[i].pszId = ident
gcplist[i].pszInfo = info
gcplist[i].dfGCPPixel = obj.col
gcplist[i].dfGCPLine = obj.row
gcplist[i].dfGCPX = obj.x
gcplist[i].dfGCPY = obj.y
gcplist[i].dfGCPZ = obj.z or 0.0
try:
gcplist = <GDAL_GCP *>CPLMalloc(len(gcps) * sizeof(GDAL_GCP))
for i, obj in enumerate(gcps):
ident = str(i).encode('utf-8')
info = "".encode('utf-8')
gcplist[i].pszId = ident
gcplist[i].pszInfo = info
gcplist[i].dfGCPPixel = obj.col
gcplist[i].dfGCPLine = obj.row
gcplist[i].dfGCPX = obj.x
gcplist[i].dfGCPY = obj.y
gcplist[i].dfGCPZ = obj.z or 0.0

osr = _osr_from_crs(crs)
OSRExportToWkt(osr, <char**>&srcwkt)
GDALSetGCPs(self._hds, len(gcps), gcplist, srcwkt)
CPLFree(gcplist)
CPLFree(<void *>srcwkt)
osr = _osr_from_crs(crs)
OSRExportToWkt(osr, &srcwkt)
exc_wrap_int(GDALSetGCPs(self._hds, len(gcps), gcplist, srcwkt))
finally:
CPLFree(gcplist)
CPLFree(srcwkt)
_safe_osr_release(osr)

if self._image is not None:
self.write(self._image)
self._image = None
if image is not None:
self.write(image)

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
self.close()

def __dealloc__(self):
CPLFree(self.band_ids)

cdef GDALDatasetH handle(self) except NULL:
"""Return the object's GDAL dataset handle"""
return self._hds
Expand All @@ -1735,11 +1746,22 @@ cdef class InMemoryRaster:
self._hds = NULL

def read(self):
io_auto(self._image, self.band(1), False)
if self._image is None:
raise IOError("You need to write data before you can read the data.")

if self._image.ndim == 2:
exc_wrap_int(io_auto(self._image, self.band(1), False))
else:
exc_wrap_int(io_auto(self._image, self._hds, False))
return self._image

def write(self, image):
io_auto(image, self.band(1), True)
def write(self, np.ndarray image):
self._image = image
if image.ndim == 2:
exc_wrap_int(io_auto(self._image, self.band(1), True))
else:
exc_wrap_int(io_auto(self._image, self._hds, True))



cdef class BufferedDatasetWriterBase(DatasetWriterBase):
Expand Down

0 comments on commit 8821bc3

Please sign in to comment.