In [1]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

In [2]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple

In [3]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [4]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP: float, KI: float, KD: float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0.1, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(
        self, env_state: np.ndarray, env_reward: float, cartpos_setpoint: float
    ) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([cartpos_setpoint, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [5]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos_setpoint,cart_pos,cart_vel,pole_ang,pole_vel,agent_pid_total,agent_pid_pos,agent_pid_ang,env_info
0,0,0,0.0,-0.006295,-0.238398,-0.034220,0.304245,-0.021481,-0.003306,-0.021151,{}
1,0,1,0.0,-0.011063,-0.433016,-0.028135,0.585942,-0.004002,-0.001185,-0.003884,{}
2,0,2,0.0,-0.019724,-0.627733,-0.016416,0.869632,-0.001113,-0.003718,-0.000741,{}
3,0,3,0.0,-0.032278,-0.432391,0.000977,0.571833,0.002411,-0.006728,0.003083,{}
4,0,4,0.0,-0.040926,-0.237283,0.012413,0.279458,0.006644,-0.010253,0.007669,{}
...,...,...,...,...,...,...,...,...,...,...,...
9995,19,495,0.0,0.045061,-0.049540,-0.002492,0.091361,0.000209,-0.024169,0.002626,{}
9996,19,496,0.0,0.044070,0.145618,-0.000664,-0.202107,0.003648,-0.026158,0.006263,{}
9997,19,497,0.0,0.046982,0.340749,-0.004706,-0.494999,0.001094,-0.023864,0.003481,{}
9998,19,498,0.0,0.053797,0.145694,-0.014606,-0.203803,-0.002020,-0.021151,0.000095,{}


In [6]:
def show_state(ep: int, t: int):
    dff = df.loc[(df["ep"] == ep) & (df["t"] == t)]

    if dff.shape[0] != 1:
        return

    fig, ax = cartpole_util.render_cartpole_state(dff.iloc[0])

_ = interact(
    show_state,
    ep=widgets.IntSlider(min=0, max=20, step=1, value=0),
    t=widgets.IntSlider(min=0, max=500, step=1, value=0),
)

interactive(children=(IntSlider(value=0, description='ep', max=20), IntSlider(value=0, description='t', max=50…

In [7]:
EP = 0

fig = cartpole_util.lineplot(df, ep=EP)
fig.show()