Skip to content

Commit

Permalink
[python/plot] Fix various visual glitches and improve layout.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Duburcq committed Sep 18, 2023
1 parent 3a53f5b commit b01b080
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 94 deletions.
16 changes: 12 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@

from .internal import loop_interactive

try:
from jiminy_py.plot import TabbedFigure

Check notice on line 60 in python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py#L60

Imports from package jiminy_py are not grouped
except ImportError:
TabbedFigure = type(None) # type: ignore[misc,assignment]


# Define universal bounds for the observation space
FREEFLYER_POS_TRANS_MAX = 1000.0
Expand Down Expand Up @@ -941,7 +946,7 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
return self.simulator.render( # type: ignore[return-value]
return_rgb_array=self.render_mode == 'rgb_array')

def plot(self, **kwargs: Any) -> None:
def plot(self, **kwargs: Any) -> TabbedFigure:

Check notice on line 949 in python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py#L949

Either all return statements in a function should return an expression, or none of them should.
"""Display common simulation data and action over time.
.. Note:
Expand Down Expand Up @@ -977,17 +982,20 @@ def plot(self, **kwargs: Any) -> None:
"Action space not supported by this method.")
return
tab_data[group] = {
".".join(key.split(".")[1:]): value
key.split(".", 2)[2]: value
for key, value in extract_variables_from_log(
log_vars, fieldnames, as_dict=True).items()}
elif isinstance(action_fieldnames, list):
tab_data.update({
".".join(key.split(".")[1:]): value
key.split(".", 2)[2]: value
for key, value in extract_variables_from_log(
log_vars, action_fieldnames, as_dict=True).items()})

# Add action tab
self.simulator.figure.add_tab("Action", t, tab_data)
self.simulator.figure.add_tab(" ".join(("Env", "Action")), t, tab_data)

# Return figure for convenience and consistency with Matplotlib
return self.simulator.figure

def replay(self, **kwargs: Any) -> None:
"""Replay the current episode until now.
Expand Down
3 changes: 2 additions & 1 deletion python/jiminy_py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def finalize_options(self) -> None:
extras_require={
"plot": [
# Standard library to generate figures.
"matplotlib>=3.5.0"
# - 3.7.0: introduces 'outside' keyword for legend location
"matplotlib>=3.7.0"
],
"meshcat": [
# Web-based mesh visualizer used as Viewer's backend.
Expand Down
169 changes: 80 additions & 89 deletions python/jiminy_py/src/jiminy_py/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, # pylint: disable=unused-argument
logging.warning(msg)

# Internal state buffers
self.figure = plt.figure()
self.figure = plt.figure(layout="constrained")
self.legend: Optional[Legend] = None
self.ref_ax: Optional[Axes] = None
self.tabs_data: Dict[str, TabData] = {}
Expand All @@ -143,9 +143,16 @@ def __init__(self, # pylint: disable=unused-argument
# Set window title
self.figure.canvas.manager.set_window_title(window_title)

# Set window size for offscreen rendering
# Customize figure subplot layout and reserve space for buttons
# self.figure.get_layout_engine().set(w_pad=0.1, h_pad=0.1)
self.subfigs = self.figure.subfigures(
2, 1, wspace=0.1, height_ratios=[0.94, 0.06])

# Set window size
if self.offscreen:
self.figure.set_size_inches(18, 12)
else:
self.figure.set_size_inches(14, 8)

# Register 'on resize' event callback to adjust layout
self.figure.canvas.mpl_connect('resize_event', self.adjust_layout)
Expand Down Expand Up @@ -173,16 +180,16 @@ def adjust_layout(self,
return

# Compute figure area for later export
bbox = Bbox([[0.0, 0.065], [1.0, 1.0]])
bbox = Bbox(((0.0, 0.07), (1.0, 1.0)))
bbox_pixels = bbox.transformed(self.figure.transFigure)
self.bbox_inches = bbox_pixels.transformed(
self.figure.dpi_scale_trans.inverted())

# Refresh button size, in case the number of tabs has changed
buttons_width = 1.0 / (len(self.tabs_data) + 1)
buttons_width = (1.0 - 0.006) / len(self.tabs_data)
for i, tab in enumerate(self.tabs_data.values()):
tab["button_axcut"].set_position(
[buttons_width * (i + 0.5), 0.01, buttons_width, 0.05])
(buttons_width * i + 0.003, 0.1, buttons_width, 1.0))

# Re-arrange subplots in case figure aspect ratio has changed
axes = self.tab_active["axes"]
Expand All @@ -197,20 +204,9 @@ def adjust_layout(self,
num_rows, num_cols = map(int, (num_rows_1, num_cols_1))
else:
num_rows, num_cols = map(int, (num_rows_2, num_cols_2))
grid_spec = self.subfigs[0].add_gridspec(num_rows, num_cols)
for i, ax in enumerate(axes, 1):
ax.set_subplotspec(plt.GridSpec(
num_rows, num_cols, figure=self.figure)[i - 1])

# Adjust layout: namely margins between subplots
right_margin = 0.03
if self.legend is not None:
legend_extent = self.legend.get_window_extent()
legend_width_rel = legend_extent.transformed(
self.figure.transFigure.inverted()).width
right_margin += legend_width_rel
self.figure.subplots_adjust(
bottom=0.10, top=0.92, left=0.04, right=1.0-right_margin,
wspace=0.40, hspace=0.30)
ax.set_subplotspec(grid_spec[i - 1])

# Refresh figure canvas if requested
if refresh_canvas:
Expand Down Expand Up @@ -244,21 +240,20 @@ def __click(self, event: Event, force_update: bool = False) -> None:

# Update axes and title
for ax in self.tab_active["axes"]:
self.figure.delaxes(ax)
self.subfigs[0].delaxes(ax)
if self.legend is not None:
self.legend.remove()
self.legend = None
self.tab_active = self.tabs_data[tab_name]
self.figure.suptitle(tab_name)
self.subfigs[0].suptitle(tab_name)
for ax in self.tab_active["axes"]:
self.figure.add_subplot(ax)
self.subfigs[0].add_subplot(ax)
handles, labels = self.tab_active["legend_data"]
if labels:
self.legend = self.figure.legend(
handles, labels, loc='center right',
bbox_to_anchor=(0.99, 0.5))
self.legend = self.subfigs[0].legend(
handles, labels, ncol=len(handles), loc='outside lower center')

# Restore navigation history and toolbar state if necessary
# # Restore navigation history and toolbar state if necessary
if not self.offscreen:
cur_stack._elements = self.tab_active["nav_stack"]
cur_stack._pos = self.tab_active["nav_pos"]
Expand Down Expand Up @@ -321,10 +316,15 @@ def add_tab(self, # pylint: disable=unused-argument
ref_ax = self.ref_ax if self.sync_tabs else None
for i, plot_name in enumerate(data.keys()):
uniq_label = '_'.join((tab_name, plot_name))
ax = self.figure.add_subplot(
ax = self.subfigs[0].add_subplot(
n_rows, n_cols, i+1, label=uniq_label)
ax.autoscale(True, axis='x', tight=True)
ax.autoscale(True, axis='y', tight=False)
ax.ticklabel_format(axis='x', style='plain', useOffset=True)
ax.ticklabel_format(
axis='y', style='sci', scilimits=(-3, 3), useOffset=False)
if self.tabs_data:
self.figure.delaxes(ax)
self.subfigs[0].delaxes(ax)
if ref_ax is not None:
ax.sharex(ref_ax)
else:
Expand All @@ -343,22 +343,24 @@ def add_tab(self, # pylint: disable=unused-argument
else:
plot_method(ax, time, plot_data)
ax.set_title(plot_name, fontsize='medium')
ax.grid()
ax.grid(True)
else:
# Draw single figure instead of subplot
ax = self.figure.add_subplot(1, 1, 1, label=tab_name)
ax = self.subfigs[0].add_subplot(1, 1, 1, label=tab_name)
plot_method(ax, time, data)
if self.tabs_data:
self.figure.delaxes(ax)
ax.grid()
self.subfigs[0].delaxes(ax)
ax.autoscale(enable=True, axis='both', tight=True)
ax.grid(True)
axes = [ax]

# Get unique legend for every subplots
legend_data = ax.get_legend_handles_labels()

# Add buttons to show/hide information
uniq_label = '_'.join((tab_name, "button"))
button_axcut = plt.axes([0.0, 0.0, 0.0, 0.0], label=uniq_label)
button_axcut = self.subfigs[1].add_axes(
[0.0, 0.0, 0.0, 0.0], label=uniq_label)
button = _ButtonBlit(button_axcut,
tab_name.replace(' ', '\n'),
color='white')
Expand All @@ -384,10 +386,11 @@ def add_tab(self, # pylint: disable=unused-argument
# Show tab without blocking
for ax in axes:
ax.set_visible(True)
self.figure.suptitle(tab_name)
self.subfigs[0].suptitle(tab_name)
handles, labels = legend_data
if labels:
self.legend = self.figure.legend(handles, labels)
self.legend = self.subfigs[0].legend(
handles, labels, loc='outside lower center')
button.ax.set_facecolor('green')
button.color = 'green'
button.hovercolor = 'green'
Expand Down Expand Up @@ -430,9 +433,9 @@ def remove_tab(self,

# Remove axes and legend manually is not more tabs available
if not self.tabs_data:
if self.figure._suptitle is not None:
self.figure._suptitle.remove()
self.figure._suptitle = None
if self.subfigs[0]._suptitle is not None:
self.subfigs[0]._suptitle.remove()
self.subfigs[0]._suptitle = None
for ax in tab["axes"]:
ax.remove()
if self.legend is not None:
Expand All @@ -456,7 +459,7 @@ def save_tab(self, pdf_path: str) -> None:
:param pdf_path: Desired location for generated pdf file.
"""
pdf_path = str(pathlib.Path(pdf_path).with_suffix('.png'))
self.figure.savefig(
self.subfigs[0].savefig(
pdf_path, format='pdf', bbox_inches=self.bbox_inches)

def save_all_tabs(self, pdf_path: str) -> None:
Expand Down Expand Up @@ -557,7 +560,7 @@ def plot_log(log_data: Dict[str, Any],
values = extract_variables_from_log(
log_vars, fieldnames, as_dict=True)
tabs_data[' '.join(("State", fields_type))] = OrderedDict(
(field[len("current"):].replace(fields_type, ""), elem)
(field.split(".", 1)[1][7:].replace(fields_type, ""), elem)
for field, elem in values.items())
except ValueError:
# Variable has not been recorded and is missing in log file
Expand Down Expand Up @@ -669,24 +672,28 @@ def plot_log_interactive() -> None:
main_arguments, plotting_commands = parser.parse_known_args()

# Load log file
log_vars = read_log(main_arguments.input)["variables"]
main_fullpath = main_arguments.input
log_vars = read_log(main_fullpath)["variables"]

# If no plotting commands, display the list of headers instead
if len(plotting_commands) == 0:
print("Available data:", *map(
lambda s: f"- {s}", log_vars.keys()), sep="\n")
sys.exit(0)

# Load comparision logs, if any.
compare_data = OrderedDict()
# Load all comparison logs, if any
compare_data: Dict[str, Dict[str, np.ndarray]] = OrderedDict()
if main_arguments.compare is not None:
for fullpath in main_arguments.compare.split(':'):
compare_data[fullpath], _ = read_log(fullpath)
if fullpath == main_fullpath or fullpath in compare_data.keys():
raise RuntimeError(
"All log files must be unique when comparing them.")
compare_data[fullpath] = read_log(fullpath)["variables"]

# Define linestyle cycle that will be used for comparison logs
linestyles = ["--", "-.", ":"]
# Define line style cycle used for logs comparison
linestyles = ("--", "-.", ":")

# Parse plotting arguments.
# Parse plotting arguments
plotted_elements = []
for cmd in plotting_commands:
# Check that the command is valid, i.e. that all elements exits.
Expand Down Expand Up @@ -715,41 +722,37 @@ def plot_log_interactive() -> None:
plotted_elements.append(
[header[i] for header in matching_fieldnames])

# Create figure.
# Create figure
n_plot = len(plotted_elements)

if n_plot == 0:
if not n_plot:
print("Nothing to plot. Exiting...")
return
fig = plt.figure(layout="constrained")

# Set window title
fig.canvas.manager.set_window_title(main_arguments.input)

fig = plt.figure()
# Set window size
fig.set_size_inches(14, 8)

# Create subplots, arranging them in a rectangular fashion.
# Do not allow for n_cols to be more than n_rows + 2.
n_cols = n_plot
n_rows = 1
while n_cols > n_rows + 2:
n_rows = n_rows + 1
n_cols = np.ceil(n_plot / (1.0 * n_rows))
n_rows = int(n_rows + 1)
n_cols = int(np.ceil(n_plot / (1.0 * n_rows)))
axes = fig.subplots(n_rows, n_cols, sharex=True, squeeze=False).flat[:]

axes: List[plt.Axes] = []
for i in range(n_plot):
ax = fig.add_subplot(int(n_rows), int(n_cols), i+1)
if i > 0:
ax.sharex(axes[0])
axes.append(ax)

# Store lines in dictionnary {file_name: plotted lines}, to enable to
# Store lines in dictionary {file_name: plotted lines}, to enable to
# toggle individually the visibility the data related to each of them.
main_name = os.path.basename(main_arguments.input)
plotted_lines: Dict[str, List[Line2D]] = {main_name: []}
for c in compare_data:
plotted_lines[os.path.basename(c)] = []

plt.gcf().canvas.manager.set_window_title(main_arguments.input)
# Plot each element
t = log_vars['Global.Time']

# Plot each element.
for ax, plotted_elem in zip(axes, plotted_elements):
for name in plotted_elem:
line = ax.step(t, log_vars[name], label=name)
Expand All @@ -763,47 +766,35 @@ def plot_log_interactive() -> None:
color=line[0].get_color())
plotted_lines[os.path.basename(c)].append(line[0])

# Add legend and grid for each plot.
for ax in axes:
# Add legend and grid for each plot
for ax, plotted_elem in zip(axes, plotted_elements):
ax.set_xlabel('time (s)')
ax.legend()
ax.grid()
if len(plotted_elem) > 1:
ax.legend()
else:
ax.set_title(plotted_elem[0], fontsize='medium')
ax.grid(True)

# If a compare plot is present, add overall legend specifying line types
plt.subplots_adjust(
bottom=0.05,
top=0.98,
left=0.06,
right=0.98,
wspace=0.1,
hspace=0.12)
if len(compare_data) > 0:
linecycler = cycle(linestyles)

# Dictionnary: line in legend to log name
# Dictionary: line in legend to log name
legend_lines = {Line2D([0], [0], color='k'): main_name}
for data_str in compare_data:
legend_line_object = Line2D(
[0], [0], color='k', linestyle=next(linecycler))
legend_lines[legend_line_object] = os.path.basename(data_str)
legend = fig.legend(
legend_lines.keys(), legend_lines.values(), loc='upper center',
ncol=3)
legend_lines.keys(), legend_lines.values(), ncol=3,
loc='outside lower center')

# Create a dict {picker: log name} for both the lines and the legend
# Create a dict {picker: log name} for legend lines and labels
picker_to_name = {}
for legline, name in zip(legend.get_lines(), legend_lines.values()):
legline.set_picker(10) # 10 pts tolerance
picker_to_name[legline] = name
for legline, name in zip(legend.get_texts(), legend_lines.values()):
for legline, legtxt, name in zip(
legend.get_lines(), legend.get_texts(), legend_lines.values()):
legline.set_picker(10) # 10 pts tolerance
picker_to_name[legline] = name

# Increase top margin to fit legend
fig.canvas.draw()
legend_height = legend.get_window_extent().inverse_transformed(
fig.transFigure).height
plt.subplots_adjust(top=0.98-legend_height)
picker_to_name.update({legline: name, legtxt: name})

# Make legend interactive
def legend_clicked(event: Event) -> None:
Expand Down

0 comments on commit b01b080

Please sign in to comment.