Skip to content

Commit

Permalink
Minor improvements to orbit determination types (#149)
Browse files Browse the repository at this point in the history
* Add calculate_max_outliers
* Use adam_core's OrbitDeterminationObservations
* Change differential correction, merge and extend to accept a Propagator type
* Change to FittedOrbits and FittedOrbitMembers to match adam_core's class definitions
  • Loading branch information
moeyensj committed Jan 30, 2024
1 parent e29d426 commit 4d093c0
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 58 deletions.
4 changes: 2 additions & 2 deletions thor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def link_test_orbit(
rchi2_threshold=config.od_rchi2_threshold,
delta=config.od_delta,
max_iter=config.od_max_iter,
propagator=config.propagator,
propagator=propagator,
propagator_kwargs={},
chunk_size=config.od_chunk_size,
max_processes=config.max_processes,
Expand Down Expand Up @@ -427,7 +427,7 @@ def link_test_orbit(
radius=config.arc_extension_radius,
delta=config.od_delta,
max_iter=config.od_max_iter,
propagator=config.propagator,
propagator=propagator,
propagator_kwargs={},
orbits_chunk_size=config.arc_extension_chunk_size,
max_processes=config.max_processes,
Expand Down
8 changes: 7 additions & 1 deletion thor/orbit_determination/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .fitted_orbits import *
# noqa: F401
from .fitted_orbits import (
FittedOrbits,
FittedOrbitMembers,
assign_duplicate_observations,
)
from .outliers import calculate_max_outliers
6 changes: 5 additions & 1 deletion thor/orbit_determination/fitted_orbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def drop_duplicate_orbits(
return filtered, filtered_orbit_members


# FittedOrbits and FittedOrbit members currently match
# the schema of adam_core.orbit_determination's FittedOrbits and FittedOrbitMembers
class FittedOrbits(qv.Table):

orbit_id = qv.LargeStringColumn(default=lambda: uuid.uuid4().hex)
Expand All @@ -143,7 +145,9 @@ class FittedOrbits(qv.Table):
num_obs = qv.Int64Column()
chi2 = qv.Float64Column()
reduced_chi2 = qv.Float64Column()
improved = qv.BooleanColumn(nullable=True)
iterations = qv.Int64Column(nullable=True)
success = qv.BooleanColumn(nullable=True)
status_code = qv.Int64Column(nullable=True)

def to_orbits(self) -> Orbits:
"""
Expand Down
37 changes: 37 additions & 0 deletions thor/orbit_determination/outliers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np


def calculate_max_outliers(
num_obs: int, min_obs: int, contamination_percentage: float
) -> int:
"""
Calculate the maximum number of allowable outliers. Linkages may contain err
oneuos observations that need to be removed. This function calculates the maximum number of
observations that can be removed before the linkage no longer has the minimum number
of observations required. The contamination percentage is the maximum percentage of observations
that allowed to be erroneous.
Parameters
----------
num_obs : int
Number of observations in the linkage.
min_obs : int
Minimum number of observations required for a valid linkage.
contamination_percentage : float
Maximum percentage of observations that allowed to be erroneous. Range is [0, 100].
Returns
-------
outliers : int
Maximum number of allowable outliers.
"""
assert (
num_obs >= min_obs
), "Number of observations must be greater than or equal to the minimum number of observations."
assert (
contamination_percentage >= 0 and contamination_percentage <= 100
), "Contamination percentage must be between 0 and 100."

max_outliers = num_obs * (contamination_percentage / 100)
outliers = np.min([max_outliers, num_obs - min_obs]).astype(int)
return outliers
Empty file.
4 changes: 3 additions & 1 deletion thor/orbit_determination/tests/test_fitted_orbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def simple_orbits():
num_obs=[10, 20, 15, 25, 5], # Specific values
chi2=np.random.rand(5),
reduced_chi2=[0.5, 0.4, 0.3, 0.2, 0.1], # Specific values
improved=pa.repeat(False, 5),
iterations=[100, 200, 300, 400, 500], # Specific values
success=[True, True, True, True, True], # Specific values
status_code=[1, 1, 1, 1, 1], # Specific values
)


Expand Down
49 changes: 49 additions & 0 deletions thor/orbit_determination/tests/test_outliers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

from ..outliers import calculate_max_outliers


def test_calculate_max_outliers():
# Test that the function returns the correct number of outliers given
# the number of observations, minimum number of observations, and
# contamination percentage in a few different cases.
min_obs = 3
num_obs = 10
contamination_percentage = 50
outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)
assert outliers == 5

min_obs = 6
num_obs = 10
contamination_percentage = 50
outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)
assert outliers == 4

min_obs = 6
num_obs = 6
contamination_percentage = 50
outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)
assert outliers == 0


def test_calculate_max_outliers_assertions():
# Test that the function raises an assertion error when the number of observations
# is less than the minimum number of observations.
min_obs = 10
num_obs = 6
contamination_percentage = 50
with pytest.raises(
AssertionError,
match=r"Number of observations must be greater than or equal to the minimum number of observations.",
):
outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)

# Test that the function raises an assertion error when the contamination percentage
# is less than 0.
min_obs = 10
num_obs = 10
contamination_percentage = -50
with pytest.raises(
AssertionError, match=r"Contamination percentage must be between 0 and 1."
):
outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)
21 changes: 9 additions & 12 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import gc
import logging
import multiprocessing as mp
import time
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Type, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -12,13 +11,13 @@
import ray
from adam_core.coordinates.residuals import Residuals
from adam_core.orbits import Orbits
from adam_core.propagator import PYOORB
from adam_core.propagator import PYOORB, Propagator
from adam_core.propagator.utils import _iterate_chunk_indices, _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray
from sklearn.neighbors import BallTree

from ..observations.observations import Observations
from ..orbit_determination import (
from ..orbit_determination.fitted_orbits import (
FittedOrbitMembers,
FittedOrbits,
assign_duplicate_observations,
Expand Down Expand Up @@ -121,13 +120,11 @@ def attribution_worker(
orbits: Union[Orbits, FittedOrbits],
observations: Observations,
radius: float = 1 / 3600,
propagator: Literal["PYOORB"] = "PYOORB",
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Attributions:
if propagator == "PYOORB":
prop = PYOORB(**propagator_kwargs)
else:
raise ValueError(f"Invalid propagator '{propagator}'.")
# Initialize the propagator
prop = propagator(**propagator_kwargs)

if isinstance(orbits, FittedOrbits):
orbits = orbits.to_orbits()
Expand Down Expand Up @@ -255,7 +252,7 @@ def attribute_observations(
orbits: Union[Orbits, FittedOrbits, ray.ObjectRef],
observations: Union[Observations, ray.ObjectRef],
radius: float = 5 / 3600,
propagator: Literal["PYOORB"] = "PYOORB",
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
orbits_chunk_size: int = 10,
observations_chunk_size: int = 100000,
Expand Down Expand Up @@ -388,7 +385,7 @@ def merge_and_extend_orbits(
max_iter: int = 20,
method: Literal["central", "finite"] = "central",
fit_epoch: bool = False,
propagator: Literal["PYOORB"] = "PYOORB",
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
orbits_chunk_size: int = 10,
observations_chunk_size: int = 100000,
Expand Down Expand Up @@ -541,7 +538,7 @@ def merge_and_extend_orbits(
# Remove the orbits that were not improved from the pool of available orbits. Orbits that were not improved
# are orbits that have already iterated to their best-fit solution given the observations available. These orbits
# are unlikely to recover more observations in subsequent iterations and so can be saved for output.
not_improved_mask = pc.equal(orbits.improved, False)
not_improved_mask = pc.equal(orbits.success, False)
orbits_out = orbits.apply_mask(not_improved_mask)
orbit_members_out = orbit_members.apply_mask(
pc.is_in(orbit_members.orbit_id, orbits_out.orbit_id)
Expand Down
38 changes: 25 additions & 13 deletions thor/orbits/iod.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import quivr as qv
import ray
from adam_core.coordinates.residuals import Residuals
from adam_core.orbit_determination import OrbitDeterminationObservations
from adam_core.propagator import PYOORB, Propagator
from adam_core.propagator.utils import _iterate_chunk_indices, _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray
Expand All @@ -22,6 +23,7 @@
FittedOrbits,
drop_duplicate_orbits,
)
from ..orbit_determination.outliers import calculate_max_outliers
from ..utils.linkages import sort_by_id_and_time
from .gauss import gaussIOD

Expand Down Expand Up @@ -133,7 +135,6 @@ def iod_worker(
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
prop = propagator(**propagator_kwargs)

iod_orbits = FittedOrbits.empty()
iod_orbit_members = FittedOrbitMembers.empty()
Expand All @@ -148,6 +149,12 @@ def iod_worker(
pc.is_in(observations.id, obs_ids)
)

observations_linkage = OrbitDeterminationObservations.from_kwargs(
id=observations_linkage.id,
coordinates=observations_linkage.coordinates,
observers=observations_linkage.get_observers().observers,
)

iod_orbit, iod_orbit_orbit_members = iod(
observations_linkage,
min_obs=min_obs,
Expand All @@ -157,7 +164,8 @@ def iod_worker(
observation_selection_method=observation_selection_method,
iterate=iterate,
light_time=light_time,
propagator=prop,
propagator=propagator,
propagator_kwargs=propagator_kwargs,
)
if len(iod_orbit) > 0:
iod_orbit = iod_orbit.set_column("orbit_id", pa.array([linkage_id]))
Expand Down Expand Up @@ -225,7 +233,7 @@ def iod_worker_remote(


def iod(
observations: Observations,
observations: OrbitDeterminationObservations,
min_obs: int = 6,
min_arc_length: float = 1.0,
contamination_percentage: float = 0.0,
Expand All @@ -235,7 +243,8 @@ def iod(
] = "combinations",
iterate: bool = False,
light_time: bool = True,
propagator: Propagator = PYOORB(),
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
"""
Run initial orbit determination on a set of observations believed to belong to a single
Expand Down Expand Up @@ -313,13 +322,16 @@ def iod(
"outlier" : Flag to indicate which observations are potential outliers (their chi2 is higher than
the chi2 threshold) [float]
"""
# Initialize the propagator
prop = propagator(**propagator_kwargs)

processable = True
if len(observations) == 0:
processable = False

obs_ids_all = observations.id.to_numpy(zero_copy_only=False)
coords_all = observations.coordinates
observers_with_states = observations.get_observers()
observers = observations.observers

observations = observations.sort_by(
[
Expand All @@ -328,7 +340,7 @@ def iod(
"coordinates.origin.code",
]
)
observers = observers_with_states.observers.sort_by(
observers = observers.sort_by(
["coordinates.time.days", "coordinates.time.nanos", "coordinates.origin.code"]
)

Expand All @@ -344,15 +356,15 @@ def iod(
num_obs = len(observations)
if num_obs < min_obs:
processable = False
num_outliers = int(num_obs * contamination_percentage / 100.0)
num_outliers = np.maximum(np.minimum(num_obs - min_obs, num_outliers), 0)

max_outliers = calculate_max_outliers(num_obs, min_obs, contamination_percentage)

# Select observation IDs to use for IOD
obs_ids = select_observations(
observations,
method=observation_selection_method,
)
obs_ids = obs_ids[: (3 * (num_outliers + 1))]
obs_ids = obs_ids[: (3 * (max_outliers + 1))]

if len(obs_ids) == 0:
processable = False
Expand Down Expand Up @@ -386,7 +398,7 @@ def iod(
continue

# Propagate initial orbit to all observation times
ephemeris = propagator.generate_ephemeris(
ephemeris = prop.generate_ephemeris(
iod_orbits, observers, chunk_size=1, max_processes=1
)

Expand All @@ -408,7 +420,7 @@ def iod(
# The reduced chi2 is above the threshold and no outliers are
# allowed, this cannot be improved by outlier rejection
# so continue to the next IOD orbit
if rchi2 > rchi2_threshold and num_outliers == 0:
if rchi2 > rchi2_threshold and max_outliers == 0:
# If we have iterated through all iod orbits and no outliers
# are allowed for this linkage then no other combination of
# observations will make it acceptable, so exit here.
Expand Down Expand Up @@ -436,9 +448,9 @@ def iod(
# anticipate that we get to this stage if the three selected observations
# belonging to one object yield a good initial orbit but the presence of outlier
# observations is skewing the sum total of the residuals and chi2
elif num_outliers > 0:
elif max_outliers > 0:
logger.debug("Attempting to identify possible outliers.")
for o in range(num_outliers):
for o in range(max_outliers):
# Select i highest observations that contribute to
# chi2 (and thereby the residuals)
remove = chi2[~mask].argsort()[-(o + 1) :]
Expand Down

0 comments on commit 4d093c0

Please sign in to comment.