From 0f2d7c863b3e93bfa70890d50801512f7106cdf5 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Fri, 10 Jun 2022 14:42:10 +0100 Subject: [PATCH 01/26] Implement CompoundTimeRange class --- stonesoup/dataassociator/tracktotrack.py | 10 ++- stonesoup/types/association.py | 27 +++++- stonesoup/types/time.py | 110 +++++++++++++++++++++++ 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index d6e4a622a..64454ac92 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -77,6 +77,11 @@ class TrackToTrackCounting(TrackToTrackAssociator): "position components compared to others (such as velocity). " "Default is 0.6" ) + one_to_one: bool = Property( + default=False, + doc="If True, the Hungarian Algorithm is applied to the results so that no track is " + "associated with more than one other at any given time step" + ) def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): """Associate two sets of tracks together. @@ -180,7 +185,10 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): (track1, track2), TimeRange(start_timestamp, end_timestamp))) - return AssociationSet(associations) + if self.one_to_one: + return AssociationSet(associations).association_deconflicter() + else: + return AssociationSet(associations) class TrackToTruth(TrackToTrackAssociator): diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 40516e0d9..2bed3e39e 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -3,7 +3,7 @@ from ..base import Property from .base import Type -from .time import TimeRange +from .time import CompoundTimeRange class Association(Type): @@ -48,7 +48,7 @@ class TimeRangeAssociation(Association): range of times """ - time_range: TimeRange = Property( + time_range: CompoundTimeRange = Property( default=None, doc="Range of times that association exists over. Default is None") @@ -119,6 +119,29 @@ def associations_including_objects(self, objects): for object_ in objects if object_ in association.objects} + def get_key_times(self): + """Return all times at which an association from the set begins or ends + + This method will return an ordered list of the times at which an association in the set + begins or ends. Note that in the case of a :class:`~.CompoundTimeRange`, + there are potentially multiple start and end times, + and for :class:`~.SingleTimeAssociation` associations, the start and end time are the same + + Returns + ------- + : list of :class:`datetime.datetime` + A list of times at which an association starts or ends + """ + key_times = [] + for association in self.associations: + if isinstance(association, SingleTimeAssociation): + key_times.append(association.timestamp) + else: + key_times.extend(association.time_range.get_key_times()) + + return key_times + + def __contains__(self, item): return item in self.associations diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 0c674eb2a..540dfd3ba 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,4 +1,5 @@ import datetime +from typing import Union from ..base import Property from .base import Type @@ -35,6 +36,11 @@ def duration(self): return self.end_timestamp - self.start_timestamp + @property + def key_times(self): + """Returns the start and end timestamp in a list""" + return [self.start_timestamp, self.end_timestamp] + def __contains__(self, timestamp): """Checks if timestamp is within range @@ -51,3 +57,107 @@ def __contains__(self, timestamp): """ return self.start_timestamp <= timestamp <= self.end_timestamp + + def overlap(self, time_range): + """Finds the intersection between this instance and another :class:`~.TimeRange` + + Parameters + ---------- + time_range: TimeRange + + Returns + ------- + TimeRange + The times contained by both this and time_range + """ + start_timestamp = max(self.start_timestamp, time_range.start_timestamp) + end_timestamp = min(self.end_timestamp, time_range.end_timestamp) + if end_timestamp > start_timestamp: + return TimeRange(start_timestamp, end_timestamp) + else: + return None + + +class CompoundTimeRange(Type): + """CompoundTimeRange type + + A container class representing one or more :class:`TimeRange` objects together + """ + time_ranges: list[TimeRange] = Property(doc="List of TimeRange objects", default=None) + + def __init__(self, time_ranges, *args, **kwargs): + super().__init__(*args, **kwargs) + if not time_ranges: + self.time_ranges = [] + self.check_overlap() + + def check_overlap(self): + """Returns a :class:`~.CompoundTimeRange` with overlap removed""" + overlap_check = CompoundTimeRange() + for time_range in self.time_ranges: + overlap_check.add(time_range.minus(overlap_check.overlap(time_range))) + self.time_ranges = overlap_check + + def add(self, time_range): + """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`""" + if time_range is None: + return + if isinstance(time_range, CompoundTimeRange): + for component in time_range.time_ranges: + self.add(component) + else: + self.time_ranges.append(time_range) + self.check_overlap() + + def __contains__(self, time): + """Checks if timestamp or is within range + + Parameters + ---------- + time : Union[datetime.datetime, TimeRange, CompoundTimeRange] + Time stamp or range to check if contained within this instance + + Returns + ------- + bool + `True` if time is fully contained within this instance + """ + + if isinstance(time, datetime): + for component in self.time_ranges: + if datetime in component: + return True + return False + elif isinstance(time, TimeRange) or isinstance(time, CompoundTimeRange): + return True if self.overlap(time) == time else False + else: + raise TypeError("Supplied parameter must be an instance of either " + "datetime, TimeRange, or CompoundTimeRange") + + def overlap(self, time_range): + """Finds the intersection between this instance and another time range + + In the case of an input :class:`~.CompoundTimeRange` this is done recursively. + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + CompoundTimeRange + The times contained by both this and time_range + """ + total_overlap = CompoundTimeRange() + if isinstance(time_range, CompoundTimeRange): + for component in time_range.time_ranges: + total_overlap.add(self.overlap(component)) + return total_overlap + elif isinstance(time_range, TimeRange): + for component in self.time_ranges: + total_overlap.add(component.overlap(time_range)) + return total_overlap + else: + raise TypeError("Supplied parameter must be an instance of either " + "TimeRange, or CompoundTimeRange") + From 31d0de7ed5c2c7216be13190e2b36162b9b3f286 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Fri, 10 Jun 2022 15:02:38 +0100 Subject: [PATCH 02/26] Implement key_times and duration into CompoundTimeRange --- stonesoup/types/time.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 540dfd3ba..5891f2a28 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -36,11 +36,6 @@ def duration(self): return self.end_timestamp - self.start_timestamp - @property - def key_times(self): - """Returns the start and end timestamp in a list""" - return [self.start_timestamp, self.end_timestamp] - def __contains__(self, timestamp): """Checks if timestamp is within range @@ -91,6 +86,22 @@ def __init__(self, time_ranges, *args, **kwargs): self.time_ranges = [] self.check_overlap() + @property + def duration(self): + """Duration of the time range""" + total_duration = 0 + for component in self.time_ranges: + total_duration += component.duration + + @property + def key_times(self): + """Returns all timestamps at which a component starts or ends""" + key_times = set() + for component in self.time_ranges: + key_times.add(component.start_timestamp) + key_times.add(component.end_timestamp) + return list(key_times).sort() + def check_overlap(self): """Returns a :class:`~.CompoundTimeRange` with overlap removed""" overlap_check = CompoundTimeRange() From 113f191e4b6ed72f1a5b2fa5eeab406934977528 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Fri, 17 Jun 2022 13:27:10 +0100 Subject: [PATCH 03/26] Start the multidimensioanal deconfliction algorithm --- stonesoup/dataassociator/_assignment.py | 32 +++++++++++ stonesoup/types/association.py | 74 ++++++++++++++----------- stonesoup/types/time.py | 6 +- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index d21f4ce7e..f95c00ae8 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -317,3 +317,35 @@ def assign2DBasic(C): # SOFTWARE AND ANY RELATED MATERIALS, AND AGREES TO INDEMNIFY THE NAVAL # RESEARCH LABORATORY FOR ALL THIRD-PARTY CLAIMS RESULTING FROM THE ACTIONS # OF RECIPIENT IN THE USE OF THE SOFTWARE. + + +def multidimensional_deconfliction(association_set): + """Solves the Multidimensional Assignment Problem (MAP) + + The assignment problem becomes more complex when time is added as a dimension. + This basic solution finds all the conflicts in an association set and then creates a + matrix of sums of conflicts in seconds, which is then passed to assign2D to solve as a + simple 2D assignment problem. Therefore, each object will only ever be assigned to one other, + throughout the relevant time range. + + Parameters + ---------- + association_set: The :class:`AssociationSet` to de-conflict + + Returns + ------- + : :class:`AssociationSet` + The association set without contradictory associations + """ + objects = list(association_set.object_set) + length = len(objects) + conflict_totals = numpy.zeros((length,length)) + association_on = numpy.full((length, length), False) + key_times = association_set.key_times + for obj in objects: + obj_ass_set = association_set.associations_including_objects(obj) + for time in key_times: + time_ass_set = association_set.associations_at_timestamp + + + diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 2bed3e39e..38c79675c 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -67,6 +67,38 @@ def __init__(self, associations=None, *args, **kwargs): if self.associations is None: self.associations = set() + @property + def key_times(self): + key_times = set(self.overall_time_range()) + for association in self.associations: + if isinstance(association, SingleTimeAssociation): + key_times.add(association.timestamp) + return list(key_times).order() + + @property + def overall_time_range(self): + """Return a :class:`~.CompoundTimeRange` of :class:`~.TimeRange` + objects in this instance. + + :class:`SingleTimeAssociation`s are discarded + """ + overall_range = CompoundTimeRange() + for association in self.associations: + if not isinstance(association, SingleTimeAssociation): + overall_range.add(association.time_range) + return overall_range + + @property + def object_set(self): + """Return all objects in the set + Returned as a set + """ + object_set = {} + for objects in self.associations.objects: + for obj in objects: + object_set.add(obj) + return object_set + def associations_at_timestamp(self, timestamp): """Return the associations that exist at a given timestamp @@ -80,7 +112,7 @@ def associations_at_timestamp(self, timestamp): Returns ------- - : set of :class:`~.Association` + : :class:`~.AssociationSet` Associations which occur at specified timestamp """ ret_associations = set() @@ -92,7 +124,7 @@ def associations_at_timestamp(self, timestamp): else: if timestamp in association.time_range: ret_associations.add(association) - return ret_associations + return AssociationSet(ret_associations) def associations_including_objects(self, objects): """Return associations that include all the given objects @@ -106,41 +138,21 @@ def associations_including_objects(self, objects): Set of objects to look for in associations Returns ------- - : set of :class:`~.Association` + : class:`~.AssociationSet` A set of objects which have been associated """ # Ensure objects is iterable if not isinstance(objects, list) and not isinstance(objects, set): objects = {objects} - - return {association - for association in self.associations - for object_ in objects - if object_ in association.objects} - - def get_key_times(self): - """Return all times at which an association from the set begins or ends - - This method will return an ordered list of the times at which an association in the set - begins or ends. Note that in the case of a :class:`~.CompoundTimeRange`, - there are potentially multiple start and end times, - and for :class:`~.SingleTimeAssociation` associations, the start and end time are the same - - Returns - ------- - : list of :class:`datetime.datetime` - A list of times at which an association starts or ends - """ - key_times = [] - for association in self.associations: - if isinstance(association, SingleTimeAssociation): - key_times.append(association.timestamp) - else: - key_times.extend(association.time_range.get_key_times()) - - return key_times - + print(type(objects)) + print(objects) + print(type(association for association in self.associations)) + + return AssociationSet({association + for association in self.associations + for object_ in objects + if object_ in association.objects}) def __contains__(self, item): return item in self.associations diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 5891f2a28..6a5a860ec 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -84,7 +84,7 @@ def __init__(self, time_ranges, *args, **kwargs): super().__init__(*args, **kwargs) if not time_ranges: self.time_ranges = [] - self.check_overlap() + self.remove_overlap() @property def duration(self): @@ -102,7 +102,7 @@ def key_times(self): key_times.add(component.end_timestamp) return list(key_times).sort() - def check_overlap(self): + def remove_overlap(self): """Returns a :class:`~.CompoundTimeRange` with overlap removed""" overlap_check = CompoundTimeRange() for time_range in self.time_ranges: @@ -118,7 +118,7 @@ def add(self, time_range): self.add(component) else: self.time_ranges.append(time_range) - self.check_overlap() + self.remove_overlap() def __contains__(self, time): """Checks if timestamp or is within range From a83e8c67bd7154281cb7a6682aa933c328b75dad Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Mon, 20 Jun 2022 18:00:57 +0100 Subject: [PATCH 04/26] Finished multidimensional_assignment --- stonesoup/dataassociator/_assignment.py | 91 +++++++++++++++++++++++-- stonesoup/types/association.py | 22 ++++++ stonesoup/types/time.py | 61 +++++++++++++++++ 3 files changed, 167 insertions(+), 7 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index f95c00ae8..f48c9f25e 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -1,4 +1,6 @@ import numpy +import datetime +from ..types.association import AssociationSet def assign2D(C, maximize=False): @@ -325,8 +327,13 @@ def multidimensional_deconfliction(association_set): The assignment problem becomes more complex when time is added as a dimension. This basic solution finds all the conflicts in an association set and then creates a matrix of sums of conflicts in seconds, which is then passed to assign2D to solve as a - simple 2D assignment problem. Therefore, each object will only ever be assigned to one other, - throughout the relevant time range. + simple 2D assignment problem. Therefore, each object will only ever be assigned to one other + at any one time. In the case of an association that only partially overlaps, the "weaker" one + (the one eliminated by assign2D) will be trimmed until there is no conflict. + + Due to the possibility of more than two conflicting associations at the same time, + this algorithm is recursive, but it is not expected many (if any) recursions will be required + for most uses. Parameters ---------- @@ -339,13 +346,83 @@ def multidimensional_deconfliction(association_set): """ objects = list(association_set.object_set) length = len(objects) - conflict_totals = numpy.zeros((length,length)) + totals = numpy.zeros((length, length)) # Time objects i and j are associated for in total association_on = numpy.full((length, length), False) + association_start = numpy.full(length, length, datetime.datetime.min) key_times = association_set.key_times - for obj in objects: - obj_ass_set = association_set.associations_including_objects(obj) - for time in key_times: - time_ass_set = association_set.associations_at_timestamp + for time in key_times: + associations_to_end = [] + for i in range(length): + for j in range(length): + if association_on[i][j]: + associations_to_end.append({i, j}) + time_ass_set = association_set.associations_at_timestamp + for association in time_ass_set: + obj_indices = [objects.index(list(association.objects)[0]), + objects.index(list(association.objects)[1])] + if len(association.objects) != 2: + raise ValueError("Supplied set must only contain pairs of associated objects") + if not association_on[obj_indices[0], obj_indices[1]]: + association_on[obj_indices[0], obj_indices[1]] = True + association_start[obj_indices[0], obj_indices[1]] = time + else: + associations_to_end.remove({obj_indices[0], obj_indices[1]}) + for indices in associations_to_end: + association_on[indices[0], indices[1]] = False + totals[indices[0], indices[1]] = totals[indices[0, indices[1]]] + \ + time - association_start[indices[0], indices[1]] + association_start = make_symmetric(association_start) + totals = make_symmetric(totals) + association_on = make_symmetric(association_on) + + solved_2d = assign2D(totals, maximize=True)[1] + winning_indices = [] # Pairs that are chosen by assign2D + for i in range(length): + winning_indices.append({i, solved_2d[i]}) + + cleaned_set = AssociationSet() + for winner in winning_indices: + assoc = association_set.associations_including_objects({objects[list(winner)[0]], + objects[list(winner)[1]]}) + cleaned_set.add(assoc) + association_set.remove(assoc) + + for assoc1 in association_set: + for assoc2 in association_set: + if conflicts(assoc1, assoc2): + association_set = multidimensional_deconfliction(association_set) + + # At this point, none of association_set should conflict with one another + for association in association_set: + for winner in cleaned_set: + if conflicts(association, winner): + association.time_range.minus(winner.time_range) + if association.time_range is not None: + cleaned_set.add(association) + + return cleaned_set + + + + + + + + + +def conflicts(assoc1, assoc2): + if not hasattr(assoc1, time_range) or not hasattr(assoc2, time_range): + raise TypeError("Associations must have a time_range property") + if assoc1.time_range.overlap(assoc2.time_range) and assoc1 != assoc2 \ + and len(assoc1.objects.intersection(assoc2.objects)) > 0: + return True + else: + return False + + +def make_symmetric(matrix): + return numpy.tril(matrix) + numpy.triu(matrix.T, 1) + diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 38c79675c..3148d88d7 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -67,6 +67,28 @@ def __init__(self, associations=None, *args, **kwargs): if self.associations is None: self.associations = set() + def add(self, association): + if association is None: + return + elif isinstance(association, Association): + self.associations.add(association) + elif isinstance(association, AssociationSet): + for component in association: + self.add(component) + else: + raise TypeError("Supplied parameter must be an Association or AssociationSet") + + def remove(self, association): + if association is None: + return + elif isinstance(association, Association): + self.associations.remove(association) + elif isinstance(association, AssociationSet): + for component in association: + self.remove(component) + else: + raise TypeError("Supplied parameter must be an Association or AssociationSet") + @property def key_times(self): key_times = set(self.overall_time_range()) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 6a5a860ec..5e1a7ca35 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -53,6 +53,45 @@ def __contains__(self, timestamp): return self.start_timestamp <= timestamp <= self.end_timestamp + def minus(self, time_range): + """Removes the overlap between this instance and another :class:`~.TimeRange`, or + :class:`~.CompoundTimeRange`. + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + TimeRange + This instance less the overlap + """ + if isinstance(time_range, CompoundTimeRange): + ans = self + for t_range in time_range: + ans = ans.minus(t_range) + if not ans: + return None + return ans + else: + overlap = self.overlap(time_range) + if self == overlap: + return None + if self.start_timestamp < overlap.start_timestamp: + start = self.start_timestamp + else: + start = overlap.end_timestamp + if self.end_timestamp > overlap.end_timestamp: + end = self.end_timestamp + else: + end = overlap.start_timestamp + if self.start_timestamp < overlap.start_timestamp and \ + self.end_timestamp > overlap.end_timestamp: + return CompoundTimeRange(TimeRange(self.start_timestamp, overlap.start_timestamp), + TimeRange(self.end_timestamp, overlap.end_timestamp)) + else: + return TimeRange(start, end) + def overlap(self, time_range): """Finds the intersection between this instance and another :class:`~.TimeRange` @@ -145,6 +184,24 @@ def __contains__(self, time): raise TypeError("Supplied parameter must be an instance of either " "datetime, TimeRange, or CompoundTimeRange") + def minus(self, time_range): + """Removes any overlap between this and another :class:`~.TimeRange` or + :class:`.~CompoundTimeRange` + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + CompoundTimeRange + The times contained by both this and time_range + """ + ans = CompoundTimeRange() + for component in self.time_ranges: + ans.add(component.minus(time_range)) + return ans + def overlap(self, time_range): """Finds the intersection between this instance and another time range @@ -163,10 +220,14 @@ def overlap(self, time_range): if isinstance(time_range, CompoundTimeRange): for component in time_range.time_ranges: total_overlap.add(self.overlap(component)) + if total_overlap == CompoundTimeRange(): + return None return total_overlap elif isinstance(time_range, TimeRange): for component in self.time_ranges: total_overlap.add(component.overlap(time_range)) + if total_overlap == CompoundTimeRange(): + return None return total_overlap else: raise TypeError("Supplied parameter must be an instance of either " From d0fb2efc37ae1062d881a9053b5e78aeb483619d Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Tue, 21 Jun 2022 17:47:55 +0100 Subject: [PATCH 05/26] tests for time begun --- stonesoup/types/tests/test_time.py | 63 +++++++++++++++++++++--------- stonesoup/types/time.py | 22 ++++++++--- 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 9630f977b..d911f11a2 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -2,50 +2,73 @@ import pytest -from ..time import TimeRange - - -def test_timerange(): - # Test time range initialisation +from ..time import TimeRange, CompoundTimeRange +@pytest.fixture +def times(): + before = datetime.datetime(year=2018, month=3, day=1, hour=3, minute=10, second=3) t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, second=35, microsecond=500) + inside = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=10, second=3) t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, second=41, microsecond=500) + after = datetime.datetime(year=2019, month=3, day=1, hour=6, minute=5, + second=41, microsecond=500) + return [before, t1, inside, t2, after] + + +def test_timerange(times): # Test creating without times with pytest.raises(TypeError): TimeRange() # Without start time with pytest.raises(TypeError): - TimeRange(start_timestamp=t1) + TimeRange(start_timestamp=times[1]) # Without end time with pytest.raises(TypeError): - TimeRange(end_timestamp=t2) + TimeRange(end_timestamp=times[3]) # Test an error is caught when end is after start with pytest.raises(ValueError): - TimeRange(start_timestamp=t2, end_timestamp=t1) + TimeRange(start_timestamp=times[3], end_timestamp=times[1]) + + # Test with wrong types for time_ranges + with pytest.raises(TypeError): + CompoundTimeRange(42) + with pytest.raises(TypeError): + CompoundTimeRange([times[1], times[3]]) + + test_range = test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + + test_compound = CompoundTimeRange() - test_range = test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) + test_compound2 = CompoundTimeRange([test_range]) - assert test_range.start_timestamp == t1 - assert test_range.end_timestamp == t2 + assert test_range.start_timestamp == times[1] + assert test_range.end_timestamp == times[3] + assert len(test_compound.time_ranges) == 0 + assert test_compound2.time_ranges[0] == test_range -def test_duration(): +def test_duration(times): # Test that duration is calculated properly - t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, - second=35, microsecond=500) - t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, - second=41, microsecond=500) - test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) + + # TimeRange + test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + + # CompoundTimeRange + + # times[2] is inside of [1] and [3], so should be equivalent to a TimeRange(times[1], times[4]) + test_range2 = CompoundTimeRange(TimeRange(start_timestamp=times[1], end_timestamp=times[3]), + TimeRange(start_timestamp=times[2], end_timestamp=times[4])) assert test_range.duration == datetime.timedelta(seconds=3726) + assert test_range.duration == datetime.timedelta(seconds=31539726) -def test_contains(): +def test_timerange_contains(): # Test that timestamps are correctly determined to be in the range t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, @@ -68,3 +91,7 @@ def test_contains(): # Upper edge assert t2 in test_range + +def test_timerange_minus(): + # Test the minus function + test1 = TimeRange() diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 5e1a7ca35..fb6a4c01e 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,5 +1,6 @@ import datetime from typing import Union +from itertools import combinations from ..base import Property from .base import Type @@ -68,7 +69,7 @@ def minus(self, time_range): """ if isinstance(time_range, CompoundTimeRange): ans = self - for t_range in time_range: + for t_range in time_range.time_ranges: ans = ans.minus(t_range) if not ans: return None @@ -117,12 +118,18 @@ class CompoundTimeRange(Type): A container class representing one or more :class:`TimeRange` objects together """ - time_ranges: list[TimeRange] = Property(doc="List of TimeRange objects", default=None) + time_ranges: list[TimeRange] = Property(doc="List of TimeRange objects. Can be empty", + default=None) - def __init__(self, time_ranges, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not time_ranges: + if self.time_ranges is None: self.time_ranges = [] + if not isinstance(self.time_ranges, list): + raise TypeError("Time_ranges must be a list") + for component in self.time_ranges: + if not isinstance(component, TimeRange): + raise TypeError("Time_ranges must contain only TimeRange objects") self.remove_overlap() @property @@ -142,7 +149,12 @@ def key_times(self): return list(key_times).sort() def remove_overlap(self): - """Returns a :class:`~.CompoundTimeRange` with overlap removed""" + """Removes overlap between components of time_ranges""" + if len(self.time_ranges) == 0: + return + if all([component.overlap(component2) is None for (component, component2) in + combinations(self.time_ranges, 2)]): + return overlap_check = CompoundTimeRange() for time_range in self.time_ranges: overlap_check.add(time_range.minus(overlap_check.overlap(time_range))) From 20c83019b8a5d806118ba57332eeffaf2b5a010c Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Wed, 22 Jun 2022 13:19:40 +0100 Subject: [PATCH 06/26] time tests complete --- stonesoup/types/tests/test_time.py | 214 +++++++++++++++++++++++++---- stonesoup/types/time.py | 93 +++++++++++-- 2 files changed, 272 insertions(+), 35 deletions(-) diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index d911f11a2..7b93ec1d3 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -6,6 +6,7 @@ @pytest.fixture def times(): + # Note times are returned chronologically for ease of reading before = datetime.datetime(year=2018, month=3, day=1, hour=3, minute=10, second=3) t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, second=35, microsecond=500) @@ -14,7 +15,9 @@ def times(): second=41, microsecond=500) after = datetime.datetime(year=2019, month=3, day=1, hour=6, minute=5, second=41, microsecond=500) - return [before, t1, inside, t2, after] + long_after = datetime.datetime(year=2022, month=6, day=1, hour=6, minute=5, + second=41, microsecond=500) + return [before, t1, inside, t2, after, long_after] def test_timerange(times): @@ -40,16 +43,20 @@ def test_timerange(times): with pytest.raises(TypeError): CompoundTimeRange([times[1], times[3]]) - test_range = test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) test_compound = CompoundTimeRange() test_compound2 = CompoundTimeRange([test_range]) + # Tests fuse_components method + fuse_test = CompoundTimeRange([test_range, TimeRange(times[3], times[4])]) + assert test_range.start_timestamp == times[1] assert test_range.end_timestamp == times[3] assert len(test_compound.time_ranges) == 0 assert test_compound2.time_ranges[0] == test_range + assert fuse_test.time_ranges == [TimeRange(times[1], times[4])] def test_duration(times): @@ -61,37 +68,198 @@ def test_duration(times): # CompoundTimeRange # times[2] is inside of [1] and [3], so should be equivalent to a TimeRange(times[1], times[4]) - test_range2 = CompoundTimeRange(TimeRange(start_timestamp=times[1], end_timestamp=times[3]), - TimeRange(start_timestamp=times[2], end_timestamp=times[4])) + test_range2 = CompoundTimeRange([TimeRange(start_timestamp=times[1], end_timestamp=times[3]), + TimeRange(start_timestamp=times[2], end_timestamp=times[4])]) assert test_range.duration == datetime.timedelta(seconds=3726) - assert test_range.duration == datetime.timedelta(seconds=31539726) + assert test_range2.duration == datetime.timedelta(seconds=31539726) -def test_timerange_contains(): +def test_contains(times): # Test that timestamps are correctly determined to be in the range - t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, - second=35, microsecond=500) - t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, - second=41, microsecond=500) - test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) + test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[1], times[4]) - # Inside - assert datetime.datetime( - year=2018, month=3, day=1, hour=5, minute=10, second=3) in test_range + with pytest.raises(TypeError): + 16 in test3 + + assert times[2] in test_range + assert not times[4] in test_range + assert not times[0] in test_range + assert times[1] in test_range + assert times[3] in test_range + + assert test2 in test_range + assert test_range not in test2 + assert test2 in test2 + assert test3 not in test_range - # Outside + # CompoundTimeRange + + compound_test = CompoundTimeRange([test_range]) + # Should be in neither + test_range2 = TimeRange(times[4], times[5]) + # Should be in compound_range2 but not 1 + test_range3 = TimeRange(times[2], times[4]) + compound_test2 = CompoundTimeRange([test_range, TimeRange(times[3], times[4])]) - assert not datetime.datetime( - year=2018, month=3, day=1, hour=3, minute=10, second=3) in test_range + assert compound_test in compound_test2 + assert times[2] in compound_test + assert times[2] in compound_test2 + assert test_range2 not in compound_test + assert test_range2 not in compound_test2 + assert test_range3 not in compound_test + assert test_range3 in compound_test2 - # Lower edge - assert t1 in test_range - # Upper edge - assert t2 in test_range +def test_equality(times): + test1 = TimeRange(times[1], times[2]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[1], times[3]) -def test_timerange_minus(): + with pytest.raises(TypeError): + test1 == "stonesoup" + + assert test1 == test2 + assert test2 == test1 + assert test3 != test1 and test1 != test3 + + ctest1 = CompoundTimeRange([test1, test3]) + ctest2 = CompoundTimeRange([TimeRange(times[1], times[3])]) + + with pytest.raises(TypeError): + ctest2 == "Stonesoup is the best!" + + assert ctest1 == ctest2 + assert ctest2 == ctest1 + ctest2.add(TimeRange(times[3], times[4])) + assert ctest1 != ctest2 + assert ctest2 != ctest1 + + assert CompoundTimeRange() == CompoundTimeRange() + assert ctest1 != CompoundTimeRange() + assert CompoundTimeRange() != ctest1 + + +def test_minus(times): # Test the minus function - test1 = TimeRange() + test1 = TimeRange(times[1], times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[2], times[3]) + test4 = TimeRange(times[4], times[5]) + + with pytest.raises(TypeError): + test1.minus(15) + + assert test1.minus(test2) == test3 + assert test1.minus(None) == test1 + assert test2.minus(test1) is None + + ctest1 = CompoundTimeRange([test2, test4]) + ctest2 = CompoundTimeRange([test1, test2]) + ctest3 = CompoundTimeRange([test4]) + + with pytest.raises(TypeError): + ctest1.minus(15) + + assert ctest1.minus(ctest2) == ctest3 + assert ctest1.minus(ctest1) == CompoundTimeRange() + assert ctest3.minus(ctest1) == CompoundTimeRange() + + assert test1.minus(ctest1) == TimeRange(times[2], times[3]) + assert test4.minus(ctest2) == test4 + assert ctest1.minus(test2) == ctest3 + + +def test_overlap(times): + test1 = TimeRange(times[1], times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[4], times[5]) + + ctest1 = CompoundTimeRange([test2, test3]) + ctest2 = CompoundTimeRange([test1, test2]) + + with pytest.raises(TypeError): + test2.overlap(ctest1) + + assert test1.overlap(test1) == test1 + assert test1.overlap(None) is None + assert test1.overlap(test2) == test2 + assert test2.overlap(test1) == test2 + assert ctest1.overlap(None) is None + assert ctest1.overlap(test2) == CompoundTimeRange([test2]) + assert ctest1.overlap(ctest2) == CompoundTimeRange([test2]) + assert ctest1.overlap(ctest2) == ctest2.overlap(ctest1) + + +def test_key_times(times): + test1 = CompoundTimeRange([TimeRange(times[0], times[1]), + TimeRange(times[3], times[4])]) + test2 = CompoundTimeRange([TimeRange(times[3], times[4]), + TimeRange(times[0], times[1])]) + test3 = CompoundTimeRange() + test4 = CompoundTimeRange([TimeRange(times[0], times[4])]) + + assert test1.key_times == [times[0], times[1], times[3], times[4]] + assert test2.key_times == [times[0], times[1], times[3], times[4]] + assert test3.key_times == [] + assert test4.key_times == [times[0], times[4]] + + +def test_remove_overlap(times): + test1_ro = CompoundTimeRange([TimeRange(times[0], times[1]), + TimeRange(times[3], times[4])]) + test1_ro.remove_overlap() + test2_ro = CompoundTimeRange([TimeRange(times[3], times[4]), + TimeRange(times[0], times[4])]) + test2_ro.remove_overlap() + test3_ro = CompoundTimeRange() + test3_ro.remove_overlap() + + test1 = CompoundTimeRange([TimeRange(times[0], times[1]), + TimeRange(times[3], times[4])]) + test3 = CompoundTimeRange() + test4 = CompoundTimeRange([TimeRange(times[0], times[4])]) + + assert test1_ro == test1 + assert test2_ro == test4 + assert test3_ro == test3 + + +def test_fuse_components(times): + # Note this is called inside the __init__ method, but is tested here explicitly + test1 = CompoundTimeRange([TimeRange(times[1], times[2])]).fuse_components() + test2 = CompoundTimeRange([TimeRange(times[1], times[2]), + TimeRange(times[2], times[4])]).fuse_components() + assert test1.time_ranges == {TimeRange(times[1], times[2])} + assert test2.time_ranges == {TimeRange(times[1], times[4])} + +def test_add(times): + test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) + test2 = CompoundTimeRange([TimeRange(times[0], times[1])]) + test3 = CompoundTimeRange([TimeRange(times[0], times[2])]) + test4 = CompoundTimeRange([TimeRange(times[0], times[2]), + TimeRange(times[4], times[5])]) + with pytest.raises(TypeError): + test1.add(True) + assert test1 != test2 + test1_copy = test1 + test1.add(None) + assert test1 == test1_copy + test1.add(test2) + assert test1 == test3 + test3.add(TimeRange(times[4], times[5])) + assert test3 == test4 + +def test_remove(times): + test1 = CompoundTimeRange([TimeRange(times[0], times[2]), + TimeRange(times[4], times[5])]) + with pytest.raises(TypeError): + test1.remove(0.4) + with pytest.raises(ValueError): + test1.remove(TimeRange(times[0], times[1])) + + test1.remove(TimeRange(times[0], times[2])) + assert test1 == CompoundTimeRange([TimeRange(times[4], times[5])]) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index fb6a4c01e..89db658ea 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,6 +1,6 @@ import datetime from typing import Union -from itertools import combinations +from itertools import combinations, permutations from ..base import Property from .base import Type @@ -37,13 +37,13 @@ def duration(self): return self.end_timestamp - self.start_timestamp - def __contains__(self, timestamp): + def __contains__(self, time): """Checks if timestamp is within range Parameters ---------- - timestamp : datetime.datetime - Time stamp to check if within range + time : Union[datetime.datetime, TimeRange] + Time stamp or range to check if within range Returns ------- @@ -51,8 +51,24 @@ def __contains__(self, timestamp): `True` if timestamp within :attr:`start_timestamp` and :attr:`end_timestamp` (inclusive) """ + if isinstance(time, datetime.datetime): + return self.start_timestamp <= time <= self.end_timestamp + elif isinstance(time, TimeRange): + return self.start_timestamp <= time.start_timestamp and \ + self.end_timestamp >= time.end_timestamp + else: + raise TypeError("Supplied parameter must be a datetime.datetime or TimeRange object") - return self.start_timestamp <= timestamp <= self.end_timestamp + def __eq__(self, other): + if other is None: + return False + if not isinstance(other, TimeRange): + raise TypeError(f"Cannot compare between a CompoundTimeRange and {type(other)}") + if self.start_timestamp == other.start_timestamp and \ + self.end_timestamp == other.end_timestamp: + return True + else: + return False def minus(self, time_range): """Removes the overlap between this instance and another :class:`~.TimeRange`, or @@ -67,6 +83,10 @@ def minus(self, time_range): TimeRange This instance less the overlap """ + if time_range is None: + return self + if not isinstance(time_range, TimeRange) and not isinstance(time_range, CompoundTimeRange): + raise TypeError("Supplied parameter must be a TimeRange or CompoundTimeRange object") if isinstance(time_range, CompoundTimeRange): ans = self for t_range in time_range.time_ranges: @@ -76,6 +96,8 @@ def minus(self, time_range): return ans else: overlap = self.overlap(time_range) + if overlap is None: + return self if self == overlap: return None if self.start_timestamp < overlap.start_timestamp: @@ -105,6 +127,10 @@ def overlap(self, time_range): TimeRange The times contained by both this and time_range """ + if time_range is None: + return None + if not isinstance(time_range, TimeRange): + raise TypeError("Supplied parameter must be a TimeRange object") start_timestamp = max(self.start_timestamp, time_range.start_timestamp) end_timestamp = min(self.end_timestamp, time_range.end_timestamp) if end_timestamp > start_timestamp: @@ -131,13 +157,17 @@ def __init__(self, *args, **kwargs): if not isinstance(component, TimeRange): raise TypeError("Time_ranges must contain only TimeRange objects") self.remove_overlap() + self.fuse_components() @property def duration(self): """Duration of the time range""" - total_duration = 0 + if len(self.time_ranges) == 0: + return datetime.timedelta(0) + total_duration = datetime.timedelta(0) for component in self.time_ranges: - total_duration += component.duration + total_duration = total_duration + component.duration + return total_duration @property def key_times(self): @@ -146,7 +176,7 @@ def key_times(self): for component in self.time_ranges: key_times.add(component.start_timestamp) key_times.add(component.end_timestamp) - return list(key_times).sort() + return sorted(list(key_times)) def remove_overlap(self): """Removes overlap between components of time_ranges""" @@ -158,7 +188,16 @@ def remove_overlap(self): overlap_check = CompoundTimeRange() for time_range in self.time_ranges: overlap_check.add(time_range.minus(overlap_check.overlap(time_range))) - self.time_ranges = overlap_check + self.time_ranges = overlap_check.time_ranges + + def fuse_components(self): + """Fuses two time ranges [a,b], [b,c] into [a,c] for all such pairs in this instance""" + for (component, component2) in permutations(self.time_ranges, 2): + if component.end_timestamp == component2.start_timestamp: + fused_component = TimeRange(component.start_timestamp, component2.end_timestamp) + self.remove(component) + self.remove(component2) + self.add(fused_component) def add(self, time_range): """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`""" @@ -170,6 +209,17 @@ def add(self, time_range): else: self.time_ranges.append(time_range) self.remove_overlap() + self.fuse_components() + + def remove(self, time_range): + """Remove a :class:`.~TimeRange` object from the time ranges. + It must be a member of self.time_ranges""" + if not isinstance(time_range, TimeRange): + raise TypeError("Supplied parameter must be a TimeRange object") + if time_range in self.time_ranges: + self.time_ranges.remove(time_range) + else: + raise ValueError("Supplied parameter must be a member of time_ranges") def __contains__(self, time): """Checks if timestamp or is within range @@ -185,20 +235,35 @@ def __contains__(self, time): `True` if time is fully contained within this instance """ - if isinstance(time, datetime): + if isinstance(time, datetime.datetime): for component in self.time_ranges: if datetime in component: return True return False elif isinstance(time, TimeRange) or isinstance(time, CompoundTimeRange): + print(self.overlap(time)) + print(time) + print(self.overlap(time) == time) return True if self.overlap(time) == time else False else: raise TypeError("Supplied parameter must be an instance of either " "datetime, TimeRange, or CompoundTimeRange") + def __eq__(self, other): + if other is None: + return False + if not isinstance(other, CompoundTimeRange): + raise TypeError(f"Cannot compare between a CompoundTimeRange and {type(other)}") + if len(self.time_ranges) == 0: + return True if len(other.time_ranges) == 0 else False + for component in self.time_ranges: + if all([component != component2 for component2 in other.time_ranges]): + return False + return True + def minus(self, time_range): """Removes any overlap between this and another :class:`~.TimeRange` or - :class:`.~CompoundTimeRange` + :class:`.~CompoundTimeRange` from this instance Parameters ---------- @@ -207,8 +272,10 @@ def minus(self, time_range): Returns ------- CompoundTimeRange - The times contained by both this and time_range + The times contained by this but not time_range. May be empty. """ + if time_range is None: + return self ans = CompoundTimeRange() for component in self.time_ranges: ans.add(component.minus(time_range)) @@ -228,6 +295,8 @@ def overlap(self, time_range): CompoundTimeRange The times contained by both this and time_range """ + if time_range is None: + return None total_overlap = CompoundTimeRange() if isinstance(time_range, CompoundTimeRange): for component in time_range.time_ranges: From a644c3ee40c8ed8da4c9b8def61bf6ffa58cf451 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Fri, 24 Jun 2022 17:03:05 +0100 Subject: [PATCH 07/26] bug fixes to multidimensional_assignment --- stonesoup/dataassociator/_assignment.py | 75 +++++++++------- .../dataassociator/tests/test_assignment.py | 33 +++++++ stonesoup/types/association.py | 59 ++++++++++--- stonesoup/types/tests/test_association.py | 86 ++++++++++++++++--- stonesoup/types/tests/test_detection.py | 1 + stonesoup/types/tests/test_time.py | 10 +-- stonesoup/types/time.py | 20 ++--- 7 files changed, 215 insertions(+), 69 deletions(-) create mode 100644 stonesoup/dataassociator/tests/test_assignment.py diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index f48c9f25e..bffbcda09 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -1,6 +1,7 @@ import numpy import datetime from ..types.association import AssociationSet +import warnings def assign2D(C, maximize=False): @@ -321,7 +322,7 @@ def assign2DBasic(C): # OF RECIPIENT IN THE USE OF THE SOFTWARE. -def multidimensional_deconfliction(association_set): +def multidimensional_deconfliction(association_set, low_diff_warning=None): """Solves the Multidimensional Assignment Problem (MAP) The assignment problem becomes more complex when time is added as a dimension. @@ -338,6 +339,10 @@ def multidimensional_deconfliction(association_set): Parameters ---------- association_set: The :class:`AssociationSet` to de-conflict + low_diff_warning: If the longest association between objects minus the shortest (in seconds) is + less than this, a warning will be given. This may occur if the range of times covered includes + only a low number of seconds. + Returns ------- @@ -346,9 +351,11 @@ def multidimensional_deconfliction(association_set): """ objects = list(association_set.object_set) length = len(objects) + if length <= 1: + return association_set totals = numpy.zeros((length, length)) # Time objects i and j are associated for in total association_on = numpy.full((length, length), False) - association_start = numpy.full(length, length, datetime.datetime.min) + association_start = numpy.full((length, length), datetime.datetime.min) key_times = association_set.key_times for time in key_times: associations_to_end = [] @@ -356,37 +363,44 @@ def multidimensional_deconfliction(association_set): for j in range(length): if association_on[i][j]: associations_to_end.append({i, j}) - time_ass_set = association_set.associations_at_timestamp + time_ass_set = association_set.associations_at_timestamp(time) for association in time_ass_set: - obj_indices = [objects.index(list(association.objects)[0]), - objects.index(list(association.objects)[1])] if len(association.objects) != 2: raise ValueError("Supplied set must only contain pairs of associated objects") + obj_indices = [objects.index(list(association.objects)[0]), + objects.index(list(association.objects)[1])] if not association_on[obj_indices[0], obj_indices[1]]: association_on[obj_indices[0], obj_indices[1]] = True association_start[obj_indices[0], obj_indices[1]] = time - else: + elif time != max(association.time_range.key_times): associations_to_end.remove({obj_indices[0], obj_indices[1]}) - for indices in associations_to_end: - association_on[indices[0], indices[1]] = False - totals[indices[0], indices[1]] = totals[indices[0, indices[1]]] + \ - time - association_start[indices[0], indices[1]] - association_start = make_symmetric(association_start) - totals = make_symmetric(totals) - association_on = make_symmetric(association_on) + for inds in associations_to_end: + print("loop") + association_on[inds[0], inds[1]] = False + totals[inds[0], inds[1]] += (time - + association_start[inds[0], inds[1]]).total_seconds() + association_start = _make_symmetric(association_start) + totals = _make_symmetric(totals) + association_on = _make_symmetric(association_on) + + totals = numpy.rint(totals).astype(int) + if low_diff_warning and numpy.max(totals) - numpy.min(totals) <= low_diff_warning: + warnings.warn(f"Difference between longest association and shortest is low after rounding" + f"({numpy.max(totals) - numpy.min(totals)} seconds)") + numpy.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself solved_2d = assign2D(totals, maximize=True)[1] - winning_indices = [] # Pairs that are chosen by assign2D + winning_indices = [] # Pairs that are chosen by assign2D for i in range(length): - winning_indices.append({i, solved_2d[i]}) - + winning_indices.append([i, solved_2d[i]]) cleaned_set = AssociationSet() for winner in winning_indices: - assoc = association_set.associations_including_objects({objects[list(winner)[0]], - objects[list(winner)[1]]}) + print(winner) + assoc = association_set.associations_including_objects({objects[winner[0]], + objects[winner[1]]}) cleaned_set.add(assoc) association_set.remove(assoc) - + # Recursive step for assoc1 in association_set: for assoc2 in association_set: if conflicts(assoc1, assoc2): @@ -403,15 +417,8 @@ def multidimensional_deconfliction(association_set): return cleaned_set - - - - - - - def conflicts(assoc1, assoc2): - if not hasattr(assoc1, time_range) or not hasattr(assoc2, time_range): + if not hasattr(assoc1, 'time_range') or not hasattr(assoc2, 'time_range'): raise TypeError("Associations must have a time_range property") if assoc1.time_range.overlap(assoc2.time_range) and assoc1 != assoc2 \ and len(assoc1.objects.intersection(assoc2.objects)) > 0: @@ -420,8 +427,18 @@ def conflicts(assoc1, assoc2): return False -def make_symmetric(matrix): - return numpy.tril(matrix) + numpy.triu(matrix.T, 1) +def _make_symmetric(matrix): + if isinstance(matrix[0, 0], datetime.datetime): + ans = matrix + for i in range(matrix.shape[0]): + for j in range(matrix.shape[0]): + if matrix[i, j] >= matrix[j, i]: + ans[j, i] = matrix[i, j] + else: + ans[i, j] = matrix[j, i] + return ans + else: + return numpy.tril(matrix) + numpy.triu(matrix.T, k=1) diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py new file mode 100644 index 000000000..dd6ecc93d --- /dev/null +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -0,0 +1,33 @@ +from ...types.association import AssociationSet, TimeRangeAssociation +from ...types.time import TimeRange, CompoundTimeRange +from ...types.track import Track +from .._assignment import multidimensional_deconfliction +import datetime +import pytest + + +def test_multi_deconfliction(): + test = AssociationSet() + tested = multidimensional_deconfliction(test) + assert test == tested + tracks = [Track(), Track(), Track(), Track()] + times = [datetime.datetime(year=2022, month=6, day=1, hour=0), + datetime.datetime(year=2022, month=6, day=1, hour=1), + datetime.datetime(year=2022, month=6, day=1, hour=5), + datetime.datetime(year=2022, month=6, day=2, hour=5), + datetime.datetime(year=2022, month=6, day=1, hour=9)] + ranges = [TimeRange(times[0], times[1]), + TimeRange(times[1], times[2]), + TimeRange(times[2], times[3]), + TimeRange(times[0], times[4]), + TimeRange(times[2], times[4])] + + assoc1 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=ranges[0]) + assoc2 = TimeRangeAssociation({tracks[2], tracks[3]}, time_range=ranges[0]) + assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) + with pytest.raises(ValueError): + multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2})) + + # Objects do not conflict, so should do nothing + test2 = AssociationSet({assoc1, assoc2}) + assert multidimensional_deconfliction(test2).associations == test2.associations diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 3148d88d7..a5c613315 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -1,9 +1,10 @@ import datetime -from typing import Set +from typing import Set, Union +from itertools import combinations from ..base import Property from .base import Type -from .time import CompoundTimeRange +from .time import TimeRange, CompoundTimeRange class Association(Type): @@ -48,7 +49,7 @@ class TimeRangeAssociation(Association): range of times """ - time_range: CompoundTimeRange = Property( + time_range: Union[CompoundTimeRange, TimeRange] = Property( default=None, doc="Range of times that association exists over. Default is None") @@ -66,6 +67,14 @@ def __init__(self, associations=None, *args, **kwargs): super().__init__(associations, *args, **kwargs) if self.associations is None: self.associations = set() + if not isinstance(self.associations, Set): + raise TypeError("Supplied parameter must be a set") + if not all([isinstance(member, Association) for member in self.associations]): + raise TypeError("Association set must contain only Association instances") + self._simplify() + + def __eq__(self, other): + return True if self.associations == other.associations else False def add(self, association): if association is None: @@ -77,11 +86,35 @@ def add(self, association): self.add(component) else: raise TypeError("Supplied parameter must be an Association or AssociationSet") + self._simplify() + + def _simplify(self): + """Where multiple associations describe the same pair of objects, combine them into one. + Note this is only implemented for pairs with a time_range attribute- others will be skipped + """ + to_remove = [] + for (assoc1, assoc2) in combinations(self.associations, 2): + if not (len(assoc1.objects) == 2 and len(assoc2. objects) == 2): + continue + if not(hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range')): + continue + if assoc1.objects == assoc2.objects: + if isinstance(assoc1.time_range, CompoundTimeRange): + assoc1.time_range.add(assoc2.time_range) + elif isinstance(assoc2.time_range, CompoundTimeRange): + assoc1.time_range = assoc2.time_range.add(assoc1.time_range) + else: + assoc1.time_range = CompoundTimeRange([assoc1.time_range, assoc2.time_range]) + to_remove.append(assoc2) + for assoc in to_remove: + self.remove(assoc) def remove(self, association): if association is None: return elif isinstance(association, Association): + if association not in self.associations: + raise ValueError("Supplied parameter must be contained by this instance") self.associations.remove(association) elif isinstance(association, AssociationSet): for component in association: @@ -91,11 +124,13 @@ def remove(self, association): @property def key_times(self): - key_times = set(self.overall_time_range()) + """Returns all timestamps at which a component starts or ends, or where there is a + :class:`.~SingleTimeAssociation`""" + key_times = list(self.overall_time_range.key_times) for association in self.associations: if isinstance(association, SingleTimeAssociation): key_times.add(association.timestamp) - return list(key_times).order() + return sorted(list(key_times)) @property def overall_time_range(self): @@ -106,7 +141,7 @@ def overall_time_range(self): """ overall_range = CompoundTimeRange() for association in self.associations: - if not isinstance(association, SingleTimeAssociation): + if hasattr(association, 'time_range'): overall_range.add(association.time_range) return overall_range @@ -115,9 +150,9 @@ def object_set(self): """Return all objects in the set Returned as a set """ - object_set = {} - for objects in self.associations.objects: - for obj in objects: + object_set = set() + for assoc in self.associations: + for obj in assoc.objects: object_set.add(obj) return object_set @@ -137,6 +172,8 @@ def associations_at_timestamp(self, timestamp): : :class:`~.AssociationSet` Associations which occur at specified timestamp """ + if not isinstance(timestamp, datetime.datetime): + raise TypeError("Supplied parameter must be a datetime.datetime object") ret_associations = set() for association in self.associations: # If the association is at a single time @@ -163,13 +200,9 @@ def associations_including_objects(self, objects): : class:`~.AssociationSet` A set of objects which have been associated """ - # Ensure objects is iterable if not isinstance(objects, list) and not isinstance(objects, set): objects = {objects} - print(type(objects)) - print(objects) - print(type(association for association in self.associations)) return AssociationSet({association for association in self.associations diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index 7fd9ed564..c6e4948a0 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -6,7 +6,7 @@ from ..association import Association, AssociationPair, AssociationSet, \ SingleTimeAssociation, TimeRangeAssociation from ..detection import Detection -from ..time import TimeRange +from ..time import TimeRange, CompoundTimeRange def test_association(): @@ -40,7 +40,8 @@ def test_associationpair(): # 2 objects assoc = AssociationPair(set(objects[:2])) - np.array_equal(assoc.objects, set(objects[:2])) + assert np.array_equal(assoc.objects, set(objects[:2])) + def test_singletimeassociation(): @@ -68,18 +69,24 @@ def test_timerangeassociation(): timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) timerange = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + ctimerange = CompoundTimeRange([timerange]) assoc = TimeRangeAssociation(objects=objects, time_range=timerange) + cassoc = TimeRangeAssociation(objects=objects, time_range=ctimerange) assert assoc.objects == objects assert assoc.time_range == timerange + assert cassoc.objects == objects + assert cassoc.time_range == ctimerange def test_associationset(): # Set up some dummy variables timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) - timerange = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + timestamp3 = datetime.datetime(2020, 3, 1, 1, 1, 1) + time_range = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + time_range2 = TimeRange(timestamp2, timestamp3) objects_list = [Detection(np.array([[1], [2]])), Detection(np.array([[3], [4]])), @@ -89,7 +96,9 @@ def test_associationset(): timestamp=timestamp1) assoc2 = TimeRangeAssociation(objects=set(objects_list[1:]), - time_range=timerange) + time_range=time_range) + assoc_duplicate = TimeRangeAssociation(objects=set(objects_list[1:]), + time_range=time_range2) assoc_set = AssociationSet({assoc1, assoc2}) @@ -105,26 +114,79 @@ def test_associationset(): assert len(assoc_set) == 2 + # test _simplify method + + simplify_test = AssociationSet({assoc1, assoc2, assoc_duplicate}) + + assert len(simplify_test.associations) == 2 + # Test associations including objects # Object only present in object 1 assert assoc_set.associations_including_objects(objects_list[0]) \ - == {assoc1} + == AssociationSet({assoc1}) # Object present in both assert assoc_set.associations_including_objects(objects_list[1]) \ - == {assoc1, assoc2} + == AssociationSet({assoc1, assoc2}) # Object present in neither - assert not assoc_set.associations_including_objects( - Detection(np.array([[0], [0]]))) + assert assoc_set.associations_including_objects(Detection(np.array([[0], [0]]))) \ + == AssociationSet() # Test associations including timestamp # Timestamp present in one object assert assoc_set.associations_at_timestamp(timestamp2) \ - == {assoc2} + == AssociationSet({assoc2}) # Timestamp present in both assert assoc_set.associations_at_timestamp(timestamp1) \ - == {assoc1, assoc2} + == AssociationSet({assoc1, assoc2}) # Timestamp not present in either - timestamp3 = datetime.datetime(2018, 3, 1, 6, 8, 35) - assert not assoc_set.associations_at_timestamp(timestamp3) + timestamp4 = datetime.datetime(2022, 3, 1, 6, 8, 35) + assert assoc_set.associations_at_timestamp(timestamp4) \ + == AssociationSet() + + +def test_association_set_add_remove(): + test = AssociationSet() + with pytest.raises(TypeError): + test.add("banana") + with pytest.raises(TypeError): + test.remove("banana") + objects = {Detection(np.array([[1], [2]])), + Detection(np.array([[3], [4]])), + Detection(np.array([[5], [6]]))} + + assoc = Association(objects) + with pytest.raises(ValueError): + test.remove(assoc) + assert assoc not in test.associations + test.add(assoc) + assert assoc in test.associations + test.remove(assoc) + assert assoc not in test.associations + + +def test_association_set_properties(): + test = AssociationSet() + assert test.key_times == [] + assert test.overall_time_range == CompoundTimeRange() + assert test.object_set == set() + objects = [Detection(np.array([[1], [2]])), + Detection(np.array([[3], [4]])), + Detection(np.array([[5], [6]]))] + assoc = Association(set(objects)) + test2 = AssociationSet({assoc}) + assert test2.key_times == [] + assert test2.overall_time_range == CompoundTimeRange() + assert test2.object_set == set(objects) + timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) + timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) + time_range = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + com_time_range = CompoundTimeRange([time_range]) + assoc2 = TimeRangeAssociation(objects=set(objects[1:]), + time_range=com_time_range) + test3 = AssociationSet({assoc, assoc2}) + assert test3.key_times == [timestamp1, timestamp2] + assert test3.overall_time_range == com_time_range + assert test3.object_set == set(objects) + diff --git a/stonesoup/types/tests/test_detection.py b/stonesoup/types/tests/test_detection.py index 5d0078551..cd696bde6 100644 --- a/stonesoup/types/tests/test_detection.py +++ b/stonesoup/types/tests/test_detection.py @@ -33,3 +33,4 @@ def test_composite_detection(): mapping=[1, 0, 2, 3]) # Last detection should overwrite metadata of earlier assert detection.metadata == {'colour': 'red', 'speed': 'fast', 'size': 'big'} + diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 7b93ec1d3..a84857617 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -211,12 +211,12 @@ def test_key_times(times): def test_remove_overlap(times): test1_ro = CompoundTimeRange([TimeRange(times[0], times[1]), TimeRange(times[3], times[4])]) - test1_ro.remove_overlap() + test1_ro._remove_overlap() test2_ro = CompoundTimeRange([TimeRange(times[3], times[4]), TimeRange(times[0], times[4])]) - test2_ro.remove_overlap() + test2_ro._remove_overlap() test3_ro = CompoundTimeRange() - test3_ro.remove_overlap() + test3_ro._remove_overlap() test1 = CompoundTimeRange([TimeRange(times[0], times[1]), TimeRange(times[3], times[4])]) @@ -230,9 +230,9 @@ def test_remove_overlap(times): def test_fuse_components(times): # Note this is called inside the __init__ method, but is tested here explicitly - test1 = CompoundTimeRange([TimeRange(times[1], times[2])]).fuse_components() + test1 = CompoundTimeRange([TimeRange(times[1], times[2])])._fuse_components() test2 = CompoundTimeRange([TimeRange(times[1], times[2]), - TimeRange(times[2], times[4])]).fuse_components() + TimeRange(times[2], times[4])])._fuse_components() assert test1.time_ranges == {TimeRange(times[1], times[2])} assert test2.time_ranges == {TimeRange(times[1], times[4])} diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 89db658ea..4afb74e6e 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -57,13 +57,13 @@ def __contains__(self, time): return self.start_timestamp <= time.start_timestamp and \ self.end_timestamp >= time.end_timestamp else: - raise TypeError("Supplied parameter must be a datetime.datetime or TimeRange object") + raise TypeError("Supplied parameter must be a datetime or TimeRange object") def __eq__(self, other): if other is None: return False if not isinstance(other, TimeRange): - raise TypeError(f"Cannot compare between a CompoundTimeRange and {type(other)}") + return False if self.start_timestamp == other.start_timestamp and \ self.end_timestamp == other.end_timestamp: return True @@ -156,8 +156,8 @@ def __init__(self, *args, **kwargs): for component in self.time_ranges: if not isinstance(component, TimeRange): raise TypeError("Time_ranges must contain only TimeRange objects") - self.remove_overlap() - self.fuse_components() + self._remove_overlap() + self._fuse_components() @property def duration(self): @@ -178,7 +178,7 @@ def key_times(self): key_times.add(component.end_timestamp) return sorted(list(key_times)) - def remove_overlap(self): + def _remove_overlap(self): """Removes overlap between components of time_ranges""" if len(self.time_ranges) == 0: return @@ -190,7 +190,7 @@ def remove_overlap(self): overlap_check.add(time_range.minus(overlap_check.overlap(time_range))) self.time_ranges = overlap_check.time_ranges - def fuse_components(self): + def _fuse_components(self): """Fuses two time ranges [a,b], [b,c] into [a,c] for all such pairs in this instance""" for (component, component2) in permutations(self.time_ranges, 2): if component.end_timestamp == component2.start_timestamp: @@ -208,8 +208,8 @@ def add(self, time_range): self.add(component) else: self.time_ranges.append(time_range) - self.remove_overlap() - self.fuse_components() + self._remove_overlap() + self._fuse_components() def remove(self, time_range): """Remove a :class:`.~TimeRange` object from the time ranges. @@ -237,7 +237,7 @@ def __contains__(self, time): if isinstance(time, datetime.datetime): for component in self.time_ranges: - if datetime in component: + if time in component: return True return False elif isinstance(time, TimeRange) or isinstance(time, CompoundTimeRange): @@ -253,7 +253,7 @@ def __eq__(self, other): if other is None: return False if not isinstance(other, CompoundTimeRange): - raise TypeError(f"Cannot compare between a CompoundTimeRange and {type(other)}") + return False if len(self.time_ranges) == 0: return True if len(other.time_ranges) == 0 else False for component in self.time_ranges: From 4a1eaba64baa1cc8f9ad8e62b75676c7cb8c005a Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Mon, 27 Jun 2022 00:47:07 +0100 Subject: [PATCH 08/26] everything hopefully works! --- stonesoup/dataassociator/_assignment.py | 122 +++++++++--------- .../dataassociator/tests/test_assignment.py | 23 +++- stonesoup/types/association.py | 11 +- stonesoup/types/tests/test_time.py | 21 +-- stonesoup/types/time.py | 32 +++-- 5 files changed, 118 insertions(+), 91 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index bffbcda09..abdc98990 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -322,15 +322,16 @@ def assign2DBasic(C): # OF RECIPIENT IN THE USE OF THE SOFTWARE. -def multidimensional_deconfliction(association_set, low_diff_warning=None): +def multidimensional_deconfliction(association_set): """Solves the Multidimensional Assignment Problem (MAP) The assignment problem becomes more complex when time is added as a dimension. This basic solution finds all the conflicts in an association set and then creates a matrix of sums of conflicts in seconds, which is then passed to assign2D to solve as a simple 2D assignment problem. Therefore, each object will only ever be assigned to one other - at any one time. In the case of an association that only partially overlaps, the "weaker" one - (the one eliminated by assign2D) will be trimmed until there is no conflict. + at any one time. In the case of an association that only partially overlaps, the time range + of the "weaker" one (the one eliminated by assign2D) will be trimmed + until there is no conflict. Due to the possibility of more than two conflicting associations at the same time, this algorithm is recursive, but it is not expected many (if any) recursions will be required @@ -339,9 +340,6 @@ def multidimensional_deconfliction(association_set, low_diff_warning=None): Parameters ---------- association_set: The :class:`AssociationSet` to de-conflict - low_diff_warning: If the longest association between objects minus the shortest (in seconds) is - less than this, a warning will be given. This may occur if the range of times covered includes - only a low number of seconds. Returns @@ -349,70 +347,66 @@ def multidimensional_deconfliction(association_set, low_diff_warning=None): : :class:`AssociationSet` The association set without contradictory associations """ + # Check if there are any conflicts + nonecheck(2,association_set) + no_conflicts = True + for assoc1 in association_set: + for assoc2 in association_set: + if conflicts(assoc1, assoc2): + no_conflicts = False + if no_conflicts: + return association_set + nonecheck(1,association_set) objects = list(association_set.object_set) length = len(objects) - if length <= 1: - return association_set totals = numpy.zeros((length, length)) # Time objects i and j are associated for in total - association_on = numpy.full((length, length), False) - association_start = numpy.full((length, length), datetime.datetime.min) - key_times = association_set.key_times - for time in key_times: - associations_to_end = [] - for i in range(length): - for j in range(length): - if association_on[i][j]: - associations_to_end.append({i, j}) - time_ass_set = association_set.associations_at_timestamp(time) - for association in time_ass_set: - if len(association.objects) != 2: - raise ValueError("Supplied set must only contain pairs of associated objects") - obj_indices = [objects.index(list(association.objects)[0]), - objects.index(list(association.objects)[1])] - if not association_on[obj_indices[0], obj_indices[1]]: - association_on[obj_indices[0], obj_indices[1]] = True - association_start[obj_indices[0], obj_indices[1]] = time - elif time != max(association.time_range.key_times): - associations_to_end.remove({obj_indices[0], obj_indices[1]}) - for inds in associations_to_end: - print("loop") - association_on[inds[0], inds[1]] = False - totals[inds[0], inds[1]] += (time - - association_start[inds[0], inds[1]]).total_seconds() - association_start = _make_symmetric(association_start) - totals = _make_symmetric(totals) - association_on = _make_symmetric(association_on) + + for association in association_set.associations: + if len(association.objects) != 2: + raise ValueError("Supplied set must only contain pairs of associated objects") + obj_indices = [objects.index(list(association.objects)[0]), + objects.index(list(association.objects)[1])] + totals[obj_indices[0], obj_indices[1]] = association.time_range.duration.total_seconds() + make_symmetric(totals) + nonecheck(2, association_set) totals = numpy.rint(totals).astype(int) - if low_diff_warning and numpy.max(totals) - numpy.min(totals) <= low_diff_warning: - warnings.warn(f"Difference between longest association and shortest is low after rounding" - f"({numpy.max(totals) - numpy.min(totals)} seconds)") numpy.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself solved_2d = assign2D(totals, maximize=True)[1] winning_indices = [] # Pairs that are chosen by assign2D for i in range(length): - winning_indices.append([i, solved_2d[i]]) + if i != solved_2d[i]: + winning_indices.append([i, solved_2d[i]]) cleaned_set = AssociationSet() + nonecheck(3, association_set) + if len(winning_indices) == 0: + raise ValueError("Problem unsolvable using this method") for winner in winning_indices: - print(winner) assoc = association_set.associations_including_objects({objects[winner[0]], objects[winner[1]]}) cleaned_set.add(assoc) association_set.remove(assoc) + nonecheck(4, association_set) + # Recursive step - for assoc1 in association_set: - for assoc2 in association_set: + runners_up = set() + for assoc1 in association_set.associations: + for assoc2 in association_set.associations: if conflicts(assoc1, assoc2): - association_set = multidimensional_deconfliction(association_set) + runners_up = multidimensional_deconfliction(association_set).associations + nonecheck(5, association_set) # At this point, none of association_set should conflict with one another - for association in association_set: + for runner_up in runners_up: for winner in cleaned_set: - if conflicts(association, winner): - association.time_range.minus(winner.time_range) - if association.time_range is not None: - cleaned_set.add(association) + if conflicts(runner_up, winner): + runner_up.time_range.minus(winner.time_range) + if runner_up.time_range is not None: + cleaned_set.add(runner_up) + else: + runners_up.remove(runner_up) + nonecheck(6, association_set) return cleaned_set @@ -420,6 +414,8 @@ def multidimensional_deconfliction(association_set, low_diff_warning=None): def conflicts(assoc1, assoc2): if not hasattr(assoc1, 'time_range') or not hasattr(assoc2, 'time_range'): raise TypeError("Associations must have a time_range property") + if assoc1.time_range is None or assoc2.time_range is None: + return False if assoc1.time_range.overlap(assoc2.time_range) and assoc1 != assoc2 \ and len(assoc1.objects.intersection(assoc2.objects)) > 0: return True @@ -427,19 +423,17 @@ def conflicts(assoc1, assoc2): return False -def _make_symmetric(matrix): - if isinstance(matrix[0, 0], datetime.datetime): - ans = matrix - for i in range(matrix.shape[0]): - for j in range(matrix.shape[0]): - if matrix[i, j] >= matrix[j, i]: - ans[j, i] = matrix[i, j] - else: - ans[i, j] = matrix[j, i] - return ans - else: - return numpy.tril(matrix) + numpy.triu(matrix.T, k=1) - - - +def make_symmetric(matrix): + """Matrix must be square""" + length = matrix.shape[0] + for i in range(length): + for j in range(length): + if matrix[i, j] >= matrix[j, i]: + matrix[j, i] = matrix[i, j] + else: + matrix[i, j] = matrix[j, i] +def nonecheck(flag, association_set): + for assoc in association_set.associations: + if assoc.time_range is None: + print(f"NONETYPE AT {flag}") \ No newline at end of file diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index dd6ecc93d..60a3f7f92 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -10,7 +10,7 @@ def test_multi_deconfliction(): test = AssociationSet() tested = multidimensional_deconfliction(test) assert test == tested - tracks = [Track(), Track(), Track(), Track()] + tracks = [Track(id=0), Track(id=1), Track(id=2), Track(id=3)] times = [datetime.datetime(year=2022, month=6, day=1, hour=0), datetime.datetime(year=2022, month=6, day=1, hour=1), datetime.datetime(year=2022, month=6, day=1, hour=5), @@ -30,4 +30,23 @@ def test_multi_deconfliction(): # Objects do not conflict, so should do nothing test2 = AssociationSet({assoc1, assoc2}) - assert multidimensional_deconfliction(test2).associations == test2.associations + assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2} + + # Objects do conflict, so remove the shorter one + assoc3 = TimeRangeAssociation({tracks[0], tracks[3]}, + time_range=CompoundTimeRange([ranges[0], ranges[4]])) + test3 = AssociationSet({assoc1, assoc3}) + # Should entirely remove assoc1 + assert multidimensional_deconfliction(test3).associations == {assoc3} + + assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, + time_range=CompoundTimeRange([ranges[1], ranges[4]])) + test4 = AssociationSet({assoc1, assoc2, assoc3, assoc4}) + # assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain + tested4 = multidimensional_deconfliction(test4) + assert len(tested4) == 2 + assert assoc2 in tested4 + merged = tested4.associations_including_objects({tracks[0], tracks[1]}) + assert len(merged) == 1 + merged = merged.associations.pop() + assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index a5c613315..f3257c6e3 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -94,7 +94,7 @@ def _simplify(self): """ to_remove = [] for (assoc1, assoc2) in combinations(self.associations, 2): - if not (len(assoc1.objects) == 2 and len(assoc2. objects) == 2): + if not (len(assoc1.objects) == 2 and len(assoc2.objects) == 2): continue if not(hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range')): continue @@ -102,7 +102,8 @@ def _simplify(self): if isinstance(assoc1.time_range, CompoundTimeRange): assoc1.time_range.add(assoc2.time_range) elif isinstance(assoc2.time_range, CompoundTimeRange): - assoc1.time_range = assoc2.time_range.add(assoc1.time_range) + assoc2.time_range.add(assoc1.time_range) + assoc1.time_range = assoc2.time_range else: assoc1.time_range = CompoundTimeRange([assoc1.time_range, assoc2.time_range]) to_remove.append(assoc2) @@ -198,7 +199,7 @@ def associations_including_objects(self, objects): Returns ------- : class:`~.AssociationSet` - A set of objects which have been associated + A set of associations containing every member of objects """ # Ensure objects is iterable if not isinstance(objects, list) and not isinstance(objects, set): @@ -206,8 +207,8 @@ def associations_including_objects(self, objects): return AssociationSet({association for association in self.associations - for object_ in objects - if object_ in association.objects}) + if all([object_ in association.objects + for object_ in objects])}) def __contains__(self, item): return item in self.associations diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index a84857617..96bdf0fa4 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -111,6 +111,8 @@ def test_contains(times): assert test_range2 not in compound_test assert test_range2 not in compound_test2 assert test_range3 not in compound_test + print(" \n range3", test_range3, "test2", compound_test2) + print(compound_test2.overlap(test_range3)) assert test_range3 in compound_test2 @@ -119,8 +121,7 @@ def test_equality(times): test2 = TimeRange(times[1], times[2]) test3 = TimeRange(times[1], times[3]) - with pytest.raises(TypeError): - test1 == "stonesoup" + assert test1 != "stonesoup" assert test1 == test2 assert test2 == test1 @@ -129,8 +130,7 @@ def test_equality(times): ctest1 = CompoundTimeRange([test1, test3]) ctest2 = CompoundTimeRange([TimeRange(times[1], times[3])]) - with pytest.raises(TypeError): - ctest2 == "Stonesoup is the best!" + assert ctest2 != "Stonesoup is the best!" assert ctest1 == ctest2 assert ctest2 == ctest1 @@ -181,8 +181,7 @@ def test_overlap(times): ctest1 = CompoundTimeRange([test2, test3]) ctest2 = CompoundTimeRange([test1, test2]) - with pytest.raises(TypeError): - test2.overlap(ctest1) + assert test2.overlap(ctest1) == ctest1.overlap(test2) assert test1.overlap(test1) == test1 assert test1.overlap(None) is None @@ -230,11 +229,13 @@ def test_remove_overlap(times): def test_fuse_components(times): # Note this is called inside the __init__ method, but is tested here explicitly - test1 = CompoundTimeRange([TimeRange(times[1], times[2])])._fuse_components() + test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) + test1._fuse_components() test2 = CompoundTimeRange([TimeRange(times[1], times[2]), - TimeRange(times[2], times[4])])._fuse_components() - assert test1.time_ranges == {TimeRange(times[1], times[2])} - assert test2.time_ranges == {TimeRange(times[1], times[4])} + TimeRange(times[2], times[4])]) + test2._fuse_components() + assert test1 == CompoundTimeRange([TimeRange(times[1], times[2])]) + assert test2 == CompoundTimeRange([TimeRange(times[1], times[4])]) def test_add(times): test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 4afb74e6e..064db8522 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -110,17 +110,18 @@ def minus(self, time_range): end = overlap.start_timestamp if self.start_timestamp < overlap.start_timestamp and \ self.end_timestamp > overlap.end_timestamp: - return CompoundTimeRange(TimeRange(self.start_timestamp, overlap.start_timestamp), - TimeRange(self.end_timestamp, overlap.end_timestamp)) + return CompoundTimeRange([TimeRange(self.start_timestamp, overlap.start_timestamp), + TimeRange(overlap.end_timestamp, self.end_timestamp)]) else: return TimeRange(start, end) def overlap(self, time_range): - """Finds the intersection between this instance and another :class:`~.TimeRange` + """Finds the intersection between this instance and another :class:`~.TimeRange` or + :class:`.~CompoundTimeRange` Parameters ---------- - time_range: TimeRange + time_range: Union[TimeRange, CompoundTimeRange] Returns ------- @@ -129,6 +130,8 @@ def overlap(self, time_range): """ if time_range is None: return None + if isinstance(time_range, CompoundTimeRange): + return time_range.overlap(self) if not isinstance(time_range, TimeRange): raise TypeError("Supplied parameter must be a TimeRange object") start_timestamp = max(self.start_timestamp, time_range.start_timestamp) @@ -180,7 +183,7 @@ def key_times(self): def _remove_overlap(self): """Removes overlap between components of time_ranges""" - if len(self.time_ranges) == 0: + if len(self.time_ranges) in {0, 1}: return if all([component.overlap(component2) is None for (component, component2) in combinations(self.time_ranges, 2)]): @@ -198,6 +201,8 @@ def _fuse_components(self): self.remove(component) self.remove(component2) self.add(fused_component) + # To avoid issues with having removed objects from the permutations + self._fuse_components() def add(self, time_range): """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`""" @@ -206,8 +211,10 @@ def add(self, time_range): if isinstance(time_range, CompoundTimeRange): for component in time_range.time_ranges: self.add(component) - else: + elif isinstance(time_range, TimeRange): self.time_ranges.append(time_range) + else: + raise TypeError("Supplied parameter must be a TimeRange or CompoundTimeRange object") self._remove_overlap() self._fuse_components() @@ -218,6 +225,12 @@ def remove(self, time_range): raise TypeError("Supplied parameter must be a TimeRange object") if time_range in self.time_ranges: self.time_ranges.remove(time_range) + elif time_range in self: + for component in self.time_ranges: + if time_range in component: + new = component.minus(time_range) + self.time_ranges.remove(component) + self.add(new) else: raise ValueError("Supplied parameter must be a member of time_ranges") @@ -240,10 +253,9 @@ def __contains__(self, time): if time in component: return True return False - elif isinstance(time, TimeRange) or isinstance(time, CompoundTimeRange): - print(self.overlap(time)) - print(time) - print(self.overlap(time) == time) + elif isinstance(time, TimeRange): + return True if self.overlap(time) == CompoundTimeRange([time]) else False + elif isinstance(time, CompoundTimeRange): return True if self.overlap(time) == time else False else: raise TypeError("Supplied parameter must be an instance of either " From addea40a26d590d410630d5de1db07b357c844f9 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Mon, 27 Jun 2022 12:23:26 +0100 Subject: [PATCH 09/26] flake8 fixes mostly --- stonesoup/dataassociator/_assignment.py | 14 -------------- .../dataassociator/tests/test_tracktotrack.py | 19 +++++++++++++++++++ stonesoup/dataassociator/tracktotrack.py | 4 ++-- .../tests/test_tracktotruthmetrics.py | 1 + .../metricgenerator/tracktotruthmetrics.py | 5 ++--- stonesoup/types/tests/test_association.py | 2 -- stonesoup/types/tests/test_detection.py | 1 - stonesoup/types/tests/test_time.py | 15 ++++++++++++--- stonesoup/types/time.py | 7 +++++-- 9 files changed, 41 insertions(+), 27 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index abdc98990..fe6fb39bf 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -1,7 +1,5 @@ import numpy -import datetime from ..types.association import AssociationSet -import warnings def assign2D(C, maximize=False): @@ -348,7 +346,6 @@ def multidimensional_deconfliction(association_set): The association set without contradictory associations """ # Check if there are any conflicts - nonecheck(2,association_set) no_conflicts = True for assoc1 in association_set: for assoc2 in association_set: @@ -356,7 +353,6 @@ def multidimensional_deconfliction(association_set): no_conflicts = False if no_conflicts: return association_set - nonecheck(1,association_set) objects = list(association_set.object_set) length = len(objects) totals = numpy.zeros((length, length)) # Time objects i and j are associated for in total @@ -368,7 +364,6 @@ def multidimensional_deconfliction(association_set): objects.index(list(association.objects)[1])] totals[obj_indices[0], obj_indices[1]] = association.time_range.duration.total_seconds() make_symmetric(totals) - nonecheck(2, association_set) totals = numpy.rint(totals).astype(int) numpy.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself @@ -379,7 +374,6 @@ def multidimensional_deconfliction(association_set): if i != solved_2d[i]: winning_indices.append([i, solved_2d[i]]) cleaned_set = AssociationSet() - nonecheck(3, association_set) if len(winning_indices) == 0: raise ValueError("Problem unsolvable using this method") for winner in winning_indices: @@ -387,7 +381,6 @@ def multidimensional_deconfliction(association_set): objects[winner[1]]}) cleaned_set.add(assoc) association_set.remove(assoc) - nonecheck(4, association_set) # Recursive step runners_up = set() @@ -395,7 +388,6 @@ def multidimensional_deconfliction(association_set): for assoc2 in association_set.associations: if conflicts(assoc1, assoc2): runners_up = multidimensional_deconfliction(association_set).associations - nonecheck(5, association_set) # At this point, none of association_set should conflict with one another for runner_up in runners_up: @@ -406,7 +398,6 @@ def multidimensional_deconfliction(association_set): cleaned_set.add(runner_up) else: runners_up.remove(runner_up) - nonecheck(6, association_set) return cleaned_set @@ -432,8 +423,3 @@ def make_symmetric(matrix): matrix[j, i] = matrix[i, j] else: matrix[i, j] = matrix[j, i] - -def nonecheck(flag, association_set): - for assoc in association_set.associations: - if assoc.time_range is None: - print(f"NONETYPE AT {flag}") \ No newline at end of file diff --git a/stonesoup/dataassociator/tests/test_tracktotrack.py b/stonesoup/dataassociator/tests/test_tracktotrack.py index 9f885dc57..18cad54c1 100644 --- a/stonesoup/dataassociator/tests/test_tracktotrack.py +++ b/stonesoup/dataassociator/tests/test_tracktotrack.py @@ -94,6 +94,16 @@ def test_euclidiantracktotrack(tracks): association_set_4 = complete_associator.associate_tracks({tracks[0]}, {tracks[5]}) + complete_associator_one2one = TrackToTrackCounting( + association_threshold=10, + consec_pairs_confirm=3, + consec_misses_end=2, + use_positional_only=False) + start_time = datetime.datetime(2019, 1, 1, 14, 0, 0) + + association_set_one2one = complete_associator_one2one.associate_tracks( + {tracks[0], tracks[2]}, {tracks[1], tracks[3], tracks[4]}) + assert len(association_set_1.associations) == 1 assoc1 = list(association_set_1.associations)[0] assert set(assoc1.objects) == {tracks[0], tracks[1]} @@ -121,6 +131,15 @@ def test_euclidiantracktotrack(tracks): assert assoc4.time_range.end_timestamp \ == start_time + datetime.timedelta(seconds=7) + assert len(association_set_one2one) == 1 + assoc5 = list(association_set_one2one)[0] + # assoc5 should be equal to assoc1 + assert set(assoc5.objects) == {tracks[0], tracks[1]} + assert assoc5.time_range.start_timestamp \ + == start_time + datetime.timedelta(seconds=1) + assert assoc5.time_range.end_timestamp \ + == start_time + datetime.timedelta(seconds=6) + def test_euclidiantracktotruth(tracks): associator = TrackToTruth( diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index 64454ac92..f0d5e72c7 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -79,8 +79,8 @@ class TrackToTrackCounting(TrackToTrackAssociator): ) one_to_one: bool = Property( default=False, - doc="If True, the Hungarian Algorithm is applied to the results so that no track is " - "associated with more than one other at any given time step" + doc="If True, it is ensured no two associations ever contain the same track " + "at the same time" ) def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): diff --git a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py index 9c7226eb5..1a832eae9 100644 --- a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py @@ -50,6 +50,7 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea # Test total_time_tracked assert siap_generator.total_time_tracked(trial_manager, trial_truths[0]) == 3 # seconds + print(trial_truths[1]) assert siap_generator.total_time_tracked(trial_manager, trial_truths[1]) == 2 assert siap_generator.total_time_tracked(trial_manager, trial_truths[2]) == 1 assert siap_generator.total_time_tracked(trial_manager, GroundTruthPath()) == 0 diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 14862faa0..2ae037827 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -278,7 +278,7 @@ def num_associated_tracks_at_time(manager, timestamp): Number of associated tracks held by `manager` at `timestamp`. """ associations = manager.association_set.associations_at_timestamp(timestamp) - association_objects = {thing for assoc in associations for thing in assoc.objects} + association_objects = associations.object_set return sum(1 for track in manager.tracks if track in association_objects) @@ -355,7 +355,6 @@ def total_time_tracked(manager, truth): return 0 truth_timestamps = sorted(state.timestamp for state in truth.states) - total_time = 0 for current_time, next_time in zip(truth_timestamps[:-1], truth_timestamps[1:]): for assoc in assocs: @@ -385,7 +384,7 @@ def min_num_tracks_needed_to_track(manager, truth): Minimum number of tracks needed to track `truth` """ assocs = sorted(manager.association_set.associations_including_objects([truth]), - key=attrgetter('time_range.end_timestamp'), + key=attrgetter('time_range.key_times[-1]'), reverse=True) if len(assocs) == 0: diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index c6e4948a0..337a9105d 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -43,7 +43,6 @@ def test_associationpair(): assert np.array_equal(assoc.objects, set(objects[:2])) - def test_singletimeassociation(): with pytest.raises(TypeError): SingleTimeAssociation() @@ -189,4 +188,3 @@ def test_association_set_properties(): assert test3.key_times == [timestamp1, timestamp2] assert test3.overall_time_range == com_time_range assert test3.object_set == set(objects) - diff --git a/stonesoup/types/tests/test_detection.py b/stonesoup/types/tests/test_detection.py index cd696bde6..5d0078551 100644 --- a/stonesoup/types/tests/test_detection.py +++ b/stonesoup/types/tests/test_detection.py @@ -33,4 +33,3 @@ def test_composite_detection(): mapping=[1, 0, 2, 3]) # Last detection should overwrite metadata of earlier assert detection.metadata == {'colour': 'red', 'speed': 'fast', 'size': 'big'} - diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 96bdf0fa4..996493094 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -4,6 +4,7 @@ from ..time import TimeRange, CompoundTimeRange + @pytest.fixture def times(): # Note times are returned chronologically for ease of reading @@ -200,11 +201,13 @@ def test_key_times(times): TimeRange(times[0], times[1])]) test3 = CompoundTimeRange() test4 = CompoundTimeRange([TimeRange(times[0], times[4])]) + test5 = TimeRange(times[0], times[4]) assert test1.key_times == [times[0], times[1], times[3], times[4]] assert test2.key_times == [times[0], times[1], times[3], times[4]] assert test3.key_times == [] assert test4.key_times == [times[0], times[4]] + assert test5.key_times == [times[0], times[4]] def test_remove_overlap(times): @@ -237,6 +240,7 @@ def test_fuse_components(times): assert test1 == CompoundTimeRange([TimeRange(times[1], times[2])]) assert test2 == CompoundTimeRange([TimeRange(times[1], times[4])]) + def test_add(times): test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) test2 = CompoundTimeRange([TimeRange(times[0], times[1])]) @@ -254,13 +258,18 @@ def test_add(times): test3.add(TimeRange(times[4], times[5])) assert test3 == test4 + def test_remove(times): test1 = CompoundTimeRange([TimeRange(times[0], times[2]), TimeRange(times[4], times[5])]) with pytest.raises(TypeError): test1.remove(0.4) with pytest.raises(ValueError): - test1.remove(TimeRange(times[0], times[1])) - - test1.remove(TimeRange(times[0], times[2])) + test1.remove(TimeRange(times[2], times[3])) + # Remove part of a component + test1.remove(TimeRange(times[0], times[1])) + assert test1 == CompoundTimeRange([TimeRange(times[1], times[2]), + TimeRange(times[4], times[5])]) + # Remove whole component + test1.remove(TimeRange(times[1], times[2])) assert test1 == CompoundTimeRange([TimeRange(times[4], times[5])]) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 064db8522..bc9dc414a 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,5 +1,4 @@ import datetime -from typing import Union from itertools import combinations, permutations from ..base import Property @@ -37,6 +36,11 @@ def duration(self): return self.end_timestamp - self.start_timestamp + @property + def key_times(self): + """Times the TimeRange begins and ends""" + return [self.start_timestamp, self.end_timestamp] + def __contains__(self, time): """Checks if timestamp is within range @@ -325,4 +329,3 @@ def overlap(self, time_range): else: raise TypeError("Supplied parameter must be an instance of either " "TimeRange, or CompoundTimeRange") - From 3ab27e960f76597a08f1758a9bb3d7e8a2117a37 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Mon, 27 Jun 2022 12:40:26 +0100 Subject: [PATCH 10/26] flake8 fixes mostly --- stonesoup/metricgenerator/tracktotruthmetrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 2ae037827..50e012894 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -384,7 +384,7 @@ def min_num_tracks_needed_to_track(manager, truth): Minimum number of tracks needed to track `truth` """ assocs = sorted(manager.association_set.associations_including_objects([truth]), - key=attrgetter('time_range.key_times[-1]'), + key=attrgetter('time_range.key_times')[-1], reverse=True) if len(assocs) == 0: From d6918a4c9be9cbbe8b58c54918c32c5e7e4fdd23 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Tue, 30 Aug 2022 15:12:49 +0100 Subject: [PATCH 11/26] improvements --- stonesoup/dataassociator/_assignment.py | 69 +++++++++---------- .../dataassociator/tests/test_assignment.py | 17 +++-- .../tests/test_tracktotruthmetrics.py | 1 - stonesoup/types/association.py | 9 +-- stonesoup/types/time.py | 2 +- 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index fe6fb39bf..e7a047245 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -345,14 +345,10 @@ def multidimensional_deconfliction(association_set): : :class:`AssociationSet` The association set without contradictory associations """ - # Check if there are any conflicts - no_conflicts = True - for assoc1 in association_set: - for assoc2 in association_set: - if conflicts(assoc1, assoc2): - no_conflicts = False - if no_conflicts: + # Check if there are any conflicts. If none we can simply return the input + if check_if_no_conflicts(association_set): return association_set + objects = list(association_set.object_set) length = len(objects) totals = numpy.zeros((length, length)) # Time objects i and j are associated for in total @@ -360,52 +356,47 @@ def multidimensional_deconfliction(association_set): for association in association_set.associations: if len(association.objects) != 2: raise ValueError("Supplied set must only contain pairs of associated objects") - obj_indices = [objects.index(list(association.objects)[0]), - objects.index(list(association.objects)[1])] - totals[obj_indices[0], obj_indices[1]] = association.time_range.duration.total_seconds() - make_symmetric(totals) + i, j = (objects.index(object_) for object_ in association.objects) + totals[i, j] = association.time_range.duration.total_seconds() + make_symmetric(totals) totals = numpy.rint(totals).astype(int) numpy.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself solved_2d = assign2D(totals, maximize=True)[1] - winning_indices = [] # Pairs that are chosen by assign2D - for i in range(length): - if i != solved_2d[i]: - winning_indices.append([i, solved_2d[i]]) cleaned_set = AssociationSet() - if len(winning_indices) == 0: - raise ValueError("Problem unsolvable using this method") - for winner in winning_indices: - assoc = association_set.associations_including_objects({objects[winner[0]], - objects[winner[1]]}) + for i, j in enumerate(solved_2d): + if i == j: + # Can't associate with self + continue + assoc = association_set.associations_including_objects({objects[i], objects[j]}) cleaned_set.add(assoc) association_set.remove(assoc) + if len(cleaned_set) == 0: + raise ValueError("Problem unsolvable using this method") - # Recursive step - runners_up = set() - for assoc1 in association_set.associations: - for assoc2 in association_set.associations: - if conflicts(assoc1, assoc2): - runners_up = multidimensional_deconfliction(association_set).associations + if check_if_no_conflicts(cleaned_set) and len(association_set) == 0: + # If no conflicts after this iteration and all objects return + return cleaned_set + else: + # Recursive step + runners_up = multidimensional_deconfliction(association_set).associations # At this point, none of association_set should conflict with one another for runner_up in runners_up: + runner_up_remaining_time = runner_up.time_range for winner in cleaned_set: if conflicts(runner_up, winner): - runner_up.time_range.minus(winner.time_range) - if runner_up.time_range is not None: - cleaned_set.add(runner_up) - else: - runners_up.remove(runner_up) - + runner_up_remaining_time = runner_up_remaining_time.minus(winner.time_range) + if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0: + runner_up_copy = runner_up + runner_up_copy.time_range = runner_up_remaining_time + cleaned_set.add(runner_up_copy) return cleaned_set def conflicts(assoc1, assoc2): - if not hasattr(assoc1, 'time_range') or not hasattr(assoc2, 'time_range'): - raise TypeError("Associations must have a time_range property") - if assoc1.time_range is None or assoc2.time_range is None: + if getattr(assoc1, 'time_range', None) is None or getattr(assoc2, 'time_range', None) is None: return False if assoc1.time_range.overlap(assoc2.time_range) and assoc1 != assoc2 \ and len(assoc1.objects.intersection(assoc2.objects)) > 0: @@ -414,6 +405,14 @@ def conflicts(assoc1, assoc2): return False +def check_if_no_conflicts(association_set): + for assoc1 in association_set: + for assoc2 in association_set: + if conflicts(assoc1, assoc2): + return False + return True + + def make_symmetric(matrix): """Matrix must be square""" length = matrix.shape[0] diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 60a3f7f92..45465216e 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -24,23 +24,28 @@ def test_multi_deconfliction(): assoc1 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=ranges[0]) assoc2 = TimeRangeAssociation({tracks[2], tracks[3]}, time_range=ranges[0]) + assoc3 = TimeRangeAssociation({tracks[0], tracks[3]}, + time_range=CompoundTimeRange([ranges[0], ranges[4]])) + assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, + time_range=CompoundTimeRange([ranges[1], ranges[4]])) + # Will fail as there is only one track, rather than two assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) with pytest.raises(ValueError): multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2})) - # Objects do not conflict, so should do nothing + # Objects do not conflict, so should do nothing test2 = AssociationSet({assoc1, assoc2}) assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2} # Objects do conflict, so remove the shorter one - assoc3 = TimeRangeAssociation({tracks[0], tracks[3]}, - time_range=CompoundTimeRange([ranges[0], ranges[4]])) test3 = AssociationSet({assoc1, assoc3}) # Should entirely remove assoc1 - assert multidimensional_deconfliction(test3).associations == {assoc3} + tested3 = multidimensional_deconfliction(test3) + assert len(tested3) == 1 + test_assoc3 = next(iter(tested3.associations)) + for var in vars(test_assoc3): + assert getattr(test_assoc3, var) == getattr(assoc3, var) - assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, - time_range=CompoundTimeRange([ranges[1], ranges[4]])) test4 = AssociationSet({assoc1, assoc2, assoc3, assoc4}) # assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain tested4 = multidimensional_deconfliction(test4) diff --git a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py index 1a832eae9..9c7226eb5 100644 --- a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py @@ -50,7 +50,6 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea # Test total_time_tracked assert siap_generator.total_time_tracked(trial_manager, trial_truths[0]) == 3 # seconds - print(trial_truths[1]) assert siap_generator.total_time_tracked(trial_manager, trial_truths[1]) == 2 assert siap_generator.total_time_tracked(trial_manager, trial_truths[2]) == 1 assert siap_generator.total_time_tracked(trial_manager, GroundTruthPath()) == 0 diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index f3257c6e3..3c8f54971 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -101,12 +101,13 @@ def _simplify(self): if assoc1.objects == assoc2.objects: if isinstance(assoc1.time_range, CompoundTimeRange): assoc1.time_range.add(assoc2.time_range) + to_remove.append(assoc2) elif isinstance(assoc2.time_range, CompoundTimeRange): assoc2.time_range.add(assoc1.time_range) - assoc1.time_range = assoc2.time_range + to_remove.append(assoc1) else: assoc1.time_range = CompoundTimeRange([assoc1.time_range, assoc2.time_range]) - to_remove.append(assoc2) + to_remove.append(assoc2) for assoc in to_remove: self.remove(assoc) @@ -130,8 +131,8 @@ def key_times(self): key_times = list(self.overall_time_range.key_times) for association in self.associations: if isinstance(association, SingleTimeAssociation): - key_times.add(association.timestamp) - return sorted(list(key_times)) + key_times.append(association.timestamp) + return sorted(key_times) @property def overall_time_range(self): diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index bc9dc414a..bd0f762c0 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -183,7 +183,7 @@ def key_times(self): for component in self.time_ranges: key_times.add(component.start_timestamp) key_times.add(component.end_timestamp) - return sorted(list(key_times)) + return sorted(key_times) def _remove_overlap(self): """Removes overlap between components of time_ranges""" From 9bf3646f1ce973874f919836af408af123caaab3 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 31 Aug 2022 18:56:04 +0100 Subject: [PATCH 12/26] fix knock on bugs created by TimeRange changes --- .../tests/test_tracktotruthmetrics.py | 4 ++- .../metricgenerator/tracktotruthmetrics.py | 27 ++++++++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py index 9c7226eb5..595405be4 100644 --- a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py @@ -70,7 +70,9 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea # Test longest_track_time_on_truth assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[0]) == 2 - assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 1 + # Associations 1 and 2 (starting from 0) will join together + # because of the AssociationSet._simplify method, so this will be 2 + assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 2 assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[2]) == 1 # Test compute_metric diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 50e012894..5bf4c0de1 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -1,5 +1,3 @@ -from operator import attrgetter - from .base import MetricGenerator from ..base import Property from ..measures import Measure @@ -357,12 +355,14 @@ def total_time_tracked(manager, truth): truth_timestamps = sorted(state.timestamp for state in truth.states) total_time = 0 for current_time, next_time in zip(truth_timestamps[:-1], truth_timestamps[1:]): + time_range = TimeRange(current_time, next_time) for assoc in assocs: - # If both timestamps are in one association then add the difference to the total - # difference and stop looking - if current_time in assoc.time_range and next_time in assoc.time_range: - total_time += (next_time - current_time).total_seconds() - break + # If there is some overlap between time ranges, add this to total_time + if time_range.overlap(assoc.time_range): + total_time += time_range.overlap(assoc.time_range).duration.total_seconds() + time_range = time_range.minus(time_range.overlap(assoc.time_range)) + if not time_range: + break return total_time @staticmethod @@ -384,7 +384,7 @@ def min_num_tracks_needed_to_track(manager, truth): Minimum number of tracks needed to track `truth` """ assocs = sorted(manager.association_set.associations_including_objects([truth]), - key=attrgetter('time_range.key_times')[-1], + key=lambda assoc: assoc.time_range.key_times[-1], reverse=True) if len(assocs) == 0: @@ -401,7 +401,16 @@ def min_num_tracks_needed_to_track(manager, truth): if not assoc_at_time: timestamp_index += 1 else: - end_time = assoc_at_time.time_range.end_timestamp + key_times = assoc_at_time.time_range.key_times + # If the current time is a start of a TimeRange we need strict inequality + try: + is_start_time = key_times.index(current_time) % 2 == 0 + except ValueError: + is_start_time = False + if is_start_time: + end_time = min([time for time in key_times if time > current_time]) + else: + end_time = min([time for time in key_times if time >= current_time]) num_tracks_needed += 1 # If not yet at the end of the truth timestamps indices, move on to the next From ab4aa376d74e7d297ca9d976fe940b3f432f9a00 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 31 Aug 2022 19:28:23 +0100 Subject: [PATCH 13/26] Use capitalised List type for python backwards compatibility --- stonesoup/types/time.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index bd0f762c0..f16916ea7 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,5 +1,6 @@ import datetime from itertools import combinations, permutations +from typing import List from ..base import Property from .base import Type @@ -151,7 +152,7 @@ class CompoundTimeRange(Type): A container class representing one or more :class:`TimeRange` objects together """ - time_ranges: list[TimeRange] = Property(doc="List of TimeRange objects. Can be empty", + time_ranges: List[TimeRange] = Property(doc="List of TimeRange objects. Can be empty", default=None) def __init__(self, *args, **kwargs): From e09857fefcf82a883f6c7c07c93cdf65ec4b6733 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Thu, 1 Sep 2022 15:09:40 +0100 Subject: [PATCH 14/26] documentation --- stonesoup/types/association.py | 11 +++++------ stonesoup/types/time.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 3c8f54971..6ef4f4757 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -127,7 +127,7 @@ def remove(self, association): @property def key_times(self): """Returns all timestamps at which a component starts or ends, or where there is a - :class:`.~SingleTimeAssociation`""" + :class:`.~SingleTimeAssociation`.""" key_times = list(self.overall_time_range.key_times) for association in self.associations: if isinstance(association, SingleTimeAssociation): @@ -136,10 +136,10 @@ def key_times(self): @property def overall_time_range(self): - """Return a :class:`~.CompoundTimeRange` of :class:`~.TimeRange` - objects in this instance. + """Returns a :class:`~.CompoundTimeRange` covering all times at which at least + one association is active. - :class:`SingleTimeAssociation`s are discarded + Note: :class:`SingleTimeAssociation`s are not counted """ overall_range = CompoundTimeRange() for association in self.associations: @@ -149,8 +149,7 @@ def overall_time_range(self): @property def object_set(self): - """Return all objects in the set - Returned as a set + """Returns a set of all objects contained by this instance. """ object_set = set() for assoc in self.associations: diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index f16916ea7..42dcb9749 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -131,7 +131,7 @@ def overlap(self, time_range): Returns ------- TimeRange - The times contained by both this and time_range + The times contained by both this and `time_range` """ if time_range is None: return None @@ -150,7 +150,7 @@ def overlap(self, time_range): class CompoundTimeRange(Type): """CompoundTimeRange type - A container class representing one or more :class:`TimeRange` objects together + A container class representing one or more :class:`~.TimeRange` objects together """ time_ranges: List[TimeRange] = Property(doc="List of TimeRange objects. Can be empty", default=None) @@ -187,7 +187,7 @@ def key_times(self): return sorted(key_times) def _remove_overlap(self): - """Removes overlap between components of time_ranges""" + """Removes overlap between components of `time_ranges`""" if len(self.time_ranges) in {0, 1}: return if all([component.overlap(component2) is None for (component, component2) in @@ -210,7 +210,7 @@ def _fuse_components(self): self._fuse_components() def add(self, time_range): - """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`""" + """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`.""" if time_range is None: return if isinstance(time_range, CompoundTimeRange): @@ -224,7 +224,7 @@ def add(self, time_range): self._fuse_components() def remove(self, time_range): - """Remove a :class:`.~TimeRange` object from the time ranges. + """Removes a :class:`.~TimeRange` object from the time ranges. It must be a member of self.time_ranges""" if not isinstance(time_range, TimeRange): raise TypeError("Supplied parameter must be a TimeRange object") From 4e1696a07c63d7c5efc27622466b3e7955dee8fe Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Mon, 5 Sep 2022 18:42:22 +0100 Subject: [PATCH 15/26] make TimeRange/s inherit from Interval/s --- stonesoup/metricgenerator/basicmetrics.py | 12 +- .../tests/test_basicmetrics.py | 12 +- .../tests/test_tracktotruthmetrics.py | 8 +- stonesoup/types/association.py | 2 +- stonesoup/types/interval.py | 81 ++++++----- stonesoup/types/tests/test_association.py | 6 +- stonesoup/types/tests/test_interval.py | 22 +-- stonesoup/types/tests/test_metric.py | 8 +- stonesoup/types/tests/test_time.py | 16 +-- stonesoup/types/time.py | 127 +++++++----------- 10 files changed, 130 insertions(+), 164 deletions(-) diff --git a/stonesoup/metricgenerator/basicmetrics.py b/stonesoup/metricgenerator/basicmetrics.py index 0c3f81092..badc7886b 100644 --- a/stonesoup/metricgenerator/basicmetrics.py +++ b/stonesoup/metricgenerator/basicmetrics.py @@ -34,24 +34,24 @@ def compute_metric(self, manager, *args, **kwargs): title='Number of targets', value=len(manager.groundtruth_paths), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) metrics.append(TimeRangeMetric( title='Number of tracks', value=len(manager.tracks), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) metrics.append(TimeRangeMetric( title='Track-to-target ratio', value=len(manager.tracks) / len(manager.groundtruth_paths), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) return metrics diff --git a/stonesoup/metricgenerator/tests/test_basicmetrics.py b/stonesoup/metricgenerator/tests/test_basicmetrics.py index a4b8092ff..91d453a74 100644 --- a/stonesoup/metricgenerator/tests/test_basicmetrics.py +++ b/stonesoup/metricgenerator/tests/test_basicmetrics.py @@ -33,22 +33,22 @@ def test_basicmetrics(): correct_metrics = {TimeRangeMetric(title='Number of targets', value=3, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator), TimeRangeMetric(title='Number of tracks', value=4, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator), TimeRangeMetric(title='Track-to-target ratio', value=4 / 3, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator)} for metric_name in ["Number of targets", diff --git a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py index 595405be4..bc0e26824 100644 --- a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py @@ -90,8 +90,8 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea for metric in metrics: assert isinstance(metric, TimeRangeMetric) - assert metric.time_range.start_timestamp == timestamps[0] - assert metric.time_range.end_timestamp == timestamps[3] + assert metric.time_range.start == timestamps[0] + assert metric.time_range.end == timestamps[3] assert metric.generator == siap_generator if metric.title.endswith(" at times"): @@ -175,8 +175,8 @@ def test_id_siap(trial_manager, trial_truths, trial_tracks, trial_associations, for metric in metrics: assert isinstance(metric, TimeRangeMetric) - assert metric.time_range.start_timestamp == timestamps[0] - assert metric.time_range.end_timestamp == timestamps[3] + assert metric.time_range.start == timestamps[0] + assert metric.time_range.end == timestamps[3] assert metric.generator == siap_generator if metric.title.endswith(" at times"): diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 6ef4f4757..7730f4ea4 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -139,7 +139,7 @@ def overall_time_range(self): """Returns a :class:`~.CompoundTimeRange` covering all times at which at least one association is active. - Note: :class:`SingleTimeAssociation`s are not counted + Note: :class:`~.SingleTimeAssociation` are not counted """ overall_range = CompoundTimeRange() for association in self.associations: diff --git a/stonesoup/types/interval.py b/stonesoup/types/interval.py index 6a47ca8b1..b53799e73 100644 --- a/stonesoup/types/interval.py +++ b/stonesoup/types/interval.py @@ -15,37 +15,45 @@ class Interval(Type): Represents a continuous, closed interval of real numbers. Represented by a lower and upper bound. """ - left: Union[int, float] = Property(doc="Lower bound of interval") - right: Union[int, float] = Property(doc="Upper bound of interval") + start: Union[int, float] = Property(doc="Lower bound of interval") + end: Union[int, float] = Property(doc="Upper bound of interval") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.left >= self.right: + if self.start >= self.end: raise ValueError('Must have left < right') def __hash__(self): - return hash((self.left, self.right)) + return hash((self.start, self.end)) + + @property + def left(self): + return self.start + + @property + def right(self): + return self.end @property def length(self): - return self.right - self.left + return self.end - self.start def __contains__(self, item): if isinstance(item, Real): - return self.left <= item <= self.right + return self.start <= item <= self.end elif isinstance(item, Interval): return self & item == item else: return False def __str__(self): - return '[{left}, {right}]'.format(left=self.left, right=self.right) + return '[{left}, {right}]'.format(left=self.start, right=self.end) - def __repr__(self): - return 'Interval{interval}'.format(interval=str(self)) + # def __repr__(self): + # return 'Interval{interval}'.format(interval=str(self)) def __eq__(self, other): - return isinstance(other, Interval) and (self.left, self.right) == (other.left, other.right) + return isinstance(other, Interval) and (self.start, self.end) == (other.start, other.end) def __and__(self, other): """Set-like intersection""" @@ -54,11 +62,11 @@ def __and__(self, other): raise ValueError("Can only intersect with Interval types") if not self.isdisjoint(other): - new_interval = (max(self.left, other.left), min(self.right, other.right)) + new_interval = max(self.start, other.start), min(self.end, other.end) if new_interval[0] == new_interval[1]: return None else: - return Interval(*new_interval) + return type(self)(new_interval[0], new_interval[1]) else: return None @@ -69,7 +77,7 @@ def __or__(self, other): raise ValueError("Can only union with Interval types") if not self.isdisjoint(other): - return [Interval(min(self.left, other.left), max(self.right, other.right))] + return [type(self)(min(self.start, other.start), max(self.end, other.end))] else: return [copy.copy(self), copy.copy(other)] @@ -82,14 +90,14 @@ def __sub__(self, other): raise ValueError("Can only subtract Interval types from Interval types") elif self.isdisjoint(other): return [copy.copy(self)] - elif other.left <= self.left and self.right <= other.right: + elif other.start <= self.start and self.end <= other.end: return [None] - elif other.left <= self.left: - return [Interval(other.right, self.right)] - elif self.right <= other.right: - return [Interval(self.left, other.left)] + elif other.start <= self.start: + return [Interval(other.end, self.end)] + elif self.end <= other.end: + return [Interval(self.start, other.start)] else: - return [Interval(self.left, other.left), Interval(other.right, self.right)] + return [Interval(self.start, other.start), Interval(other.end, self.end)] def __xor__(self, other): """Set-like symmetric difference""" @@ -117,7 +125,7 @@ def __lt__(self, other): if not isinstance(other, Interval): raise ValueError("Can only compare Interval types to Interval types") - return other.left < self.left and self.right < other.right + return other.start < self.start and self.end < other.end def __ge__(self, other): """Superset check""" @@ -144,12 +152,10 @@ def isdisjoint(self, other): if not isinstance(other, Interval): raise ValueError("Interval types can only overlap with Interval types") - lb = min(self.left, other.left) - ub = max(self.right, other.right) + max_start = max(self.start, other.start) + min_end = min(self.end, other.end) - return not ((ub - lb < self.length + other.length) - or (self.right == other.left) - or (other.right == self.left)) + return max_start > min_end class Intervals(Type): @@ -173,7 +179,7 @@ def __init__(self, *args, **kwargs): if isinstance(self.intervals, (Interval, Tuple)): self.intervals = [self.intervals] else: - raise ValueError("Must contain Interval types") + raise TypeError("Must contain Interval types") elif len(self.intervals) == 2 and all(isinstance(elem, Real) for elem in self.intervals): self.intervals = [self.intervals] @@ -183,7 +189,7 @@ def __init__(self, *args, **kwargs): if isinstance(interval, Sequence) and len(interval) == 2: self.intervals[i] = Interval(*interval) else: - raise ValueError("Individual intervals must be an Interval or Sequence type") + raise TypeError("Individual intervals must be an Interval or Sequence type") self.intervals = self.get_merged_intervals(self.intervals) @@ -209,7 +215,6 @@ def isdisjoint(self, other): if not isinstance(other, Intervals): raise ValueError("Can only compare Intervals to Intervals") - return all(interval.isdisjoint(other_int) for other_int in other for interval in self) @staticmethod @@ -244,10 +249,10 @@ def __contains__(self, item): return any(item in interval for interval in self) def __str__(self): - return str([[interval.left, interval.right] for interval in self]) + return str([[interval.start, interval.end] for interval in self]) - def __repr__(self): - return 'Intervals{intervals}'.format(intervals=str(self)) + # def __repr__(self): + # return 'Intervals{intervals}'.format(intervals=str(self)) @property def length(self): @@ -264,7 +269,8 @@ def __reversed__(self): return self._iter(reverse=True) def __eq__(self, other): - + if len(self) == 0: + return True if len(other) == 0 else False if isinstance(other, Interval): other = Intervals(other) @@ -274,6 +280,8 @@ def __eq__(self, other): def __and__(self, other): """Set-like intersection""" + if other is None: + return False if not isinstance(other, (Interval, Intervals)): raise ValueError("Can only intersect with Intervals types") @@ -287,12 +295,11 @@ def __and__(self, other): if new_interval: new_intervals.append(new_interval) new_intervals = self.get_merged_intervals(new_intervals) - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __or__(self, other): """Set-like union""" - if not isinstance(other, (Interval, Intervals)): raise ValueError('Can only union with Intervals types') @@ -301,7 +308,7 @@ def __or__(self, other): new_intervals = self.intervals + other.intervals new_intervals = self.get_merged_intervals(new_intervals) - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __sub__(self, other): @@ -325,7 +332,7 @@ def __sub__(self, other): if diff[0] is not None: temp_intervals.extend(diff) new_intervals = temp_intervals - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __xor__(self, other): @@ -335,7 +342,7 @@ def __xor__(self, other): raise ValueError("Can only compare Intervals from Intervals") if isinstance(other, Interval): - other = Intervals(other) + other = type(self)(other) return (self | other) - (self & other) diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index 337a9105d..bc2e650e8 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -67,7 +67,7 @@ def test_timerangeassociation(): Detection(np.array([[5], [6]]))} timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) - timerange = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + timerange = TimeRange(start=timestamp1, end=timestamp2) ctimerange = CompoundTimeRange([timerange]) assoc = TimeRangeAssociation(objects=objects, time_range=timerange) @@ -84,7 +84,7 @@ def test_associationset(): timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) timestamp3 = datetime.datetime(2020, 3, 1, 1, 1, 1) - time_range = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, end=timestamp2) time_range2 = TimeRange(timestamp2, timestamp3) objects_list = [Detection(np.array([[1], [2]])), @@ -180,7 +180,7 @@ def test_association_set_properties(): assert test2.object_set == set(objects) timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) - time_range = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, end=timestamp2) com_time_range = CompoundTimeRange([time_range]) assoc2 = TimeRangeAssociation(objects=set(objects[1:]), time_range=com_time_range) diff --git a/stonesoup/types/tests/test_interval.py b/stonesoup/types/tests/test_interval.py index 0da6672a2..ef741380e 100644 --- a/stonesoup/types/tests/test_interval.py +++ b/stonesoup/types/tests/test_interval.py @@ -38,11 +38,6 @@ def test_interval_contains(): assert 'a string' not in a -def test_interval_str(): - assert str(a) == '[{left}, {right}]'.format(left=0, right=1) - assert a.__repr__() == 'Interval{interval}'.format(interval=str(a)) - - def test_interval_eq(): assert a == Interval(0, 1) assert a != (0, 1) @@ -122,6 +117,7 @@ def test_interval_disjoint(): assert a.isdisjoint(b) # No overlap assert not a.isdisjoint(c) # Overlap assert not a.isdisjoint(d) # Meet and no overlap + assert not a.isdisjoint(a) # a is not disjoint with itself def test_intervals_init(): @@ -134,7 +130,7 @@ def test_intervals_init(): assert len(temp.intervals) == 1 assert temp.intervals[0] == a - with pytest.raises(ValueError, match="Must contain Interval types"): + with pytest.raises(TypeError, match="Must contain Interval types"): Intervals('a string') temp = Intervals((0, 1)) @@ -146,7 +142,7 @@ def test_intervals_init(): assert len(A.intervals) == 2 assert A.intervals == [a, b] # Converts lists of tuples to lists of Interval types - with pytest.raises(ValueError, + with pytest.raises(TypeError, match="Individual intervals must be an Interval or Sequence type"): Intervals([(0, 1), 'a string']) @@ -171,11 +167,12 @@ def test_intervals_overlap(): def test_intervals_disjoint(): - with pytest.raises(ValueError, match="Can only compare Intervals to Intervals"): - A.isdisjoint('a string') - assert not A.isdisjoint(B) # Overlap + # with pytest.raises(ValueError, match="Can only compare Intervals to Intervals"): + # A.isdisjoint('a string') + # assert not A.isdisjoint(B) # Overlap assert not A.isdisjoint(C) # Meet with no overlap assert A.isdisjoint(D) + assert not A.isdisjoint(A) def test_intervals_merge(): @@ -196,11 +193,6 @@ def test_intervals_contains(): assert Intervals([Interval(0.25, 0.75)]) in A -def test_intervals_str(): - assert str(A) == str([[interval.left, interval.right] for interval in A]) - assert A.__repr__() == 'Intervals{intervals}'.format(intervals=str(A)) - - def test_intervals_len(): assert Intervals([]).length == 0 assert A.length == 2 diff --git a/stonesoup/types/tests/test_metric.py b/stonesoup/types/tests/test_metric.py index 3058fdb97..aae9b2212 100644 --- a/stonesoup/types/tests/test_metric.py +++ b/stonesoup/types/tests/test_metric.py @@ -85,8 +85,8 @@ def test_timerangemetric(): value = 5 timestamp1 = datetime.datetime.now() timestamp2 = timestamp1 + datetime.timedelta(seconds=10) - time_range = TimeRange(start_timestamp=timestamp1, - end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, + end=timestamp2) class temp_generator(MetricGenerator): @@ -113,8 +113,8 @@ def test_timerangeplottingmetric(): value = 5 timestamp1 = datetime.datetime.now() timestamp2 = timestamp1 + datetime.timedelta(seconds=10) - time_range = TimeRange(start_timestamp=timestamp1, - end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, + end=timestamp2) class temp_generator(MetricGenerator): diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 996493094..2cfa3f725 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -28,15 +28,15 @@ def test_timerange(times): # Without start time with pytest.raises(TypeError): - TimeRange(start_timestamp=times[1]) + TimeRange(start=times[1]) # Without end time with pytest.raises(TypeError): - TimeRange(end_timestamp=times[3]) + TimeRange(end=times[3]) # Test an error is caught when end is after start with pytest.raises(ValueError): - TimeRange(start_timestamp=times[3], end_timestamp=times[1]) + TimeRange(start=times[3], end=times[1]) # Test with wrong types for time_ranges with pytest.raises(TypeError): @@ -44,7 +44,7 @@ def test_timerange(times): with pytest.raises(TypeError): CompoundTimeRange([times[1], times[3]]) - test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + test_range = TimeRange(start=times[1], end=times[3]) test_compound = CompoundTimeRange() @@ -64,13 +64,13 @@ def test_duration(times): # Test that duration is calculated properly # TimeRange - test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + test_range = TimeRange(start=times[1], end=times[3]) # CompoundTimeRange # times[2] is inside of [1] and [3], so should be equivalent to a TimeRange(times[1], times[4]) - test_range2 = CompoundTimeRange([TimeRange(start_timestamp=times[1], end_timestamp=times[3]), - TimeRange(start_timestamp=times[2], end_timestamp=times[4])]) + test_range2 = CompoundTimeRange([TimeRange(start=times[1], end=times[3]), + TimeRange(start=times[2], end=times[4])]) assert test_range.duration == datetime.timedelta(seconds=3726) assert test_range2.duration == datetime.timedelta(seconds=31539726) @@ -79,7 +79,7 @@ def test_duration(times): def test_contains(times): # Test that timestamps are correctly determined to be in the range - test_range = TimeRange(start_timestamp=times[1], end_timestamp=times[3]) + test_range = TimeRange(start=times[1], end=times[3]) test2 = TimeRange(times[1], times[2]) test3 = TimeRange(times[1], times[4]) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 42dcb9749..9232951cb 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,12 +1,11 @@ import datetime from itertools import combinations, permutations -from typing import List from ..base import Property -from .base import Type +from ..types.interval import Interval, Intervals -class TimeRange(Type): +class TimeRange(Interval): """TimeRange type An object representing a time range between two timestamps. @@ -23,24 +22,27 @@ class TimeRange(Type): True """ - start_timestamp: datetime.datetime = Property(doc="Start of the time range") - end_timestamp: datetime.datetime = Property(doc="End of the time range") + start: datetime.datetime = Property(doc="Start of the time range") + end: datetime.datetime = Property(doc="End of the time range") - def __init__(self, start_timestamp, end_timestamp, *args, **kwargs): - if end_timestamp < start_timestamp: - raise ValueError("start_timestamp must be before end_timestamp") - super().__init__(start_timestamp, end_timestamp, *args, **kwargs) + @property + def start_timestamp(self): + return self.start + + @property + def end_timestamp(self): + return self.end @property def duration(self): """Duration of the time range""" - return self.end_timestamp - self.start_timestamp + return self.length @property def key_times(self): """Times the TimeRange begins and ends""" - return [self.start_timestamp, self.end_timestamp] + return [self.start, self.end] def __contains__(self, time): """Checks if timestamp is within range @@ -53,27 +55,16 @@ def __contains__(self, time): Returns ------- bool - `True` if timestamp within :attr:`start_timestamp` and - :attr:`end_timestamp` (inclusive) + `True` if timestamp within :attr:`start` and + :attr:`end` (inclusive) """ if isinstance(time, datetime.datetime): - return self.start_timestamp <= time <= self.end_timestamp - elif isinstance(time, TimeRange): - return self.start_timestamp <= time.start_timestamp and \ - self.end_timestamp >= time.end_timestamp + return self.start <= time <= self.end else: - raise TypeError("Supplied parameter must be a datetime or TimeRange object") + return super().__contains__(time) def __eq__(self, other): - if other is None: - return False - if not isinstance(other, TimeRange): - return False - if self.start_timestamp == other.start_timestamp and \ - self.end_timestamp == other.end_timestamp: - return True - else: - return False + return isinstance(other, TimeRange) and super().__eq__(other) def minus(self, time_range): """Removes the overlap between this instance and another :class:`~.TimeRange`, or @@ -105,18 +96,18 @@ def minus(self, time_range): return self if self == overlap: return None - if self.start_timestamp < overlap.start_timestamp: - start = self.start_timestamp + if self.start < overlap.start: + start = self.start else: - start = overlap.end_timestamp - if self.end_timestamp > overlap.end_timestamp: - end = self.end_timestamp + start = overlap.end + if self.end > overlap.end: + end = self.end else: - end = overlap.start_timestamp - if self.start_timestamp < overlap.start_timestamp and \ - self.end_timestamp > overlap.end_timestamp: - return CompoundTimeRange([TimeRange(self.start_timestamp, overlap.start_timestamp), - TimeRange(overlap.end_timestamp, self.end_timestamp)]) + end = overlap.start + if self.start < overlap.start and \ + self.end > overlap.end: + return CompoundTimeRange([TimeRange(self.start, overlap.start), + TimeRange(overlap.end, self.end)]) else: return TimeRange(start, end) @@ -139,26 +130,20 @@ def overlap(self, time_range): return time_range.overlap(self) if not isinstance(time_range, TimeRange): raise TypeError("Supplied parameter must be a TimeRange object") - start_timestamp = max(self.start_timestamp, time_range.start_timestamp) - end_timestamp = min(self.end_timestamp, time_range.end_timestamp) - if end_timestamp > start_timestamp: - return TimeRange(start_timestamp, end_timestamp) - else: - return None + return super().__and__(time_range) + def __or__(self, other): + return super().__or__(other) -class CompoundTimeRange(Type): + +class CompoundTimeRange(Intervals): """CompoundTimeRange type A container class representing one or more :class:`~.TimeRange` objects together """ - time_ranges: List[TimeRange] = Property(doc="List of TimeRange objects. Can be empty", - default=None) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.time_ranges is None: - self.time_ranges = [] if not isinstance(self.time_ranges, list): raise TypeError("Time_ranges must be a list") for component in self.time_ranges: @@ -167,6 +152,10 @@ def __init__(self, *args, **kwargs): self._remove_overlap() self._fuse_components() + @property + def time_ranges(self): + return self.intervals + @property def duration(self): """Duration of the time range""" @@ -182,8 +171,8 @@ def key_times(self): """Returns all timestamps at which a component starts or ends""" key_times = set() for component in self.time_ranges: - key_times.add(component.start_timestamp) - key_times.add(component.end_timestamp) + key_times.add(component.start) + key_times.add(component.end) return sorted(key_times) def _remove_overlap(self): @@ -201,8 +190,8 @@ def _remove_overlap(self): def _fuse_components(self): """Fuses two time ranges [a,b], [b,c] into [a,c] for all such pairs in this instance""" for (component, component2) in permutations(self.time_ranges, 2): - if component.end_timestamp == component2.start_timestamp: - fused_component = TimeRange(component.start_timestamp, component2.end_timestamp) + if component.end == component2.start: + fused_component = TimeRange(component.start, component2.end) self.remove(component) self.remove(component2) self.add(fused_component) @@ -258,25 +247,14 @@ def __contains__(self, time): if time in component: return True return False - elif isinstance(time, TimeRange): - return True if self.overlap(time) == CompoundTimeRange([time]) else False - elif isinstance(time, CompoundTimeRange): - return True if self.overlap(time) == time else False + elif isinstance(time, (TimeRange, CompoundTimeRange)): + return super().__contains__(time) else: raise TypeError("Supplied parameter must be an instance of either " "datetime, TimeRange, or CompoundTimeRange") def __eq__(self, other): - if other is None: - return False - if not isinstance(other, CompoundTimeRange): - return False - if len(self.time_ranges) == 0: - return True if len(other.time_ranges) == 0 else False - for component in self.time_ranges: - if all([component != component2 for component2 in other.time_ranges]): - return False - return True + return isinstance(other, CompoundTimeRange) and super().__eq__(other) def minus(self, time_range): """Removes any overlap between this and another :class:`~.TimeRange` or @@ -314,19 +292,8 @@ def overlap(self, time_range): """ if time_range is None: return None - total_overlap = CompoundTimeRange() - if isinstance(time_range, CompoundTimeRange): - for component in time_range.time_ranges: - total_overlap.add(self.overlap(component)) - if total_overlap == CompoundTimeRange(): - return None - return total_overlap - elif isinstance(time_range, TimeRange): - for component in self.time_ranges: - total_overlap.add(component.overlap(time_range)) - if total_overlap == CompoundTimeRange(): - return None - return total_overlap - else: + if not isinstance(time_range, (TimeRange, CompoundTimeRange)): raise TypeError("Supplied parameter must be an instance of either " "TimeRange, or CompoundTimeRange") + else: + return super().__and__(time_range) From c7430d9bddaefb770aaa97481393ab98ae3e84a6 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman <95758965+orosoman-dstl@users.noreply.github.com> Date: Fri, 25 Nov 2022 15:37:11 +0000 Subject: [PATCH 16/26] making names consistent with inbuilt/Interval/s --- stonesoup/dataassociator/_assignment.py | 11 ++--- .../metricgenerator/tracktotruthmetrics.py | 6 +-- stonesoup/types/association.py | 18 ++++---- stonesoup/types/interval.py | 3 -- stonesoup/types/tests/test_time.py | 44 +++++++++---------- stonesoup/types/time.py | 27 ++++++------ 6 files changed, 52 insertions(+), 57 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index e7a047245..337c75074 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -1,4 +1,5 @@ import numpy +import copy from ..types.association import AssociationSet @@ -365,13 +366,14 @@ def multidimensional_deconfliction(association_set): solved_2d = assign2D(totals, maximize=True)[1] cleaned_set = AssociationSet() + association_set_reduced = copy.copy(association_set) for i, j in enumerate(solved_2d): if i == j: # Can't associate with self continue assoc = association_set.associations_including_objects({objects[i], objects[j]}) cleaned_set.add(assoc) - association_set.remove(assoc) + association_set_reduced.remove(assoc) if len(cleaned_set) == 0: raise ValueError("Problem unsolvable using this method") @@ -380,14 +382,13 @@ def multidimensional_deconfliction(association_set): return cleaned_set else: # Recursive step - runners_up = multidimensional_deconfliction(association_set).associations + runners_up = multidimensional_deconfliction(association_set_reduced).associations - # At this point, none of association_set should conflict with one another for runner_up in runners_up: runner_up_remaining_time = runner_up.time_range for winner in cleaned_set: if conflicts(runner_up, winner): - runner_up_remaining_time = runner_up_remaining_time.minus(winner.time_range) + runner_up_remaining_time = runner_up_remaining_time - winner.time_range if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0: runner_up_copy = runner_up runner_up_copy.time_range = runner_up_remaining_time @@ -398,7 +399,7 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if getattr(assoc1, 'time_range', None) is None or getattr(assoc2, 'time_range', None) is None: return False - if assoc1.time_range.overlap(assoc2.time_range) and assoc1 != assoc2 \ + if assoc1.time_range & assoc2.time_range and assoc1 != assoc2 \ and len(assoc1.objects.intersection(assoc2.objects)) > 0: return True else: diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 5bf4c0de1..9d7515664 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -358,9 +358,9 @@ def total_time_tracked(manager, truth): time_range = TimeRange(current_time, next_time) for assoc in assocs: # If there is some overlap between time ranges, add this to total_time - if time_range.overlap(assoc.time_range): - total_time += time_range.overlap(assoc.time_range).duration.total_seconds() - time_range = time_range.minus(time_range.overlap(assoc.time_range)) + if time_range & assoc.time_range: + total_time += (time_range & assoc.time_range).duration.total_seconds() + time_range = time_range - (time_range & assoc.time_range) if not time_range: break return total_time diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 7730f4ea4..cafb37583 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -67,14 +67,12 @@ def __init__(self, associations=None, *args, **kwargs): super().__init__(associations, *args, **kwargs) if self.associations is None: self.associations = set() - if not isinstance(self.associations, Set): - raise TypeError("Supplied parameter must be a set") - if not all([isinstance(member, Association) for member in self.associations]): + if not all(isinstance(member, Association) for member in self.associations): raise TypeError("Association set must contain only Association instances") self._simplify() def __eq__(self, other): - return True if self.associations == other.associations else False + return self.associations == other.associations def add(self, association): if association is None: @@ -92,7 +90,7 @@ def _simplify(self): """Where multiple associations describe the same pair of objects, combine them into one. Note this is only implemented for pairs with a time_range attribute- others will be skipped """ - to_remove = [] + to_remove = set() for (assoc1, assoc2) in combinations(self.associations, 2): if not (len(assoc1.objects) == 2 and len(assoc2.objects) == 2): continue @@ -101,13 +99,13 @@ def _simplify(self): if assoc1.objects == assoc2.objects: if isinstance(assoc1.time_range, CompoundTimeRange): assoc1.time_range.add(assoc2.time_range) - to_remove.append(assoc2) + to_remove.add(assoc2) elif isinstance(assoc2.time_range, CompoundTimeRange): assoc2.time_range.add(assoc1.time_range) - to_remove.append(assoc1) + to_remove.add(assoc1) else: assoc1.time_range = CompoundTimeRange([assoc1.time_range, assoc2.time_range]) - to_remove.append(assoc2) + to_remove.add(assoc2) for assoc in to_remove: self.remove(assoc) @@ -207,8 +205,8 @@ def associations_including_objects(self, objects): return AssociationSet({association for association in self.associations - if all([object_ in association.objects - for object_ in objects])}) + if all(object_ in association.objects + for object_ in objects)}) def __contains__(self, item): return item in self.associations diff --git a/stonesoup/types/interval.py b/stonesoup/types/interval.py index b53799e73..e37d9419e 100644 --- a/stonesoup/types/interval.py +++ b/stonesoup/types/interval.py @@ -251,9 +251,6 @@ def __contains__(self, item): def __str__(self): return str([[interval.start, interval.end] for interval in self]) - # def __repr__(self): - # return 'Intervals{intervals}'.format(intervals=str(self)) - @property def length(self): return sum(interval.length for interval in self) diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 2cfa3f725..aa3d3a04e 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -112,8 +112,6 @@ def test_contains(times): assert test_range2 not in compound_test assert test_range2 not in compound_test2 assert test_range3 not in compound_test - print(" \n range3", test_range3, "test2", compound_test2) - print(compound_test2.overlap(test_range3)) assert test_range3 in compound_test2 @@ -152,29 +150,29 @@ def test_minus(times): test4 = TimeRange(times[4], times[5]) with pytest.raises(TypeError): - test1.minus(15) + test1 - 15 - assert test1.minus(test2) == test3 - assert test1.minus(None) == test1 - assert test2.minus(test1) is None + assert test1 - test2 == test3 + assert test1 == test1 - None + assert test2 - test1 is None ctest1 = CompoundTimeRange([test2, test4]) ctest2 = CompoundTimeRange([test1, test2]) ctest3 = CompoundTimeRange([test4]) with pytest.raises(TypeError): - ctest1.minus(15) + ctest1 - 15 - assert ctest1.minus(ctest2) == ctest3 - assert ctest1.minus(ctest1) == CompoundTimeRange() - assert ctest3.minus(ctest1) == CompoundTimeRange() + assert ctest1 - ctest2 == ctest3 + assert ctest1 - ctest1 == CompoundTimeRange() + assert ctest3 - ctest1 == CompoundTimeRange() - assert test1.minus(ctest1) == TimeRange(times[2], times[3]) - assert test4.minus(ctest2) == test4 - assert ctest1.minus(test2) == ctest3 + assert test1 - ctest1 == TimeRange(times[2], times[3]) + assert test4 - ctest2 == test4 + assert ctest1 - test2 == ctest3 -def test_overlap(times): +def test_and(times): test1 = TimeRange(times[1], times[3]) test2 = TimeRange(times[1], times[2]) test3 = TimeRange(times[4], times[5]) @@ -182,16 +180,16 @@ def test_overlap(times): ctest1 = CompoundTimeRange([test2, test3]) ctest2 = CompoundTimeRange([test1, test2]) - assert test2.overlap(ctest1) == ctest1.overlap(test2) + assert test2 & ctest1 == ctest1 & test2 - assert test1.overlap(test1) == test1 - assert test1.overlap(None) is None - assert test1.overlap(test2) == test2 - assert test2.overlap(test1) == test2 - assert ctest1.overlap(None) is None - assert ctest1.overlap(test2) == CompoundTimeRange([test2]) - assert ctest1.overlap(ctest2) == CompoundTimeRange([test2]) - assert ctest1.overlap(ctest2) == ctest2.overlap(ctest1) + assert test1 & test1 == test1 + assert test1 & None is None + assert test1 & test2 == test2 + assert test2 & test1 == test2 + assert ctest1 & None is None + assert ctest1 & test2 == CompoundTimeRange([test2]) + assert ctest1 & ctest2 == CompoundTimeRange([test2]) + assert ctest1 & ctest2 == ctest2 & ctest1 def test_key_times(times): diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 9232951cb..0b09d7057 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,4 +1,5 @@ import datetime +import copy from itertools import combinations, permutations from ..base import Property @@ -66,7 +67,7 @@ def __contains__(self, time): def __eq__(self, other): return isinstance(other, TimeRange) and super().__eq__(other) - def minus(self, time_range): + def __sub__(self, time_range): """Removes the overlap between this instance and another :class:`~.TimeRange`, or :class:`~.CompoundTimeRange`. @@ -80,18 +81,18 @@ def minus(self, time_range): This instance less the overlap """ if time_range is None: - return self + return copy.copy(self) if not isinstance(time_range, TimeRange) and not isinstance(time_range, CompoundTimeRange): raise TypeError("Supplied parameter must be a TimeRange or CompoundTimeRange object") if isinstance(time_range, CompoundTimeRange): ans = self for t_range in time_range.time_ranges: - ans = ans.minus(t_range) + ans = ans - t_range if not ans: return None return ans else: - overlap = self.overlap(time_range) + overlap = self & time_range if overlap is None: return self if self == overlap: @@ -111,7 +112,7 @@ def minus(self, time_range): else: return TimeRange(start, end) - def overlap(self, time_range): + def __and__(self, time_range): """Finds the intersection between this instance and another :class:`~.TimeRange` or :class:`.~CompoundTimeRange` @@ -127,7 +128,7 @@ def overlap(self, time_range): if time_range is None: return None if isinstance(time_range, CompoundTimeRange): - return time_range.overlap(self) + return time_range & self if not isinstance(time_range, TimeRange): raise TypeError("Supplied parameter must be a TimeRange object") return super().__and__(time_range) @@ -179,12 +180,12 @@ def _remove_overlap(self): """Removes overlap between components of `time_ranges`""" if len(self.time_ranges) in {0, 1}: return - if all([component.overlap(component2) is None for (component, component2) in + if all([component & component2 is None for (component, component2) in combinations(self.time_ranges, 2)]): return overlap_check = CompoundTimeRange() for time_range in self.time_ranges: - overlap_check.add(time_range.minus(overlap_check.overlap(time_range))) + overlap_check.add(time_range - overlap_check & time_range) self.time_ranges = overlap_check.time_ranges def _fuse_components(self): @@ -222,7 +223,7 @@ def remove(self, time_range): elif time_range in self: for component in self.time_ranges: if time_range in component: - new = component.minus(time_range) + new = component - time_range self.time_ranges.remove(component) self.add(new) else: @@ -256,7 +257,7 @@ def __contains__(self, time): def __eq__(self, other): return isinstance(other, CompoundTimeRange) and super().__eq__(other) - def minus(self, time_range): + def __sub__(self, time_range): """Removes any overlap between this and another :class:`~.TimeRange` or :class:`.~CompoundTimeRange` from this instance @@ -270,13 +271,13 @@ def minus(self, time_range): The times contained by this but not time_range. May be empty. """ if time_range is None: - return self + return copy.copy(self) ans = CompoundTimeRange() for component in self.time_ranges: - ans.add(component.minus(time_range)) + ans.add(component - time_range) return ans - def overlap(self, time_range): + def __and__(self, time_range): """Finds the intersection between this instance and another time range In the case of an input :class:`~.CompoundTimeRange` this is done recursively. From 5b2c133272741c4b4ffed391c18d97e2b5246425 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 28 Jun 2023 16:44:32 +0100 Subject: [PATCH 17/26] fixes in response to RG comments --- stonesoup/dataassociator/_assignment.py | 23 ++++-------- .../dataassociator/tests/test_assignment.py | 18 +++++++-- stonesoup/dataassociator/tracktotrack.py | 3 +- .../metricgenerator/tracktotruthmetrics.py | 6 +-- stonesoup/types/association.py | 7 ++-- stonesoup/types/interval.py | 5 +-- stonesoup/types/tests/test_association.py | 12 +++--- stonesoup/types/tests/test_interval.py | 8 ++-- stonesoup/types/tests/test_time.py | 37 ++++++++----------- stonesoup/types/time.py | 4 +- 10 files changed, 59 insertions(+), 64 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index 337c75074..2dbc3f4af 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -377,7 +377,7 @@ def multidimensional_deconfliction(association_set): if len(cleaned_set) == 0: raise ValueError("Problem unsolvable using this method") - if check_if_no_conflicts(cleaned_set) and len(association_set) == 0: + if check_if_no_conflicts(cleaned_set) and len(association_set_reduced) == 0: # If no conflicts after this iteration and all objects return return cleaned_set else: @@ -397,29 +397,22 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): - if getattr(assoc1, 'time_range', None) is None or getattr(assoc2, 'time_range', None) is None: - return False - if assoc1.time_range & assoc2.time_range and assoc1 != assoc2 \ - and len(assoc1.objects.intersection(assoc2.objects)) > 0: + if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ + len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ + len(assoc1.time_range.intersection(assoc2.time_range)) > 0: return True else: return False def check_if_no_conflicts(association_set): - for assoc1 in association_set: - for assoc2 in association_set: - if conflicts(assoc1, assoc2): + for assoc1 in range(0, len(association_set)): + for assoc2 in range(assoc1, len(association_set)): + if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]): return False return True def make_symmetric(matrix): """Matrix must be square""" - length = matrix.shape[0] - for i in range(length): - for j in range(length): - if matrix[i, j] >= matrix[j, i]: - matrix[j, i] = matrix[i, j] - else: - matrix[i, j] = matrix[j, i] + return numpy.maximum(matrix, matrix.transpose()) diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 45465216e..0a511951c 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -28,16 +28,18 @@ def test_multi_deconfliction(): time_range=CompoundTimeRange([ranges[0], ranges[4]])) assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=CompoundTimeRange([ranges[1], ranges[4]])) + assoc4_again = TimeRangeAssociation({tracks[0], tracks[1]}, + time_range=CompoundTimeRange([ranges[1], ranges[4]])) # Will fail as there is only one track, rather than two assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) with pytest.raises(ValueError): multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2})) - # Objects do not conflict, so should do nothing + # Objects do not conflict, so should return input association set test2 = AssociationSet({assoc1, assoc2}) assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2} - # Objects do conflict, so remove the shorter one + # Objects do conflict, so remove the shorter time range test3 = AssociationSet({assoc1, assoc3}) # Should entirely remove assoc1 tested3 = multidimensional_deconfliction(test3) @@ -53,5 +55,15 @@ def test_multi_deconfliction(): assert assoc2 in tested4 merged = tested4.associations_including_objects({tracks[0], tracks[1]}) assert len(merged) == 1 - merged = merged.associations.pop() + merged = next(iter(merged.associations)) + assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) + + test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, assoc4_again}) + # Very similar to above, but we add a duplicate assoc4, which should have no effect on the result. + tested5 = multidimensional_deconfliction(test5) + assert len(tested5) == 2 + assert assoc2 in tested5 + merged = tested5.associations_including_objects({tracks[0], tracks[1]}) + assert len(merged) == 1 + merged = next(iter(merged.associations)) assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index f0d5e72c7..21b87561b 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -8,6 +8,7 @@ from ..types.groundtruth import GroundTruthPath from ..types.track import Track from ..types.time import TimeRange +from ._assignment import multidimensional_deconfliction class TrackToTrackCounting(TrackToTrackAssociator): @@ -186,7 +187,7 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): TimeRange(start_timestamp, end_timestamp))) if self.one_to_one: - return AssociationSet(associations).association_deconflicter() + return multidimensional_deconfliction(AssociationSet(associations)) else: return AssociationSet(associations) diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 9d7515664..6db40e1e8 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -403,11 +403,7 @@ def min_num_tracks_needed_to_track(manager, truth): else: key_times = assoc_at_time.time_range.key_times # If the current time is a start of a TimeRange we need strict inequality - try: - is_start_time = key_times.index(current_time) % 2 == 0 - except ValueError: - is_start_time = False - if is_start_time: + if current_time in key_times and key_times.index(current_time) % 2 == 0: end_time = min([time for time in key_times if time > current_time]) else: end_time = min([time for time in key_times if time >= current_time]) diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index cafb37583..f1aaa29b0 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -88,13 +88,12 @@ def add(self, association): def _simplify(self): """Where multiple associations describe the same pair of objects, combine them into one. - Note this is only implemented for pairs with a time_range attribute- others will be skipped + Note this is only implemented for pairs with a time_range attribute - others will be skipped """ to_remove = set() for (assoc1, assoc2) in combinations(self.associations, 2): - if not (len(assoc1.objects) == 2 and len(assoc2.objects) == 2): - continue - if not(hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range')): + if not (len(assoc1.objects) == 2 and len(assoc2.objects) == 2) or \ + not(hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range')): continue if assoc1.objects == assoc2.objects: if isinstance(assoc1.time_range, CompoundTimeRange): diff --git a/stonesoup/types/interval.py b/stonesoup/types/interval.py index e37d9419e..4bdd840ca 100644 --- a/stonesoup/types/interval.py +++ b/stonesoup/types/interval.py @@ -49,9 +49,6 @@ def __contains__(self, item): def __str__(self): return '[{left}, {right}]'.format(left=self.start, right=self.end) - # def __repr__(self): - # return 'Interval{interval}'.format(interval=str(self)) - def __eq__(self, other): return isinstance(other, Interval) and (self.start, self.end) == (other.start, other.end) @@ -267,7 +264,7 @@ def __reversed__(self): def __eq__(self, other): if len(self) == 0: - return True if len(other) == 0 else False + return len(other) == 0 if isinstance(other, Interval): other = Intervals(other) diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index bc2e650e8..acc0b1510 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -85,7 +85,7 @@ def test_associationset(): timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) timestamp3 = datetime.datetime(2020, 3, 1, 1, 1, 1) time_range = TimeRange(start=timestamp1, end=timestamp2) - time_range2 = TimeRange(timestamp2, timestamp3) + time_range2 = TimeRange(start=timestamp2, end=timestamp3) objects_list = [Detection(np.array([[1], [2]])), Detection(np.array([[3], [4]])), @@ -96,7 +96,7 @@ def test_associationset(): assoc2 = TimeRangeAssociation(objects=set(objects_list[1:]), time_range=time_range) - assoc_duplicate = TimeRangeAssociation(objects=set(objects_list[1:]), + assoc2_same_objects = TimeRangeAssociation(objects=set(objects_list[1:]), time_range=time_range2) assoc_set = AssociationSet({assoc1, assoc2}) @@ -115,7 +115,7 @@ def test_associationset(): # test _simplify method - simplify_test = AssociationSet({assoc1, assoc2, assoc_duplicate}) + simplify_test = AssociationSet({assoc1, assoc2, assoc2_same_objects}) assert len(simplify_test.associations) == 2 @@ -148,17 +148,17 @@ def test_associationset(): def test_association_set_add_remove(): test = AssociationSet() with pytest.raises(TypeError): - test.add("banana") + test.add("a string") with pytest.raises(TypeError): - test.remove("banana") + test.remove("a string") objects = {Detection(np.array([[1], [2]])), Detection(np.array([[3], [4]])), Detection(np.array([[5], [6]]))} assoc = Association(objects) + assert assoc not in test.associations with pytest.raises(ValueError): test.remove(assoc) - assert assoc not in test.associations test.add(assoc) assert assoc in test.associations test.remove(assoc) diff --git a/stonesoup/types/tests/test_interval.py b/stonesoup/types/tests/test_interval.py index ef741380e..df9dfa716 100644 --- a/stonesoup/types/tests/test_interval.py +++ b/stonesoup/types/tests/test_interval.py @@ -167,9 +167,6 @@ def test_intervals_overlap(): def test_intervals_disjoint(): - # with pytest.raises(ValueError, match="Can only compare Intervals to Intervals"): - # A.isdisjoint('a string') - # assert not A.isdisjoint(B) # Overlap assert not A.isdisjoint(C) # Meet with no overlap assert A.isdisjoint(D) assert not A.isdisjoint(A) @@ -199,6 +196,11 @@ def test_intervals_len(): assert Intervals([Interval(0.5, 0.75), Interval(0.1, 0.2)]).length == 0.35 +def test_intervals_str(): + assert str(A) == str([[interval.left, interval.right] for interval in A]) + assert A.__repr__() == 'Intervals{intervals}'.format(intervals=str(A)) + + def test_intervals_iter(): intervals = iter([a, b]) A_iter = iter(A) diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index aa3d3a04e..fa4c287d8 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -28,11 +28,11 @@ def test_timerange(times): # Without start time with pytest.raises(TypeError): - TimeRange(start=times[1]) + TimeRange(end=times[3]) # Without end time with pytest.raises(TypeError): - TimeRange(end=times[3]) + TimeRange(start=times[1]) # Test an error is caught when end is after start with pytest.raises(ValueError): @@ -50,7 +50,7 @@ def test_timerange(times): test_compound2 = CompoundTimeRange([test_range]) - # Tests fuse_components method + # Test fuse_components method fuse_test = CompoundTimeRange([test_range, TimeRange(times[3], times[4])]) assert test_range.start_timestamp == times[1] @@ -72,8 +72,8 @@ def test_duration(times): test_range2 = CompoundTimeRange([TimeRange(start=times[1], end=times[3]), TimeRange(start=times[2], end=times[4])]) - assert test_range.duration == datetime.timedelta(seconds=3726) - assert test_range2.duration == datetime.timedelta(seconds=31539726) + assert test_range.duration == times[3] - times[1] + assert test_range2.duration == times[4] - times[1] def test_contains(times): @@ -86,11 +86,11 @@ def test_contains(times): with pytest.raises(TypeError): 16 in test3 - assert times[2] in test_range - assert not times[4] in test_range - assert not times[0] in test_range assert times[1] in test_range + assert times[2] in test_range assert times[3] in test_range + assert not times[0] in test_range + assert not times[4] in test_range assert test2 in test_range assert test_range not in test2 @@ -120,7 +120,7 @@ def test_equality(times): test2 = TimeRange(times[1], times[2]) test3 = TimeRange(times[1], times[3]) - assert test1 != "stonesoup" + assert test1 != "a string" assert test1 == test2 assert test2 == test1 @@ -129,13 +129,11 @@ def test_equality(times): ctest1 = CompoundTimeRange([test1, test3]) ctest2 = CompoundTimeRange([TimeRange(times[1], times[3])]) - assert ctest2 != "Stonesoup is the best!" + assert ctest2 != "a string" - assert ctest1 == ctest2 - assert ctest2 == ctest1 + assert ctest1 == ctest2 and ctest2 == ctest1 ctest2.add(TimeRange(times[3], times[4])) - assert ctest1 != ctest2 - assert ctest2 != ctest1 + assert ctest1 != ctest2 and ctest2 != ctest1 assert CompoundTimeRange() == CompoundTimeRange() assert ctest1 != CompoundTimeRange() @@ -218,14 +216,11 @@ def test_remove_overlap(times): test3_ro = CompoundTimeRange() test3_ro._remove_overlap() - test1 = CompoundTimeRange([TimeRange(times[0], times[1]), - TimeRange(times[3], times[4])]) - test3 = CompoundTimeRange() - test4 = CompoundTimeRange([TimeRange(times[0], times[4])]) + test2 = CompoundTimeRange([TimeRange(times[0], times[4])]) - assert test1_ro == test1 - assert test2_ro == test4 - assert test3_ro == test3 + assert test1_ro.duration == TimeRange(times[0], times[1]).duration + TimeRange(times[3], times[4]).duration + assert test2_ro == test2 + assert test3_ro == CompoundTimeRange() def test_fuse_components(times): diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 0b09d7057..7cc6cc698 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -78,7 +78,7 @@ def __sub__(self, time_range): Returns ------- TimeRange - This instance less the overlap + This instance less the overlap with the other time_range """ if time_range is None: return copy.copy(self) @@ -87,7 +87,7 @@ def __sub__(self, time_range): if isinstance(time_range, CompoundTimeRange): ans = self for t_range in time_range.time_ranges: - ans = ans - t_range + ans -= t_range if not ans: return None return ans From 7ccb3344494137e55e0f9f1f095a4476f831ab73 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 28 Jun 2023 17:14:50 +0100 Subject: [PATCH 18/26] flake8, minor bugs --- stonesoup/dataassociator/_assignment.py | 2 +- stonesoup/dataassociator/tests/test_assignment.py | 6 +++--- stonesoup/types/tests/test_time.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index 2dbc3f4af..bc5eaeaa8 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -399,7 +399,7 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ - len(assoc1.time_range.intersection(assoc2.time_range)) > 0: + len(assoc1.time_range & assoc2.time_range)) > 0: return True else: return False diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 0a511951c..03b544c5a 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -28,7 +28,7 @@ def test_multi_deconfliction(): time_range=CompoundTimeRange([ranges[0], ranges[4]])) assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=CompoundTimeRange([ranges[1], ranges[4]])) - assoc4_again = TimeRangeAssociation({tracks[0], tracks[1]}, + a4_clone = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=CompoundTimeRange([ranges[1], ranges[4]])) # Will fail as there is only one track, rather than two assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) @@ -58,8 +58,8 @@ def test_multi_deconfliction(): merged = next(iter(merged.associations)) assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) - test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, assoc4_again}) - # Very similar to above, but we add a duplicate assoc4, which should have no effect on the result. + test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, a4_clone}) + # Very similar to above, but we add a duplicate assoc4 - should have no effect on the result. tested5 = multidimensional_deconfliction(test5) assert len(tested5) == 2 assert assoc2 in tested5 diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index fa4c287d8..d408b461c 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -218,7 +218,8 @@ def test_remove_overlap(times): test2 = CompoundTimeRange([TimeRange(times[0], times[4])]) - assert test1_ro.duration == TimeRange(times[0], times[1]).duration + TimeRange(times[3], times[4]).duration + assert test1_ro.duration == TimeRange(times[0], times[1]).duration + \ + TimeRange(times[3], times[4]).duration assert test2_ro == test2 assert test3_ro == CompoundTimeRange() From 3478bbbf4761689b8e93da0a04f49ccf536bcd29 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 28 Jun 2023 17:26:32 +0100 Subject: [PATCH 19/26] remove extra bracket --- stonesoup/dataassociator/_assignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index bc5eaeaa8..be552b6e2 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -399,7 +399,7 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ - len(assoc1.time_range & assoc2.time_range)) > 0: + len(assoc1.time_range & assoc2.time_range) > 0: return True else: return False From 9bcde6d089afbcb667dbdd36b2d572e93748e9c2 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 28 Jun 2023 18:17:23 +0100 Subject: [PATCH 20/26] added two spaces --- stonesoup/dataassociator/tests/test_assignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 03b544c5a..7a1ac7593 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -29,7 +29,7 @@ def test_multi_deconfliction(): assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=CompoundTimeRange([ranges[1], ranges[4]])) a4_clone = TimeRangeAssociation({tracks[0], tracks[1]}, - time_range=CompoundTimeRange([ranges[1], ranges[4]])) + time_range=CompoundTimeRange([ranges[1], ranges[4]])) # Will fail as there is only one track, rather than two assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) with pytest.raises(ValueError): From 3a77c726efeb9b73812caa53c7cca30f5e1c9223 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 28 Jun 2023 18:23:35 +0100 Subject: [PATCH 21/26] fix bug with conflicts method --- stonesoup/dataassociator/_assignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index be552b6e2..fd6ec3e1a 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -399,7 +399,7 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ - len(assoc1.time_range & assoc2.time_range) > 0: + (assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0: return True else: return False From 7129c0a9411d9410b3d2785f23f2999c6aa8c4bb Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 12 Jul 2023 17:55:39 +0100 Subject: [PATCH 22/26] fix another bug with conflicts method --- stonesoup/dataassociator/_assignment.py | 13 +++++++------ stonesoup/types/association.py | 2 +- stonesoup/types/tests/test_association.py | 2 +- stonesoup/types/time.py | 5 +++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index fd6ec3e1a..cb19e31d7 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -348,7 +348,7 @@ def multidimensional_deconfliction(association_set): """ # Check if there are any conflicts. If none we can simply return the input if check_if_no_conflicts(association_set): - return association_set + return copy.copy(association_set) objects = list(association_set.object_set) length = len(objects) @@ -371,8 +371,8 @@ def multidimensional_deconfliction(association_set): if i == j: # Can't associate with self continue - assoc = association_set.associations_including_objects({objects[i], objects[j]}) - cleaned_set.add(assoc) + assoc = association_set_reduced.associations_including_objects({objects[i], objects[j]}) + cleaned_set.add(copy.copy(assoc)) association_set_reduced.remove(assoc) if len(cleaned_set) == 0: raise ValueError("Problem unsolvable using this method") @@ -388,9 +388,9 @@ def multidimensional_deconfliction(association_set): runner_up_remaining_time = runner_up.time_range for winner in cleaned_set: if conflicts(runner_up, winner): - runner_up_remaining_time = runner_up_remaining_time - winner.time_range + runner_up_remaining_time -= winner.time_range if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0: - runner_up_copy = runner_up + runner_up_copy = copy.copy(runner_up) runner_up_copy.time_range = runner_up_remaining_time cleaned_set.add(runner_up_copy) return cleaned_set @@ -399,7 +399,8 @@ def multidimensional_deconfliction(association_set): def conflicts(assoc1, assoc2): if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ - (assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0: + (assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0 and \ + assoc1 != assoc2: return True else: return False diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index f1aaa29b0..3fc817845 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -88,7 +88,7 @@ def add(self, association): def _simplify(self): """Where multiple associations describe the same pair of objects, combine them into one. - Note this is only implemented for pairs with a time_range attribute - others will be skipped + This is only implemented for pairs with a time_range attribute - others will be skipped """ to_remove = set() for (assoc1, assoc2) in combinations(self.associations, 2): diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index acc0b1510..7056223b9 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -97,7 +97,7 @@ def test_associationset(): assoc2 = TimeRangeAssociation(objects=set(objects_list[1:]), time_range=time_range) assoc2_same_objects = TimeRangeAssociation(objects=set(objects_list[1:]), - time_range=time_range2) + time_range=time_range2) assoc_set = AssociationSet({assoc1, assoc2}) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 7cc6cc698..97dc95da9 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -185,8 +185,9 @@ def _remove_overlap(self): return overlap_check = CompoundTimeRange() for time_range in self.time_ranges: - overlap_check.add(time_range - overlap_check & time_range) - self.time_ranges = overlap_check.time_ranges + if time_range - overlap_check: + overlap_check.add(time_range - overlap_check & time_range) + self.intervals = copy.copy(overlap_check.time_ranges) def _fuse_components(self): """Fuses two time ranges [a,b], [b,c] into [a,c] for all such pairs in this instance""" From c2b645a0047e6142f7c987c5b702a9dc95d991f6 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Tue, 18 Jul 2023 12:07:17 +0100 Subject: [PATCH 23/26] Final(?) bug fixes, clean-up --- stonesoup/dataassociator/_assignment.py | 1 - stonesoup/types/tests/test_interval.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index cb19e31d7..96fa6ae7a 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -346,7 +346,6 @@ def multidimensional_deconfliction(association_set): : :class:`AssociationSet` The association set without contradictory associations """ - # Check if there are any conflicts. If none we can simply return the input if check_if_no_conflicts(association_set): return copy.copy(association_set) diff --git a/stonesoup/types/tests/test_interval.py b/stonesoup/types/tests/test_interval.py index df9dfa716..026115e32 100644 --- a/stonesoup/types/tests/test_interval.py +++ b/stonesoup/types/tests/test_interval.py @@ -198,7 +198,13 @@ def test_intervals_len(): def test_intervals_str(): assert str(A) == str([[interval.left, interval.right] for interval in A]) - assert A.__repr__() == 'Intervals{intervals}'.format(intervals=str(A)) + assert A.__repr__() == ('Intervals(\n' + ' intervals=[Interval(\n' + ' start=0,\n' + ' end=1),\n' + ' Interval(\n' + ' start=2,\n' + ' end=3)])') def test_intervals_iter(): From 2e234fc086281381388d4b804cee8f0a90edcbd6 Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 9 Aug 2023 16:03:57 +0100 Subject: [PATCH 24/26] Bug fixes in multidimensional deconfliction method and related --- stonesoup/dataassociator/_assignment.py | 26 +++++++++++-------- .../dataassociator/tests/test_assignment.py | 13 +++++++--- stonesoup/types/association.py | 7 ++--- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py index 96fa6ae7a..1974456e5 100644 --- a/stonesoup/dataassociator/_assignment.py +++ b/stonesoup/dataassociator/_assignment.py @@ -358,11 +358,10 @@ def multidimensional_deconfliction(association_set): raise ValueError("Supplied set must only contain pairs of associated objects") i, j = (objects.index(object_) for object_ in association.objects) totals[i, j] = association.time_range.duration.total_seconds() - make_symmetric(totals) + totals = numpy.maximum(totals, totals.transpose()) # make symmetric totals = numpy.rint(totals).astype(int) numpy.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself - solved_2d = assign2D(totals, maximize=True)[1] cleaned_set = AssociationSet() association_set_reduced = copy.copy(association_set) @@ -370,13 +369,23 @@ def multidimensional_deconfliction(association_set): if i == j: # Can't associate with self continue - assoc = association_set_reduced.associations_including_objects({objects[i], objects[j]}) - cleaned_set.add(copy.copy(assoc)) - association_set_reduced.remove(assoc) + try: + assoc = next(iter(association_set_reduced.associations_including_objects( + {objects[i], objects[j]}))) # There should only be 1 association in this set + except StopIteration: + # We took the association out previously in the loop + continue + if all(assoc.duration > clean_assoc.duration or not conflicts(assoc, clean_assoc) + for clean_assoc in cleaned_set): + cleaned_set.add(copy.copy(assoc)) + association_set_reduced.remove(assoc) + if len(cleaned_set) == 0: raise ValueError("Problem unsolvable using this method") - if check_if_no_conflicts(cleaned_set) and len(association_set_reduced) == 0: + if len(association_set_reduced) == 0: + if check_if_no_conflicts(cleaned_set): + raise RuntimeError("Conflicts still present in cleaned set") # If no conflicts after this iteration and all objects return return cleaned_set else: @@ -411,8 +420,3 @@ def check_if_no_conflicts(association_set): if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]): return False return True - - -def make_symmetric(matrix): - """Matrix must be square""" - return numpy.maximum(matrix, matrix.transpose()) diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 7a1ac7593..280903d17 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -1,15 +1,20 @@ from ...types.association import AssociationSet, TimeRangeAssociation from ...types.time import TimeRange, CompoundTimeRange from ...types.track import Track -from .._assignment import multidimensional_deconfliction +from .._assignment import multidimensional_deconfliction, check_if_no_conflicts import datetime import pytest +def is_assoc_in_assoc_set(assoc, assoc_set): + return any(assoc.time_range == set_assoc.time_range and + assoc.objects == set_assoc.objects for set_assoc in assoc_set) + + def test_multi_deconfliction(): test = AssociationSet() tested = multidimensional_deconfliction(test) - assert test == tested + assert test.associations == tested.associations tracks = [Track(id=0), Track(id=1), Track(id=2), Track(id=3)] times = [datetime.datetime(year=2022, month=6, day=1, hour=0), datetime.datetime(year=2022, month=6, day=1, hour=1), @@ -52,7 +57,7 @@ def test_multi_deconfliction(): # assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain tested4 = multidimensional_deconfliction(test4) assert len(tested4) == 2 - assert assoc2 in tested4 + assert is_assoc_in_assoc_set(assoc2, tested4) merged = tested4.associations_including_objects({tracks[0], tracks[1]}) assert len(merged) == 1 merged = next(iter(merged.associations)) @@ -62,7 +67,7 @@ def test_multi_deconfliction(): # Very similar to above, but we add a duplicate assoc4 - should have no effect on the result. tested5 = multidimensional_deconfliction(test5) assert len(tested5) == 2 - assert assoc2 in tested5 + assert is_assoc_in_assoc_set(assoc2, tested5) merged = tested5.associations_including_objects({tracks[0], tracks[1]}) assert len(merged) == 1 merged = next(iter(merged.associations)) diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 3fc817845..0e89c5747 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -52,6 +52,10 @@ class TimeRangeAssociation(Association): time_range: Union[CompoundTimeRange, TimeRange] = Property( default=None, doc="Range of times that association exists over. Default is None") + @property + def duration(self): + return self.time_range.duration.total_seconds() + class AssociationSet(Type): """AssociationSet type @@ -71,9 +75,6 @@ def __init__(self, associations=None, *args, **kwargs): raise TypeError("Association set must contain only Association instances") self._simplify() - def __eq__(self, other): - return self.associations == other.associations - def add(self, association): if association is None: return From 7251e5fe6b0f288e99edb569c1de5f9b14f88f0f Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 9 Aug 2023 16:16:53 +0100 Subject: [PATCH 25/26] Flake-8 fixes --- stonesoup/dataassociator/tests/test_assignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py index 280903d17..091b5201b 100644 --- a/stonesoup/dataassociator/tests/test_assignment.py +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -1,7 +1,7 @@ from ...types.association import AssociationSet, TimeRangeAssociation from ...types.time import TimeRange, CompoundTimeRange from ...types.track import Track -from .._assignment import multidimensional_deconfliction, check_if_no_conflicts +from .._assignment import multidimensional_deconfliction import datetime import pytest From 5045a47470aa47691113e63e3deb4195d04d5a5c Mon Sep 17 00:00:00 2001 From: Oliver Rosoman Date: Wed, 9 Aug 2023 16:23:49 +0100 Subject: [PATCH 26/26] Add back in AssociationSet.__eq__ --- stonesoup/types/association.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 0e89c5747..cd9cabcfb 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -75,6 +75,9 @@ def __init__(self, associations=None, *args, **kwargs): raise TypeError("Association set must contain only Association instances") self._simplify() + def __eq__(self, other): + return self.associations == other.associations + def add(self, association): if association is None: return