In [None]:
import numpy as np
import matplotlib.pyplot as plt
from glider import Glider
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import pickle
from stable_baselines3.common.env_checker import check_env
from matplotlib.patches import Ellipse

In [None]:
glider = Glider()
checkpoint_callback = CheckpointCallback(
    save_freq=50_000,
    save_path="./big_state_models/",
    name_prefix="rl_model",
)

In [None]:
for i in range(40):
    action = 1 * np.random.choice([0, 1, 2])
    glider.step(action)

In [None]:
plt.plot(glider.x, glider.y)
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Trajectory")
plt.show()

In [None]:
glider = Glider()
model = PPO.load("big_state_models/rl_model_400000_steps.zip", env=glider)
done = False
obs = glider.reset()
while not done:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, _ = glider.step(action.item())

In [None]:
plt.plot(glider.theta)

In [None]:
plt.scatter(glider.x, glider.y)

In [None]:
beta = np.array(glider.beta)
mass = glider.a(beta) * glider.b(beta)
plt.plot(glider.t_hist, mass)

In [None]:
np.min(mass), np.max(mass)

In [None]:
def save_history(glider: Glider, filename: str) -> None:
    """
    Save the full history of the glider in a dictionary
    """
    history = {
        "u": glider.u,
        "v": glider.v,
        "w": glider.w,
        "x": glider.x,
        "y": glider.y,
        "theta": glider.theta,
        "beta": glider.beta,
    }
    with open(filename + ".pkl", "wb") as f:
        pickle.dump(history, f)

In [None]:
save_history(glider=glider, filename="history")