Skip to content

Commit

Permalink
Merge pull request #807 from dstl/kdtree_tutorial_doc_fix
Browse files Browse the repository at this point in the history
Fix issue with kD-tree example not building
  • Loading branch information
sdhiscocks committed May 25, 2023
2 parents 683600a + 35ac949 commit 8448405
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions docs/tutorials/dataassociation/KDTree_Tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@
# We will simulate a large number of targets moving in the :math:`x`, :math:`y` Cartesian plane. We will then add
# detections with a high clutter rate at each time step.

import numpy as np
import datetime
from itertools import tee

import numpy as np

from stonesoup.types.array import StateVector, CovarianceMatrix
from stonesoup.types.state import GaussianState
Expand All @@ -171,15 +173,16 @@

from stonesoup.simulator.simple import MultiTargetGroundTruthSimulator

groundtruth_sim = MultiTargetGroundTruthSimulator(
transition_model=transition_model,
initial_state=initial_state,
timestep=timestep_size,
number_steps=number_of_steps,
birth_rate=birth_rate,
death_probability=death_probability,
initial_number_targets=200
)
groundtruth_sims = [
MultiTargetGroundTruthSimulator(
transition_model=transition_model,
initial_state=initial_state,
timestep=timestep_size,
number_steps=number_of_steps,
birth_rate=birth_rate,
death_probability=death_probability,
initial_number_targets=100)
for _ in range(3)]

# %%
# Initialise the measurement models
Expand All @@ -199,13 +202,18 @@
clutter_area = np.array([[-1, 1], [-1, 1]]) * 500
clutter_rate = 50

detection_sim = SimpleDetectionSimulator(
groundtruth=groundtruth_sim,
measurement_model=measurement_model,
detection_probability=probability_detection,
meas_range=clutter_area,
clutter_rate=clutter_rate
)
detection_sims = [
SimpleDetectionSimulator(
groundtruth=groundtruth_sim,
measurement_model=measurement_model,
detection_probability=probability_detection,
meas_range=clutter_area,
clutter_rate=clutter_rate)
for groundtruth_sim in groundtruth_sims]

# Use tee to create 3 identical versions, for GNN, k-D tree and TPR-Tree
sim_sets = [tee(sim, 3) for sim in detection_sims]
sim_sets = list(zip(*sim_sets))

# %%
# Import tracker components
Expand Down Expand Up @@ -257,7 +265,7 @@
run_times_KDTree = []

# run loop to calculate average run time
for _ in range(0, 3):
for n, detection_sim in enumerate(sim_sets[0]):
start_time = timer.perf_counter()

# create tracker components
Expand Down Expand Up @@ -296,8 +304,8 @@
tracks = set()

for time, ctracks in tracker:
groundtruth.update(groundtruth_sim.groundtruth_paths)
detections.update(detection_sim.detections)
groundtruth.update(groundtruth_sims[n].groundtruth_paths)
detections.update(detection_sims[n].detections)
tracks.update(ctracks)

end_time = timer.perf_counter()
Expand Down Expand Up @@ -332,7 +340,7 @@

run_times_TPRTree = []

for _ in range(0, 3):
for detection_sim in sim_sets[1]:
start_time = timer.perf_counter()

TPR_data_associator = TPRTreeGNN2D(hypothesiser=hypothesiser,
Expand Down Expand Up @@ -370,13 +378,9 @@
)

# run tracker
groundtruth = set()
detections = set()
tracks = set()

for time, ctracks in tracker:
groundtruth.update(groundtruth_sim.groundtruth_paths)
detections.update(detection_sim.detections)
tracks.update(ctracks)

end_time = timer.perf_counter()
Expand Down Expand Up @@ -408,7 +412,7 @@
run_times_GNN = []

# run loop to calculate average run time
for _ in range(0, 3):
for detection_sim in sim_sets[2]:
start_time = timer.perf_counter()

# create tracker components
Expand Down Expand Up @@ -442,13 +446,9 @@
)

# run tracker
groundtruth = set()
detections = set()
tracks = set()

for time, ctracks in tracker:
groundtruth.update(groundtruth_sim.groundtruth_paths)
detections.update(detection_sim.detections)
tracks.update(ctracks)

end_time = timer.perf_counter()
Expand Down

0 comments on commit 8448405

Please sign in to comment.