Skip to content

Commit

Permalink
Merges in improvements to generate test orbits, filtering, and more (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Jan 29, 2024
1 parent fabc0b0 commit e29d426
Show file tree
Hide file tree
Showing 13 changed files with 561 additions and 200 deletions.
18 changes: 18 additions & 0 deletions thor/clusters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import multiprocessing as mp
import time
import uuid
from typing import List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -781,6 +782,9 @@ def cluster_and_link(
mjd0 = mjd[first][0]
dt = mjd - mjd0

if max_processes is None:
max_processes = mp.cpu_count()

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
# Put all arrays (which can be large) in ray's
Expand All @@ -798,6 +802,7 @@ def cluster_and_link(
for vxi_chunk, vyi_chunk in zip(
_iterate_chunks(vxx, chunk_size), _iterate_chunks(vyy, chunk_size)
):

futures.append(
cluster_velocity_remote.remote(
vxi_chunk,
Expand All @@ -813,6 +818,19 @@ def cluster_and_link(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
clusters_chunk, cluster_members_chunk = ray.get(finished[0])
clusters = qv.concatenate([clusters, clusters_chunk])
if clusters.fragmented():
clusters = qv.defragment(clusters)

cluster_members = qv.concatenate(
[cluster_members, cluster_members_chunk]
)
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
clusters_chunk, cluster_members_chunk = ray.get(finished[0])
Expand Down
85 changes: 38 additions & 47 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import abc
import logging
import multiprocessing as mp
import time
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import quivr as qv
import ray
Expand All @@ -14,7 +13,7 @@
from adam_core.ray_cluster import initialize_use_ray

from thor.config import Config
from thor.observations.observations import Observations
from thor.observations.observations import Observations, observations_iterator

from ..orbit import TestOrbits

Expand Down Expand Up @@ -108,10 +107,9 @@ def apply(
ephemeris = test_orbit.generate_ephemeris_from_observations(observations)

filtered_observations = Observations.empty()
state_ids = observations.state_id.unique().sort()
state_ids = observations.state_id.unique()

for state_id in state_ids:

# Select the ephemeris and observations for this state
ephemeris_state = ephemeris.select("id", state_id)
observations_state = observations.select("state_id", state_id)
Expand Down Expand Up @@ -206,8 +204,7 @@ def _within_radius(


def filter_observations_worker(
state_id_chunk: List[int],
observations: Union[str, Observations],
observations: Observations,
test_orbit: TestOrbits,
filters: List[ObservationFilter],
) -> Observations:
Expand All @@ -230,28 +227,17 @@ def filter_observations_worker(
filtered_observations : `~thor.observations.observations.Observations`
Filtered observations.
"""
if isinstance(observations, str):
observations_chunk = pq.read_table(
observations, filters=pc.field("state_id").isin(pa.array(state_id_chunk))
)
observations_chunk = Observations.from_pyarrow(observations_chunk)
else:
observations_chunk = observations.apply_mask(
pc.is_in(observations.state_id, pa.array(state_id_chunk))
)

filtered_observations = observations_chunk
for filter_i in filters:
filtered_observations = filter_i.apply(
filtered_observations,
observations = filter_i.apply(
observations,
test_orbit,
)

# Defragment the observations
if len(filtered_observations) > 0:
filtered_observations = qv.defragment(filtered_observations)
if len(observations) > 0:
observations = qv.defragment(observations)

return filtered_observations
return observations


filter_observations_worker_remote = ray.remote(filter_observations_worker)
Expand All @@ -263,7 +249,7 @@ def filter_observations(
test_orbit: TestOrbits,
config: Config,
filters: Optional[List[ObservationFilter]] = None,
chunk_size: int = 100,
chunk_size: int = 1_000_000,
) -> Observations:
"""
Filter observations by applying a list of filters. The input observations
Expand Down Expand Up @@ -300,16 +286,12 @@ def filter_observations(
)

if isinstance(observations, str):
if not observations.endswith(".parquet"):
raise ValueError("observations file should be a parquet file.")

state_ids = pq.read_table(observations, columns=["state_id"])["state_id"]
num_obs = len(state_ids)
state_ids = pc.unique(state_ids).sort()
num_obs = pq.read_metadata(observations).num_rows
logger.info(f"Filtering {num_obs} observations in parquet file.")

elif isinstance(observations, Observations):
num_obs = len(observations)
state_ids = pc.unique(observations.state_id).sort()
logger.info(f"Reading {num_obs} observations in memory.")

else:
raise ValueError(
Expand All @@ -320,26 +302,34 @@ def filter_observations(
# By default we always filter by radius from the predicted position of the test orbit
filters = [TestOrbitRadiusObservationFilter(radius=config.cell_radius)]

filtered_observations = Observations.empty()
if config.max_processes is None:
max_processes = mp.cpu_count()
else:
max_processes = config.max_processes

use_ray = initialize_use_ray(num_cpus=config.max_processes)
filtered_observations = Observations.empty()
logger.info(f"{config.json()}")
use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:

if isinstance(observations, Observations):
observations = ray.put(observations)
logger.info("Placed observations in the object store.")

futures = []
for state_id_chunk in _iterate_chunks(state_ids, chunk_size):

futures: List[ray.ObjectRef] = []
for observations_chunk in observations_iterator(
observations, chunk_size=chunk_size
):
futures.append(
filter_observations_worker_remote.remote(
state_id_chunk,
observations,
observations_chunk,
test_orbit,
filters,
)
)
if len(futures) > max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
filtered_observations = qv.concatenate(
[filtered_observations, ray.get(finished[0])]
)
if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
Expand All @@ -354,13 +344,14 @@ def filter_observations(
logger.info("Removed observations from the object store.")

else:

for state_id_chunk in _iterate_chunks(state_ids, chunk_size):

for observations_chunk in observations_iterator(
observations, chunk_size=chunk_size
):
filtered_observations_chunk = filter_observations_worker(
state_id_chunk, observations, test_orbit, filters
observations_chunk,
test_orbit,
filters,
)

filtered_observations = qv.concatenate(
[filtered_observations, filtered_observations_chunk]
)
Expand Down
Loading

0 comments on commit e29d426

Please sign in to comment.