diff --git a/docs/examples/KalmanFilterOOSMExample.py b/docs/examples/KalmanFilterOOSMExample.py
index 934dcfb6b..56c90efda 100644
--- a/docs/examples/KalmanFilterOOSMExample.py
+++ b/docs/examples/KalmanFilterOOSMExample.py
@@ -154,9 +154,9 @@
plotter = AnimatedPlotterly(timesteps=time_steps)
plotter.plot_ground_truths(truth, [0, 2])
-plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue', symbol='0'),
+plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue'),
measurements_label='Detections with no lag')
-plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange', symbol='0'),
+plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange'),
measurements_label='Detections with lag')
plotter.plot_sensors([sensor1_platform, sensor2_platform],
marker=dict(color='black', symbol='129', size=15),
diff --git a/docs/examples/track_fusion_example.py b/docs/examples/track_fusion_example.py
index facedd5a0..56b335ff0 100644
--- a/docs/examples/track_fusion_example.py
+++ b/docs/examples/track_fusion_example.py
@@ -54,7 +54,6 @@
from stonesoup.types.array import CovarianceMatrix
from stonesoup.simulator.simple import SingleTargetGroundTruthSimulator
from stonesoup.models.clutter.clutter import ClutterModel
-from stonesoup.types.detection import Clutter
# Instantiate the radars to collect measurements - Use a :class:`~.RadarBearingRange` radar.#
from stonesoup.sensor.radar.radar import RadarBearingRange
@@ -185,19 +184,10 @@
# Plot the detections from the two radars
plotter = Plotterly()
-plotter.plot_measurements([d for ds in s1_detections for d in ds if not isinstance(d, Clutter)],
- [0, 2], marker=dict(color='red'), measurements_label='Sensor 1 measurements')
-
-plotter.plot_measurements([d for ds in s1_detections for d in ds if isinstance(d, Clutter)],
- [0, 2], marker=dict(color='red', symbol='star-triangle-up'),
+plotter.plot_measurements(s1_detections, [0, 2], marker=dict(color='red'),
measurements_label='Sensor 1 measurements')
-
-plotter.plot_measurements([d for ds in s2_detections for d in ds if not isinstance(d, Clutter)],
- [0, 2], marker=dict(color='blue'), measurements_label='Sensor 2 measurements')
-plotter.plot_measurements([d for ds in s2_detections for d in ds if isinstance(d, Clutter)],
- [0, 2], marker=dict(color='blue', symbol='star-triangle-up'),
+plotter.plot_measurements(s2_detections, [0, 2], marker=dict(color='blue'),
measurements_label='Sensor 2 measurements')
-
plotter.plot_sensors({sensor1_platform, sensor2_platform}, [0, 1],
marker=dict(color='black', symbol='1', size=10))
plotter.plot_ground_truths(truths, [0, 2])
diff --git a/setup.cfg b/setup.cfg
index eb32a9f7b..cd83b9b44 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -20,6 +20,7 @@ python_requires = >=3.8
packages = find:
install_requires =
matplotlib
+ mergedeep
numpy>=1.17
ordered-set
pymap3d
diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py
index d85d2e21c..e447da759 100644
--- a/stonesoup/plotter.py
+++ b/stonesoup/plotter.py
@@ -1,14 +1,17 @@
import warnings
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
+from enum import IntEnum
from itertools import chain
from typing import Collection, Iterable, Union, List, Optional, Tuple, Dict
+
import numpy as np
from matplotlib import animation as animation
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerPatch
from matplotlib.lines import Line2D
from matplotlib.patches import Ellipse
+from mergedeep import merge
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.stats import kde
@@ -21,19 +24,15 @@
except ImportError:
go = None
+from .base import Base, Property
+from .models.base import LinearModel, Model
from .types import detection
-from .types.groundtruth import GroundTruthPath
from .types.array import StateVector
+from .types.groundtruth import GroundTruthPath
from .types.metric import SingleTimeMetric
from .types.state import State, StateMutableSequence
from .types.update import Update
-from .base import Base, Property
-
-from .models.base import LinearModel, Model
-
-from enum import IntEnum
-
class Dimension(IntEnum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
@@ -285,9 +284,11 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
self.legend_dict[measurements_label] = measurements_handle
if plot_clutter:
+ clutter_kwargs = kwargs.copy()
+ clutter_kwargs.update(dict(marker='2'))
clutter_array = np.array(list(plot_clutter.values()))
- artists.append(self.ax.scatter(*clutter_array.T, color='y', marker='2'))
- clutter_handle = Line2D([], [], linestyle='', marker='2', color='y')
+ artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs))
+ clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs)
clutter_label = "Clutter"
# Generate legend items for clutter
@@ -998,7 +999,7 @@ def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs):
if self.dimension == 3:
layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well
- layout_kwargs.update(kwargs)
+ merge(layout_kwargs, kwargs)
# Generate plot axes
self.fig = go.Figure(layout=layout_kwargs)
@@ -1054,7 +1055,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot
truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot")))
- truths_kwargs.update(kwargs)
+ merge(truths_kwargs, kwargs)
add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
@@ -1142,7 +1143,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if self.dimension == 3: # make markers smaller in 3d plot
measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA')))
- measurement_kwargs.update(kwargs)
+ merge(measurement_kwargs, kwargs)
if measurement_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
measurement_kwargs['showlegend'] = True
@@ -1175,7 +1176,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if plot_clutter:
name = measurements_label + "
(Clutter)"
- measurement_kwargs = dict(
+ clutter_kwargs = dict(
mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'),
name=name, legendgroup=name, legendrank=210)
@@ -1183,12 +1184,12 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond",
color='#FECB52')))
- measurement_kwargs.update(kwargs)
- if measurement_kwargs['legendgroup'] not in {trace.legendgroup
- for trace in self.fig.data}:
- measurement_kwargs['showlegend'] = True
+ merge(clutter_kwargs, kwargs)
+ if clutter_kwargs['legendgroup'] not in {trace.legendgroup
+ for trace in self.fig.data}:
+ clutter_kwargs['showlegend'] = True
else:
- measurement_kwargs['showlegend'] = False
+ clutter_kwargs['showlegend'] = False
clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64)
if self.dimension == 1:
@@ -1196,14 +1197,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
x=[state.timestamp for state in plot_clutter.keys()],
y=clutter_array[:, 0],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
- **measurement_kwargs,
+ **clutter_kwargs,
)
elif self.dimension == 2:
self.fig.add_scatter(
x=clutter_array[:, 0],
y=clutter_array[:, 1],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
- **measurement_kwargs,
+ **clutter_kwargs,
)
elif self.dimension == 3:
self.fig.add_scatter3d(
@@ -1211,7 +1212,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
y=clutter_array[:, 1],
z=clutter_array[:, 2],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
- **measurement_kwargs,
+ **clutter_kwargs,
)
def get_next_color(self):
@@ -1283,7 +1284,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
if self.dimension == 3: # change visuals to work well in 3d
track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4))
- track_kwargs.update(kwargs)
+ merge(track_kwargs, kwargs)
add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
@@ -1484,7 +1485,7 @@ def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs
sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=sensor_label, legendrank=50)
- sensor_kwargs.update(kwargs)
+ merge(sensor_kwargs, kwargs)
sensor_kwargs['name'] = sensor_label
if sensor_kwargs['legendgroup'] not in {trace.legendgroup
@@ -1573,7 +1574,7 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping
plotting_kwargs = dict(
mode="markers", legendgroup=label, legendrank=200,
name=label, thetaunit="radians")
- plotting_kwargs.update(kwargs)
+ merge(plotting_kwargs, kwargs)
add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
@@ -1620,7 +1621,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
``line=dict(dash="dash")``.
"""
truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100)
- truths_kwargs.update(kwargs)
+ merge(truths_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
@@ -1680,7 +1681,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if plot_detections:
name = measurements_label + "
(Detections)"
measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200)
- measurement_kwargs.update(kwargs)
+ merge(measurement_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_detections.items()]
@@ -1691,16 +1692,16 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if plot_clutter:
name = measurements_label + "
(Clutter)"
- measurement_kwargs = dict(mode='markers', legendrank=210,
- marker=dict(symbol="star-triangle-up", color='#FECB52'))
- measurement_kwargs.update(kwargs)
+ clutter_kwargs = dict(mode='markers', legendrank=210,
+ marker=dict(symbol="star-triangle-up", color='#FECB52'))
+ merge(clutter_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_clutter.items()]
self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping,
range_mapping=range_mapping, label=name,
- **measurement_kwargs)
+ **clutter_kwargs)
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks",
**kwargs):
@@ -1734,7 +1735,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
raise NotImplementedError
track_kwargs = dict(mode='markers+lines', legendrank=300)
- track_kwargs.update(kwargs)
+ merge(track_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
@@ -2460,7 +2461,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth",
truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label,
line=dict(dash="dash", color=self.colorway[0]), legendrank=100,
name=truths_label, showlegend=True)
- truth_kwargs.update(kwargs)
+ merge(truth_kwargs, kwargs)
# legend dummy trace
self.fig.add_trace(go.Scatter(truth_kwargs))
@@ -2469,9 +2470,8 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth",
for n, _ in enumerate(truths):
# change the colour of each truth and include n in its name
- truth_kwargs.update({
- "line": dict(dash="dash", color=self.colorway[n % len(self.colorway)])})
- truth_kwargs.update(kwargs)
+ merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)])))
+ merge(truth_kwargs, kwargs)
self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces
for frame in self.fig.frames:
@@ -2613,7 +2613,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
legendgroup='Detections (Measurements)',
legendrank=200, showlegend=True,
marker=dict(color="#636EFA"), hoverinfo='none')
- measurement_kwargs.update(kwargs)
+ merge(measurement_kwargs, kwargs)
self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend
@@ -2622,11 +2622,12 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
# change necessary kwargs to initialise clutter trace
name = measurements_label + "
(Clutter)"
- measurement_kwargs.update({"legendgroup": 'Clutter', "legendrank": 300,
- "marker": dict(symbol="star-triangle-up", color='#FECB52'),
- "name": name, 'showlegend': True})
+ clutter_kwargs = {"legendgroup": 'Clutter', "legendrank": 300,
+ "marker": dict(symbol="star-triangle-up", color='#FECB52'),
+ "name": name, 'showlegend': True}
+ merge(clutter_kwargs, kwargs)
- self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting clutter
+ self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter
# add data to frames
for frame in self.fig.frames:
@@ -2968,7 +2969,7 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs):
sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=sensor_label, legendrank=50,
name=sensor_label, showlegend=True)
- sensor_kwargs.update(kwargs)
+ merge(sensor_kwargs, kwargs)
self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace