Skip to content

Commit

Permalink
add spatial and temporal scale warnings #43 #44
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed May 3, 2024
1 parent 69e96ec commit 8ee8e7b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 4 deletions.
28 changes: 26 additions & 2 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
check_prediciton_aggregation,
check_prediction_return,
check_random_state,
check_spatial_scale,
check_spatio_bin_jitter_magnitude,
check_task,
check_temporal_bin_start_jitter,
check_temporal_scale,
check_transform_njobs,
check_transform_spatio_bin_jitter_magnitude,
check_verbosity,
Expand Down Expand Up @@ -288,10 +290,32 @@ def split(
save_path = os.path.join(self.save_dir, "ensemble_quadtree_df.csv") if self.save_tmp else ""

if "grid_len" not in self.__dir__():
# We har using AdaSTEM
# We are using AdaSTEM
self.grid_len = None
check_spatial_scale(
X_train[self.Spatio1].min(),
X_train[self.Spatio1].max(),
X_train[self.Spatio2].min(),
X_train[self.Spatio2].max(),
self.grid_len_upper_threshold,
self.grid_len_lower_threshold,
)
check_temporal_scale(
X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval
)
else:
# We har using STEM
# We are using STEM
check_spatial_scale(
X_train[self.Spatio1].min(),
X_train[self.Spatio1].max(),
X_train[self.Spatio2].min(),
X_train[self.Spatio2].max(),
self.grid_len,
self.grid_len,
)
check_temporal_scale(
X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval
)
pass

spatio_bin_jitter_magnitude = check_transform_spatio_bin_jitter_magnitude(
Expand Down
28 changes: 26 additions & 2 deletions stemflow/model/SphereAdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from ..utils.validation import (
check_base_model,
check_prediciton_aggregation,
check_spatial_scale,
check_spatio_bin_jitter_magnitude,
check_task,
check_temporal_bin_start_jitter,
check_temporal_scale,
check_transform_njobs,
check_verbosity,
)
Expand Down Expand Up @@ -244,10 +246,32 @@ def split(
save_path = os.path.join(self.save_dir, "ensemble_quadtree_df.csv") if self.save_tmp else ""

if "grid_len" not in self.__dir__():
# We har using AdaSTEM
# We are using AdaSTEM
self.grid_len = None
check_spatial_scale(
X_train[self.Spatio1].min(),
X_train[self.Spatio1].max(),
X_train[self.Spatio2].min(),
X_train[self.Spatio2].max(),
self.grid_len_upper_threshold,
self.grid_len_lower_threshold,
)
check_temporal_scale(
X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval
)
else:
# We har using STEM
# We are using STEM
check_spatial_scale(
X_train[self.Spatio1].min(),
X_train[self.Spatio1].max(),
X_train[self.Spatio2].min(),
X_train[self.Spatio2].max(),
self.grid_len,
self.grid_len,
)
check_temporal_scale(
X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval
)
pass

partial_get_one_ensemble_sphere_quadtree = partial(
Expand Down
30 changes: 30 additions & 0 deletions stemflow/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,33 @@ def check_X_y_shape_match(X, y):
X_size = X.shape[0]
if not y_size == X_size:
raise ValueError(f"The shape of X and y should match. Got X: {X_size}, y: {y_size}")


def check_spatial_scale(x_min, x_max, y_min, y_max, grid_length_upper, grid_length_lower):
if (grid_length_upper <= (x_max - x_min) / 100) or (grid_length_upper <= (y_max - y_min) / 100):
warnings(
"The grid_len_upper_threshold is significantly smaller than the scale of longitude and latitude (x and y). Be sure if this is desired."
)
if (grid_length_upper >= (x_max - x_min) * 100) or (grid_length_upper >= (y_max - y_min) * 100):
warnings(
"The grid_len_upper_threshold is significantly larger than the scale of longitude and latitude (x and y). Be sure if this is desired."
)
if (grid_length_lower <= (x_max - x_min) / 100) or (grid_length_lower <= (y_max - y_min) / 100):
warnings(
"The grid_len_lower_threshold is significantly smaller than the scale of longitude and latitude (x and y). Be sure if this is desired."
)
if (grid_length_lower >= (x_max - x_min) * 100) or (grid_length_lower >= (y_max - y_min) * 100):
warnings(
"The grid_len_lower_threshold is significantly larger than the scale of longitude and latitude (x and y). Be sure if this is desired."
)


def check_temporal_scale(t_min, t_max, temporal_bin_interval):
if temporal_bin_interval <= (t_max - t_min) / 100:
warnings(
"The temporal_bin_interval is significantly smaller than the scale of temporal parameters in provided data. Be sure if this is desired."
)
if temporal_bin_interval >= (t_max - t_min) * 100:
warnings(
"The temporal_bin_interval is significantly larger than the scale of temporal parameters in provided data. Be sure if this is desired."
)

0 comments on commit 8ee8e7b

Please sign in to comment.