Skip to content

Commit

Permalink
fix: Ml nodata (#310)
Browse files Browse the repository at this point in the history
* simplify mosaic procedure

* adding mosiac and save unit tests

* fix path tifs

* resolve ml install issues

* handle 0s in X[targ_name]

* add todo

* reverting to main

* pin scikit fix

* add tests

* throw error if stacked by time
resolves #311

* droping test - meaningless output

---------

Co-authored-by: jgrss <jbgraesser@gmail.com>
  • Loading branch information
mmann1123 and jgrss committed May 1, 2024
1 parent a4fa5a6 commit 3cedad3
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 30 deletions.
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ coreg = earthpy
zarr = zarr
numcodecs
ml = dask-ml>=2022.5.27
scikit-learn>=0.23.0,<=1.2.0
scikit-learn==1.2.0
lightgbm
sklearn-xarray@git+https://github.com/jgrss/sklearn-xarray.git
sklearn-xarray@git+https://github.com/mmann1123/sklearn-xarray.git
numpy_groupies
perf = rtree
pygeos
netCDF4
Expand Down
3 changes: 2 additions & 1 deletion src/geowombat/backends/xarray_.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def reduce_func(
xr.where(left != tmp_nodata, left, right),
)


# Open all the data pointers
data_arrays = [
open_rasterio(
Expand Down Expand Up @@ -547,9 +548,9 @@ def reduce_func(
attrs.update(tags)
darray = darray.assign_attrs(**attrs)


if dtype is not None:
attrs = darray.attrs.copy()

return darray.astype(dtype).assign_attrs(**attrs)

else:
Expand Down
10 changes: 9 additions & 1 deletion src/geowombat/ml/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def _prepare_predictors(self, data, targ_name):

# drop nans
try:
# prep target axis
Xna = X[~X[targ_name].isnull()]
Xna = X[X[targ_name] != X.gw.nodataval]
Xna = X[X[targ_name] != 0] # Xtarg is being generated with meaningless 0s
# TODO: if X.gw.nodataval is not None:
# Xna = X[X!= X.gw.nodata ] # changes here would have to be reflected in y as well
except KeyError:
Xna = X

Expand Down Expand Up @@ -250,6 +253,11 @@ def fit(
>>> with gw.open(l8_224078_20200518) as src:
>>> X, Xy, clf = fit(src, cl)
"""
if data.gw.has_time_coord:
# throw error
raise ValueError(
"DataArray must not have a time coordinate. Use stack_dim='band' with gw.open() or use .isel(time=0) to select a single time slice."
)
if clf._estimator_type == "clusterer":
data = self._add_time_dim(data)
X, Xna = self._prepare_predictors(data, targ_name)
Expand Down
78 changes: 52 additions & 26 deletions tests/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,39 @@ def test_tree_predict(self):
self.assertTrue(np.all(np.isnan(y2.values[0, 0:5, 0])))
self.assertTrue(
np.allclose(
y1.values[0, -5:-1, 0],
y2.values[0, -5:-1, 0],
y1.values,
y2.values,
equal_nan=True,
)
)

def test_tree_predict_nodata(self):
with gw.config.update(
ref_res=300,
):
# assigning invalid nodata value
with gw.open(l8_224078_20200518, nodata=10) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y1 = fit_predict(src, tree_pipeline, aoi_poly, col="lc")
with gw.open(l8_224078_20200518) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y2 = fit_predict(src, tree_pipeline, aoi_poly, col="lc")

self.assertTrue(
np.allclose(
y1.values,
y2.values,
equal_nan=True,
)
)
def test_output_type_attri(self):

with gw.config.update(
Expand All @@ -135,7 +162,6 @@ def test_output_type_attri(self):

self.assertTrue(isinstance(y1, xr_da))
self.assertTrue(isinstance(y2, xr_da))
# self.assertTrue(isinstance(y1.chunks, tuple))
self.assertTrue(len(y1.attrs) > 0)
self.assertTrue(len(y2.attrs) > 0)

Expand All @@ -156,29 +182,6 @@ def test_fitpredict_eq_fit_predict_point(self):

self.assertTrue(np.allclose(y1.values, y2.values, equal_nan=True))

def test_fitpredict_time_point(self):

with gw.config.update(
ref_res=300,
):
with gw.open(
[l8_224078_20200518, l8_224078_20200518], stack_dim="time"
) as src:
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
(DeprecationWarning, FutureWarning, UserWarning),
)
y1 = fit_predict(
src,
pl_wo_feat,
aoi_point,
col="lc",
mask_nodataval=False,
)

self.assertTrue(np.all(y1.sel(time=1).values == y1.sel(time=2).values))

def test_fitpredict_eq_fit_predict_cluster(self):

with gw.config.update(
Expand Down Expand Up @@ -249,6 +252,29 @@ def test_classes_match_prediction_b(self):
)
)

# def test_fitpredict_time_point(self):

# with gw.config.update(
# ref_res=300,
# ):
# with gw.open(
# [l8_224078_20200518, l8_224078_20200518], stack_dim="time"
# ) as src:
# with warnings.catch_warnings():
# warnings.simplefilter(
# "ignore",
# (DeprecationWarning, FutureWarning, UserWarning),
# )
# y1 = fit_predict(
# src,
# pl_wo_feat,
# aoi_point,
# col="lc",
# mask_nodataval=False,
# )

# self.assertTrue(np.all(y1.sel(time=1).values == y1.sel(time=2).values))

# def test_nodataval_replace(self):

# with gw.config.update(ref_res=300):
Expand Down

0 comments on commit 3cedad3

Please sign in to comment.