Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jgrss/features delayed #250

Merged
merged 21 commits into from
Mar 29, 2023
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
extend-ignore = E203, E266, E501, W503, F403, F401, F841, C901
max-line-length = 79
max-complexity = 10
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ repos:
--wrap-descriptions,
'79',
]
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,22 @@ requires = [
'Cython>=0.29.0,<3.0.0',
'numpy>=1.19.0'
]

[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| \.vscode
| \.idea
| _build
| buck-out
| build
| dist
)/
'''
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ docs = numpydoc
tests = testfixtures
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the read the docs we have an "all" option is it worth putting that back in?

jax
jaxlib
pre-commit
coreg = earthpy
pyfftw==0.12.0
bottleneck
Expand Down
4 changes: 3 additions & 1 deletion src/geowombat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .core import calc_area
from .core import subset
from .core import clip
from .core import clip_by_polygon
from .core import mask
from .core import replace
from .core import recode
Expand Down Expand Up @@ -53,6 +54,7 @@
'calc_area',
'subset',
'clip',
'clip_by_polygon',
'mask',
'replace',
'recode',
Expand Down Expand Up @@ -81,5 +83,5 @@
'bounds_to_coords',
'lonlat_to_xy',
'xy_to_lonlat',
'__version__'
'__version__',
]
152 changes: 118 additions & 34 deletions src/geowombat/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,31 @@
import numcodecs

ZARR_INSTALLED = True
except:
except ImportError:
ZARR_INSTALLED = False


logger = logging.getLogger(__name__)


def get_dims_from_bounds(
bounds: BoundingBox, res: T.Tuple[float, float]
) -> T.Tuple[int, int]:
width = int((bounds.right - bounds.left) / abs(res[0]))
height = int((bounds.top - bounds.bottom) / abs(res[1]))

return height, width


def get_file_info(src_obj):
src_bounds = src_obj.bounds
src_res = src_obj.res
src_width = src_obj.width
src_height = src_obj.height

FileInfo = namedtuple('FileInfo', 'src_bounds src_res src_width src_height')
FileInfo = namedtuple(
'FileInfo', 'src_bounds src_res src_width src_height'
)

return FileInfo(
src_bounds=src_bounds,
Expand All @@ -60,7 +71,15 @@ def get_file_info(src_obj):


def to_gtiff(
filename, data, window, indexes, transform, n_workers, separate, tags, kwargs
filename,
data,
window,
indexes,
transform,
n_workers,
separate,
tags,
kwargs,
):
"""Writes data to a GeoTiff file.

Expand Down Expand Up @@ -140,7 +159,11 @@ def to_gtiff(

class RasterioStore(object):
def __init__(
self, filename: T.Union[str, Path], mode: str = 'w', tags: dict = None, **kwargs
self,
filename: T.Union[str, Path],
mode: str = 'w',
tags: dict = None,
**kwargs,
):
self.filename = Path(filename)
self.mode = mode
Expand All @@ -153,7 +176,9 @@ def __setitem__(self, key, item):
index_range, y, x = key
indexes = list(
range(
index_range.start + 1, index_range.stop + 1, index_range.step or 1
index_range.start + 1,
index_range.stop + 1,
index_range.step or 1,
)
)
else:
Expand Down Expand Up @@ -204,10 +229,13 @@ def close_delayed(self, store):

def write(self, data: xr.DataArray, compute: bool = False) -> Delayed:
if isinstance(data.data, da.Array):
return da.store(da.squeeze(data.data), self, lock=True, compute=compute)
return da.store(
da.squeeze(data.data), self, lock=True, compute=compute
)
else:
self.dst.write(
data.squeeze().data, indexes=list(range(1, data.data.shape[0] + 1))
data.squeeze().data,
indexes=list(range(1, data.data.shape[0] + 1)),
)

def close(self):
Expand Down Expand Up @@ -294,7 +322,9 @@ def __setitem__(self, key, item):
index_range, y, x = key
indexes = list(
range(
index_range.start + 1, index_range.stop + 1, index_range.step or 1
index_range.start + 1,
index_range.stop + 1,
index_range.step or 1,
)
)

Expand All @@ -307,12 +337,14 @@ def __setitem__(self, key, item):

if self.out_block_type.lower() == 'zarr':

group_name = '{BASE}_y{Y:09d}_x{X:09d}_h{H:09d}_w{W:09d}'.format(
BASE=self.f_base,
Y=y.start,
X=x.start,
H=y.stop - y.start,
W=x.stop - x.start,
group_name = (
'{BASE}_y{Y:09d}_x{X:09d}_h{H:09d}_w{W:09d}'.format(
BASE=self.f_base,
Y=y.start,
X=x.start,
H=y.stop - y.start,
W=x.stop - x.start,
)
)

group = self.root.create_group(group_name)
Expand All @@ -322,7 +354,10 @@ def __setitem__(self, key, item):
item,
compressor=self.compressor,
dtype=item.dtype.name,
chunks=(self.kwargs['blockysize'], self.kwargs['blockxsize']),
chunks=(
self.kwargs['blockysize'],
self.kwargs['blockxsize'],
),
)

group.attrs['row_off'] = y.start
Expand Down Expand Up @@ -405,7 +440,9 @@ def __setitem__(self, key, item):
self.separate and self.out_block_type.lower() == 'gtiff'
):

with rio.open(out_filename, mode=io_mode, sharing=False, **kwargs) as dst_:
with rio.open(
out_filename, mode=io_mode, sharing=False, **kwargs
) as dst_:

dst_.write(item, window=w, indexes=indexes)

Expand Down Expand Up @@ -605,7 +642,13 @@ def window_to_bounds(filenames, w):
return left, bottom, right, top


def align_bounds(minx, miny, maxx, maxy, res):
def align_bounds(
minx: float,
miny: float,
maxx: float,
maxy: float,
res: T.Tuple[float, float],
) -> T.Tuple[Affine, int, int]:
"""Aligns bounds to resolution.

Args:
Expand All @@ -620,9 +663,8 @@ def align_bounds(minx, miny, maxx, maxy, res):
"""
xres, yres = res

new_height = (maxy - miny) / yres
new_width = (maxx - minx) / xres

new_height = int(np.floor((maxy - miny) / yres))
new_width = int(np.floor((maxx - minx) / xres))
new_transform = Affine(xres, 0.0, minx, 0.0, -yres, maxy)

return aligned_target(new_transform, new_width, new_height, res)
Expand Down Expand Up @@ -666,7 +708,12 @@ def get_file_bounds(
dst_res = src_info.src_res

# Transform the extent to the reference CRS
bounds_left, bounds_bottom, bounds_right, bounds_top = transform_bounds(
(
bounds_left,
bounds_bottom,
bounds_right,
bounds_top,
) = transform_bounds(
src_crs,
dst_crs,
src_info.src_bounds.left,
Expand Down Expand Up @@ -785,7 +832,11 @@ def warp_images(
# Get the union bounds of all images.
# *Target-aligned-pixels are returned.
warp_kwargs['bounds'] = get_file_bounds(
filenames, bounds_by=bounds_by, crs=crs, res=res, return_bounds=True
filenames,
bounds_by=bounds_by,
crs=crs,
res=res,
return_bounds=True,
)

return [warp(fn, **warp_kwargs) for fn in filenames]
Expand Down Expand Up @@ -859,7 +910,12 @@ def warp(
# Check if the data need to be subset
if (bounds is None) or (tuple(bounds) == tuple(src_info.src_bounds)):
if crs:
left_coord, bottom_coord, right_coord, top_coord = transform_bounds(
(
left_coord,
bottom_coord,
right_coord,
top_coord,
) = transform_bounds(
src_crs,
dst_crs,
src_info.src_bounds.left,
Expand Down Expand Up @@ -905,7 +961,10 @@ def warp(

elif isinstance(bounds, (list, np.ndarray, tuple)):
dst_bounds = BoundingBox(
left=bounds[0], bottom=bounds[1], right=bounds[2], top=bounds[3]
left=bounds[0],
bottom=bounds[1],
right=bounds[2],
top=bounds[3],
)

else:
Expand All @@ -915,8 +974,7 @@ def warp(
)
raise TypeError

dst_width = int((dst_bounds.right - dst_bounds.left) / dst_res[0])
dst_height = int((dst_bounds.top - dst_bounds.bottom) / dst_res[1])
dst_height, dst_width = get_dims_from_bounds(dst_bounds, dst_res)

# Do not warp if all the key metadata match the reference information
if (
Expand All @@ -940,7 +998,12 @@ def warp(
src_info.src_bounds.top,
)
dst_transform = Affine(
dst_res[0], 0.0, dst_bounds.left, 0.0, -dst_res[1], dst_bounds.top
dst_res[0],
0.0,
dst_bounds.left,
0.0,
-dst_res[1],
dst_bounds.top,
)

if tac:
Expand Down Expand Up @@ -992,7 +1055,9 @@ def reproject_array(
num_threads: int,
) -> np.ndarray:
"""Reprojects a DataArray and translates to a numpy ndarray."""
dst_array = np.zeros((data.gw.nbands, dst_height, dst_width), dtype=data.dtype)
dst_array = np.zeros(
(data.gw.nbands, dst_height, dst_width), dtype=data.dtype
)
dst_array, dst_transform = reproject(
data.gw.compute(num_workers=num_threads),
dst_array,
Expand Down Expand Up @@ -1109,20 +1174,32 @@ def transform_crs(
if isinstance(dst_width, int) and isinstance(dst_height, int):
xs = (
dst_transform
* (np.arange(0, dst_width) + 0.5, np.arange(0, dst_width) + 0.5)
* (
np.arange(0, dst_width) + 0.5,
np.arange(0, dst_width) + 0.5,
)
)[0]
ys = (
dst_transform
* (np.arange(0, dst_height) + 0.5, np.arange(0, dst_height) + 0.5)
* (
np.arange(0, dst_height) + 0.5,
np.arange(0, dst_height) + 0.5,
)
)[1]
else:
xs = (
dst_transform
* (np.arange(0, dst_width_) + 0.5, np.arange(0, dst_width_) + 0.5)
* (
np.arange(0, dst_width_) + 0.5,
np.arange(0, dst_width_) + 0.5,
)
)[0]
ys = (
dst_transform
* (np.arange(0, dst_height_) + 0.5, np.arange(0, dst_height_) + 0.5)
* (
np.arange(0, dst_height_) + 0.5,
np.arange(0, dst_height_) + 0.5,
)
)[1]

XYCoords = namedtuple('XYCoords', 'xs ys')
Expand Down Expand Up @@ -1162,10 +1239,17 @@ def transform_crs(

# Ensure the final transform is set based on adjusted bounds
dst_transform = Affine(
abs(dst_res[0]), 0.0, dst_bounds.left, 0.0, -abs(dst_res[1]), dst_bounds.top
abs(dst_res[0]),
0.0,
dst_bounds.left,
0.0,
-abs(dst_res[1]),
dst_bounds.top,
)

proj_func = dask.delayed(reproject_array) if delayed_array else reproject_array
proj_func = (
dask.delayed(reproject_array) if delayed_array else reproject_array
)
transformed_array = proj_func(
data_src,
dst_height,
Expand Down
Loading