diff --git a/stonesoup/dataassociator/_assignment.py b/stonesoup/dataassociator/_assignment.py new file mode 100644 index 000000000..f29229eda --- /dev/null +++ b/stonesoup/dataassociator/_assignment.py @@ -0,0 +1,108 @@ +import copy + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from ..types.association import AssociationSet + + +def multidimensional_deconfliction(association_set): + """Solves the Multidimensional Assignment Problem (MAP) + + The assignment problem becomes more complex when time is added as a dimension. + This basic solution finds all the conflicts in an association set and then creates a + matrix of sums of conflicts in seconds, which is then passed to linear_sum_assignment to + solve as a simple 2D assignment problem. + Therefore, each object will only ever be assigned to one other + at any one time. In the case of an association that only partially overlaps, the time range + of the "weaker" one (the one eliminated by assign2D) will be trimmed + until there is no conflict. + + Due to the possibility of more than two conflicting associations at the same time, + this algorithm is recursive, but it is not expected many (if any) recursions will be required + for most uses. + + Parameters + ---------- + association_set: The :class:`AssociationSet` to de-conflict + + + Returns + ------- + : :class:`AssociationSet` + The association set without contradictory associations + """ + if check_if_no_conflicts(association_set): + return copy.copy(association_set) + + objects = list(association_set.object_set) + length = len(objects) + totals = np.zeros((length, length)) # Time objects i and j are associated for in total + + for association in association_set.associations: + if len(association.objects) != 2: + raise ValueError("Supplied set must only contain pairs of associated objects") + i, j = (objects.index(object_) for object_ in association.objects) + totals[i, j] = association.time_range.duration.total_seconds() + totals = np.maximum(totals, totals.transpose()) # make symmetric + + totals = np.rint(totals).astype(int) + np.fill_diagonal(totals, 0) # Don't want to count associations of an object with itself + solved_2d = linear_sum_assignment(totals, maximize=True)[1] + cleaned_set = AssociationSet() + association_set_reduced = copy.copy(association_set) + for i, j in enumerate(solved_2d): + if i == j: + # Can't associate with self + continue + try: + assoc = next(iter(association_set_reduced.associations_including_objects( + {objects[i], objects[j]}))) # There should only be 1 association in this set + except StopIteration: + # We took the association out previously in the loop + continue + if all(assoc.duration > clean_assoc.duration or not conflicts(assoc, clean_assoc) + for clean_assoc in cleaned_set): + cleaned_set.add(copy.copy(assoc)) + association_set_reduced.remove(assoc) + + if len(cleaned_set) == 0: + raise ValueError("Problem unsolvable using this method") + + if len(association_set_reduced) == 0: + if check_if_no_conflicts(cleaned_set): + raise RuntimeError("Conflicts still present in cleaned set") + # If no conflicts after this iteration and all objects return + return cleaned_set + else: + # Recursive step + runners_up = multidimensional_deconfliction(association_set_reduced).associations + + for runner_up in runners_up: + runner_up_remaining_time = runner_up.time_range + for winner in cleaned_set: + if conflicts(runner_up, winner): + runner_up_remaining_time -= winner.time_range + if runner_up_remaining_time and runner_up_remaining_time.duration.total_seconds() > 0: + runner_up_copy = copy.copy(runner_up) + runner_up_copy.time_range = runner_up_remaining_time + cleaned_set.add(runner_up_copy) + return cleaned_set + + +def conflicts(assoc1, assoc2): + if hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range') and \ + len(assoc1.objects.intersection(assoc2.objects)) > 0 and \ + (assoc1.time_range & assoc2.time_range).duration.total_seconds() > 0 and \ + assoc1 != assoc2: + return True + else: + return False + + +def check_if_no_conflicts(association_set): + for assoc1 in range(0, len(association_set)): + for assoc2 in range(assoc1, len(association_set)): + if conflicts(list(association_set)[assoc1], list(association_set)[assoc2]): + return False + return True diff --git a/stonesoup/dataassociator/tests/test_assignment.py b/stonesoup/dataassociator/tests/test_assignment.py new file mode 100644 index 000000000..091b5201b --- /dev/null +++ b/stonesoup/dataassociator/tests/test_assignment.py @@ -0,0 +1,74 @@ +from ...types.association import AssociationSet, TimeRangeAssociation +from ...types.time import TimeRange, CompoundTimeRange +from ...types.track import Track +from .._assignment import multidimensional_deconfliction +import datetime +import pytest + + +def is_assoc_in_assoc_set(assoc, assoc_set): + return any(assoc.time_range == set_assoc.time_range and + assoc.objects == set_assoc.objects for set_assoc in assoc_set) + + +def test_multi_deconfliction(): + test = AssociationSet() + tested = multidimensional_deconfliction(test) + assert test.associations == tested.associations + tracks = [Track(id=0), Track(id=1), Track(id=2), Track(id=3)] + times = [datetime.datetime(year=2022, month=6, day=1, hour=0), + datetime.datetime(year=2022, month=6, day=1, hour=1), + datetime.datetime(year=2022, month=6, day=1, hour=5), + datetime.datetime(year=2022, month=6, day=2, hour=5), + datetime.datetime(year=2022, month=6, day=1, hour=9)] + ranges = [TimeRange(times[0], times[1]), + TimeRange(times[1], times[2]), + TimeRange(times[2], times[3]), + TimeRange(times[0], times[4]), + TimeRange(times[2], times[4])] + + assoc1 = TimeRangeAssociation({tracks[0], tracks[1]}, time_range=ranges[0]) + assoc2 = TimeRangeAssociation({tracks[2], tracks[3]}, time_range=ranges[0]) + assoc3 = TimeRangeAssociation({tracks[0], tracks[3]}, + time_range=CompoundTimeRange([ranges[0], ranges[4]])) + assoc4 = TimeRangeAssociation({tracks[0], tracks[1]}, + time_range=CompoundTimeRange([ranges[1], ranges[4]])) + a4_clone = TimeRangeAssociation({tracks[0], tracks[1]}, + time_range=CompoundTimeRange([ranges[1], ranges[4]])) + # Will fail as there is only one track, rather than two + assoc_fail = TimeRangeAssociation({tracks[0]}, time_range=ranges[0]) + with pytest.raises(ValueError): + multidimensional_deconfliction(AssociationSet({assoc_fail, assoc1, assoc2})) + + # Objects do not conflict, so should return input association set + test2 = AssociationSet({assoc1, assoc2}) + assert multidimensional_deconfliction(test2).associations == {assoc1, assoc2} + + # Objects do conflict, so remove the shorter time range + test3 = AssociationSet({assoc1, assoc3}) + # Should entirely remove assoc1 + tested3 = multidimensional_deconfliction(test3) + assert len(tested3) == 1 + test_assoc3 = next(iter(tested3.associations)) + for var in vars(test_assoc3): + assert getattr(test_assoc3, var) == getattr(assoc3, var) + + test4 = AssociationSet({assoc1, assoc2, assoc3, assoc4}) + # assoc1 and assoc4 should merge together, assoc3 should be removed, and assoc2 should remain + tested4 = multidimensional_deconfliction(test4) + assert len(tested4) == 2 + assert is_assoc_in_assoc_set(assoc2, tested4) + merged = tested4.associations_including_objects({tracks[0], tracks[1]}) + assert len(merged) == 1 + merged = next(iter(merged.associations)) + assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) + + test5 = AssociationSet({assoc1, assoc2, assoc3, assoc4, a4_clone}) + # Very similar to above, but we add a duplicate assoc4 - should have no effect on the result. + tested5 = multidimensional_deconfliction(test5) + assert len(tested5) == 2 + assert is_assoc_in_assoc_set(assoc2, tested5) + merged = tested5.associations_including_objects({tracks[0], tracks[1]}) + assert len(merged) == 1 + merged = next(iter(merged.associations)) + assert merged.time_range == CompoundTimeRange([ranges[0], ranges[1], ranges[4]]) diff --git a/stonesoup/dataassociator/tests/test_tracktotrack.py b/stonesoup/dataassociator/tests/test_tracktotrack.py index 9f885dc57..18cad54c1 100644 --- a/stonesoup/dataassociator/tests/test_tracktotrack.py +++ b/stonesoup/dataassociator/tests/test_tracktotrack.py @@ -94,6 +94,16 @@ def test_euclidiantracktotrack(tracks): association_set_4 = complete_associator.associate_tracks({tracks[0]}, {tracks[5]}) + complete_associator_one2one = TrackToTrackCounting( + association_threshold=10, + consec_pairs_confirm=3, + consec_misses_end=2, + use_positional_only=False) + start_time = datetime.datetime(2019, 1, 1, 14, 0, 0) + + association_set_one2one = complete_associator_one2one.associate_tracks( + {tracks[0], tracks[2]}, {tracks[1], tracks[3], tracks[4]}) + assert len(association_set_1.associations) == 1 assoc1 = list(association_set_1.associations)[0] assert set(assoc1.objects) == {tracks[0], tracks[1]} @@ -121,6 +131,15 @@ def test_euclidiantracktotrack(tracks): assert assoc4.time_range.end_timestamp \ == start_time + datetime.timedelta(seconds=7) + assert len(association_set_one2one) == 1 + assoc5 = list(association_set_one2one)[0] + # assoc5 should be equal to assoc1 + assert set(assoc5.objects) == {tracks[0], tracks[1]} + assert assoc5.time_range.start_timestamp \ + == start_time + datetime.timedelta(seconds=1) + assert assoc5.time_range.end_timestamp \ + == start_time + datetime.timedelta(seconds=6) + def test_euclidiantracktotruth(tracks): associator = TrackToTruth( diff --git a/stonesoup/dataassociator/tracktotrack.py b/stonesoup/dataassociator/tracktotrack.py index d6e4a622a..21b87561b 100644 --- a/stonesoup/dataassociator/tracktotrack.py +++ b/stonesoup/dataassociator/tracktotrack.py @@ -8,6 +8,7 @@ from ..types.groundtruth import GroundTruthPath from ..types.track import Track from ..types.time import TimeRange +from ._assignment import multidimensional_deconfliction class TrackToTrackCounting(TrackToTrackAssociator): @@ -77,6 +78,11 @@ class TrackToTrackCounting(TrackToTrackAssociator): "position components compared to others (such as velocity). " "Default is 0.6" ) + one_to_one: bool = Property( + default=False, + doc="If True, it is ensured no two associations ever contain the same track " + "at the same time" + ) def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): """Associate two sets of tracks together. @@ -180,7 +186,10 @@ def associate_tracks(self, tracks_set_1: Set[Track], tracks_set_2: Set[Track]): (track1, track2), TimeRange(start_timestamp, end_timestamp))) - return AssociationSet(associations) + if self.one_to_one: + return multidimensional_deconfliction(AssociationSet(associations)) + else: + return AssociationSet(associations) class TrackToTruth(TrackToTrackAssociator): diff --git a/stonesoup/metricgenerator/basicmetrics.py b/stonesoup/metricgenerator/basicmetrics.py index 0c3f81092..badc7886b 100644 --- a/stonesoup/metricgenerator/basicmetrics.py +++ b/stonesoup/metricgenerator/basicmetrics.py @@ -34,24 +34,24 @@ def compute_metric(self, manager, *args, **kwargs): title='Number of targets', value=len(manager.groundtruth_paths), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) metrics.append(TimeRangeMetric( title='Number of tracks', value=len(manager.tracks), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) metrics.append(TimeRangeMetric( title='Track-to-target ratio', value=len(manager.tracks) / len(manager.groundtruth_paths), time_range=TimeRange( - start_timestamp=min(timestamps), - end_timestamp=max(timestamps)), + start=min(timestamps), + end=max(timestamps)), generator=self)) return metrics diff --git a/stonesoup/metricgenerator/tests/test_basicmetrics.py b/stonesoup/metricgenerator/tests/test_basicmetrics.py index a4b8092ff..91d453a74 100644 --- a/stonesoup/metricgenerator/tests/test_basicmetrics.py +++ b/stonesoup/metricgenerator/tests/test_basicmetrics.py @@ -33,22 +33,22 @@ def test_basicmetrics(): correct_metrics = {TimeRangeMetric(title='Number of targets', value=3, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator), TimeRangeMetric(title='Number of tracks', value=4, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator), TimeRangeMetric(title='Track-to-target ratio', value=4 / 3, time_range=TimeRange( - start_timestamp=start_time, - end_timestamp=start_time + + start=start_time, + end=start_time + datetime.timedelta(seconds=4)), generator=generator)} for metric_name in ["Number of targets", diff --git a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py index 9c7226eb5..bc0e26824 100644 --- a/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tests/test_tracktotruthmetrics.py @@ -70,7 +70,9 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea # Test longest_track_time_on_truth assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[0]) == 2 - assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 1 + # Associations 1 and 2 (starting from 0) will join together + # because of the AssociationSet._simplify method, so this will be 2 + assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[1]) == 2 assert siap_generator.longest_track_time_on_truth(trial_manager, trial_truths[2]) == 1 # Test compute_metric @@ -88,8 +90,8 @@ def test_siap(trial_manager, trial_truths, trial_tracks, trial_associations, mea for metric in metrics: assert isinstance(metric, TimeRangeMetric) - assert metric.time_range.start_timestamp == timestamps[0] - assert metric.time_range.end_timestamp == timestamps[3] + assert metric.time_range.start == timestamps[0] + assert metric.time_range.end == timestamps[3] assert metric.generator == siap_generator if metric.title.endswith(" at times"): @@ -173,8 +175,8 @@ def test_id_siap(trial_manager, trial_truths, trial_tracks, trial_associations, for metric in metrics: assert isinstance(metric, TimeRangeMetric) - assert metric.time_range.start_timestamp == timestamps[0] - assert metric.time_range.end_timestamp == timestamps[3] + assert metric.time_range.start == timestamps[0] + assert metric.time_range.end == timestamps[3] assert metric.generator == siap_generator if metric.title.endswith(" at times"): diff --git a/stonesoup/metricgenerator/tracktotruthmetrics.py b/stonesoup/metricgenerator/tracktotruthmetrics.py index 14862faa0..6db40e1e8 100644 --- a/stonesoup/metricgenerator/tracktotruthmetrics.py +++ b/stonesoup/metricgenerator/tracktotruthmetrics.py @@ -1,5 +1,3 @@ -from operator import attrgetter - from .base import MetricGenerator from ..base import Property from ..measures import Measure @@ -278,7 +276,7 @@ def num_associated_tracks_at_time(manager, timestamp): Number of associated tracks held by `manager` at `timestamp`. """ associations = manager.association_set.associations_at_timestamp(timestamp) - association_objects = {thing for assoc in associations for thing in assoc.objects} + association_objects = associations.object_set return sum(1 for track in manager.tracks if track in association_objects) @@ -355,15 +353,16 @@ def total_time_tracked(manager, truth): return 0 truth_timestamps = sorted(state.timestamp for state in truth.states) - total_time = 0 for current_time, next_time in zip(truth_timestamps[:-1], truth_timestamps[1:]): + time_range = TimeRange(current_time, next_time) for assoc in assocs: - # If both timestamps are in one association then add the difference to the total - # difference and stop looking - if current_time in assoc.time_range and next_time in assoc.time_range: - total_time += (next_time - current_time).total_seconds() - break + # If there is some overlap between time ranges, add this to total_time + if time_range & assoc.time_range: + total_time += (time_range & assoc.time_range).duration.total_seconds() + time_range = time_range - (time_range & assoc.time_range) + if not time_range: + break return total_time @staticmethod @@ -385,7 +384,7 @@ def min_num_tracks_needed_to_track(manager, truth): Minimum number of tracks needed to track `truth` """ assocs = sorted(manager.association_set.associations_including_objects([truth]), - key=attrgetter('time_range.end_timestamp'), + key=lambda assoc: assoc.time_range.key_times[-1], reverse=True) if len(assocs) == 0: @@ -402,7 +401,12 @@ def min_num_tracks_needed_to_track(manager, truth): if not assoc_at_time: timestamp_index += 1 else: - end_time = assoc_at_time.time_range.end_timestamp + key_times = assoc_at_time.time_range.key_times + # If the current time is a start of a TimeRange we need strict inequality + if current_time in key_times and key_times.index(current_time) % 2 == 0: + end_time = min([time for time in key_times if time > current_time]) + else: + end_time = min([time for time in key_times if time >= current_time]) num_tracks_needed += 1 # If not yet at the end of the truth timestamps indices, move on to the next diff --git a/stonesoup/types/association.py b/stonesoup/types/association.py index 40516e0d9..cd9cabcfb 100644 --- a/stonesoup/types/association.py +++ b/stonesoup/types/association.py @@ -1,9 +1,10 @@ import datetime -from typing import Set +from typing import Set, Union +from itertools import combinations from ..base import Property from .base import Type -from .time import TimeRange +from .time import TimeRange, CompoundTimeRange class Association(Type): @@ -48,9 +49,13 @@ class TimeRangeAssociation(Association): range of times """ - time_range: TimeRange = Property( + time_range: Union[CompoundTimeRange, TimeRange] = Property( default=None, doc="Range of times that association exists over. Default is None") + @property + def duration(self): + return self.time_range.duration.total_seconds() + class AssociationSet(Type): """AssociationSet type @@ -66,6 +71,92 @@ def __init__(self, associations=None, *args, **kwargs): super().__init__(associations, *args, **kwargs) if self.associations is None: self.associations = set() + if not all(isinstance(member, Association) for member in self.associations): + raise TypeError("Association set must contain only Association instances") + self._simplify() + + def __eq__(self, other): + return self.associations == other.associations + + def add(self, association): + if association is None: + return + elif isinstance(association, Association): + self.associations.add(association) + elif isinstance(association, AssociationSet): + for component in association: + self.add(component) + else: + raise TypeError("Supplied parameter must be an Association or AssociationSet") + self._simplify() + + def _simplify(self): + """Where multiple associations describe the same pair of objects, combine them into one. + This is only implemented for pairs with a time_range attribute - others will be skipped + """ + to_remove = set() + for (assoc1, assoc2) in combinations(self.associations, 2): + if not (len(assoc1.objects) == 2 and len(assoc2.objects) == 2) or \ + not(hasattr(assoc1, 'time_range') and hasattr(assoc2, 'time_range')): + continue + if assoc1.objects == assoc2.objects: + if isinstance(assoc1.time_range, CompoundTimeRange): + assoc1.time_range.add(assoc2.time_range) + to_remove.add(assoc2) + elif isinstance(assoc2.time_range, CompoundTimeRange): + assoc2.time_range.add(assoc1.time_range) + to_remove.add(assoc1) + else: + assoc1.time_range = CompoundTimeRange([assoc1.time_range, assoc2.time_range]) + to_remove.add(assoc2) + for assoc in to_remove: + self.remove(assoc) + + def remove(self, association): + if association is None: + return + elif isinstance(association, Association): + if association not in self.associations: + raise ValueError("Supplied parameter must be contained by this instance") + self.associations.remove(association) + elif isinstance(association, AssociationSet): + for component in association: + self.remove(component) + else: + raise TypeError("Supplied parameter must be an Association or AssociationSet") + + @property + def key_times(self): + """Returns all timestamps at which a component starts or ends, or where there is a + :class:`.~SingleTimeAssociation`.""" + key_times = list(self.overall_time_range.key_times) + for association in self.associations: + if isinstance(association, SingleTimeAssociation): + key_times.append(association.timestamp) + return sorted(key_times) + + @property + def overall_time_range(self): + """Returns a :class:`~.CompoundTimeRange` covering all times at which at least + one association is active. + + Note: :class:`~.SingleTimeAssociation` are not counted + """ + overall_range = CompoundTimeRange() + for association in self.associations: + if hasattr(association, 'time_range'): + overall_range.add(association.time_range) + return overall_range + + @property + def object_set(self): + """Returns a set of all objects contained by this instance. + """ + object_set = set() + for assoc in self.associations: + for obj in assoc.objects: + object_set.add(obj) + return object_set def associations_at_timestamp(self, timestamp): """Return the associations that exist at a given timestamp @@ -80,9 +171,11 @@ def associations_at_timestamp(self, timestamp): Returns ------- - : set of :class:`~.Association` + : :class:`~.AssociationSet` Associations which occur at specified timestamp """ + if not isinstance(timestamp, datetime.datetime): + raise TypeError("Supplied parameter must be a datetime.datetime object") ret_associations = set() for association in self.associations: # If the association is at a single time @@ -92,7 +185,7 @@ def associations_at_timestamp(self, timestamp): else: if timestamp in association.time_range: ret_associations.add(association) - return ret_associations + return AssociationSet(ret_associations) def associations_including_objects(self, objects): """Return associations that include all the given objects @@ -106,18 +199,17 @@ def associations_including_objects(self, objects): Set of objects to look for in associations Returns ------- - : set of :class:`~.Association` - A set of objects which have been associated + : class:`~.AssociationSet` + A set of associations containing every member of objects """ - # Ensure objects is iterable if not isinstance(objects, list) and not isinstance(objects, set): objects = {objects} - return {association - for association in self.associations - for object_ in objects - if object_ in association.objects} + return AssociationSet({association + for association in self.associations + if all(object_ in association.objects + for object_ in objects)}) def __contains__(self, item): return item in self.associations diff --git a/stonesoup/types/interval.py b/stonesoup/types/interval.py index 6a47ca8b1..4bdd840ca 100644 --- a/stonesoup/types/interval.py +++ b/stonesoup/types/interval.py @@ -15,37 +15,42 @@ class Interval(Type): Represents a continuous, closed interval of real numbers. Represented by a lower and upper bound. """ - left: Union[int, float] = Property(doc="Lower bound of interval") - right: Union[int, float] = Property(doc="Upper bound of interval") + start: Union[int, float] = Property(doc="Lower bound of interval") + end: Union[int, float] = Property(doc="Upper bound of interval") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.left >= self.right: + if self.start >= self.end: raise ValueError('Must have left < right') def __hash__(self): - return hash((self.left, self.right)) + return hash((self.start, self.end)) + + @property + def left(self): + return self.start + + @property + def right(self): + return self.end @property def length(self): - return self.right - self.left + return self.end - self.start def __contains__(self, item): if isinstance(item, Real): - return self.left <= item <= self.right + return self.start <= item <= self.end elif isinstance(item, Interval): return self & item == item else: return False def __str__(self): - return '[{left}, {right}]'.format(left=self.left, right=self.right) - - def __repr__(self): - return 'Interval{interval}'.format(interval=str(self)) + return '[{left}, {right}]'.format(left=self.start, right=self.end) def __eq__(self, other): - return isinstance(other, Interval) and (self.left, self.right) == (other.left, other.right) + return isinstance(other, Interval) and (self.start, self.end) == (other.start, other.end) def __and__(self, other): """Set-like intersection""" @@ -54,11 +59,11 @@ def __and__(self, other): raise ValueError("Can only intersect with Interval types") if not self.isdisjoint(other): - new_interval = (max(self.left, other.left), min(self.right, other.right)) + new_interval = max(self.start, other.start), min(self.end, other.end) if new_interval[0] == new_interval[1]: return None else: - return Interval(*new_interval) + return type(self)(new_interval[0], new_interval[1]) else: return None @@ -69,7 +74,7 @@ def __or__(self, other): raise ValueError("Can only union with Interval types") if not self.isdisjoint(other): - return [Interval(min(self.left, other.left), max(self.right, other.right))] + return [type(self)(min(self.start, other.start), max(self.end, other.end))] else: return [copy.copy(self), copy.copy(other)] @@ -82,14 +87,14 @@ def __sub__(self, other): raise ValueError("Can only subtract Interval types from Interval types") elif self.isdisjoint(other): return [copy.copy(self)] - elif other.left <= self.left and self.right <= other.right: + elif other.start <= self.start and self.end <= other.end: return [None] - elif other.left <= self.left: - return [Interval(other.right, self.right)] - elif self.right <= other.right: - return [Interval(self.left, other.left)] + elif other.start <= self.start: + return [Interval(other.end, self.end)] + elif self.end <= other.end: + return [Interval(self.start, other.start)] else: - return [Interval(self.left, other.left), Interval(other.right, self.right)] + return [Interval(self.start, other.start), Interval(other.end, self.end)] def __xor__(self, other): """Set-like symmetric difference""" @@ -117,7 +122,7 @@ def __lt__(self, other): if not isinstance(other, Interval): raise ValueError("Can only compare Interval types to Interval types") - return other.left < self.left and self.right < other.right + return other.start < self.start and self.end < other.end def __ge__(self, other): """Superset check""" @@ -144,12 +149,10 @@ def isdisjoint(self, other): if not isinstance(other, Interval): raise ValueError("Interval types can only overlap with Interval types") - lb = min(self.left, other.left) - ub = max(self.right, other.right) + max_start = max(self.start, other.start) + min_end = min(self.end, other.end) - return not ((ub - lb < self.length + other.length) - or (self.right == other.left) - or (other.right == self.left)) + return max_start > min_end class Intervals(Type): @@ -173,7 +176,7 @@ def __init__(self, *args, **kwargs): if isinstance(self.intervals, (Interval, Tuple)): self.intervals = [self.intervals] else: - raise ValueError("Must contain Interval types") + raise TypeError("Must contain Interval types") elif len(self.intervals) == 2 and all(isinstance(elem, Real) for elem in self.intervals): self.intervals = [self.intervals] @@ -183,7 +186,7 @@ def __init__(self, *args, **kwargs): if isinstance(interval, Sequence) and len(interval) == 2: self.intervals[i] = Interval(*interval) else: - raise ValueError("Individual intervals must be an Interval or Sequence type") + raise TypeError("Individual intervals must be an Interval or Sequence type") self.intervals = self.get_merged_intervals(self.intervals) @@ -209,7 +212,6 @@ def isdisjoint(self, other): if not isinstance(other, Intervals): raise ValueError("Can only compare Intervals to Intervals") - return all(interval.isdisjoint(other_int) for other_int in other for interval in self) @staticmethod @@ -244,10 +246,7 @@ def __contains__(self, item): return any(item in interval for interval in self) def __str__(self): - return str([[interval.left, interval.right] for interval in self]) - - def __repr__(self): - return 'Intervals{intervals}'.format(intervals=str(self)) + return str([[interval.start, interval.end] for interval in self]) @property def length(self): @@ -264,7 +263,8 @@ def __reversed__(self): return self._iter(reverse=True) def __eq__(self, other): - + if len(self) == 0: + return len(other) == 0 if isinstance(other, Interval): other = Intervals(other) @@ -274,6 +274,8 @@ def __eq__(self, other): def __and__(self, other): """Set-like intersection""" + if other is None: + return False if not isinstance(other, (Interval, Intervals)): raise ValueError("Can only intersect with Intervals types") @@ -287,12 +289,11 @@ def __and__(self, other): if new_interval: new_intervals.append(new_interval) new_intervals = self.get_merged_intervals(new_intervals) - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __or__(self, other): """Set-like union""" - if not isinstance(other, (Interval, Intervals)): raise ValueError('Can only union with Intervals types') @@ -301,7 +302,7 @@ def __or__(self, other): new_intervals = self.intervals + other.intervals new_intervals = self.get_merged_intervals(new_intervals) - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __sub__(self, other): @@ -325,7 +326,7 @@ def __sub__(self, other): if diff[0] is not None: temp_intervals.extend(diff) new_intervals = temp_intervals - new_intervals = Intervals(new_intervals) + new_intervals = type(self)(new_intervals) return new_intervals def __xor__(self, other): @@ -335,7 +336,7 @@ def __xor__(self, other): raise ValueError("Can only compare Intervals from Intervals") if isinstance(other, Interval): - other = Intervals(other) + other = type(self)(other) return (self | other) - (self & other) diff --git a/stonesoup/types/tests/test_association.py b/stonesoup/types/tests/test_association.py index 7fd9ed564..7056223b9 100644 --- a/stonesoup/types/tests/test_association.py +++ b/stonesoup/types/tests/test_association.py @@ -6,7 +6,7 @@ from ..association import Association, AssociationPair, AssociationSet, \ SingleTimeAssociation, TimeRangeAssociation from ..detection import Detection -from ..time import TimeRange +from ..time import TimeRange, CompoundTimeRange def test_association(): @@ -40,7 +40,7 @@ def test_associationpair(): # 2 objects assoc = AssociationPair(set(objects[:2])) - np.array_equal(assoc.objects, set(objects[:2])) + assert np.array_equal(assoc.objects, set(objects[:2])) def test_singletimeassociation(): @@ -67,19 +67,25 @@ def test_timerangeassociation(): Detection(np.array([[5], [6]]))} timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) - timerange = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + timerange = TimeRange(start=timestamp1, end=timestamp2) + ctimerange = CompoundTimeRange([timerange]) assoc = TimeRangeAssociation(objects=objects, time_range=timerange) + cassoc = TimeRangeAssociation(objects=objects, time_range=ctimerange) assert assoc.objects == objects assert assoc.time_range == timerange + assert cassoc.objects == objects + assert cassoc.time_range == ctimerange def test_associationset(): # Set up some dummy variables timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) - timerange = TimeRange(start_timestamp=timestamp1, end_timestamp=timestamp2) + timestamp3 = datetime.datetime(2020, 3, 1, 1, 1, 1) + time_range = TimeRange(start=timestamp1, end=timestamp2) + time_range2 = TimeRange(start=timestamp2, end=timestamp3) objects_list = [Detection(np.array([[1], [2]])), Detection(np.array([[3], [4]])), @@ -89,7 +95,9 @@ def test_associationset(): timestamp=timestamp1) assoc2 = TimeRangeAssociation(objects=set(objects_list[1:]), - time_range=timerange) + time_range=time_range) + assoc2_same_objects = TimeRangeAssociation(objects=set(objects_list[1:]), + time_range=time_range2) assoc_set = AssociationSet({assoc1, assoc2}) @@ -105,26 +113,78 @@ def test_associationset(): assert len(assoc_set) == 2 + # test _simplify method + + simplify_test = AssociationSet({assoc1, assoc2, assoc2_same_objects}) + + assert len(simplify_test.associations) == 2 + # Test associations including objects # Object only present in object 1 assert assoc_set.associations_including_objects(objects_list[0]) \ - == {assoc1} + == AssociationSet({assoc1}) # Object present in both assert assoc_set.associations_including_objects(objects_list[1]) \ - == {assoc1, assoc2} + == AssociationSet({assoc1, assoc2}) # Object present in neither - assert not assoc_set.associations_including_objects( - Detection(np.array([[0], [0]]))) + assert assoc_set.associations_including_objects(Detection(np.array([[0], [0]]))) \ + == AssociationSet() # Test associations including timestamp # Timestamp present in one object assert assoc_set.associations_at_timestamp(timestamp2) \ - == {assoc2} + == AssociationSet({assoc2}) # Timestamp present in both assert assoc_set.associations_at_timestamp(timestamp1) \ - == {assoc1, assoc2} + == AssociationSet({assoc1, assoc2}) # Timestamp not present in either - timestamp3 = datetime.datetime(2018, 3, 1, 6, 8, 35) - assert not assoc_set.associations_at_timestamp(timestamp3) + timestamp4 = datetime.datetime(2022, 3, 1, 6, 8, 35) + assert assoc_set.associations_at_timestamp(timestamp4) \ + == AssociationSet() + + +def test_association_set_add_remove(): + test = AssociationSet() + with pytest.raises(TypeError): + test.add("a string") + with pytest.raises(TypeError): + test.remove("a string") + objects = {Detection(np.array([[1], [2]])), + Detection(np.array([[3], [4]])), + Detection(np.array([[5], [6]]))} + + assoc = Association(objects) + assert assoc not in test.associations + with pytest.raises(ValueError): + test.remove(assoc) + test.add(assoc) + assert assoc in test.associations + test.remove(assoc) + assert assoc not in test.associations + + +def test_association_set_properties(): + test = AssociationSet() + assert test.key_times == [] + assert test.overall_time_range == CompoundTimeRange() + assert test.object_set == set() + objects = [Detection(np.array([[1], [2]])), + Detection(np.array([[3], [4]])), + Detection(np.array([[5], [6]]))] + assoc = Association(set(objects)) + test2 = AssociationSet({assoc}) + assert test2.key_times == [] + assert test2.overall_time_range == CompoundTimeRange() + assert test2.object_set == set(objects) + timestamp1 = datetime.datetime(2018, 3, 1, 5, 3, 35) + timestamp2 = datetime.datetime(2018, 3, 1, 5, 8, 35) + time_range = TimeRange(start=timestamp1, end=timestamp2) + com_time_range = CompoundTimeRange([time_range]) + assoc2 = TimeRangeAssociation(objects=set(objects[1:]), + time_range=com_time_range) + test3 = AssociationSet({assoc, assoc2}) + assert test3.key_times == [timestamp1, timestamp2] + assert test3.overall_time_range == com_time_range + assert test3.object_set == set(objects) diff --git a/stonesoup/types/tests/test_interval.py b/stonesoup/types/tests/test_interval.py index 0da6672a2..026115e32 100644 --- a/stonesoup/types/tests/test_interval.py +++ b/stonesoup/types/tests/test_interval.py @@ -38,11 +38,6 @@ def test_interval_contains(): assert 'a string' not in a -def test_interval_str(): - assert str(a) == '[{left}, {right}]'.format(left=0, right=1) - assert a.__repr__() == 'Interval{interval}'.format(interval=str(a)) - - def test_interval_eq(): assert a == Interval(0, 1) assert a != (0, 1) @@ -122,6 +117,7 @@ def test_interval_disjoint(): assert a.isdisjoint(b) # No overlap assert not a.isdisjoint(c) # Overlap assert not a.isdisjoint(d) # Meet and no overlap + assert not a.isdisjoint(a) # a is not disjoint with itself def test_intervals_init(): @@ -134,7 +130,7 @@ def test_intervals_init(): assert len(temp.intervals) == 1 assert temp.intervals[0] == a - with pytest.raises(ValueError, match="Must contain Interval types"): + with pytest.raises(TypeError, match="Must contain Interval types"): Intervals('a string') temp = Intervals((0, 1)) @@ -146,7 +142,7 @@ def test_intervals_init(): assert len(A.intervals) == 2 assert A.intervals == [a, b] # Converts lists of tuples to lists of Interval types - with pytest.raises(ValueError, + with pytest.raises(TypeError, match="Individual intervals must be an Interval or Sequence type"): Intervals([(0, 1), 'a string']) @@ -171,11 +167,9 @@ def test_intervals_overlap(): def test_intervals_disjoint(): - with pytest.raises(ValueError, match="Can only compare Intervals to Intervals"): - A.isdisjoint('a string') - assert not A.isdisjoint(B) # Overlap assert not A.isdisjoint(C) # Meet with no overlap assert A.isdisjoint(D) + assert not A.isdisjoint(A) def test_intervals_merge(): @@ -196,17 +190,23 @@ def test_intervals_contains(): assert Intervals([Interval(0.25, 0.75)]) in A -def test_intervals_str(): - assert str(A) == str([[interval.left, interval.right] for interval in A]) - assert A.__repr__() == 'Intervals{intervals}'.format(intervals=str(A)) - - def test_intervals_len(): assert Intervals([]).length == 0 assert A.length == 2 assert Intervals([Interval(0.5, 0.75), Interval(0.1, 0.2)]).length == 0.35 +def test_intervals_str(): + assert str(A) == str([[interval.left, interval.right] for interval in A]) + assert A.__repr__() == ('Intervals(\n' + ' intervals=[Interval(\n' + ' start=0,\n' + ' end=1),\n' + ' Interval(\n' + ' start=2,\n' + ' end=3)])') + + def test_intervals_iter(): intervals = iter([a, b]) A_iter = iter(A) diff --git a/stonesoup/types/tests/test_metric.py b/stonesoup/types/tests/test_metric.py index 3058fdb97..aae9b2212 100644 --- a/stonesoup/types/tests/test_metric.py +++ b/stonesoup/types/tests/test_metric.py @@ -85,8 +85,8 @@ def test_timerangemetric(): value = 5 timestamp1 = datetime.datetime.now() timestamp2 = timestamp1 + datetime.timedelta(seconds=10) - time_range = TimeRange(start_timestamp=timestamp1, - end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, + end=timestamp2) class temp_generator(MetricGenerator): @@ -113,8 +113,8 @@ def test_timerangeplottingmetric(): value = 5 timestamp1 = datetime.datetime.now() timestamp2 = timestamp1 + datetime.timedelta(seconds=10) - time_range = TimeRange(start_timestamp=timestamp1, - end_timestamp=timestamp2) + time_range = TimeRange(start=timestamp1, + end=timestamp2) class temp_generator(MetricGenerator): diff --git a/stonesoup/types/tests/test_time.py b/stonesoup/types/tests/test_time.py index 9630f977b..d408b461c 100644 --- a/stonesoup/types/tests/test_time.py +++ b/stonesoup/types/tests/test_time.py @@ -2,69 +2,268 @@ import pytest -from ..time import TimeRange +from ..time import TimeRange, CompoundTimeRange -def test_timerange(): - # Test time range initialisation - +@pytest.fixture +def times(): + # Note times are returned chronologically for ease of reading + before = datetime.datetime(year=2018, month=3, day=1, hour=3, minute=10, second=3) t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, second=35, microsecond=500) + inside = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=10, second=3) t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, second=41, microsecond=500) + after = datetime.datetime(year=2019, month=3, day=1, hour=6, minute=5, + second=41, microsecond=500) + long_after = datetime.datetime(year=2022, month=6, day=1, hour=6, minute=5, + second=41, microsecond=500) + return [before, t1, inside, t2, after, long_after] + + +def test_timerange(times): # Test creating without times with pytest.raises(TypeError): TimeRange() # Without start time with pytest.raises(TypeError): - TimeRange(start_timestamp=t1) + TimeRange(end=times[3]) # Without end time with pytest.raises(TypeError): - TimeRange(end_timestamp=t2) + TimeRange(start=times[1]) # Test an error is caught when end is after start with pytest.raises(ValueError): - TimeRange(start_timestamp=t2, end_timestamp=t1) + TimeRange(start=times[3], end=times[1]) + + # Test with wrong types for time_ranges + with pytest.raises(TypeError): + CompoundTimeRange(42) + with pytest.raises(TypeError): + CompoundTimeRange([times[1], times[3]]) + + test_range = TimeRange(start=times[1], end=times[3]) + + test_compound = CompoundTimeRange() - test_range = test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) + test_compound2 = CompoundTimeRange([test_range]) - assert test_range.start_timestamp == t1 - assert test_range.end_timestamp == t2 + # Test fuse_components method + fuse_test = CompoundTimeRange([test_range, TimeRange(times[3], times[4])]) + assert test_range.start_timestamp == times[1] + assert test_range.end_timestamp == times[3] + assert len(test_compound.time_ranges) == 0 + assert test_compound2.time_ranges[0] == test_range + assert fuse_test.time_ranges == [TimeRange(times[1], times[4])] -def test_duration(): + +def test_duration(times): # Test that duration is calculated properly - t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, - second=35, microsecond=500) - t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, - second=41, microsecond=500) - test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) - assert test_range.duration == datetime.timedelta(seconds=3726) + # TimeRange + test_range = TimeRange(start=times[1], end=times[3]) + + # CompoundTimeRange + + # times[2] is inside of [1] and [3], so should be equivalent to a TimeRange(times[1], times[4]) + test_range2 = CompoundTimeRange([TimeRange(start=times[1], end=times[3]), + TimeRange(start=times[2], end=times[4])]) + + assert test_range.duration == times[3] - times[1] + assert test_range2.duration == times[4] - times[1] -def test_contains(): +def test_contains(times): # Test that timestamps are correctly determined to be in the range - t1 = datetime.datetime(year=2018, month=3, day=1, hour=5, minute=3, - second=35, microsecond=500) - t2 = datetime.datetime(year=2018, month=3, day=1, hour=6, minute=5, - second=41, microsecond=500) - test_range = TimeRange(start_timestamp=t1, end_timestamp=t2) + test_range = TimeRange(start=times[1], end=times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[1], times[4]) + + with pytest.raises(TypeError): + 16 in test3 + + assert times[1] in test_range + assert times[2] in test_range + assert times[3] in test_range + assert not times[0] in test_range + assert not times[4] in test_range + + assert test2 in test_range + assert test_range not in test2 + assert test2 in test2 + assert test3 not in test_range + + # CompoundTimeRange + + compound_test = CompoundTimeRange([test_range]) + # Should be in neither + test_range2 = TimeRange(times[4], times[5]) + # Should be in compound_range2 but not 1 + test_range3 = TimeRange(times[2], times[4]) + compound_test2 = CompoundTimeRange([test_range, TimeRange(times[3], times[4])]) + + assert compound_test in compound_test2 + assert times[2] in compound_test + assert times[2] in compound_test2 + assert test_range2 not in compound_test + assert test_range2 not in compound_test2 + assert test_range3 not in compound_test + assert test_range3 in compound_test2 + + +def test_equality(times): + test1 = TimeRange(times[1], times[2]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[1], times[3]) + + assert test1 != "a string" + + assert test1 == test2 + assert test2 == test1 + assert test3 != test1 and test1 != test3 - # Inside - assert datetime.datetime( - year=2018, month=3, day=1, hour=5, minute=10, second=3) in test_range + ctest1 = CompoundTimeRange([test1, test3]) + ctest2 = CompoundTimeRange([TimeRange(times[1], times[3])]) - # Outside + assert ctest2 != "a string" - assert not datetime.datetime( - year=2018, month=3, day=1, hour=3, minute=10, second=3) in test_range + assert ctest1 == ctest2 and ctest2 == ctest1 + ctest2.add(TimeRange(times[3], times[4])) + assert ctest1 != ctest2 and ctest2 != ctest1 - # Lower edge - assert t1 in test_range + assert CompoundTimeRange() == CompoundTimeRange() + assert ctest1 != CompoundTimeRange() + assert CompoundTimeRange() != ctest1 - # Upper edge - assert t2 in test_range + +def test_minus(times): + # Test the minus function + test1 = TimeRange(times[1], times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[2], times[3]) + test4 = TimeRange(times[4], times[5]) + + with pytest.raises(TypeError): + test1 - 15 + + assert test1 - test2 == test3 + assert test1 == test1 - None + assert test2 - test1 is None + + ctest1 = CompoundTimeRange([test2, test4]) + ctest2 = CompoundTimeRange([test1, test2]) + ctest3 = CompoundTimeRange([test4]) + + with pytest.raises(TypeError): + ctest1 - 15 + + assert ctest1 - ctest2 == ctest3 + assert ctest1 - ctest1 == CompoundTimeRange() + assert ctest3 - ctest1 == CompoundTimeRange() + + assert test1 - ctest1 == TimeRange(times[2], times[3]) + assert test4 - ctest2 == test4 + assert ctest1 - test2 == ctest3 + + +def test_and(times): + test1 = TimeRange(times[1], times[3]) + test2 = TimeRange(times[1], times[2]) + test3 = TimeRange(times[4], times[5]) + + ctest1 = CompoundTimeRange([test2, test3]) + ctest2 = CompoundTimeRange([test1, test2]) + + assert test2 & ctest1 == ctest1 & test2 + + assert test1 & test1 == test1 + assert test1 & None is None + assert test1 & test2 == test2 + assert test2 & test1 == test2 + assert ctest1 & None is None + assert ctest1 & test2 == CompoundTimeRange([test2]) + assert ctest1 & ctest2 == CompoundTimeRange([test2]) + assert ctest1 & ctest2 == ctest2 & ctest1 + + +def test_key_times(times): + test1 = CompoundTimeRange([TimeRange(times[0], times[1]), + TimeRange(times[3], times[4])]) + test2 = CompoundTimeRange([TimeRange(times[3], times[4]), + TimeRange(times[0], times[1])]) + test3 = CompoundTimeRange() + test4 = CompoundTimeRange([TimeRange(times[0], times[4])]) + test5 = TimeRange(times[0], times[4]) + + assert test1.key_times == [times[0], times[1], times[3], times[4]] + assert test2.key_times == [times[0], times[1], times[3], times[4]] + assert test3.key_times == [] + assert test4.key_times == [times[0], times[4]] + assert test5.key_times == [times[0], times[4]] + + +def test_remove_overlap(times): + test1_ro = CompoundTimeRange([TimeRange(times[0], times[1]), + TimeRange(times[3], times[4])]) + test1_ro._remove_overlap() + test2_ro = CompoundTimeRange([TimeRange(times[3], times[4]), + TimeRange(times[0], times[4])]) + test2_ro._remove_overlap() + test3_ro = CompoundTimeRange() + test3_ro._remove_overlap() + + test2 = CompoundTimeRange([TimeRange(times[0], times[4])]) + + assert test1_ro.duration == TimeRange(times[0], times[1]).duration + \ + TimeRange(times[3], times[4]).duration + assert test2_ro == test2 + assert test3_ro == CompoundTimeRange() + + +def test_fuse_components(times): + # Note this is called inside the __init__ method, but is tested here explicitly + test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) + test1._fuse_components() + test2 = CompoundTimeRange([TimeRange(times[1], times[2]), + TimeRange(times[2], times[4])]) + test2._fuse_components() + assert test1 == CompoundTimeRange([TimeRange(times[1], times[2])]) + assert test2 == CompoundTimeRange([TimeRange(times[1], times[4])]) + + +def test_add(times): + test1 = CompoundTimeRange([TimeRange(times[1], times[2])]) + test2 = CompoundTimeRange([TimeRange(times[0], times[1])]) + test3 = CompoundTimeRange([TimeRange(times[0], times[2])]) + test4 = CompoundTimeRange([TimeRange(times[0], times[2]), + TimeRange(times[4], times[5])]) + with pytest.raises(TypeError): + test1.add(True) + assert test1 != test2 + test1_copy = test1 + test1.add(None) + assert test1 == test1_copy + test1.add(test2) + assert test1 == test3 + test3.add(TimeRange(times[4], times[5])) + assert test3 == test4 + + +def test_remove(times): + test1 = CompoundTimeRange([TimeRange(times[0], times[2]), + TimeRange(times[4], times[5])]) + with pytest.raises(TypeError): + test1.remove(0.4) + with pytest.raises(ValueError): + test1.remove(TimeRange(times[2], times[3])) + # Remove part of a component + test1.remove(TimeRange(times[0], times[1])) + assert test1 == CompoundTimeRange([TimeRange(times[1], times[2]), + TimeRange(times[4], times[5])]) + # Remove whole component + test1.remove(TimeRange(times[1], times[2])) + assert test1 == CompoundTimeRange([TimeRange(times[4], times[5])]) diff --git a/stonesoup/types/time.py b/stonesoup/types/time.py index 0c674eb2a..97dc95da9 100644 --- a/stonesoup/types/time.py +++ b/stonesoup/types/time.py @@ -1,10 +1,12 @@ import datetime +import copy +from itertools import combinations, permutations from ..base import Property -from .base import Type +from ..types.interval import Interval, Intervals -class TimeRange(Type): +class TimeRange(Interval): """TimeRange type An object representing a time range between two timestamps. @@ -21,33 +23,279 @@ class TimeRange(Type): True """ - start_timestamp: datetime.datetime = Property(doc="Start of the time range") - end_timestamp: datetime.datetime = Property(doc="End of the time range") + start: datetime.datetime = Property(doc="Start of the time range") + end: datetime.datetime = Property(doc="End of the time range") - def __init__(self, start_timestamp, end_timestamp, *args, **kwargs): - if end_timestamp < start_timestamp: - raise ValueError("start_timestamp must be before end_timestamp") - super().__init__(start_timestamp, end_timestamp, *args, **kwargs) + @property + def start_timestamp(self): + return self.start + + @property + def end_timestamp(self): + return self.end @property def duration(self): """Duration of the time range""" - return self.end_timestamp - self.start_timestamp + return self.length + + @property + def key_times(self): + """Times the TimeRange begins and ends""" + return [self.start, self.end] - def __contains__(self, timestamp): + def __contains__(self, time): """Checks if timestamp is within range Parameters ---------- - timestamp : datetime.datetime - Time stamp to check if within range + time : Union[datetime.datetime, TimeRange] + Time stamp or range to check if within range Returns ------- bool - `True` if timestamp within :attr:`start_timestamp` and - :attr:`end_timestamp` (inclusive) + `True` if timestamp within :attr:`start` and + :attr:`end` (inclusive) + """ + if isinstance(time, datetime.datetime): + return self.start <= time <= self.end + else: + return super().__contains__(time) + + def __eq__(self, other): + return isinstance(other, TimeRange) and super().__eq__(other) + + def __sub__(self, time_range): + """Removes the overlap between this instance and another :class:`~.TimeRange`, or + :class:`~.CompoundTimeRange`. + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + TimeRange + This instance less the overlap with the other time_range + """ + if time_range is None: + return copy.copy(self) + if not isinstance(time_range, TimeRange) and not isinstance(time_range, CompoundTimeRange): + raise TypeError("Supplied parameter must be a TimeRange or CompoundTimeRange object") + if isinstance(time_range, CompoundTimeRange): + ans = self + for t_range in time_range.time_ranges: + ans -= t_range + if not ans: + return None + return ans + else: + overlap = self & time_range + if overlap is None: + return self + if self == overlap: + return None + if self.start < overlap.start: + start = self.start + else: + start = overlap.end + if self.end > overlap.end: + end = self.end + else: + end = overlap.start + if self.start < overlap.start and \ + self.end > overlap.end: + return CompoundTimeRange([TimeRange(self.start, overlap.start), + TimeRange(overlap.end, self.end)]) + else: + return TimeRange(start, end) + + def __and__(self, time_range): + """Finds the intersection between this instance and another :class:`~.TimeRange` or + :class:`.~CompoundTimeRange` + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + TimeRange + The times contained by both this and `time_range` """ + if time_range is None: + return None + if isinstance(time_range, CompoundTimeRange): + return time_range & self + if not isinstance(time_range, TimeRange): + raise TypeError("Supplied parameter must be a TimeRange object") + return super().__and__(time_range) + + def __or__(self, other): + return super().__or__(other) + + +class CompoundTimeRange(Intervals): + """CompoundTimeRange type + + A container class representing one or more :class:`~.TimeRange` objects together + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not isinstance(self.time_ranges, list): + raise TypeError("Time_ranges must be a list") + for component in self.time_ranges: + if not isinstance(component, TimeRange): + raise TypeError("Time_ranges must contain only TimeRange objects") + self._remove_overlap() + self._fuse_components() + + @property + def time_ranges(self): + return self.intervals + + @property + def duration(self): + """Duration of the time range""" + if len(self.time_ranges) == 0: + return datetime.timedelta(0) + total_duration = datetime.timedelta(0) + for component in self.time_ranges: + total_duration = total_duration + component.duration + return total_duration - return self.start_timestamp <= timestamp <= self.end_timestamp + @property + def key_times(self): + """Returns all timestamps at which a component starts or ends""" + key_times = set() + for component in self.time_ranges: + key_times.add(component.start) + key_times.add(component.end) + return sorted(key_times) + + def _remove_overlap(self): + """Removes overlap between components of `time_ranges`""" + if len(self.time_ranges) in {0, 1}: + return + if all([component & component2 is None for (component, component2) in + combinations(self.time_ranges, 2)]): + return + overlap_check = CompoundTimeRange() + for time_range in self.time_ranges: + if time_range - overlap_check: + overlap_check.add(time_range - overlap_check & time_range) + self.intervals = copy.copy(overlap_check.time_ranges) + + def _fuse_components(self): + """Fuses two time ranges [a,b], [b,c] into [a,c] for all such pairs in this instance""" + for (component, component2) in permutations(self.time_ranges, 2): + if component.end == component2.start: + fused_component = TimeRange(component.start, component2.end) + self.remove(component) + self.remove(component2) + self.add(fused_component) + # To avoid issues with having removed objects from the permutations + self._fuse_components() + + def add(self, time_range): + """Add a :class:`~.TimeRange` or :class:`~.CompoundTimeRange` object to `time_ranges`.""" + if time_range is None: + return + if isinstance(time_range, CompoundTimeRange): + for component in time_range.time_ranges: + self.add(component) + elif isinstance(time_range, TimeRange): + self.time_ranges.append(time_range) + else: + raise TypeError("Supplied parameter must be a TimeRange or CompoundTimeRange object") + self._remove_overlap() + self._fuse_components() + + def remove(self, time_range): + """Removes a :class:`.~TimeRange` object from the time ranges. + It must be a member of self.time_ranges""" + if not isinstance(time_range, TimeRange): + raise TypeError("Supplied parameter must be a TimeRange object") + if time_range in self.time_ranges: + self.time_ranges.remove(time_range) + elif time_range in self: + for component in self.time_ranges: + if time_range in component: + new = component - time_range + self.time_ranges.remove(component) + self.add(new) + else: + raise ValueError("Supplied parameter must be a member of time_ranges") + + def __contains__(self, time): + """Checks if timestamp or is within range + + Parameters + ---------- + time : Union[datetime.datetime, TimeRange, CompoundTimeRange] + Time stamp or range to check if contained within this instance + + Returns + ------- + bool + `True` if time is fully contained within this instance + """ + + if isinstance(time, datetime.datetime): + for component in self.time_ranges: + if time in component: + return True + return False + elif isinstance(time, (TimeRange, CompoundTimeRange)): + return super().__contains__(time) + else: + raise TypeError("Supplied parameter must be an instance of either " + "datetime, TimeRange, or CompoundTimeRange") + + def __eq__(self, other): + return isinstance(other, CompoundTimeRange) and super().__eq__(other) + + def __sub__(self, time_range): + """Removes any overlap between this and another :class:`~.TimeRange` or + :class:`.~CompoundTimeRange` from this instance + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + CompoundTimeRange + The times contained by this but not time_range. May be empty. + """ + if time_range is None: + return copy.copy(self) + ans = CompoundTimeRange() + for component in self.time_ranges: + ans.add(component - time_range) + return ans + + def __and__(self, time_range): + """Finds the intersection between this instance and another time range + + In the case of an input :class:`~.CompoundTimeRange` this is done recursively. + + Parameters + ---------- + time_range: Union[TimeRange, CompoundTimeRange] + + Returns + ------- + CompoundTimeRange + The times contained by both this and time_range + """ + if time_range is None: + return None + if not isinstance(time_range, (TimeRange, CompoundTimeRange)): + raise TypeError("Supplied parameter must be an instance of either " + "TimeRange, or CompoundTimeRange") + else: + return super().__and__(time_range)