From 4d093c0d76cd487611cc04af200dd9ca3ad47ad8 Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Tue, 30 Jan 2024 12:04:06 -0800 Subject: [PATCH] Minor improvements to orbit determination types (#149) * Add calculate_max_outliers * Use adam_core's OrbitDeterminationObservations * Change differential correction, merge and extend to accept a Propagator type * Change to FittedOrbits and FittedOrbitMembers to match adam_core's class definitions --- thor/main.py | 4 +- thor/orbit_determination/__init__.py | 8 ++- thor/orbit_determination/fitted_orbits.py | 6 +- thor/orbit_determination/outliers.py | 37 ++++++++++++ thor/orbit_determination/tests/__init__.py | 0 .../tests/test_fitted_orbits.py | 4 +- .../tests/test_outliers.py | 49 +++++++++++++++ thor/orbits/attribution.py | 21 +++---- thor/orbits/iod.py | 38 ++++++++---- thor/orbits/od.py | 60 ++++++++++--------- 10 files changed, 169 insertions(+), 58 deletions(-) create mode 100644 thor/orbit_determination/outliers.py create mode 100644 thor/orbit_determination/tests/__init__.py create mode 100644 thor/orbit_determination/tests/test_outliers.py diff --git a/thor/main.py b/thor/main.py index 81301f64..73984457 100644 --- a/thor/main.py +++ b/thor/main.py @@ -359,7 +359,7 @@ def link_test_orbit( rchi2_threshold=config.od_rchi2_threshold, delta=config.od_delta, max_iter=config.od_max_iter, - propagator=config.propagator, + propagator=propagator, propagator_kwargs={}, chunk_size=config.od_chunk_size, max_processes=config.max_processes, @@ -427,7 +427,7 @@ def link_test_orbit( radius=config.arc_extension_radius, delta=config.od_delta, max_iter=config.od_max_iter, - propagator=config.propagator, + propagator=propagator, propagator_kwargs={}, orbits_chunk_size=config.arc_extension_chunk_size, max_processes=config.max_processes, diff --git a/thor/orbit_determination/__init__.py b/thor/orbit_determination/__init__.py index 0e9f032d..bfb3c192 100644 --- a/thor/orbit_determination/__init__.py +++ b/thor/orbit_determination/__init__.py @@ -1 +1,7 @@ -from .fitted_orbits import * +# noqa: F401 +from .fitted_orbits import ( + FittedOrbits, + FittedOrbitMembers, + assign_duplicate_observations, +) +from .outliers import calculate_max_outliers diff --git a/thor/orbit_determination/fitted_orbits.py b/thor/orbit_determination/fitted_orbits.py index 22af251b..888af7bb 100644 --- a/thor/orbit_determination/fitted_orbits.py +++ b/thor/orbit_determination/fitted_orbits.py @@ -134,6 +134,8 @@ def drop_duplicate_orbits( return filtered, filtered_orbit_members +# FittedOrbits and FittedOrbit members currently match +# the schema of adam_core.orbit_determination's FittedOrbits and FittedOrbitMembers class FittedOrbits(qv.Table): orbit_id = qv.LargeStringColumn(default=lambda: uuid.uuid4().hex) @@ -143,7 +145,9 @@ class FittedOrbits(qv.Table): num_obs = qv.Int64Column() chi2 = qv.Float64Column() reduced_chi2 = qv.Float64Column() - improved = qv.BooleanColumn(nullable=True) + iterations = qv.Int64Column(nullable=True) + success = qv.BooleanColumn(nullable=True) + status_code = qv.Int64Column(nullable=True) def to_orbits(self) -> Orbits: """ diff --git a/thor/orbit_determination/outliers.py b/thor/orbit_determination/outliers.py new file mode 100644 index 00000000..dd8b7564 --- /dev/null +++ b/thor/orbit_determination/outliers.py @@ -0,0 +1,37 @@ +import numpy as np + + +def calculate_max_outliers( + num_obs: int, min_obs: int, contamination_percentage: float +) -> int: + """ + Calculate the maximum number of allowable outliers. Linkages may contain err + oneuos observations that need to be removed. This function calculates the maximum number of + observations that can be removed before the linkage no longer has the minimum number + of observations required. The contamination percentage is the maximum percentage of observations + that allowed to be erroneous. + + Parameters + ---------- + num_obs : int + Number of observations in the linkage. + min_obs : int + Minimum number of observations required for a valid linkage. + contamination_percentage : float + Maximum percentage of observations that allowed to be erroneous. Range is [0, 100]. + + Returns + ------- + outliers : int + Maximum number of allowable outliers. + """ + assert ( + num_obs >= min_obs + ), "Number of observations must be greater than or equal to the minimum number of observations." + assert ( + contamination_percentage >= 0 and contamination_percentage <= 100 + ), "Contamination percentage must be between 0 and 100." + + max_outliers = num_obs * (contamination_percentage / 100) + outliers = np.min([max_outliers, num_obs - min_obs]).astype(int) + return outliers diff --git a/thor/orbit_determination/tests/__init__.py b/thor/orbit_determination/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/thor/orbit_determination/tests/test_fitted_orbits.py b/thor/orbit_determination/tests/test_fitted_orbits.py index 0b52a63d..31c2313e 100644 --- a/thor/orbit_determination/tests/test_fitted_orbits.py +++ b/thor/orbit_determination/tests/test_fitted_orbits.py @@ -36,7 +36,9 @@ def simple_orbits(): num_obs=[10, 20, 15, 25, 5], # Specific values chi2=np.random.rand(5), reduced_chi2=[0.5, 0.4, 0.3, 0.2, 0.1], # Specific values - improved=pa.repeat(False, 5), + iterations=[100, 200, 300, 400, 500], # Specific values + success=[True, True, True, True, True], # Specific values + status_code=[1, 1, 1, 1, 1], # Specific values ) diff --git a/thor/orbit_determination/tests/test_outliers.py b/thor/orbit_determination/tests/test_outliers.py new file mode 100644 index 00000000..e10f8232 --- /dev/null +++ b/thor/orbit_determination/tests/test_outliers.py @@ -0,0 +1,49 @@ +import pytest + +from ..outliers import calculate_max_outliers + + +def test_calculate_max_outliers(): + # Test that the function returns the correct number of outliers given + # the number of observations, minimum number of observations, and + # contamination percentage in a few different cases. + min_obs = 3 + num_obs = 10 + contamination_percentage = 50 + outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) + assert outliers == 5 + + min_obs = 6 + num_obs = 10 + contamination_percentage = 50 + outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) + assert outliers == 4 + + min_obs = 6 + num_obs = 6 + contamination_percentage = 50 + outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) + assert outliers == 0 + + +def test_calculate_max_outliers_assertions(): + # Test that the function raises an assertion error when the number of observations + # is less than the minimum number of observations. + min_obs = 10 + num_obs = 6 + contamination_percentage = 50 + with pytest.raises( + AssertionError, + match=r"Number of observations must be greater than or equal to the minimum number of observations.", + ): + outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) + + # Test that the function raises an assertion error when the contamination percentage + # is less than 0. + min_obs = 10 + num_obs = 10 + contamination_percentage = -50 + with pytest.raises( + AssertionError, match=r"Contamination percentage must be between 0 and 1." + ): + outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) diff --git a/thor/orbits/attribution.py b/thor/orbits/attribution.py index e3129485..2008dc9f 100644 --- a/thor/orbits/attribution.py +++ b/thor/orbits/attribution.py @@ -1,8 +1,7 @@ -import gc import logging import multiprocessing as mp import time -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Type, Union import numpy as np import numpy.typing as npt @@ -12,13 +11,13 @@ import ray from adam_core.coordinates.residuals import Residuals from adam_core.orbits import Orbits -from adam_core.propagator import PYOORB +from adam_core.propagator import PYOORB, Propagator from adam_core.propagator.utils import _iterate_chunk_indices, _iterate_chunks from adam_core.ray_cluster import initialize_use_ray from sklearn.neighbors import BallTree from ..observations.observations import Observations -from ..orbit_determination import ( +from ..orbit_determination.fitted_orbits import ( FittedOrbitMembers, FittedOrbits, assign_duplicate_observations, @@ -121,13 +120,11 @@ def attribution_worker( orbits: Union[Orbits, FittedOrbits], observations: Observations, radius: float = 1 / 3600, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, ) -> Attributions: - if propagator == "PYOORB": - prop = PYOORB(**propagator_kwargs) - else: - raise ValueError(f"Invalid propagator '{propagator}'.") + # Initialize the propagator + prop = propagator(**propagator_kwargs) if isinstance(orbits, FittedOrbits): orbits = orbits.to_orbits() @@ -255,7 +252,7 @@ def attribute_observations( orbits: Union[Orbits, FittedOrbits, ray.ObjectRef], observations: Union[Observations, ray.ObjectRef], radius: float = 5 / 3600, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, orbits_chunk_size: int = 10, observations_chunk_size: int = 100000, @@ -388,7 +385,7 @@ def merge_and_extend_orbits( max_iter: int = 20, method: Literal["central", "finite"] = "central", fit_epoch: bool = False, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, orbits_chunk_size: int = 10, observations_chunk_size: int = 100000, @@ -541,7 +538,7 @@ def merge_and_extend_orbits( # Remove the orbits that were not improved from the pool of available orbits. Orbits that were not improved # are orbits that have already iterated to their best-fit solution given the observations available. These orbits # are unlikely to recover more observations in subsequent iterations and so can be saved for output. - not_improved_mask = pc.equal(orbits.improved, False) + not_improved_mask = pc.equal(orbits.success, False) orbits_out = orbits.apply_mask(not_improved_mask) orbit_members_out = orbit_members.apply_mask( pc.is_in(orbit_members.orbit_id, orbits_out.orbit_id) diff --git a/thor/orbits/iod.py b/thor/orbits/iod.py index c7f8d823..202514aa 100644 --- a/thor/orbits/iod.py +++ b/thor/orbits/iod.py @@ -11,6 +11,7 @@ import quivr as qv import ray from adam_core.coordinates.residuals import Residuals +from adam_core.orbit_determination import OrbitDeterminationObservations from adam_core.propagator import PYOORB, Propagator from adam_core.propagator.utils import _iterate_chunk_indices, _iterate_chunks from adam_core.ray_cluster import initialize_use_ray @@ -22,6 +23,7 @@ FittedOrbits, drop_duplicate_orbits, ) +from ..orbit_determination.outliers import calculate_max_outliers from ..utils.linkages import sort_by_id_and_time from .gauss import gaussIOD @@ -133,7 +135,6 @@ def iod_worker( propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: - prop = propagator(**propagator_kwargs) iod_orbits = FittedOrbits.empty() iod_orbit_members = FittedOrbitMembers.empty() @@ -148,6 +149,12 @@ def iod_worker( pc.is_in(observations.id, obs_ids) ) + observations_linkage = OrbitDeterminationObservations.from_kwargs( + id=observations_linkage.id, + coordinates=observations_linkage.coordinates, + observers=observations_linkage.get_observers().observers, + ) + iod_orbit, iod_orbit_orbit_members = iod( observations_linkage, min_obs=min_obs, @@ -157,7 +164,8 @@ def iod_worker( observation_selection_method=observation_selection_method, iterate=iterate, light_time=light_time, - propagator=prop, + propagator=propagator, + propagator_kwargs=propagator_kwargs, ) if len(iod_orbit) > 0: iod_orbit = iod_orbit.set_column("orbit_id", pa.array([linkage_id])) @@ -225,7 +233,7 @@ def iod_worker_remote( def iod( - observations: Observations, + observations: OrbitDeterminationObservations, min_obs: int = 6, min_arc_length: float = 1.0, contamination_percentage: float = 0.0, @@ -235,7 +243,8 @@ def iod( ] = "combinations", iterate: bool = False, light_time: bool = True, - propagator: Propagator = PYOORB(), + propagator: Type[Propagator] = PYOORB, + propagator_kwargs: dict = {}, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: """ Run initial orbit determination on a set of observations believed to belong to a single @@ -313,13 +322,16 @@ def iod( "outlier" : Flag to indicate which observations are potential outliers (their chi2 is higher than the chi2 threshold) [float] """ + # Initialize the propagator + prop = propagator(**propagator_kwargs) + processable = True if len(observations) == 0: processable = False obs_ids_all = observations.id.to_numpy(zero_copy_only=False) coords_all = observations.coordinates - observers_with_states = observations.get_observers() + observers = observations.observers observations = observations.sort_by( [ @@ -328,7 +340,7 @@ def iod( "coordinates.origin.code", ] ) - observers = observers_with_states.observers.sort_by( + observers = observers.sort_by( ["coordinates.time.days", "coordinates.time.nanos", "coordinates.origin.code"] ) @@ -344,15 +356,15 @@ def iod( num_obs = len(observations) if num_obs < min_obs: processable = False - num_outliers = int(num_obs * contamination_percentage / 100.0) - num_outliers = np.maximum(np.minimum(num_obs - min_obs, num_outliers), 0) + + max_outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage) # Select observation IDs to use for IOD obs_ids = select_observations( observations, method=observation_selection_method, ) - obs_ids = obs_ids[: (3 * (num_outliers + 1))] + obs_ids = obs_ids[: (3 * (max_outliers + 1))] if len(obs_ids) == 0: processable = False @@ -386,7 +398,7 @@ def iod( continue # Propagate initial orbit to all observation times - ephemeris = propagator.generate_ephemeris( + ephemeris = prop.generate_ephemeris( iod_orbits, observers, chunk_size=1, max_processes=1 ) @@ -408,7 +420,7 @@ def iod( # The reduced chi2 is above the threshold and no outliers are # allowed, this cannot be improved by outlier rejection # so continue to the next IOD orbit - if rchi2 > rchi2_threshold and num_outliers == 0: + if rchi2 > rchi2_threshold and max_outliers == 0: # If we have iterated through all iod orbits and no outliers # are allowed for this linkage then no other combination of # observations will make it acceptable, so exit here. @@ -436,9 +448,9 @@ def iod( # anticipate that we get to this stage if the three selected observations # belonging to one object yield a good initial orbit but the presence of outlier # observations is skewing the sum total of the residuals and chi2 - elif num_outliers > 0: + elif max_outliers > 0: logger.debug("Attempting to identify possible outliers.") - for o in range(num_outliers): + for o in range(max_outliers): # Select i highest observations that contribute to # chi2 (and thereby the residuals) remove = chi2[~mask].argsort()[-(o + 1) :] diff --git a/thor/orbits/od.py b/thor/orbits/od.py index 28af3248..b109e86a 100644 --- a/thor/orbits/od.py +++ b/thor/orbits/od.py @@ -1,7 +1,7 @@ import logging import multiprocessing as mp import time -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Type, Union import numpy as np import numpy.typing as npt @@ -10,17 +10,19 @@ import ray from adam_core.coordinates import CartesianCoordinates, CoordinateCovariances from adam_core.coordinates.residuals import Residuals +from adam_core.orbit_determination import OrbitDeterminationObservations from adam_core.orbits import Orbits -from adam_core.propagator import PYOORB, _iterate_chunks +from adam_core.propagator import PYOORB, Propagator, _iterate_chunks from adam_core.propagator.utils import _iterate_chunk_indices from adam_core.ray_cluster import initialize_use_ray from scipy.linalg import solve from ..observations.observations import Observations -from ..orbit_determination import FittedOrbitMembers, FittedOrbits -from ..utils.linkages import sort_by_id_and_time +from ..orbit_determination.fitted_orbits import FittedOrbitMembers, FittedOrbits +from ..orbit_determination.outliers import calculate_max_outliers logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) __all__ = ["differential_correction"] @@ -38,12 +40,12 @@ def od_worker( max_iter: int = 20, method: Literal["central", "finite"] = "central", fit_epoch: bool = False, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: + od_orbits = FittedOrbits.empty() od_orbit_members = FittedOrbitMembers.empty() - for orbit_id in orbit_ids: time_start = time.time() logger.debug(f"Differentially correcting orbit {orbit_id}...") @@ -54,6 +56,12 @@ def od_worker( ).obs_id orbit_observations = observations.apply_mask(pc.is_in(observations.id, obs_ids)) + orbit_observations = OrbitDeterminationObservations.from_kwargs( + id=orbit_observations.id, + coordinates=orbit_observations.coordinates, + observers=orbit_observations.get_observers().observers, + ) + od_orbit, od_orbit_orbit_members = od( orbit, orbit_observations, @@ -97,7 +105,7 @@ def od_worker_remote( max_iter: int = 20, method: Literal["central", "finite"] = "central", fit_epoch: bool = False, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]] @@ -124,7 +132,7 @@ def od_worker_remote( def od( orbit: FittedOrbits, - observations: Observations, + observations: OrbitDeterminationObservations, rchi2_threshold: float = 100, min_obs: int = 5, min_arc_length: float = 1.0, @@ -133,13 +141,11 @@ def od( max_iter: int = 20, method: Literal["central", "finite"] = "central", fit_epoch: bool = False, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, ) -> Tuple[FittedOrbits, FittedOrbitMembers]: - if propagator == "PYOORB": - prop = PYOORB(**propagator_kwargs) - else: - raise ValueError(f"Invalid propagator '{propagator}'.") + # Intialize the propagator + prop = propagator(**propagator_kwargs) if method not in ["central", "finite"]: err = "method should be one of 'central' or 'finite'." @@ -148,8 +154,7 @@ def od( obs_ids_all = observations.id.to_numpy(zero_copy_only=False) coords = observations.coordinates coords_sigma = coords.covariance.sigmas[:, 1:3] - observers_with_states = observations.get_observers() - observers = observers_with_states.observers + observers = observations.observers times_all = coords.time.mjd().to_numpy(zero_copy_only=False) # FLAG: can we stop iterating to find a solution? @@ -170,9 +175,10 @@ def od( logger.debug("This orbit has fewer than {} observations.".format(min_obs)) processable = False else: - num_outliers = int(num_obs * contamination_percentage / 100.0) - num_outliers = np.maximum(np.minimum(num_obs - min_obs, num_outliers), 0) - logger.debug("Maximum number of outliers allowed: {}".format(num_outliers)) + max_outliers = calculate_max_outliers( + num_obs, min_obs, contamination_percentage + ) + logger.debug(f"Maximum number of outliers allowed: {max_outliers}") outliers_tried = 0 # Calculate chi2 for residuals on the given observations @@ -218,7 +224,7 @@ def od( DELTA_DECREASE_FACTOR = 100 max_iter_i = max_iter - max_iter_outliers = max_iter * (num_outliers + 1) + max_iter_outliers = max_iter * (max_outliers + 1) while not converged and processable: iterations += 1 @@ -231,7 +237,7 @@ def od( logger.debug(f"Maximum number of iterations completed.") break if iterations == max_iter_i + 1 and ( - solution_found or (num_outliers == outliers_tried) + solution_found or (max_outliers == outliers_tried) ): logger.debug(f"Maximum number of iterations completed.") break @@ -485,8 +491,8 @@ def od( converged = True elif ( - num_outliers > 0 - and outliers_tried <= num_outliers + max_outliers > 0 + and outliers_tried <= max_outliers and iterations > max_iter_i and not solution_found ): @@ -554,13 +560,11 @@ def od( num_obs=[num_obs], chi2=[chi2_total_prev], reduced_chi2=[rchi2_prev], - improved=[improved], + iterations=[iterations], + success=[improved], + status_code=[0], ) - # od_orbit["num_params"] = num_params - # od_orbit["num_iterations"] = iterations - # od_orbit["improved"] = improved - od_orbit_members = FittedOrbitMembers.from_kwargs( orbit_id=np.full( len(obs_ids_all), orbit_prev.orbit_id[0].as_py(), dtype="object" @@ -586,7 +590,7 @@ def differential_correction( max_iter: int = 20, method: Literal["central", "finite"] = "central", fit_epoch: bool = False, - propagator: Literal["PYOORB"] = "PYOORB", + propagator: Type[Propagator] = PYOORB, propagator_kwargs: dict = {}, chunk_size: int = 10, max_processes: Optional[int] = 1,