In [None]:
import lderiv_control as ld
import copy
import gym
import numpy as np
import scipy.integrate as si
import matplotlib.pyplot as plt
from stable_baselines3 import PPO, DDPG, A2C, TD3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback

In [None]:
env = ld.Swing()

In [None]:
# Save a checkpoint every save_freq steps
checkpoint_callback = CheckpointCallback(
    save_freq=50_000,
    save_path="./logs/",
    name_prefix="rl_model",
)

In [None]:
model = PPO("MlpPolicy", env, verbose=1)
model.save("logs/rl_model_0_steps")
model.learn(total_timesteps=2e5, callback=checkpoint_callback)

In [None]:
model.save("trained_model_new")
del model
env = ld.Swing()
model = PPO.load("trained_model_new.zip", env=env)

In [None]:
done = False
obs = env.reset()
while not done:
    action, _states = model.predict(obs)
    obs, reward, done, _ = env.step(action)

phi_hist = np.array(env.phi)
l_hist = np.array(env.L)

x_t = l_hist * np.sin(phi_hist)
y_t = -l_hist * np.cos(phi_hist)
ref_x = np.array(env.lmax) * np.sin(env.phi)
ref_y = -np.array(env.lmax) * np.cos(env.phi)

In [None]:
fontdict = {"fontsize": 16}
plt.plot(env.tau / 2 * np.arange(len(env.phi)), env.phi)
plt.yticks([0, np.pi, 2 * np.pi], ["0", r"$\pi$", r"2$\pi$"])
plt.xlabel("Time", fontdict=fontdict)
plt.ylabel("Angle", fontdict=fontdict)
plt.title("Angle over time", fontdict=fontdict)
plt.savefig("theta.png")

In [None]:
plt.plot(env.tau / 2 * np.arange(len(env.phi)), env.L)
plt.xlabel("Time", fontdict=fontdict)
plt.ylabel("Length", fontdict=fontdict)
plt.title("Length over time", fontdict=fontdict)
plt.savefig("length.png")
plt.show()

In [None]:
fig, ax1 = plt.subplots(figsize=(14, 10))

ax2 = ax1.twinx()
ax1.plot(env.tau / 2 * np.arange(len(env.phi)), env.phi, "g-")
ax2.plot(env.tau / 2 * np.arange(len(env.phi)), env.L, "k--")

ax1.set_xlabel("Time")
ax1.set_ylabel("Angles", color="g")
ax2.set_ylabel("Lengths", color="k")

plt.title("Lengths and Angles over time")
plt.savefig("overlay.png")