diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index f29229eda..b335bb8df 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -1,4 +1,5 @@ import copy +from itertools import islice import numpy as np from scipy.optimize import linear_sum_assignment @@ -92,8 +93,8 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ - len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ - (assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0 and \ + assoc1.objects.intersection(assoc2.objects) and \ + assoc1.time_range & assoc2.time_range and \ assoc1 != assoc2: return True else: @@ -101,8 +102,8 @@ def conflicts(assoc1, assoc2): def check_if_no_conflicts(association_set): - for assoc1 in range(0, len(association_set)): - for assoc2 in range(assoc1, len(association_set)): - if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]): + for n, assoc1 in enumerate(association_set): + for assoc2 in islice(association_set, n, None): + if conflicts(assoc1, assoc2): return False return True diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index 21b87561b..9e271d334 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -1,6 +1,8 @@ from operator import attrgetter from typing import Set +from ordered_set import OrderedSet + from .base import TrackToTrackAssociator from ..base import Property from ..measures import Measure, Euclidean, EuclideanWeighted @@ -106,7 +108,7 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): raise ValueError("Must provide mapping of position components to pos_map") if not self.measure: - state1 = list(tracks_set_1)[0][0] + state1 = next(iter(tracks_set_1))[0] total = len(state1.state_vector) if not self.pos_map: self.pos_map = [i for i in range(total)] @@ -175,7 +177,7 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): if n_unsuccessful >= self.consec_misses_end and \ start_timestamp: associations.add(TimeRangeAssociation( - (track1, track2), + OrderedSet((track1, track2)), TimeRange(start_timestamp, end_timestamp))) start_timestamp = None @@ -183,7 +185,7 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): if start_timestamp: end_timestamp = track1_states[-1].timestamp associations.add(TimeRangeAssociation( - (track1, track2), + OrderedSet((track1, track2)), TimeRange(start_timestamp, end_timestamp))) if self.one_to_one: @@ -347,7 +349,7 @@ def associate_tracks(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPat # in a row end the association and record if n_failures >= self.consec_misses_end: associations.add(TimeRangeAssociation( - (track, current_truth), + OrderedSet((track, current_truth)), TimeRange(start_timestamp, end_timestamp))) # If the current potential association @@ -369,7 +371,7 @@ def associate_tracks(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPat if current_truth: associations.add(TimeRangeAssociation( - (track, current_truth), + OrderedSet((track, current_truth)), TimeRange(start_timestamp, end_timestamp))) return AssociationSet(associations) @@ -410,13 +412,13 @@ def associate_tracks(self, tracks_set, truths_set): if track.id == truth.id: try: associations.add( - TimeRangeAssociation((track, truth), + TimeRangeAssociation(OrderedSet((track, truth)), TimeRange(max(track[0].timestamp, truth[0].timestamp), min(track[-1].timestamp, truth[-1].timestamp)))) except (TypeError, ValueError): # A timestamp is None, or non-overlapping timestamps (start > end) - associations.add(Association((track, truth))) + associations.add(Association(OrderedSet((track, truth)))) return AssociationSet(associations)