-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for int64, uint64, and int8 #353
Changes from all commits
7478d46
de1eb9b
8729926
9caee65
25e0416
d71b1d9
c745334
dafc108
be1866c
b11c550
ddee08f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
import shapely.prepared | ||
import shapely.wkb | ||
from osgeo import gdal | ||
from osgeo import gdal_array | ||
from osgeo import gdalconst | ||
from osgeo import ogr | ||
from osgeo import osr | ||
|
@@ -40,19 +41,7 @@ | |
if sys.version_info >= (3, 8): | ||
import multiprocessing.shared_memory | ||
|
||
NUMPY_TO_GDAL_TYPE = { | ||
numpy.dtype(bool): gdal.GDT_Byte, | ||
numpy.dtype(numpy.int8): gdal.GDT_Byte, | ||
numpy.dtype(numpy.uint8): gdal.GDT_Byte, | ||
numpy.dtype(numpy.int16): gdal.GDT_Int16, | ||
numpy.dtype(numpy.int32): gdal.GDT_Int32, | ||
numpy.dtype(numpy.uint16): gdal.GDT_UInt16, | ||
numpy.dtype(numpy.uint32): gdal.GDT_UInt32, | ||
numpy.dtype(numpy.float32): gdal.GDT_Float32, | ||
numpy.dtype(numpy.float64): gdal.GDT_Float64, | ||
numpy.dtype(numpy.csingle): gdal.GDT_CFloat32, | ||
numpy.dtype(numpy.complex64): gdal.GDT_CFloat64, | ||
} | ||
GDAL_VERSION = tuple(int(_) for _ in gdal.__version__.split('.')) | ||
|
||
|
||
class ReclassificationMissingValuesError(Exception): | ||
|
@@ -86,18 +75,6 @@ def __init__(self, missing_values, raster_path, value_map): | |
_LOGGING_PERIOD = 5.0 # min 5.0 seconds per update log message for the module | ||
_LARGEST_ITERBLOCK = 2**16 # largest block for iterblocks to read in cells | ||
|
||
_GDAL_TYPE_TO_NUMPY_LOOKUP = { | ||
gdal.GDT_Byte: numpy.uint8, | ||
gdal.GDT_Int16: numpy.int16, | ||
gdal.GDT_Int32: numpy.int32, | ||
gdal.GDT_UInt16: numpy.uint16, | ||
gdal.GDT_UInt32: numpy.uint32, | ||
gdal.GDT_Float32: numpy.float32, | ||
gdal.GDT_Float64: numpy.float64, | ||
gdal.GDT_CFloat32: numpy.csingle, | ||
gdal.GDT_CFloat64: numpy.complex64, | ||
} | ||
|
||
_GDAL_WARP_ALGORITHMS = [] | ||
for _warp_algo in (_attrname for _attrname in dir(gdalconst) | ||
if _attrname.startswith('GRA_')): | ||
|
@@ -772,8 +749,8 @@ def raster_map(op, rasters, target_path, target_nodata=None, target_dtype=None, | |
f'the target dtype {target_dtype}') | ||
|
||
driver, options = raster_driver_creation_tuple | ||
if target_dtype == numpy.int8 and 'PIXELTYPE=SIGNEDBYTE' not in options: | ||
options += ('PIXELTYPE=SIGNEDBYTE',) | ||
gdal_type, type_creation_options = _numpy_to_gdal_type(target_dtype) | ||
options = list(options) + type_creation_options | ||
|
||
def apply_op(*arrays): | ||
"""Apply the function ``op`` to the input arrays. | ||
|
@@ -802,7 +779,7 @@ def apply_op(*arrays): | |
[(path, 1) for path in rasters], # assume the first band | ||
apply_op, | ||
target_path, | ||
NUMPY_TO_GDAL_TYPE[numpy.dtype(target_dtype)], | ||
gdal_type, | ||
target_nodata, | ||
raster_driver_creation_tuple=(driver, options)) | ||
|
||
|
@@ -1221,19 +1198,8 @@ def new_raster_from_base( | |
driver = gdal.GetDriverByName(raster_driver_creation_tuple[0]) | ||
|
||
local_raster_creation_options = list(raster_driver_creation_tuple[1]) | ||
# PIXELTYPE is sometimes used to define signed vs. unsigned bytes and | ||
# the only place that is stored is in the IMAGE_STRUCTURE metadata | ||
# copy it over if it exists and it not already defined by the input | ||
# creation options. It's okay to get this info from the first band since | ||
# all bands have the same datatype | ||
numpy_dtype = _gdal_to_numpy_type(datatype, local_raster_creation_options) | ||
base_band = base_raster.GetRasterBand(1) | ||
metadata = base_band.GetMetadata('IMAGE_STRUCTURE') | ||
if 'PIXELTYPE' in metadata and not any( | ||
['PIXELTYPE' in option for option in | ||
local_raster_creation_options]): | ||
local_raster_creation_options.append( | ||
'PIXELTYPE=' + metadata['PIXELTYPE']) | ||
|
||
block_size = base_band.GetBlockSize() | ||
# It's not clear how or IF we can determine if the output should be | ||
# striped or tiled. Here we leave it up to the default inputs or if its | ||
|
@@ -1283,7 +1249,6 @@ def new_raster_from_base( | |
timed_logger = TimedLoggingAdapter(_LOGGING_PERIOD) | ||
pixels_processed = 0 | ||
n_pixels = n_cols * n_rows | ||
numpy_dtype = _GDAL_TYPE_TO_NUMPY_LOOKUP[datatype] | ||
if fill_value_list is not None: | ||
for index, fill_value in enumerate(fill_value_list): | ||
if fill_value is None: | ||
|
@@ -2167,15 +2132,9 @@ def get_raster_info(raster_path): | |
|
||
# datatype is the same for the whole raster, but is associated with band | ||
band = raster.GetRasterBand(1) | ||
band_datatype = band.DataType | ||
raster_properties['datatype'] = band_datatype | ||
raster_properties['numpy_type'] = ( | ||
_GDAL_TYPE_TO_NUMPY_LOOKUP[band_datatype]) | ||
# this part checks to see if the byte is signed or not | ||
if band_datatype == gdal.GDT_Byte: | ||
metadata = band.GetMetadata('IMAGE_STRUCTURE') | ||
if 'PIXELTYPE' in metadata and metadata['PIXELTYPE'] == 'SIGNEDBYTE': | ||
raster_properties['numpy_type'] = numpy.int8 | ||
raster_properties['datatype'] = band.DataType | ||
raster_properties['numpy_type'] = _gdal_to_numpy_type( | ||
band.DataType, band.GetMetadata('IMAGE_STRUCTURE')) | ||
band = None | ||
raster = None | ||
return raster_properties | ||
|
@@ -2413,11 +2372,14 @@ def reclassify_raster( | |
keys = sorted(numpy.array(list(value_map_copy.keys()))) | ||
values = numpy.array([value_map_copy[x] for x in keys]) | ||
|
||
numpy_dtype = _gdal_to_numpy_type( | ||
target_datatype, raster_driver_creation_tuple[1]) | ||
|
||
def _map_dataset_to_value_op(original_values): | ||
"""Convert a block of original values to the lookup values.""" | ||
out_array = numpy.full( | ||
original_values.shape, target_nodata, | ||
dtype=_GDAL_TYPE_TO_NUMPY_LOOKUP[target_datatype]) | ||
dtype=numpy_dtype) | ||
if nodata is None: | ||
valid_mask = numpy.full(original_values.shape, True) | ||
else: | ||
|
@@ -2619,9 +2581,9 @@ def warp_raster( | |
base_raster = gdal.OpenEx(base_raster_path, gdal.OF_RASTER) | ||
|
||
raster_creation_options = list(raster_driver_creation_tuple[1]) | ||
if (base_raster_info['numpy_type'] == numpy.int8 and | ||
'PIXELTYPE' not in ' '.join(raster_creation_options)): | ||
raster_creation_options.append('PIXELTYPE=SIGNEDBYTE') | ||
_, type_creation_options = _numpy_to_gdal_type( | ||
base_raster_info['numpy_type']) | ||
raster_creation_options += type_creation_options | ||
|
||
if resample_method.lower() not in _GDAL_WARP_ALGORITHMS: | ||
raise ValueError( | ||
|
@@ -3703,40 +3665,54 @@ def mask_op(base_array, mask_array): | |
os.remove(mask_raster_path) | ||
|
||
|
||
def _gdal_to_numpy_type(band): | ||
"""Calculate the equivalent numpy datatype from a GDAL raster band type. | ||
|
||
This function doesn't handle complex or unknown types. If they are | ||
passed in, this function will raise a ValueError. | ||
def _gdal_to_numpy_type(gdal_type, metadata): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the function interface is changing, could you update the docstring here as well? |
||
"""Calculate the equivalent numpy datatype from a GDAL type and metadata. | ||
|
||
Args: | ||
band (gdal.Band): GDAL Band | ||
gdal_type: GDAL.GDT_* data type code | ||
metadata: mapping or list of strings to check for the existence of | ||
the 'PIXELTYPE=SIGNEDBYTE' flag | ||
|
||
Return: | ||
numpy_datatype (numpy.dtype): equivalent of band.DataType | ||
Returns: | ||
numpy.dtype that is the equivalent of the input gdal type | ||
|
||
Raises: | ||
ValueError if an unsupported data type is entered | ||
""" | ||
# doesn't include GDT_Byte because that's a special case | ||
base_gdal_type_to_numpy = { | ||
gdal.GDT_Int16: numpy.int16, | ||
gdal.GDT_Int32: numpy.int32, | ||
gdal.GDT_UInt16: numpy.uint16, | ||
gdal.GDT_UInt32: numpy.uint32, | ||
gdal.GDT_Float32: numpy.float32, | ||
gdal.GDT_Float64: numpy.float64, | ||
} | ||
if (GDAL_VERSION < (3, 7, 0) and gdal_type == gdal.GDT_Byte and | ||
(('PIXELTYPE=SIGNEDBYTE' in metadata) or | ||
('PIXELTYPE' in metadata and metadata['PIXELTYPE'] == 'SIGNEDBYTE'))): | ||
return numpy.int8 | ||
|
||
if band.DataType in base_gdal_type_to_numpy: | ||
return base_gdal_type_to_numpy[band.DataType] | ||
numpy_type = gdal_array.GDALTypeCodeToNumericTypeCode(gdal_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well how about that! This is a nice function to know about. |
||
if numpy_type is None: | ||
raise ValueError(f"Unsupported DataType: {gdal_type}") | ||
return numpy_type | ||
|
||
if band.DataType != gdal.GDT_Byte: | ||
raise ValueError("Unsupported DataType: %s" % str(band.DataType)) | ||
|
||
# band must be GDT_Byte type, check if it is signed/unsigned | ||
metadata = band.GetMetadata('IMAGE_STRUCTURE') | ||
if 'PIXELTYPE' in metadata and metadata['PIXELTYPE'] == 'SIGNEDBYTE': | ||
return numpy.int8 | ||
return numpy.uint8 | ||
def _numpy_to_gdal_type(numpy_type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a docstring here that includes the input and output types? Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a docstring! |
||
"""Calculate the equivalent GDAL type and metadata from a numpy type. | ||
|
||
Args: | ||
numpy_type: numpy data type | ||
|
||
Returns: | ||
(gdal type, metadata) tuple. gdal type is a gdal.GDT_* type code. | ||
metadata is an empty list in most cases, or ['PIXELTYPE=SIGNEDBYTE'] | ||
if needed to indicate a signed byte type. | ||
|
||
Raises: | ||
ValueError if an unsupported data type is entered | ||
""" | ||
numpy_dtype = numpy.dtype(numpy_type) | ||
|
||
if GDAL_VERSION < (3, 7, 0) and numpy_dtype == numpy.dtype(numpy.int8): | ||
return gdal.GDT_Byte, ['PIXELTYPE=SIGNEDBYTE'] | ||
|
||
gdal_type = gdal_array.NumericTypeCodeToGDALTypeCode(numpy_dtype) | ||
if gdal_type is None: | ||
raise ValueError(f"Unsupported DataType: {numpy_type}") | ||
return gdal_type, [] | ||
|
||
|
||
def merge_bounding_box_list(bounding_box_list, bounding_box_mode): | ||
|
@@ -4195,12 +4171,13 @@ def numpy_array_to_raster( | |
Return: | ||
None | ||
""" | ||
|
||
raster_driver = gdal.GetDriverByName(raster_driver_creation_tuple[0]) | ||
driver_name, creation_options = raster_driver_creation_tuple | ||
raster_driver = gdal.GetDriverByName(driver_name) | ||
ny, nx = base_array.shape | ||
gdal_type, type_creation_options = _numpy_to_gdal_type(base_array.dtype) | ||
new_raster = raster_driver.Create( | ||
target_path, nx, ny, 1, NUMPY_TO_GDAL_TYPE[base_array.dtype], | ||
options=raster_driver_creation_tuple[1]) | ||
target_path, nx, ny, 1, gdal_type, | ||
options=list(creation_options) + type_creation_options) | ||
if projection_wkt is not None: | ||
new_raster.SetProjection(projection_wkt) | ||
if origin is not None and pixel_size is not None: | ||
|
@@ -4420,14 +4397,16 @@ def _mult_op(base_array, base_nodata, scale, datatype): | |
scaled_raster_path = os.path.join( | ||
workspace_dir, | ||
f'scaled_{os.path.basename(base_stitch_raster_path)}') | ||
gdal_type = _gdal_to_numpy_type( | ||
target_band.DataType, | ||
target_band.GetMetadata('IMAGE_STRUCTURE')) | ||
# multiply the pixels in the resampled raster by the ratio of | ||
# the pixel area in the wgs84 units divided by the area of the | ||
# original pixel | ||
raster_calculator( | ||
[(base_stitch_raster_path, 1), (base_stitch_nodata, 'raw'), | ||
m2_area_per_lat/base_pixel_area_m2, | ||
(_GDAL_TYPE_TO_NUMPY_LOOKUP[ | ||
target_raster_info['datatype']], 'raw')], _mult_op, | ||
(gdal_type, 'raw')], _mult_op, | ||
scaled_raster_path, | ||
target_raster_info['datatype'], base_stitch_nodata) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -656,22 +656,20 @@ def raster_band_percentile( | |
will select the next element higher than the percentile cutoff). | ||
|
||
""" | ||
raster_type = pygeoprocessing.get_raster_info( | ||
base_raster_path_band[0])['datatype'] | ||
if raster_type in ( | ||
gdal.GDT_Byte, gdal.GDT_Int16, gdal.GDT_UInt16, gdal.GDT_Int32, | ||
gdal.GDT_UInt32): | ||
numpy_type = pygeoprocessing.get_raster_info( | ||
base_raster_path_band[0])['numpy_type'] | ||
if numpy.issubdtype(numpy_type, numpy.integer): | ||
Comment on lines
+659
to
+661
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a really nice update - one less thing to have to worry about updating if GDAL adds more int or float types! |
||
return _raster_band_percentile_int( | ||
base_raster_path_band, working_sort_directory, percentile_list, | ||
heap_buffer_size, ffi_buffer_size) | ||
elif raster_type in (gdal.GDT_Float32, gdal.GDT_Float64): | ||
elif numpy.issubdtype(numpy_type, numpy.floating): | ||
return _raster_band_percentile_double( | ||
base_raster_path_band, working_sort_directory, percentile_list, | ||
heap_buffer_size, ffi_buffer_size) | ||
else: | ||
raise ValueError( | ||
'Cannot process raster type %s (not a known integer nor float ' | ||
'type)', raster_type) | ||
'type)', numpy_type) | ||
|
||
|
||
def _raster_band_percentile_int( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. This is a great idea and I'm glad they're supporting this in
pyproject.toml
now.