Skip to content

Commit

Permalink
Merge pull request #148 from decargroup/feature/147-add-error-plottin…
Browse files Browse the repository at this point in the history
…g-to-plot_predicted_trajectory

Add option to plot error
  • Loading branch information
sdahdah committed Aug 21, 2023
2 parents 317fd7e + 0b47866 commit 18ec6e8
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions pykoop/koopman_pipeline.py
Expand Up @@ -2691,6 +2691,7 @@ def plot_predicted_trajectory(
relift_state: bool = True,
plot_lifted: bool = False,
plot_input: bool = False,
plot_error: bool = False,
episode_feature: Optional[bool] = None,
plot_ground_truth: bool = True,
episode_style: Optional[str] = None,
Expand Down Expand Up @@ -2719,6 +2720,9 @@ def plot_predicted_trajectory(
plot_input : bool
If true, plot the input as well as the state. If false, plot
only the original state (default).
plot_error : bool
If true, plot the prediction error instead of the state. If false,
plot the predicted state and ground truth (default).
episode_feature : Optional[bool]
True if first feature indicates which episode a timestep is from.
If ``None``, ``self.episode_feature_`` is used.
Expand Down Expand Up @@ -2797,33 +2801,62 @@ def plot_predicted_trajectory(
# Plot results
for row in range(n_row):
for ep in range(n_eps):
if episode_style == 'overlay':
line_pred = ax[row, 0].plot(
eps[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} prediction',
**plot_args,
)
if eps_gt is not None and row < n_states:
ax[row, 0].plot(
eps_gt[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} ground truth',
linestyle='--',
color=line_pred[0].get_color(),
if plot_error:
if episode_style == 'overlay':
if eps_gt is not None and row < n_states:
ax[row, 0].plot(
eps_gt[ep][1][:, row] - eps[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} error',
**plot_args,
)
else:
ax[row, 0].plot(
eps[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} input',
**plot_args,
)
else:
if eps_gt is not None and row < n_states:
ax[row, ep].plot(
eps_gt[ep][1][:, row] - eps[ep][1][:, row],
label=f'Prediction error',
linestyle='--',
**plot_args,
)
else:
ax[row, ep].plot(
eps[ep][1][:, row],
label=f'Input',
**plot_args,
)
else:
if episode_style == 'overlay':
line_pred = ax[row, 0].plot(
eps[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} prediction',
**plot_args,
)
else:
line_pred = ax[row, ep].plot(
eps[ep][1][:, row],
label=f'Prediction',
**plot_args,
)
if eps_gt is not None and row < n_states:
ax[row, ep].plot(
eps_gt[ep][1][:, row],
label=f'Ground truth',
linestyle='--',
if eps_gt is not None and row < n_states:
ax[row, 0].plot(
eps_gt[ep][1][:, row],
label=f'Ep. {int(eps[ep][0])} ground truth',
linestyle='--',
color=line_pred[0].get_color(),
**plot_args,
)
else:
line_pred = ax[row, ep].plot(
eps[ep][1][:, row],
label=f'Prediction',
**plot_args,
)
if eps_gt is not None and row < n_states:
ax[row, ep].plot(
eps_gt[ep][1][:, row],
label=f'Ground truth',
linestyle='--',
**plot_args,
)
# Set y labels
if plot_lifted:
names = self.get_feature_names_out(
Expand Down

0 comments on commit 18ec6e8

Please sign in to comment.