In [1]:
import copy
import gym
import numpy as np
import scipy.integrate as si
import matplotlib.pyplot as plt
import matplotlib.ticker as tck
from stable_baselines3 import PPO, DDPG
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
class Swing(gym.Env):
    def __init__(self):
        super(Swing, self).__init__()
        self.lmin = 2
        self.lmax = 2.5
        self.phidot_0 = -0.1
        self.target = np.pi
        self.time = 0
        self.pumps = 0
        self.tau = np.sqrt(self.lmin / 9.81) / 4
        self.ldot_max = (self.lmax - self.lmin) / self.tau
        self.observation_space = gym.spaces.Box(
            low=np.array([0, -10, self.lmin]),
            high=np.array([2 * np.pi, 10, self.lmax]),  # phi, phi dot, L
        )
        self.observation_space = gym.spaces.Box(
            low=np.array([-1, self.lmin]), high=np.array([10, self.lmax])
        )
        self.action_space = gym.spaces.Box(low=-np.array([1]), high=np.array([1]))
        self.phi = [np.pi / 8]
        self.phi_dot = [self.phidot_0]
        self.L = [self.lmin]
        self.Ldot_hist = []

    def fun(self, t, y, ldot, g=9.81):
        """Define system of equations to simulate"""
        y0_dot = y[1]
        y1_dot = -(2 * ldot / y[2]) * y[1] - (g / y[2]) * np.sin(y[0])
        y2_dot = ldot
        y_dot = np.hstack((y0_dot, y1_dot, y2_dot))
        return y_dot

    def forward(self, ldot):
        """Simulate the swing forward"""
        sol = si.solve_ivp(
            self.fun,
            [self.time, self.time + self.tau],
            y0=[self.phi[-1], self.phi_dot[-1], self.L[-1]],
            args=[ldot],
        )
        phi = np.mod(sol.y[0], 2 * np.pi)
        self.phi.extend(list(phi[1:]))
        phi_dot = sol.y[1]
        self.phi_dot.extend(list(phi_dot[1:]))
        L = sol.y[2]
        self.L.extend(list(L[1:]))
        self.time += self.tau
        self.pumps += 1
        pass

    def check_valid_action(self, ldot):
        """Check if an action will take us out of bounds. if so don't allow it."""
        next_l = self.L[-1] + self.tau * ldot
        if next_l > self.lmax:
            ldot = (self.lmax - self.L[-1]) / self.tau
        elif next_l < self.lmin:
            ldot = (self.lmin - self.L[-1]) / self.tau
        else:
            ldot = ldot
        return ldot

    def step(self, action):
        """Take action and simulate"""
        ldot = self.ldot_max * action[0]
        ldot = self.check_valid_action(ldot)
        self.Ldot_hist.append(ldot)
        self.forward(ldot)
        # state = np.array([self.phi[-1], self.phi_dot[-1], self.L[-1]], dtype=np.float32)
        state = np.array([self.phi_dot[-1], self.L[-1]], dtype=np.float32)
        if np.isclose(state[0], self.target, rtol=0.05):
            reward = 10
            done = True
        elif self.pumps > 10_000:
            reward = -1
            done = True
        else:
            reward = -1
            done = False
        info = {}
        return state, reward, done, info

    def reset(self):
        """Reset system to beginning of episode."""
        self.time = 0
        self.pumps = 0
        self.L = [self.lmin]
        self.phi = [np.pi / 8]
        self.phi_dot = [self.phidot_0]
        self.Ldot_hist.clear()
        # state = np.array([self.phi[-1], self.phi_dot[-1], self.L[-1]], dtype=np.float32)
        state = np.array([self.phi_dot[-1], self.L[-1]], dtype=np.float32)
        return state

    def render(self):
        pass

In [3]:
env = Swing()

  "Box bound precision lowered by casting to {}".format(self.dtype)


In [4]:
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=3.5e5)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 470      |
|    ep_rew_mean     | -459     |
| time/              |          |
|    fps             | 1027     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 585          |
|    ep_rew_mean          | -574         |
| time/                   |              |
|    fps                  | 912          |
|    iterations           | 2            |
|    time_elapsed         | 4            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0054652635 |
|    clip_fraction        | 0.0352       |
|    clip_range           | 0.2          |
|    en

<stable_baselines3.ppo.ppo.PPO at 0x7fb7822cc550>

In [5]:
model.save("ldot_controller")