Skip to content

Commit

Permalink
feat: Add 2D plot to graphing callback
Browse files Browse the repository at this point in the history
  • Loading branch information
iwishiwasaneagle committed Mar 29, 2024
1 parent f429dd1 commit 38b4a5e
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions docs/examples/drl_3d_wp/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import optuna
import pandas as pd
from drl_3d_wp.state import State
from jdrones.plotting import plot_2d_path
from loguru import logger
from matplotlib import pyplot as plt
from stable_baselines3.common.callbacks import BaseCallback
Expand All @@ -28,7 +29,7 @@ def _on_step(self):
env = self.model.get_env()
obs = env.reset()
states = None
t = 0
t = -env.unwrapped.get_attr("dt")[0]
episode_starts = np.ones((env.num_envs,), dtype=bool)
while True:
action, states = self.model.predict(
Expand All @@ -42,6 +43,10 @@ def _on_step(self):
obs_ = info_["state"]
tx, ty, tz = info_["target"]
action_ = action[0]
energy_ = info_["energy"]
distance_from_target_ = info_["distance_from_target"]
control_action_ = info_["control_action"]
dcontrol_action_ = info_["dcontrol_action"]

if action_.shape == 1:
action_ = np.array([action_])
Expand All @@ -59,11 +64,14 @@ def _on_step(self):
vx, vy, vz = obs_i.vel
p1, p2, p3, p4 = obs_i.prop_omega
log.append(
info_
| dict(
dict(
time=t,
reward=reward_,
action=action_,
energy=energy_,
distance_from_target=distance_from_target_,
control_action=control_action_,
dcontrol_action=dcontrol_action_,
x=x,
y=y,
z=z,
Expand All @@ -90,6 +98,29 @@ def _on_step(self):

df = pd.DataFrame(log)

fig, ax = plt.subplots()
df_long = (
df[["time", "x", "y", "z"]]
.melt(
var_name="variable",
value_name="value",
id_vars="time",
)
.sort_values(by=["time"])
.rename({"time": "t"}, axis="columns")
)
df_long["tag"] = "PPO+LQR"
plot_2d_path(df_long, ax)
ax.scatter(*df[["x", "y"]].iloc[0].to_list(), zorder=10, c="g")
ax.scatter(*df[["x", "y"]].iloc[-1].to_list(), zorder=10, c="r")
fig.tight_layout()
self.logger.record(
"data/position_2d",
Figure(fig, close=True),
exclude=("stdout", "log", "json", "csv"),
)
plt.close(fig)

fig, ax = plt.subplots()
ax.plot(df.time, df.energy)
ax2 = ax.twinx()
Expand Down Expand Up @@ -408,4 +439,4 @@ def _energy_callback(self, buffer):
self._generic_mean_callback("eval/step_energy", buffer)

def _distance_from_target_callback(self, buffer):
self._generic_mean_callback("eval/mean+step_distance_from_target", buffer)
self._generic_mean_callback("eval/mean_step_distance_from_target", buffer)

0 comments on commit 38b4a5e

Please sign in to comment.