From c282c45c1b2f5677f74a02fb40ac5ad4990be4b1 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Tue, 7 Apr 2020 11:46:31 +0100 Subject: [PATCH 1/2] Modify detection and ground truth CSV readers - Modifications try to reuse code more now where possible - Groundtruth now includes id - Groundtruth and detctions that of the same timestamp are yielded together --- stonesoup/reader/generic.py | 181 +++++++++++++------------ stonesoup/reader/tests/test_generic.py | 117 +++++++++++----- 2 files changed, 176 insertions(+), 122 deletions(-) diff --git a/stonesoup/reader/generic.py b/stonesoup/reader/generic.py index 1857d2774..a4ed0ce73 100644 --- a/stonesoup/reader/generic.py +++ b/stonesoup/reader/generic.py @@ -6,22 +6,21 @@ """ import csv -from collections import defaultdict from datetime import datetime, timedelta from math import modf import numpy as np from dateutil.parser import parse -from ..base import Property -from ..types.detection import Detection from .base import GroundTruthReader, DetectionReader from .file import TextFileReader -from stonesoup.buffered_generator import BufferedGenerator +from ..base import Property +from ..buffered_generator import BufferedGenerator +from ..types.detection import Detection from ..types.groundtruth import GroundTruthPath, GroundTruthState -class CSVGroundTruthReader(GroundTruthReader, TextFileReader): +class _CSVReader(TextFileReader): state_vector_fields = Property( [str], doc='List of columns names to be used in state vector') time_field = Property( @@ -30,103 +29,109 @@ class CSVGroundTruthReader(GroundTruthReader, TextFileReader): str, default=None, doc='Optional datetime format') timestamp = Property( bool, default=False, doc='Treat time field as a timestamp from epoch') - path_id_field = Property( - str, doc='Name of column to be used as path ID') + metadata_fields = Property( + [str], default=None, doc='List of columns to be saved as metadata, default all') csv_options = Property( - dict, default={}, - doc='Keyword arguments for the underlying csv reader') + dict, default={}, doc='Keyword arguments for the underlying csv reader') + + def _get_metadata(self, row): + if self.metadata_fields is None: + local_metadata = dict(row) + for key in list(local_metadata): + if key == self.time_field or key in self.state_vector_fields: + del local_metadata[key] + else: + local_metadata = {field: row[field] + for field in self.metadata_fields + if field in row} + return local_metadata + + def _get_time(self, row): + if self.time_field_format is not None: + time_field_value = datetime.strptime(row[self.time_field], self.time_field_format) + elif self.timestamp is True: + fractional, timestamp = modf(float(row[self.time_field])) + time_field_value = datetime.utcfromtimestamp(int(timestamp)) + time_field_value += timedelta(microseconds=fractional * 1E6) + else: + time_field_value = parse(row[self.time_field], ignoretz=True) + return time_field_value + + +class CSVGroundTruthReader(GroundTruthReader, _CSVReader): + """A simple reader for csv files of truth data. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._groundtruth_paths = set() + CSV file must have headers, as these are used to determine which fields + to use to generate the ground truth state. Those states with the same ID will be put into + a :class:`~.GroundTruthPath` in sequence, and all paths that are updated at the same time + are yielded together, and such assumes file is in time order. - @property - def groundtruth_paths(self): - return self._groundtruth_paths.copy() + Parameters + ---------- + """ + path_id_field = Property( + str, doc='Name of column to be used as path ID') @BufferedGenerator.generator_method def groundtruth_paths_gen(self): with self.path.open(encoding=self.encoding, newline='') as csv_file: - groundtruth_dict = defaultdict(GroundTruthPath) - - reader = csv.DictReader(csv_file, **self.csv_options) - for row in reader: - if self.time_field_format is not None: - time_field_value = datetime.strptime( - row[self.time_field], self.time_field_format) - elif self.timestamp is True: - fractional, timestamp = modf(float(row[self.time_field])) - time_field_value = datetime.utcfromtimestamp( - int(timestamp)) - time_field_value += timedelta(microseconds=fractional*1E6) - else: - time_field_value = parse(row[self.time_field]) - - state = GroundTruthState(np.array( - [[row[col_name]] for col_name in self.state_vector_fields], - dtype=np.float32), time_field_value) - - groundtruth_dict[row[self.path_id_field]].append(state) - self._groundtruth_paths = set(groundtruth_dict.values()) - - yield time_field_value, self._groundtruth_paths - - -class CSVDetectionReader(DetectionReader, TextFileReader): + groundtruth_dict = {} + updated_paths = set() + previous_time = None + for row in csv.DictReader(csv_file, **self.csv_options): + + time = self._get_time(row) + if previous_time is not None and previous_time != time: + yield previous_time, updated_paths + updated_paths = set() + previous_time = time + + state = GroundTruthState( + np.array([[row[col_name]] for col_name in self.state_vector_fields], + dtype=np.float_), + timestamp=self._get_time(row), + metadata=self._get_metadata(row)) + + id_ = row[self.path_id_field] + if id_ not in groundtruth_dict: + groundtruth_dict[id_] = GroundTruthPath(id=id_) + groundtruth_path = groundtruth_dict[id_] + groundtruth_path.append(state) + updated_paths.add(groundtruth_path) + + # Yield remaining + yield previous_time, updated_paths + + +class CSVDetectionReader(DetectionReader, _CSVReader): """A simple detection reader for csv files of detections. - CSV file must have headers, as these are used to determine which fields - to use to generate the detection. + CSV file must have headers, as these are used to determine which fields to use to generate + the detection. Detections at the same time are yielded together, and such assume file is in + time order. Parameters ---------- """ - state_vector_fields = Property( - [str], doc='List of columns names to be used in state vector') - time_field = Property( - str, doc='Name of column to be used as time field') - time_field_format = Property( - str, default=None, doc='Optional datetime format') - timestamp = Property( - bool, default=False, doc='Treat time field as a timestamp from epoch') - metadata_fields = Property( - [str], default=None, doc='List of columns to be saved as metadata, ' - 'default all') - csv_options = Property( - dict, default={}, - doc='Keyword arguments for the underlying csv reader') - @BufferedGenerator.generator_method def detections_gen(self): with self.path.open(encoding=self.encoding, newline='') as csv_file: - reader = csv.DictReader(csv_file, **self.csv_options) - for row in reader: - if self.time_field_format is not None: - time_field_value = datetime.strptime( - row[self.time_field], self.time_field_format) - elif self.timestamp is True: - fractional, timestamp = modf(float(row[self.time_field])) - time_field_value = datetime.utcfromtimestamp( - int(timestamp)) - time_field_value += timedelta(microseconds=fractional*1E6) - else: - time_field_value = parse(row[self.time_field]) - - if self.metadata_fields is None: - local_metadata = dict(row) - copy_local_metadata = dict(local_metadata) - for (key, value) in copy_local_metadata.items(): - if (key == self.time_field) or \ - (key in self.state_vector_fields): - del local_metadata[key] - else: - local_metadata = {field: row[field] - for field in self.metadata_fields - if field in row} - - detect = Detection(np.array( - [[row[col_name]] for col_name in self.state_vector_fields], - dtype=np.float32), time_field_value, - metadata=local_metadata) - yield time_field_value, {detect} + detections = set() + previous_time = None + for row in csv.DictReader(csv_file, **self.csv_options): + + time = self._get_time(row) + if previous_time is not None and previous_time != time: + yield previous_time, detections + detections = set() + previous_time = time + + detections.add(Detection( + np.array([[row[col_name]] for col_name in self.state_vector_fields], + dtype=np.float_), + timestamp=time, + metadata=self._get_metadata(row))) + + # Yield remaining + yield previous_time, detections diff --git a/stonesoup/reader/tests/test_generic.py b/stonesoup/reader/tests/test_generic.py index 2775309bf..a0d37849f 100644 --- a/stonesoup/reader/tests/test_generic.py +++ b/stonesoup/reader/tests/test_generic.py @@ -4,11 +4,13 @@ from textwrap import dedent import numpy as np +import pytest from ..generic import CSVDetectionReader, CSVGroundTruthReader -def test_csv_gt(tmpdir): +@pytest.fixture() +def csv_gt_filename(tmpdir): csv_filename = tmpdir.join("test.csv") with csv_filename.open('w') as csv_file: csv_file.write(dedent("""\ @@ -19,24 +21,22 @@ def test_csv_gt(tmpdir): 13,23,33,32018332,2018-01-01T14:03:00Z 14,24,34,32018332,2018-01-01T14:04:00Z """)) + return csv_filename + +def test_csv_gt_2d(csv_gt_filename): # run test with: # - 2d co-ordinates # - default time field format # - no special csv options - csv_reader = CSVGroundTruthReader(csv_filename.strpath, + csv_reader = CSVGroundTruthReader(csv_gt_filename.strpath, state_vector_fields=["x", "y"], time_field="t", path_id_field="identifier") - all_gt_paths = [ - gt_paths_at_timestep - for timestep, gt_paths_at_timestep - in csv_reader.groundtruth_paths_gen()] - - final_gt_paths = [ - gt_path - for gt_path in all_gt_paths[len(all_gt_paths) - 1]] + final_gt_paths = set() + for _, gt_paths_at_timestep in csv_reader: + final_gt_paths.update(gt_paths_at_timestep) assert len(final_gt_paths) == 2 ground_truth_states = [ @@ -52,23 +52,20 @@ def test_csv_gt(tmpdir): assert gt_state.timestamp.minute == n assert gt_state.timestamp.date() == datetime.date(2018, 1, 1) + +def test_csv_gt_3d_time(csv_gt_filename): # run test with: # - 3d co-ordinates # - time field format specified - csv_reader = CSVGroundTruthReader(csv_filename.strpath, + csv_reader = CSVGroundTruthReader(csv_gt_filename.strpath, state_vector_fields=["x", "y", "z"], time_field="t", time_field_format="%Y-%m-%dT%H:%M:%SZ", path_id_field="identifier") - all_gt_paths = [ - gt_paths_at_timestep - for timestep, gt_paths_at_timestep - in csv_reader.groundtruth_paths_gen()] - - final_gt_paths = [ - gt_path - for gt_path in all_gt_paths[len(all_gt_paths) - 1]] + final_gt_paths = set() + for _, gt_paths_at_timestep in csv_reader: + final_gt_paths.update(gt_paths_at_timestep) assert len(final_gt_paths) == 2 ground_truth_states = [ @@ -84,10 +81,13 @@ def test_csv_gt(tmpdir): assert gt_state.timestamp.minute == n assert gt_state.timestamp.date() == datetime.date(2018, 1, 1) + +def test_csv_gt_3d_timestamp_csv_opt(tmpdir): # run test with: # - time field represented as a Unix epoch timestamp # - csv options specified - with csv_filename.open('w') as csv_file: + csv_gt_filename = tmpdir.join("test.csv") + with csv_gt_filename.open('w') as csv_file: csv_file.write(dedent("""\ 10,20,30,22018332,1514815200 11,21,31,22018332,1514815260 @@ -96,7 +96,7 @@ def test_csv_gt(tmpdir): 14,24,34,32018332,1514815440 """)) - csv_reader = CSVGroundTruthReader(csv_filename.strpath, + csv_reader = CSVGroundTruthReader(csv_gt_filename.strpath, state_vector_fields=["x", "y", "z"], time_field="t", timestamp=True, @@ -105,14 +105,9 @@ def test_csv_gt(tmpdir): ['x', 'y', 'z', 'identifier', 't']}) - all_gt_paths = [ - gt_paths_at_timestep - for timestep, gt_paths_at_timestep - in csv_reader.groundtruth_paths_gen()] - - final_gt_paths = [ - gt_path - for gt_path in all_gt_paths[len(all_gt_paths) - 1]] + final_gt_paths = set() + for _, gt_paths_at_timestep in csv_reader: + final_gt_paths.update(gt_paths_at_timestep) assert len(final_gt_paths) == 2 ground_truth_states = [ @@ -129,7 +124,32 @@ def test_csv_gt(tmpdir): assert gt_state.timestamp.date() == datetime.date(2018, 1, 1) -def test_csv(tmpdir): +def test_csv_gt_multi_per_timestep(tmpdir): + csv_gt_filename = tmpdir.join("test.csv") + with csv_gt_filename.open('w') as csv_file: + csv_file.write(dedent("""\ + x,y,z,identifier,t + 10,20,30,22018332,2018-01-01T14:00:00Z + 11,21,31,22018332,2018-01-01T14:01:00Z + 12,22,32,22018332,2018-01-01T14:02:00Z + 13,23,33,32018332,2018-01-01T14:02:00Z + 14,24,34,32018332,2018-01-01T14:03:00Z + """)) + + csv_reader = CSVGroundTruthReader(csv_gt_filename.strpath, + state_vector_fields=["x", "y"], + time_field="t", + path_id_field="identifier") + + for time, ground_truth_paths in csv_reader: + if time == datetime.datetime(2018, 1, 1, 14, 2): + assert len(ground_truth_paths) == 2 + else: + assert len(ground_truth_paths) == 1 + + +@pytest.fixture() +def csv_det_filename(tmpdir): csv_filename = tmpdir.join("test.csv") with csv_filename.open('w') as csv_file: csv_file.write(dedent("""\ @@ -138,11 +158,14 @@ def test_csv(tmpdir): 11,21,31,22018332,2018-01-01T14:01:00Z 12,22,32,22018332,2018-01-01T14:02:00Z """)) + return csv_filename + +def test_csv_default(csv_det_filename): # run test with: # - 'metadata_fields' for 'CSVDetectionReader' == default # - copy all metadata items - csv_reader = CSVDetectionReader(csv_filename.strpath, ["x", "y"], "t") + csv_reader = CSVDetectionReader(csv_det_filename.strpath, ["x", "y"], "t") detections = [ detection for _, detections in csv_reader @@ -161,11 +184,13 @@ def test_csv(tmpdir): assert int(detection.metadata['z']) == 30 + n assert detection.metadata['identifier'] == '22018332' + +def test_csv_metadata_time(csv_det_filename): # run test with: # - 'metadata_fields' for 'CSVDetectionReader' contains # 'z' but not 'identifier' # - 'time_field_format' is specified - csv_reader = CSVDetectionReader(csv_filename.strpath, ["x", "y"], "t", + csv_reader = CSVDetectionReader(csv_det_filename.strpath, ["x", "y"], "t", time_field_format="%Y-%m-%dT%H:%M:%SZ", metadata_fields=["z"]) detections = [ @@ -184,11 +209,14 @@ def test_csv(tmpdir): assert 'z' in detection.metadata.keys() assert int(detection.metadata['z']) == 30 + n + +def test_csv_missing_metadata_timestamp(tmpdir): # run test with: # - 'metadata_fields' for 'CSVDetectionReader' contains # column names that do not exist in CSV file # - 'time' field represented as a Unix epoch timestamp - with csv_filename.open('w') as csv_file: + csv_det_filename = tmpdir.join("test.csv") + with csv_det_filename.open('w') as csv_file: csv_file.write(dedent("""\ x,y,z,identifier,t 10,20,30,22018332,1514815200 @@ -196,7 +224,7 @@ def test_csv(tmpdir): 12,22,32,22018332,1514815320 """)) - csv_reader = CSVDetectionReader(csv_filename.strpath, ["x", "y"], "t", + csv_reader = CSVDetectionReader(csv_det_filename.strpath, ["x", "y"], "t", metadata_fields=["heading"], timestamp=True) detections = [ @@ -214,6 +242,27 @@ def test_csv(tmpdir): assert len(detection.metadata) == 0 +def test_csv_multi_per_timestep(tmpdir): + csv_det_filename = tmpdir.join("test.csv") + with csv_det_filename.open('w') as csv_file: + csv_file.write(dedent("""\ + x,y,z,identifier,t + 10,20,30,22018332,2018-01-01T14:00:00Z + 11,21,31,22018332,2018-01-01T14:01:00Z + 12,22,32,22018332,2018-01-01T14:02:00Z + 13,23,33,32018332,2018-01-01T14:02:00Z + 14,24,34,32018332,2018-01-01T14:03:00Z + """)) + + csv_reader = CSVDetectionReader(csv_det_filename.strpath, ["x", "y"], "t") + + for time, detections in csv_reader: + if time == datetime.datetime(2018, 1, 1, 14, 2): + assert len(detections) == 2 + else: + assert len(detections) == 1 + + def test_tsv(tmpdir): csv_filename = tmpdir.join("test.csv") with csv_filename.open('w') as csv_file: From 5be1e81eadb2f0eb9bdcc68b8fd619d8ff8c1fc2 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Tue, 19 May 2020 12:12:18 +0100 Subject: [PATCH 2/2] Avoid duplicate time conversion in CSVGroundTruthReader --- stonesoup/reader/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/reader/generic.py b/stonesoup/reader/generic.py index a4ed0ce73..6e792c0df 100644 --- a/stonesoup/reader/generic.py +++ b/stonesoup/reader/generic.py @@ -89,7 +89,7 @@ def groundtruth_paths_gen(self): state = GroundTruthState( np.array([[row[col_name]] for col_name in self.state_vector_fields], dtype=np.float_), - timestamp=self._get_time(row), + timestamp=time, metadata=self._get_metadata(row)) id_ = row[self.path_id_field]