In [None]:
import numpy as np
import matplotlib.pyplot as plt
from glider import Glider
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import pickle
from stable_baselines3.common.env_checker import check_env
from matplotlib.patches import Ellipse

In [None]:
glider = Glider(u0=0.0, v0=0.01, w0=0.0)

checkpoint_callback = CheckpointCallback(
    save_freq=50_000,
    save_path="./big_state_models/",
    name_prefix="rl_model",
)

In [None]:
model = PPO("MlpPolicy", glider, verbose=1, tensorboard_log="big_state_logs/")
model.learn(total_timesteps=5e5, callback=checkpoint_callback)

In [None]:
glider = Glider(u0=0.25, v0=-0.25, w0=0.1)
model = PPO.load("big_state_models/rl_model_500000_steps.zip", env=glider)
done = False
obs = glider.reset()
while not done:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, _ = glider.step(action.item())

In [None]:
plt.plot(glider.theta)

In [None]:
plt.scatter(glider.x, glider.y)

In [None]:
plt.plot(glider.beta)

In [None]:
def save_history(glider: Glider, filename: str) -> None:
    """
    Save the full history of the glider in a dictionary
    """
    history = {
        "u": glider.u,
        "v": glider.v,
        "w": glider.w,
        "x": glider.x,
        "y": glider.y,
        "theta": glider.theta,
        "beta": glider.beta,
    }
    with open(filename + ".pkl", "wb") as f:
        pickle.dump(history, f)

In [None]:
save_history(glider=glider, filename="history")

In [None]:
# ell = Ellipse(
#         xy = (1, 1),
#         width = 1,
#         height = 0.3,
#         angle = np.rad2deg(np.pi/2.5)
#         )
# fig, ax = plt.subplots()
# ax.add_artist(ell)
# ax.set_xlim(-5, 5)
# ax.set_ylim(-5, 5)
# plt.show()