Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Ml nodata #310

Merged
merged 13 commits into from
May 1, 2024
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ coreg = earthpy
zarr = zarr
numcodecs
ml = dask-ml>=2022.5.27
scikit-learn>=0.23.0,<=1.2.0
scikit-learn==1.2.0
lightgbm
sklearn-xarray@git+https://github.com/jgrss/sklearn-xarray.git
sklearn-xarray@git+https://github.com/mmann1123/sklearn-xarray.git
numpy_groupies
perf = rtree
pygeos
netCDF4
Expand Down
3 changes: 2 additions & 1 deletion src/geowombat/backends/xarray_.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def reduce_func(
xr.where(left != tmp_nodata, left, right),
)


# Open all the data pointers
data_arrays = [
open_rasterio(
Expand Down Expand Up @@ -547,9 +548,9 @@ def reduce_func(
attrs.update(tags)
darray = darray.assign_attrs(**attrs)


if dtype is not None:
attrs = darray.attrs.copy()

return darray.astype(dtype).assign_attrs(**attrs)

else:
Expand Down
10 changes: 9 additions & 1 deletion src/geowombat/ml/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def _prepare_predictors(self, data, targ_name):

# drop nans
try:
# prep target axis
Xna = X[~X[targ_name].isnull()]
Xna = X[X[targ_name] != X.gw.nodataval]
Xna = X[X[targ_name] != 0] # Xtarg is being generated with meaningless 0s
# TODO: if X.gw.nodataval is not None:
# Xna = X[X!= X.gw.nodata ] # changes here would have to be reflected in y as well
except KeyError:
Xna = X

Expand Down Expand Up @@ -250,6 +253,11 @@ def fit(
>>> with gw.open(l8_224078_20200518) as src:
>>> X, Xy, clf = fit(src, cl)
"""
if data.gw.has_time_coord:
# throw error
raise ValueError(
"DataArray must not have a time coordinate. Use stack_dim='band' with gw.open() or use .isel(time=0) to select a single time slice."
)
if clf._estimator_type == "clusterer":
data = self._add_time_dim(data)
X, Xna = self._prepare_predictors(data, targ_name)
Expand Down
78 changes: 52 additions & 26 deletions tests/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,39 @@ def test_tree_predict(self):
self.assertTrue(np.all(np.isnan(y2.values[0, 0:5, 0])))
self.assertTrue(
np.allclose(
y1.values[0, -5:-1, 0],
y2.values[0, -5:-1, 0],
y1.values,
y2.values,
equal_nan=True,
)
)

def test_tree_predict_nodata(self):
with gw.config.update(
ref_res=300,
):
# assigning invalid nodata value
with gw.open(l8_224078_20200518, nodata=10) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y1 = fit_predict(src, tree_pipeline, aoi_poly, col="lc")
with gw.open(l8_224078_20200518) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y2 = fit_predict(src, tree_pipeline, aoi_poly, col="lc")

self.assertTrue(
np.allclose(
y1.values,
y2.values,
equal_nan=True,
)
)
def test_output_type_attri(self):

with gw.config.update(
Expand All @@ -135,7 +162,6 @@ def test_output_type_attri(self):

self.assertTrue(isinstance(y1, xr_da))
self.assertTrue(isinstance(y2, xr_da))
# self.assertTrue(isinstance(y1.chunks, tuple))
self.assertTrue(len(y1.attrs) > 0)
self.assertTrue(len(y2.attrs) > 0)

Expand All @@ -156,29 +182,6 @@ def test_fitpredict_eq_fit_predict_point(self):

self.assertTrue(np.allclose(y1.values, y2.values, equal_nan=True))

def test_fitpredict_time_point(self):

with gw.config.update(
ref_res=300,
):
with gw.open(
[l8_224078_20200518, l8_224078_20200518], stack_dim="time"
) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y1 = fit_predict(
src,
pl_wo_feat,
aoi_point,
col="lc",
mask_nodataval=False,
)

self.assertTrue(np.all(y1.sel(time=1).values == y1.sel(time=2).values))

def test_fitpredict_eq_fit_predict_cluster(self):

with gw.config.update(
Expand Down Expand Up @@ -249,6 +252,29 @@ def test_classes_match_prediction_b(self):
)
)

# def test_fitpredict_time_point(self):

# with gw.config.update(
# ref_res=300,
# ):
# with gw.open(
# [l8_224078_20200518, l8_224078_20200518], stack_dim="time"
# ) as src:
# with warnings.catch_warnings():
# warnings.simplefilter(
# "ignore",
# (DeprecationWarning, FutureWarning, UserWarning),
# )
# y1 = fit_predict(
# src,
# pl_wo_feat,
# aoi_point,
# col="lc",
# mask_nodataval=False,
# )

# self.assertTrue(np.all(y1.sel(time=1).values == y1.sel(time=2).values))

# def test_nodataval_replace(self):

# with gw.config.update(ref_res=300):
Expand Down
Loading