diff --git a/docs/examples/KalmanFilterOOSMExample.py b/docs/examples/KalmanFilterOOSMExample.py index 934dcfb6b..56c90efda 100644 --- a/docs/examples/KalmanFilterOOSMExample.py +++ b/docs/examples/KalmanFilterOOSMExample.py @@ -154,9 +154,9 @@ plotter = AnimatedPlotterly(timesteps=time_steps) plotter.plot_ground_truths(truth, [0, 2]) -plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue', symbol='0'), +plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue'), measurements_label='Detections with no lag') -plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange', symbol='0'), +plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange'), measurements_label='Detections with lag') plotter.plot_sensors([sensor1_platform, sensor2_platform], marker=dict(color='black', symbol='129', size=15), diff --git a/docs/examples/track_fusion_example.py b/docs/examples/track_fusion_example.py index facedd5a0..56b335ff0 100644 --- a/docs/examples/track_fusion_example.py +++ b/docs/examples/track_fusion_example.py @@ -54,7 +54,6 @@ from stonesoup.types.array import CovarianceMatrix from stonesoup.simulator.simple import SingleTargetGroundTruthSimulator from stonesoup.models.clutter.clutter import ClutterModel -from stonesoup.types.detection import Clutter # Instantiate the radars to collect measurements - Use a :class:`~.RadarBearingRange` radar.# from stonesoup.sensor.radar.radar import RadarBearingRange @@ -185,19 +184,10 @@ # Plot the detections from the two radars plotter = Plotterly() -plotter.plot_measurements([d for ds in s1_detections for d in ds if not isinstance(d, Clutter)], - [0, 2], marker=dict(color='red'), measurements_label='Sensor 1 measurements') - -plotter.plot_measurements([d for ds in s1_detections for d in ds if isinstance(d, Clutter)], - [0, 2], marker=dict(color='red', symbol='star-triangle-up'), +plotter.plot_measurements(s1_detections, [0, 2], marker=dict(color='red'), measurements_label='Sensor 1 measurements') - -plotter.plot_measurements([d for ds in s2_detections for d in ds if not isinstance(d, Clutter)], - [0, 2], marker=dict(color='blue'), measurements_label='Sensor 2 measurements') -plotter.plot_measurements([d for ds in s2_detections for d in ds if isinstance(d, Clutter)], - [0, 2], marker=dict(color='blue', symbol='star-triangle-up'), +plotter.plot_measurements(s2_detections, [0, 2], marker=dict(color='blue'), measurements_label='Sensor 2 measurements') - plotter.plot_sensors({sensor1_platform, sensor2_platform}, [0, 1], marker=dict(color='black', symbol='1', size=10)) plotter.plot_ground_truths(truths, [0, 2]) diff --git a/setup.cfg b/setup.cfg index eb32a9f7b..cd83b9b44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ python_requires = >=3.8 packages = find: install_requires = matplotlib + mergedeep numpy>=1.17 ordered-set pymap3d diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index d85d2e21c..e447da759 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -1,14 +1,17 @@ import warnings from abc import ABC, abstractmethod from datetime import datetime, timedelta +from enum import IntEnum from itertools import chain from typing import Collection, Iterable, Union, List, Optional, Tuple, Dict + import numpy as np from matplotlib import animation as animation from matplotlib import pyplot as plt from matplotlib.legend_handler import HandlerPatch from matplotlib.lines import Line2D from matplotlib.patches import Ellipse +from mergedeep import merge from scipy.integrate import quad from scipy.optimize import brentq from scipy.stats import kde @@ -21,19 +24,15 @@ except ImportError: go = None +from .base import Base, Property +from .models.base import LinearModel, Model from .types import detection -from .types.groundtruth import GroundTruthPath from .types.array import StateVector +from .types.groundtruth import GroundTruthPath from .types.metric import SingleTimeMetric from .types.state import State, StateMutableSequence from .types.update import Update -from .base import Base, Property - -from .models.base import LinearModel, Model - -from enum import IntEnum - class Dimension(IntEnum): """Dimension Enum class for specifying plotting parameters in the Plotter class. @@ -285,9 +284,11 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, self.legend_dict[measurements_label] = measurements_handle if plot_clutter: + clutter_kwargs = kwargs.copy() + clutter_kwargs.update(dict(marker='2')) clutter_array = np.array(list(plot_clutter.values())) - artists.append(self.ax.scatter(*clutter_array.T, color='y', marker='2')) - clutter_handle = Line2D([], [], linestyle='', marker='2', color='y') + artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs)) + clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) clutter_label = "Clutter" # Generate legend items for clutter @@ -998,7 +999,7 @@ def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs): if self.dimension == 3: layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well - layout_kwargs.update(kwargs) + merge(layout_kwargs, kwargs) # Generate plot axes self.fig = go.Figure(layout=layout_kwargs) @@ -1054,7 +1055,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot"))) - truths_kwargs.update(kwargs) + merge(truths_kwargs, kwargs) add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup for trace in self.fig.data} @@ -1142,7 +1143,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if self.dimension == 3: # make markers smaller in 3d plot measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA'))) - measurement_kwargs.update(kwargs) + merge(measurement_kwargs, kwargs) if measurement_kwargs['legendgroup'] not in {trace.legendgroup for trace in self.fig.data}: measurement_kwargs['showlegend'] = True @@ -1175,7 +1176,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_clutter: name = measurements_label + "
(Clutter)" - measurement_kwargs = dict( + clutter_kwargs = dict( mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'), name=name, legendgroup=name, legendrank=210) @@ -1183,12 +1184,12 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond", color='#FECB52'))) - measurement_kwargs.update(kwargs) - if measurement_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data}: - measurement_kwargs['showlegend'] = True + merge(clutter_kwargs, kwargs) + if clutter_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data}: + clutter_kwargs['showlegend'] = True else: - measurement_kwargs['showlegend'] = False + clutter_kwargs['showlegend'] = False clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64) if self.dimension == 1: @@ -1196,14 +1197,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, x=[state.timestamp for state in plot_clutter.keys()], y=clutter_array[:, 0], text=[self._format_state_text(state) for state in plot_clutter.keys()], - **measurement_kwargs, + **clutter_kwargs, ) elif self.dimension == 2: self.fig.add_scatter( x=clutter_array[:, 0], y=clutter_array[:, 1], text=[self._format_state_text(state) for state in plot_clutter.keys()], - **measurement_kwargs, + **clutter_kwargs, ) elif self.dimension == 3: self.fig.add_scatter3d( @@ -1211,7 +1212,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, y=clutter_array[:, 1], z=clutter_array[:, 2], text=[self._format_state_text(state) for state in plot_clutter.keys()], - **measurement_kwargs, + **clutter_kwargs, ) def get_next_color(self): @@ -1283,7 +1284,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ if self.dimension == 3: # change visuals to work well in 3d track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4)) - track_kwargs.update(kwargs) + merge(track_kwargs, kwargs) add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup for trace in self.fig.data} @@ -1484,7 +1485,7 @@ def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), legendgroup=sensor_label, legendrank=50) - sensor_kwargs.update(kwargs) + merge(sensor_kwargs, kwargs) sensor_kwargs['name'] = sensor_label if sensor_kwargs['legendgroup'] not in {trace.legendgroup @@ -1573,7 +1574,7 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping plotting_kwargs = dict( mode="markers", legendgroup=label, legendrank=200, name=label, thetaunit="radians") - plotting_kwargs.update(kwargs) + merge(plotting_kwargs, kwargs) add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup for trace in self.fig.data} @@ -1620,7 +1621,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa ``line=dict(dash="dash")``. """ truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100) - truths_kwargs.update(kwargs) + merge(truths_kwargs, kwargs) angle_mapping = mapping[0] if len(mapping) > 1: range_mapping = mapping[1] @@ -1680,7 +1681,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_detections: name = measurements_label + "
(Detections)" measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200) - measurement_kwargs.update(kwargs) + merge(measurement_kwargs, kwargs) plotting_data = [State(state_vector=plotting_state_vector, timestamp=det.timestamp) for det, plotting_state_vector in plot_detections.items()] @@ -1691,16 +1692,16 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_clutter: name = measurements_label + "
(Clutter)" - measurement_kwargs = dict(mode='markers', legendrank=210, - marker=dict(symbol="star-triangle-up", color='#FECB52')) - measurement_kwargs.update(kwargs) + clutter_kwargs = dict(mode='markers', legendrank=210, + marker=dict(symbol="star-triangle-up", color='#FECB52')) + merge(clutter_kwargs, kwargs) plotting_data = [State(state_vector=plotting_state_vector, timestamp=det.timestamp) for det, plotting_state_vector in plot_clutter.items()] self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping, range_mapping=range_mapping, label=name, - **measurement_kwargs) + **clutter_kwargs) def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", **kwargs): @@ -1734,7 +1735,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ raise NotImplementedError track_kwargs = dict(mode='markers+lines', legendrank=300) - track_kwargs.update(kwargs) + merge(track_kwargs, kwargs) angle_mapping = mapping[0] if len(mapping) > 1: range_mapping = mapping[1] @@ -2460,7 +2461,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label, line=dict(dash="dash", color=self.colorway[0]), legendrank=100, name=truths_label, showlegend=True) - truth_kwargs.update(kwargs) + merge(truth_kwargs, kwargs) # legend dummy trace self.fig.add_trace(go.Scatter(truth_kwargs)) @@ -2469,9 +2470,8 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", for n, _ in enumerate(truths): # change the colour of each truth and include n in its name - truth_kwargs.update({ - "line": dict(dash="dash", color=self.colorway[n % len(self.colorway)])}) - truth_kwargs.update(kwargs) + merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)]))) + merge(truth_kwargs, kwargs) self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces for frame in self.fig.frames: @@ -2613,7 +2613,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, legendgroup='Detections (Measurements)', legendrank=200, showlegend=True, marker=dict(color="#636EFA"), hoverinfo='none') - measurement_kwargs.update(kwargs) + merge(measurement_kwargs, kwargs) self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend @@ -2622,11 +2622,12 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, # change necessary kwargs to initialise clutter trace name = measurements_label + "
(Clutter)" - measurement_kwargs.update({"legendgroup": 'Clutter', "legendrank": 300, - "marker": dict(symbol="star-triangle-up", color='#FECB52'), - "name": name, 'showlegend': True}) + clutter_kwargs = {"legendgroup": 'Clutter', "legendrank": 300, + "marker": dict(symbol="star-triangle-up", color='#FECB52'), + "name": name, 'showlegend': True} + merge(clutter_kwargs, kwargs) - self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting clutter + self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter # add data to frames for frame in self.fig.frames: @@ -2968,7 +2969,7 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), legendgroup=sensor_label, legendrank=50, name=sensor_label, showlegend=True) - sensor_kwargs.update(kwargs) + merge(sensor_kwargs, kwargs) self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace