Skip to content

Commit

Permalink
Merge 2ae6202 into 39b2a3a
Browse files Browse the repository at this point in the history
  • Loading branch information
sphamba committed Mar 22, 2024
2 parents 39b2a3a + 2ae6202 commit ff5e0e6
Show file tree
Hide file tree
Showing 10 changed files with 1,173 additions and 69 deletions.
47 changes: 24 additions & 23 deletions gpm/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

# -----------------------------------------------------------------------------.
"""This module defines functions providing GPM-API Dataset information."""
from itertools import chain

import numpy as np
import xarray as xr

Expand Down Expand Up @@ -79,7 +81,8 @@ def get_dataset_variables(ds, sort=False):
def _get_available_spatial_dims(xr_obj):
"""Get xarray object available spatial dimensions."""
dims = list(xr_obj.dims)
return tuple(np.array(SPATIAL_DIMS)[np.isin(SPATIAL_DIMS, dims)].tolist())
flattened_spatial_dims = list(chain.from_iterable(SPATIAL_DIMS))
return tuple(np.array(flattened_spatial_dims)[np.isin(flattened_spatial_dims, dims)].tolist())


def _get_available_vertical_dims(xr_obj):
Expand Down Expand Up @@ -199,13 +202,14 @@ def _is_spatial_2d_datarray(da, strict):
if not _is_expected_spatial_dims(spatial_dims):
return False
vertical_dims = _get_available_vertical_dims(da)
if not vertical_dims:
if strict:
if len(da.dims) == 2:
return True
else:
return True
return False

if vertical_dims:
return False

if strict and len(da.dims) != 2:
return False

return True


def _is_spatial_3d_datarray(da, strict):
Expand All @@ -214,33 +218,30 @@ def _is_spatial_3d_datarray(da, strict):
if not _is_expected_spatial_dims(spatial_dims):
return False
vertical_dims = _get_available_vertical_dims(da)

if not vertical_dims:
return False
else:
if strict:
if len(da.dims) == 2:
return True
else:
return True
return False

if strict and len(da.dims) != 3:
return False

return True


def _is_transect_datarray(da, strict):
"""Check if a DataArray is a spatial 3D array."""
spatial_dims = list(_get_available_spatial_dims(da))
if len(spatial_dims) != 1:
return False

vertical_dims = list(_get_available_vertical_dims(da))

if not vertical_dims:
return False
else:
if strict:
if len(da.dims) == 2:
return True
else:
return True
return False

if strict and len(da.dims) != 2:
return False

return True


def _is_spatial_2d_dataset(ds, strict):
Expand Down
12 changes: 4 additions & 8 deletions gpm/dataset/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,10 @@
}

SPATIAL_DIMS = [
"along_track",
"cross_track",
"lat",
"lon", # choose whether to use instead latitude/longitude
"latitude",
"longitude",
"x",
"y", # compatibility with satpy/gpm_geo i.e.
["along_track", "cross_track"],
["lat", "lon"], # choose whether to use instead latitude/longitude
["latitude", "longitude"],
["x", "y"], # compatibility with satpy/gpm_geo i.e.
]
VERTICAL_DIMS = ["range", "nBnEnv", "height"]
FREQUENCY_DIMS = ["radar_frequency", "pmw_frequency"]
Expand Down
86 changes: 69 additions & 17 deletions gpm/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from gpm.utils import geospatial
import posixpath as pxp
import ntpath as ntp
import numpy as np
from pytest_mock import MockerFixture
import xarray as xr
from gpm.tests.utils.fake_datasets import get_orbit_dataarray, get_grid_dataarray
Expand Down Expand Up @@ -452,6 +453,9 @@ def prevent_pyplot_show(
mocker.patch("matplotlib.pyplot.show")


#### Orbit Data Array


@pytest.fixture(scope="function")
def orbit_dataarray() -> xr.DataArray:
"""Create orbit data array near 0 longitude and latitude"""
Expand Down Expand Up @@ -497,7 +501,23 @@ def orbit_pole_dataarray() -> xr.DataArray:
)


#### Orbit Dataarray with NaN values
@pytest.fixture(scope="function")
def orbit_spatial_3d_dataarray(orbit_dataarray: xr.DataArray) -> xr.DataArray:
"""Return a 3D orbit data array"""

# Add a vertical dimension with shape larger than 1 to prevent squeezing
return orbit_dataarray.expand_dims(dim={"height": 2})


@pytest.fixture
def orbit_transect_dataarray(orbit_dataarray: xr.DataArray) -> xr.DataArray:
"""Return a transect orbit data array"""

orbit_dataarray = orbit_dataarray.expand_dims(dim={"height": 2})
return orbit_dataarray.isel(along_track=0)


#### Orbit Data Array with NaN values


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -526,7 +546,7 @@ def orbit_data_nan_along_track_dataarray(orbit_dataarray) -> xr.DataArray:
return orbit_dataarray


#### Orbit Dataarray with NaN coordinates
#### Orbit Data Array with NaN coordinates


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -573,7 +593,7 @@ def orbit_nan_inner_cross_track_dataarray(orbit_dataarray) -> xr.DataArray:
return orbit_dataarray


#### Grid DataArray
#### Grid Data Array


@pytest.fixture(scope="function")
Expand All @@ -590,20 +610,6 @@ def grid_dataarray() -> xr.DataArray:
)


@pytest.fixture(scope="function")
def grid_antimeridian_dataarray() -> xr.DataArray:
"""Create grid data array going over the antimeridian"""

return get_grid_dataarray(
start_lon=160,
start_lat=-5,
end_lon=-170,
end_lat=15,
n_lon=20,
n_lat=15,
)


@pytest.fixture(scope="function")
def grid_nan_lon_dataarray(grid_dataarray) -> xr.DataArray:
"""Create grid data array near 0 longitude and latitude with some NaN longitudes"""
Expand All @@ -616,3 +622,49 @@ def grid_nan_lon_dataarray(grid_dataarray) -> xr.DataArray:
grid_dataarray["lon"] = lon

return grid_dataarray


@pytest.fixture(scope="function")
def grid_spatial_3d_dataarray(grid_dataarray: xr.DataArray) -> xr.DataArray:
"""Return a 3D grid data array"""

# Add a vertical dimension with shape larger than 1 to prevent squeezing
return grid_dataarray.expand_dims(dim={"height": 2})


@pytest.fixture
def grid_transect_dataarray(grid_dataarray: xr.DataArray) -> xr.DataArray:
"""Return a transect grid data array"""

grid_dataarray = grid_dataarray.expand_dims(dim={"height": 2})
return grid_dataarray.isel(lat=0)


#### Datasets


@pytest.fixture
def dataset_collection(
orbit_dataarray: xr.DataArray,
grid_dataarray: xr.DataArray,
orbit_spatial_3d_dataarray: xr.DataArray,
grid_spatial_3d_dataarray: xr.DataArray,
orbit_transect_dataarray: xr.DataArray,
grid_transect_dataarray: xr.DataArray,
) -> xr.Dataset:
"""Return a dataset with a variety of data arrays"""

da_frequency = xr.DataArray(np.zeros((0, 0)), dims=["other", "radar_frequency"])

return xr.Dataset(
{
"variable_0": orbit_dataarray,
"variable_1": grid_dataarray,
"variable_2": orbit_spatial_3d_dataarray,
"variable_3": grid_spatial_3d_dataarray,
"variable_4": orbit_transect_dataarray,
"variable_5": grid_transect_dataarray,
"variable_6": da_frequency,
"variable_7": xr.DataArray(),
}
)
Loading

0 comments on commit ff5e0e6

Please sign in to comment.