Skip to content

Commit

Permalink
Fixed #31766 -- Avoided uncessary computation in GDALRaster
Browse files Browse the repository at this point in the history
Implemented a clone method using ds_copy to clone GDALRaster objects. Added conditional to simply clone GDALRaster if SRID matches with transform srs argument.
  • Loading branch information
bartondc authored and felixxm committed Sep 8, 2020
1 parent 5ea1621 commit c239491
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -110,6 +110,7 @@ answer newbie questions, and generally made Django that much better:
Baptiste Mispelon <bmispelon@gmail.com>
Barry Pederson <bp@barryp.org>
Bartolome Sanchez Salado <i42sasab@uco.es>
Barton Ip <notbartonip@gmail.com>
Bartosz Grabski <bartosz.grabski@gmail.com>
Bashar Al-Abdulhadi
Bastian Kleineidam <calvin@debian.org>
Expand Down
18 changes: 18 additions & 0 deletions django/contrib/gis/gdal/raster/source.py
Expand Up @@ -425,6 +425,21 @@ def warp(self, ds_input, resampling='NearestNeighbour', max_error=0.0):

return target

def clone(self):
"""Return a clone of this GDALRaster."""
return GDALRaster(
capi.copy_ds(
self.driver._ptr,
force_bytes(self.name),
self._ptr,
c_int(),
c_char_p(),
c_void_p(),
c_void_p(),
),
write=self._write,
)

def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour',
max_error=0.0):
"""
Expand All @@ -443,6 +458,9 @@ def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour',
'Transform only accepts SpatialReference, string, and integer '
'objects.'
)

if target_srs.srid == self.srid:
return self.clone()
# Create warped virtual dataset in the target reference system
target = capi.auto_create_warped_vrt(
self._ptr, self.srs.wkt.encode(), target_srs.wkt.encode(),
Expand Down
53 changes: 53 additions & 0 deletions tests/gis_tests/gdal_tests/test_raster.py
Expand Up @@ -2,6 +2,7 @@
import shutil
import struct
import tempfile
from unittest import mock

from django.contrib.gis.gdal import GDAL_VERSION, GDALRaster, SpatialReference
from django.contrib.gis.gdal.error import GDALException
Expand Down Expand Up @@ -470,6 +471,33 @@ def test_raster_warp_nodata_zone(self):
# The result is an empty raster filled with the correct nodata value.
self.assertEqual(result, [23] * 16)

def test_raster_clone(self):
# Create in memory raster.
source = GDALRaster({
'datatype': 1,
'driver': 'MEM',
'width': 4,
'height': 4,
'srid': 3086,
'origin': (500000, 400000),
'scale': (100, -100),
'skew': (0, 0),
'bands': [{
'data': range(16),
'nodata_value': 23,
}],
})
clone = source.clone()
self.assertEqual(clone.name, source.name)
self.assertEqual(clone._write, source._write)
self.assertEqual(clone.srs.srid, source.srs.srid)
self.assertEqual(clone.width, source.width)
self.assertEqual(clone.height, source.height)
self.assertEqual(clone.origin, source.origin)
self.assertEqual(clone.scale, source.scale)
self.assertEqual(clone.skew, source.skew)
self.assertIsNot(clone, source)

def test_raster_transform(self):
tests = [
3086,
Expand Down Expand Up @@ -531,6 +559,31 @@ def test_raster_transform(self):
],
)

def test_raster_transform_clone(self):
with mock.patch.object(GDALRaster, 'clone') as mocked_clone:
# Create in file based raster.
rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
ndv = 99
source = GDALRaster({
'datatype': 1,
'driver': 'tif',
'name': rstfile.name,
'width': 5,
'height': 5,
'nr_of_bands': 1,
'srid': 4326,
'origin': (-5, 5),
'scale': (2, -2),
'skew': (0, 0),
'bands': [{
'data': range(25),
'nodata_value': ndv,
}],
})
# transform() returns a clone because it is the same SRID.
target = source.transform(4326)
self.assertEqual(mocked_clone.call_count, 1)


class GDALBandTests(SimpleTestCase):
rs_path = os.path.join(os.path.dirname(__file__), '../data/rasters/raster.tif')
Expand Down

0 comments on commit c239491

Please sign in to comment.