Skip to content

Commit

Permalink
backpressure queues
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Jan 28, 2024
1 parent 71dd202 commit c449041
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 6 deletions.
12 changes: 12 additions & 0 deletions thor/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def cluster_and_link(
for vxi_chunk, vyi_chunk in zip(
_iterate_chunks(vxx, chunk_size), _iterate_chunks(vyy, chunk_size)
):

futures.append(
cluster_velocity_remote.remote(
vxi_chunk,
Expand All @@ -813,6 +814,17 @@ def cluster_and_link(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
clusters_chunk, cluster_members_chunk = ray.get(finished[0])
clusters = qv.concatenate([clusters, clusters_chunk])
if clusters.fragmented():
clusters = qv.defragment(clusters)

cluster_members = qv.concatenate([cluster_members, cluster_members_chunk])
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
clusters_chunk, cluster_members_chunk = ray.get(finished[0])
Expand Down
4 changes: 1 addition & 3 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,16 +321,14 @@ def filter_observations(
for observations_chunk in observations_iterator(
observations, chunk_size=chunk_size
):
print("sending in chunk")
futures.append(
filter_observations_worker_remote.remote(
observations_chunk,
test_orbit,
filters,
)
)
if len(futures) > max_processes + 1:
print("retrieving chunk")
if len(futures) > max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
filtered_observations = qv.concatenate(
[filtered_observations, ray.get(finished[0])]
Expand Down
2 changes: 1 addition & 1 deletion thor/observations/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def convert_input_observations_to_observations(
if use_ray:
futures: List[ray.ObjectRef] = []
for input_observation_chunk in input_iterator:
if len(futures) > max_processes * 2:
if len(futures) > max_processes * 1.5:
futures, output_observations = _process_next_future_result(
futures, output_observations, output_writer
)
Expand Down
9 changes: 9 additions & 0 deletions thor/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ def range_observations(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
ranged_detections_chunk = ray.get(finished[0])
ranged_detections = qv.concatenate(
[ranged_detections, ranged_detections_chunk]
)
if ranged_detections.fragmented():
ranged_detections = qv.defragment(ranged_detections)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
ranged_detections_chunk = ray.get(finished[0])
Expand Down
6 changes: 6 additions & 0 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ def attribute_observations(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
attributions_chunk = ray.get(finished[0])
attributions = qv.concatenate([attributions, attributions_chunk])
attributions = qv.defragment(attributions)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
attributions_chunk = ray.get(finished[0])
Expand Down
13 changes: 13 additions & 0 deletions thor/orbits/iod.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,19 @@ def initial_orbit_determination(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
iod_orbits_chunk, iod_orbit_members_chunk = result
iod_orbits = qv.concatenate([iod_orbits, iod_orbits_chunk])
iod_orbit_members = qv.concatenate(
[iod_orbit_members, iod_orbit_members_chunk]
)
if iod_orbits.fragmented():
iod_orbits = qv.defragment(iod_orbits)
if iod_orbit_members.fragmented():
iod_orbit_members = qv.defragment(iod_orbit_members)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
Expand Down
12 changes: 12 additions & 0 deletions thor/orbits/od.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,18 @@ def differential_correction(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
od_orbits_chunk, od_orbit_members_chunk = ray.get(finished[0])
od_orbits = qv.concatenate([od_orbits, od_orbits_chunk])
if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
od_orbit_members = qv.concatenate(
[od_orbit_members, od_orbit_members_chunk]
)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
od_orbits_chunk, od_orbit_members_chunk = ray.get(finished[0])
Expand Down
10 changes: 8 additions & 2 deletions thor/range_and_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def range_and_transform(
ranged_detections_spherical_ref = ray.put(ranged_detections_spherical)

# Get state IDs
# state_ids = observations.state_id.unique().sort()
state_ids = observations.state_id.unique()
futures = []
for state_id in state_ids:
Expand All @@ -191,6 +190,14 @@ def range_and_transform(
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
transformed_detections = qv.concatenate(
[transformed_detections, ray.get(finished[0])]
)
if transformed_detections.fragmented():
transformed_detections = qv.defragment(transformed_detections)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
transformed_detections = qv.concatenate(
Expand All @@ -207,7 +214,6 @@ def range_and_transform(

else:
# Get state IDs
# state_ids = observations.state_id.unique().sort()
state_ids = observations.state_id.unique()
for state_id in state_ids:
# mask = pc.equal(state_id, observations.state_id)
Expand Down

0 comments on commit c449041

Please sign in to comment.