Skip to content

Commit

Permalink
perf: avoid double computation in Sanitizer
Browse files Browse the repository at this point in the history
Defer computation of valid samples to `transform`. (part of #109)
  • Loading branch information
nicrie committed Nov 8, 2023
1 parent 1f38a5b commit 0c42251
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions xeofs/preprocessing/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,36 +67,39 @@ def fit(
# sample dimensions, certain grid cells may be masked (e.g., due to
# ocean areas). To ensure correct reconstruction of scores,
# we need to identify the sample positions of NaNs in the fitted
# dataset. Keep in mind that when transforming new data,
# we have to recheck for valid samples, as the new dataset may have
# different samples.
X_valid = X.sel({self.feature_name: self.is_valid_feature})
self.is_valid_sample = ~X_valid.isnull().all(self.feature_name).compute()
# dataset. Since we have to recheck valid sample locations for each new data
# that is transformed, we defer this check to the transform method to
# avoid unnecessary computation.

return self

def transform(self, X: DataArray) -> DataArray:
# Check if input is a DataArray
self._check_input_type(X)

# Check if input has the correct dimensions
self._check_input_dims(X)

# Check if input has the correct coordinates
self._check_input_coords(X)

# Store sample coordinates for inverse transform
self.sample_coords_transform = X.coords[self.sample_name]

# Remove NaN entries; only consider full-dimensional NaNs
# We already know valid features from the fitted dataset
X = X.isel({self.feature_name: self.is_valid_feature})
# However, we need to recheck for valid samples, as the new dataset may
# Check for valid samples, as the new dataset may
# have different samples
is_valid_sample = ~X.isnull().all(self.feature_name).compute()
X = X.isel({self.sample_name: is_valid_sample})

if not hasattr(self, "is_valid_sample"):
# For first transform, store valid sample locations so that we can
# always reconstruct the fitted scores.
self.is_valid_sample = is_valid_sample
# Store valid sample locations for inverse transform
self.is_valid_sample_transform = is_valid_sample

X = X.isel({self.sample_name: is_valid_sample})

return X

def inverse_transform_data(self, X: DataArray) -> DataArray:
Expand Down

0 comments on commit 0c42251

Please sign in to comment.