diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py b/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py index 726247072..6b6065306 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py @@ -30,7 +30,7 @@ ImuSensor as imu) from jiminy_py.dynamics import compute_freeflyer_state_from_fixed_body from jiminy_py.log import extract_variables_from_log -from jiminy_py.simulator import Simulator +from jiminy_py.simulator import Simulator, TabbedFigure from jiminy_py.viewer.viewer import (DEFAULT_CAMERA_XYZRPY_REL, interactive_mode, get_default_backend, @@ -941,7 +941,7 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: return self.simulator.render( # type: ignore[return-value] return_rgb_array=self.render_mode == 'rgb_array') - def plot(self, **kwargs: Any) -> None: + def plot(self, **kwargs: Any) -> TabbedFigure: """Display common simulation data and action over time. .. Note: @@ -950,7 +950,7 @@ def plot(self, **kwargs: Any) -> None: :param kwargs: Extra keyword arguments to forward to `simulator.plot`. """ # Call base implementation - self.simulator.plot(**kwargs) + figure = self.simulator.plot(**kwargs) # Extract log data log_vars = self.simulator.log_data.get("variables", {}) @@ -969,25 +969,28 @@ def plot(self, **kwargs: Any) -> None: if action_fieldnames is None: # It was impossible to register the action to the telemetry, likely # because of incompatible dtype. Early return without adding tab. - return + return figure if isinstance(action_fieldnames, dict): for group, fieldnames in action_fieldnames.items(): if not isinstance(fieldnames, list): LOGGER.error( "Action space not supported by this method.") - return + return figure tab_data[group] = { - ".".join(key.split(".")[1:]): value + key.split(".", 2)[2]: value for key, value in extract_variables_from_log( log_vars, fieldnames, as_dict=True).items()} elif isinstance(action_fieldnames, list): tab_data.update({ - ".".join(key.split(".")[1:]): value + key.split(".", 2)[2]: value for key, value in extract_variables_from_log( log_vars, action_fieldnames, as_dict=True).items()}) # Add action tab - self.simulator.figure.add_tab("Action", t, tab_data) + figure.add_tab(" ".join(("Env", "Action")), t, tab_data) + + # Return figure for convenience and consistency with Matplotlib + return figure def replay(self, **kwargs: Any) -> None: """Replay the current episode until now. diff --git a/python/gym_jiminy/examples/reinforcement_learning/rllib/acrobot_ppo.py b/python/gym_jiminy/examples/reinforcement_learning/rllib/acrobot_ppo.py index f6734ab94..c99c07972 100644 --- a/python/gym_jiminy/examples/reinforcement_learning/rllib/acrobot_ppo.py +++ b/python/gym_jiminy/examples/reinforcement_learning/rllib/acrobot_ppo.py @@ -316,4 +316,4 @@ policy_fn = build_policy_wrapper( policy_map, clip_actions=False, explore=False) for seed in (1, 1, 2): - env.evaluate(policy_fn, seed=seed) + env.evaluate(policy_fn, seed=seed, horizon=env._max_episode_steps) diff --git a/python/jiminy_py/setup.py b/python/jiminy_py/setup.py index e7ffa33df..36af20131 100644 --- a/python/jiminy_py/setup.py +++ b/python/jiminy_py/setup.py @@ -116,7 +116,8 @@ def finalize_options(self) -> None: extras_require={ "plot": [ # Standard library to generate figures. - "matplotlib>=3.5.0" + # - 3.7.0: introduces 'outside' keyword for legend location + "matplotlib>=3.7.0" ], "meshcat": [ # Web-based mesh visualizer used as Viewer's backend. diff --git a/python/jiminy_py/src/jiminy_py/plot.py b/python/jiminy_py/src/jiminy_py/plot.py index 59648a9d3..c291eb854 100644 --- a/python/jiminy_py/src/jiminy_py/plot.py +++ b/python/jiminy_py/src/jiminy_py/plot.py @@ -133,7 +133,7 @@ def __init__(self, # pylint: disable=unused-argument logging.warning(msg) # Internal state buffers - self.figure = plt.figure() + self.figure = plt.figure(layout="constrained") self.legend: Optional[Legend] = None self.ref_ax: Optional[Axes] = None self.tabs_data: Dict[str, TabData] = {} @@ -143,9 +143,16 @@ def __init__(self, # pylint: disable=unused-argument # Set window title self.figure.canvas.manager.set_window_title(window_title) - # Set window size for offscreen rendering + # Customize figure subplot layout and reserve space for buttons + # self.figure.get_layout_engine().set(w_pad=0.1, h_pad=0.1) + self.subfigs = self.figure.subfigures( + 2, 1, wspace=0.1, height_ratios=[0.94, 0.06]) + + # Set window size if self.offscreen: self.figure.set_size_inches(18, 12) + else: + self.figure.set_size_inches(14, 8) # Register 'on resize' event callback to adjust layout self.figure.canvas.mpl_connect('resize_event', self.adjust_layout) @@ -173,16 +180,16 @@ def adjust_layout(self, return # Compute figure area for later export - bbox = Bbox([[0.0, 0.065], [1.0, 1.0]]) + bbox = Bbox(((0.0, 0.07), (1.0, 1.0))) bbox_pixels = bbox.transformed(self.figure.transFigure) self.bbox_inches = bbox_pixels.transformed( self.figure.dpi_scale_trans.inverted()) # Refresh button size, in case the number of tabs has changed - buttons_width = 1.0 / (len(self.tabs_data) + 1) + buttons_width = (1.0 - 0.006) / len(self.tabs_data) for i, tab in enumerate(self.tabs_data.values()): tab["button_axcut"].set_position( - [buttons_width * (i + 0.5), 0.01, buttons_width, 0.05]) + (buttons_width * i + 0.003, 0.1, buttons_width, 1.0)) # Re-arrange subplots in case figure aspect ratio has changed axes = self.tab_active["axes"] @@ -197,20 +204,9 @@ def adjust_layout(self, num_rows, num_cols = map(int, (num_rows_1, num_cols_1)) else: num_rows, num_cols = map(int, (num_rows_2, num_cols_2)) + grid_spec = self.subfigs[0].add_gridspec(num_rows, num_cols) for i, ax in enumerate(axes, 1): - ax.set_subplotspec(plt.GridSpec( - num_rows, num_cols, figure=self.figure)[i - 1]) - - # Adjust layout: namely margins between subplots - right_margin = 0.03 - if self.legend is not None: - legend_extent = self.legend.get_window_extent() - legend_width_rel = legend_extent.transformed( - self.figure.transFigure.inverted()).width - right_margin += legend_width_rel - self.figure.subplots_adjust( - bottom=0.10, top=0.92, left=0.04, right=1.0-right_margin, - wspace=0.40, hspace=0.30) + ax.set_subplotspec(grid_spec[i - 1]) # Refresh figure canvas if requested if refresh_canvas: @@ -244,21 +240,20 @@ def __click(self, event: Event, force_update: bool = False) -> None: # Update axes and title for ax in self.tab_active["axes"]: - self.figure.delaxes(ax) + self.subfigs[0].delaxes(ax) if self.legend is not None: self.legend.remove() self.legend = None self.tab_active = self.tabs_data[tab_name] - self.figure.suptitle(tab_name) + self.subfigs[0].suptitle(tab_name) for ax in self.tab_active["axes"]: - self.figure.add_subplot(ax) + self.subfigs[0].add_subplot(ax) handles, labels = self.tab_active["legend_data"] if labels: - self.legend = self.figure.legend( - handles, labels, loc='center right', - bbox_to_anchor=(0.99, 0.5)) + self.legend = self.subfigs[0].legend( + handles, labels, ncol=len(handles), loc='outside lower center') - # Restore navigation history and toolbar state if necessary + # # Restore navigation history and toolbar state if necessary if not self.offscreen: cur_stack._elements = self.tab_active["nav_stack"] cur_stack._pos = self.tab_active["nav_pos"] @@ -321,10 +316,15 @@ def add_tab(self, # pylint: disable=unused-argument ref_ax = self.ref_ax if self.sync_tabs else None for i, plot_name in enumerate(data.keys()): uniq_label = '_'.join((tab_name, plot_name)) - ax = self.figure.add_subplot( + ax = self.subfigs[0].add_subplot( n_rows, n_cols, i+1, label=uniq_label) + ax.autoscale(True, axis='x', tight=True) + ax.autoscale(True, axis='y', tight=False) + ax.ticklabel_format(axis='x', style='plain', useOffset=True) + ax.ticklabel_format( + axis='y', style='sci', scilimits=(-3, 3), useOffset=False) if self.tabs_data: - self.figure.delaxes(ax) + self.subfigs[0].delaxes(ax) if ref_ax is not None: ax.sharex(ref_ax) else: @@ -343,14 +343,15 @@ def add_tab(self, # pylint: disable=unused-argument else: plot_method(ax, time, plot_data) ax.set_title(plot_name, fontsize='medium') - ax.grid() + ax.grid(True) else: # Draw single figure instead of subplot - ax = self.figure.add_subplot(1, 1, 1, label=tab_name) + ax = self.subfigs[0].add_subplot(1, 1, 1, label=tab_name) plot_method(ax, time, data) if self.tabs_data: - self.figure.delaxes(ax) - ax.grid() + self.subfigs[0].delaxes(ax) + ax.autoscale(enable=True, axis='both', tight=True) + ax.grid(True) axes = [ax] # Get unique legend for every subplots @@ -358,7 +359,8 @@ def add_tab(self, # pylint: disable=unused-argument # Add buttons to show/hide information uniq_label = '_'.join((tab_name, "button")) - button_axcut = plt.axes([0.0, 0.0, 0.0, 0.0], label=uniq_label) + button_axcut = self.subfigs[1].add_axes( + [0.0, 0.0, 0.0, 0.0], label=uniq_label) button = _ButtonBlit(button_axcut, tab_name.replace(' ', '\n'), color='white') @@ -384,10 +386,11 @@ def add_tab(self, # pylint: disable=unused-argument # Show tab without blocking for ax in axes: ax.set_visible(True) - self.figure.suptitle(tab_name) + self.subfigs[0].suptitle(tab_name) handles, labels = legend_data if labels: - self.legend = self.figure.legend(handles, labels) + self.legend = self.subfigs[0].legend( + handles, labels, loc='outside lower center') button.ax.set_facecolor('green') button.color = 'green' button.hovercolor = 'green' @@ -430,9 +433,9 @@ def remove_tab(self, # Remove axes and legend manually is not more tabs available if not self.tabs_data: - if self.figure._suptitle is not None: - self.figure._suptitle.remove() - self.figure._suptitle = None + if self.subfigs[0]._suptitle is not None: + self.subfigs[0]._suptitle.remove() + self.subfigs[0]._suptitle = None for ax in tab["axes"]: ax.remove() if self.legend is not None: @@ -456,7 +459,7 @@ def save_tab(self, pdf_path: str) -> None: :param pdf_path: Desired location for generated pdf file. """ pdf_path = str(pathlib.Path(pdf_path).with_suffix('.png')) - self.figure.savefig( + self.subfigs[0].savefig( pdf_path, format='pdf', bbox_inches=self.bbox_inches) def save_all_tabs(self, pdf_path: str) -> None: @@ -557,7 +560,7 @@ def plot_log(log_data: Dict[str, Any], values = extract_variables_from_log( log_vars, fieldnames, as_dict=True) tabs_data[' '.join(("State", fields_type))] = OrderedDict( - (field[len("current"):].replace(fields_type, ""), elem) + (field.split(".", 1)[1][7:].replace(fields_type, ""), elem) for field, elem in values.items()) except ValueError: # Variable has not been recorded and is missing in log file @@ -669,7 +672,8 @@ def plot_log_interactive() -> None: main_arguments, plotting_commands = parser.parse_known_args() # Load log file - log_vars = read_log(main_arguments.input)["variables"] + main_fullpath = main_arguments.input + log_vars = read_log(main_fullpath)["variables"] # If no plotting commands, display the list of headers instead if len(plotting_commands) == 0: @@ -677,16 +681,19 @@ def plot_log_interactive() -> None: lambda s: f"- {s}", log_vars.keys()), sep="\n") sys.exit(0) - # Load comparision logs, if any. - compare_data = OrderedDict() + # Load all comparison logs, if any + compare_data: Dict[str, Dict[str, np.ndarray]] = OrderedDict() if main_arguments.compare is not None: for fullpath in main_arguments.compare.split(':'): - compare_data[fullpath], _ = read_log(fullpath) + if fullpath == main_fullpath or fullpath in compare_data.keys(): + raise RuntimeError( + "All log files must be unique when comparing them.") + compare_data[fullpath] = read_log(fullpath)["variables"] - # Define linestyle cycle that will be used for comparison logs - linestyles = ["--", "-.", ":"] + # Define line style cycle used for logs comparison + linestyles = ("--", "-.", ":") - # Parse plotting arguments. + # Parse plotting arguments plotted_elements = [] for cmd in plotting_commands: # Check that the command is valid, i.e. that all elements exits. @@ -715,41 +722,37 @@ def plot_log_interactive() -> None: plotted_elements.append( [header[i] for header in matching_fieldnames]) - # Create figure. + # Create figure n_plot = len(plotted_elements) - - if n_plot == 0: + if not n_plot: print("Nothing to plot. Exiting...") return + fig = plt.figure(layout="constrained") + + # Set window title + fig.canvas.manager.set_window_title(main_arguments.input) - fig = plt.figure() + # Set window size + fig.set_size_inches(14, 8) # Create subplots, arranging them in a rectangular fashion. # Do not allow for n_cols to be more than n_rows + 2. n_cols = n_plot n_rows = 1 while n_cols > n_rows + 2: - n_rows = n_rows + 1 - n_cols = np.ceil(n_plot / (1.0 * n_rows)) + n_rows = int(n_rows + 1) + n_cols = int(np.ceil(n_plot / (1.0 * n_rows))) + axes = fig.subplots(n_rows, n_cols, sharex=True, squeeze=False).flat[:] - axes: List[plt.Axes] = [] - for i in range(n_plot): - ax = fig.add_subplot(int(n_rows), int(n_cols), i+1) - if i > 0: - ax.sharex(axes[0]) - axes.append(ax) - - # Store lines in dictionnary {file_name: plotted lines}, to enable to + # Store lines in dictionary {file_name: plotted lines}, to enable to # toggle individually the visibility the data related to each of them. main_name = os.path.basename(main_arguments.input) plotted_lines: Dict[str, List[Line2D]] = {main_name: []} for c in compare_data: plotted_lines[os.path.basename(c)] = [] - plt.gcf().canvas.manager.set_window_title(main_arguments.input) + # Plot each element t = log_vars['Global.Time'] - - # Plot each element. for ax, plotted_elem in zip(axes, plotted_elements): for name in plotted_elem: line = ax.step(t, log_vars[name], label=name) @@ -763,47 +766,35 @@ def plot_log_interactive() -> None: color=line[0].get_color()) plotted_lines[os.path.basename(c)].append(line[0]) - # Add legend and grid for each plot. - for ax in axes: + # Add legend and grid for each plot + for ax, plotted_elem in zip(axes, plotted_elements): ax.set_xlabel('time (s)') - ax.legend() - ax.grid() + if len(plotted_elem) > 1: + ax.legend() + else: + ax.set_title(plotted_elem[0], fontsize='medium') + ax.grid(True) # If a compare plot is present, add overall legend specifying line types - plt.subplots_adjust( - bottom=0.05, - top=0.98, - left=0.06, - right=0.98, - wspace=0.1, - hspace=0.12) if len(compare_data) > 0: linecycler = cycle(linestyles) - # Dictionnary: line in legend to log name + # Dictionary: line in legend to log name legend_lines = {Line2D([0], [0], color='k'): main_name} for data_str in compare_data: legend_line_object = Line2D( [0], [0], color='k', linestyle=next(linecycler)) legend_lines[legend_line_object] = os.path.basename(data_str) legend = fig.legend( - legend_lines.keys(), legend_lines.values(), loc='upper center', - ncol=3) + legend_lines.keys(), legend_lines.values(), ncol=3, + loc='outside lower center') - # Create a dict {picker: log name} for both the lines and the legend + # Create a dict {picker: log name} for legend lines and labels picker_to_name = {} - for legline, name in zip(legend.get_lines(), legend_lines.values()): - legline.set_picker(10) # 10 pts tolerance - picker_to_name[legline] = name - for legline, name in zip(legend.get_texts(), legend_lines.values()): + for legline, legtxt, name in zip( + legend.get_lines(), legend.get_texts(), legend_lines.values()): legline.set_picker(10) # 10 pts tolerance - picker_to_name[legline] = name - - # Increase top margin to fit legend - fig.canvas.draw() - legend_height = legend.get_window_extent().inverse_transformed( - fig.transFigure).height - plt.subplots_adjust(top=0.98-legend_height) + picker_to_name.update({legline: name, legtxt: name}) # Make legend interactive def legend_clicked(event: Event) -> None: diff --git a/python/jiminy_py/src/jiminy_py/simulator.py b/python/jiminy_py/src/jiminy_py/simulator.py index 565d48c03..f0379ab2e 100644 --- a/python/jiminy_py/src/jiminy_py/simulator.py +++ b/python/jiminy_py/src/jiminy_py/simulator.py @@ -125,7 +125,7 @@ def callback_wrapper(t: float, self.__pbar: Optional[tqdm] = None # Figure holder - self.figure: Optional[TabbedFigure] = None + self._figure: Optional[TabbedFigure] = None # Reset the low-level jiminy engine self.reset() @@ -634,14 +634,14 @@ def close(self) -> None: if hasattr(self, "viewer") and self.viewer is not None: self.viewer.close() self.viewer = None - if hasattr(self, "figure") and self.figure is not None: - self.figure.close() - self.figure = None + if hasattr(self, "figure") and self._figure is not None: + self._figure.close() + self._figure = None def plot(self, enable_flexiblity_data: bool = False, block: Optional[bool] = None, - **kwargs: Any) -> None: + **kwargs: Any) -> TabbedFigure: """Display common simulation data over time. The figure features several tabs: @@ -671,9 +671,11 @@ def plot(self, ) from e # Create figure, without closing the existing one - self.figure = plot_log( + self._figure = plot_log( self.log_data, self.robot, enable_flexiblity_data, block, **kwargs) + return self._figure + def get_controller_options(self) -> dict: """Getter of the options of Jiminy Controller. """