Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set dims to longitude, latitude if projection passed #328

Open
wants to merge 7 commits into
base: 2.0-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions verde/base/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class BaseGridder(BaseEstimator):
# (pd.DataFrame, xr.Dataset, etc)
dims = ("northing", "easting")

# The default dimension names for generated outputs if projection is passed
unproj_dims = ("latitude", "longitude")

# The default name for any extra coordinates given to methods below
# through the `extra_coords` keyword argument. Coordinates are
# included in the outputs (pandas.DataFrame or xarray.Dataset)
Expand Down Expand Up @@ -482,7 +485,7 @@ def grid(
self.predict(project_coordinates(coordinates, projection))
)
# Get names for dims, data and any extra coordinates
dims = self._get_dims(dims)
dims = self._get_dims(dims, projection=projection)
data_names = self._get_data_names(data, data_names)
extra_coords_names = self._get_extra_coords_names(coordinates)
# Create xarray.Dataset
Expand Down Expand Up @@ -563,7 +566,7 @@ def scatter(
The interpolated values on a random set of points.

"""
dims = self._get_dims(dims)
dims = self._get_dims(dims, projection=projection)
region = get_instance_region(self, region)
coordinates = scatter_points(region, size, random_state=random_state, **kwargs)
if projection is None:
Expand Down Expand Up @@ -669,7 +672,7 @@ def profile(
The interpolated values along the profile.

"""
dims = self._get_dims(dims)
dims = self._get_dims(dims, projection=projection)
# Project the input points to generate the profile in Cartesian
# coordinates (the distance calculation doesn't make sense in
# geographic coordinates since we don't do actual distances on a
Expand All @@ -694,12 +697,14 @@ def profile(
columns.extend(zip(data_names, data))
return pd.DataFrame(dict(columns), columns=[c[0] for c in columns])

def _get_dims(self, dims):
def _get_dims(self, dims, projection=None):
"""
Get default dimension names.
"""
if dims is not None:
return dims
if projection is not None:
return self.unproj_dims
return self.dims

def _get_extra_coords_names(self, coordinates):
Expand Down
45 changes: 31 additions & 14 deletions verde/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,28 @@ def test_check_coordinates():

def test_get_dims():
"Tests that get_dims returns the expected results"
# Define a dummy projection function
def proj(lon, lat, inverse=False):
return None

gridder = BaseGridder()
assert gridder._get_dims(dims=None) == ("northing", "easting")
assert gridder._get_dims(dims=("john", "paul")) == ("john", "paul")
assert gridder._get_dims(dims=None, projection=None) == ("northing", "easting")
assert gridder._get_dims(dims=("john", "paul"), projection=None) == ("john", "paul")
gridder.dims = ("latitude", "longitude")
assert gridder._get_dims(dims=None) == ("latitude", "longitude")
assert gridder._get_dims(dims=None, projection=None) == ("latitude", "longitude")
# Test with a projection
gridder = BaseGridder()
assert gridder._get_dims(dims=None, projection=proj) == ("latitude", "longitude")
assert gridder._get_dims(dims=("john", "paul"), projection=proj) == (
"john",
"paul",
)
gridder.unproj_dims = ("latitude_1", "longitude_1")
assert gridder._get_dims(dims=None, projection=proj) == (
"latitude_1",
"longitude_1",
)


def test_get_data_names():
Expand Down Expand Up @@ -214,14 +231,14 @@ def proj(lon, lat, inverse=False):
# Check the scatter
scat = grd.scatter(region, 1000, random_state=0, projection=proj)
npt.assert_allclose(scat.scalars, data)
npt.assert_allclose(scat.easting, coordinates[0])
npt.assert_allclose(scat.northing, coordinates[1])
npt.assert_allclose(scat.longitude, coordinates[0])
npt.assert_allclose(scat.latitude, coordinates[1])

# Check the grid
grid = grd.grid(region=region, shape=shape, projection=proj)
npt.assert_allclose(grid.scalars.values, data_true)
npt.assert_allclose(grid.easting.values, coordinates_true[0][0, :])
npt.assert_allclose(grid.northing.values, coordinates_true[1][:, 0])
npt.assert_allclose(grid.longitude.values, coordinates_true[0][0, :])
npt.assert_allclose(grid.latitude.values, coordinates_true[1][:, 0])

# Check the profile
prof = grd.profile(
Expand All @@ -230,8 +247,8 @@ def proj(lon, lat, inverse=False):
npt.assert_allclose(prof.scalars, data_true[-1, :])
# Coordinates should still be evenly spaced since the projection is a
# multiplication.
npt.assert_allclose(prof.easting, coordinates_true[0][0, :])
npt.assert_allclose(prof.northing, coordinates_true[1][-1, :])
npt.assert_allclose(prof.longitude, coordinates_true[0][0, :])
npt.assert_allclose(prof.latitude, coordinates_true[1][-1, :])
# Distance should still be in the projected coordinates. If the projection
# is from geographic, we shouldn't be returning distances in degrees but in
# projected meters. The distances will be evenly spaced in unprojected
Expand Down Expand Up @@ -339,14 +356,14 @@ def proj(lon, lat, inverse=False):
region, 1000, random_state=0, projection=proj, extra_coords=(13, 17)
)
npt.assert_allclose(scat.scalars, data)
npt.assert_allclose(scat.easting, coordinates[0])
npt.assert_allclose(scat.northing, coordinates[1])
npt.assert_allclose(scat.longitude, coordinates[0])
npt.assert_allclose(scat.latitude, coordinates[1])

# Check the grid
grid = grd.grid(region=region, shape=shape, projection=proj, extra_coords=(13, 17))
npt.assert_allclose(grid.scalars.values, data_true)
npt.assert_allclose(grid.easting.values, coordinates_true[0][0, :])
npt.assert_allclose(grid.northing.values, coordinates_true[1][:, 0])
npt.assert_allclose(grid.longitude.values, coordinates_true[0][0, :])
npt.assert_allclose(grid.latitude.values, coordinates_true[1][:, 0])

# Check the profile
prof = grd.profile(
Expand All @@ -359,8 +376,8 @@ def proj(lon, lat, inverse=False):
npt.assert_allclose(prof.scalars, data_true[-1, :])
# Coordinates should still be evenly spaced since the projection is a
# multiplication.
npt.assert_allclose(prof.easting, coordinates_true[0][0, :])
npt.assert_allclose(prof.northing, coordinates_true[1][-1, :])
npt.assert_allclose(prof.longitude, coordinates_true[0][0, :])
npt.assert_allclose(prof.latitude, coordinates_true[1][-1, :])
# Distance should still be in the projected coordinates. If the projection
# is from geographic, we shouldn't be returning distances in degrees but in
# projected meters. The distances will be evenly spaced in unprojected
Expand Down