In [7]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

In [8]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple

In [9]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [10]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP: float, KI: float, KD: float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0.1, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(
        self, env_state: np.ndarray, env_reward: float, cartpos_setpoint: float
    ) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([cartpos_setpoint, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [11]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos_setpoint,cart_pos,cart_vel,pole_ang,pole_vel,agent_pid_total,agent_pid_pos,agent_pid_ang,env_info
0,0,0,0.0,0.014409,-0.197921,-0.047399,0.256019,-0.027768,0.008832,-0.028651,{}
1,0,1,0.0,0.010450,-0.392335,-0.042278,0.533383,-0.005729,0.001695,-0.005898,{}
2,0,2,0.0,0.002603,-0.586838,-0.031611,0.812450,-0.003088,-0.000541,-0.003034,{}
3,0,3,0.0,-0.009133,-0.391297,-0.015362,0.509994,0.000166,-0.003244,0.000490,{}
4,0,4,0.0,-0.016959,-0.195962,-0.005162,0.212510,0.004107,-0.006454,0.004752,{}
...,...,...,...,...,...,...,...,...,...,...,...
9995,19,495,0.0,0.041388,-0.024680,0.013774,0.044434,0.006850,-0.028873,0.009737,{}
9996,19,496,0.0,0.040895,0.170242,0.014663,-0.243871,0.004906,-0.026947,0.007601,{}
9997,19,497,0.0,0.044299,0.365151,0.009785,-0.531894,0.002468,-0.024637,0.004932,{}
9998,19,498,0.0,0.051602,0.169893,-0.000853,-0.236144,-0.000532,-0.021904,0.001659,{}


In [14]:
def show_state(ep: int, t: int):
    dff = df.loc[(df["ep"] == ep) & (df["t"] == t)]

    if dff.shape[0] != 1:
        return

    row = dff.iloc[0]
    global fig
    fig, ax = cartpole_util._render_cartpole_state(
        row["cart_pos"], row["cart_vel"], row["pole_ang"], row["pole_vel"], 0.0
    )


_ = interact(
    show_state,
    ep=widgets.IntSlider(min=0, max=20, step=1, value=0),
    t=widgets.IntSlider(min=0, max=500, step=1, value=0),
)

interactive(children=(IntSlider(value=0, description='ep', max=20), IntSlider(value=0, description='t', max=50…