In [None]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

In [None]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple

In [None]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [None]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP: float, KI: float, KD: float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0.0, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(
        self, env_state: np.ndarray, env_reward: float, cartpos_setpoint: float
    ) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([cartpos_setpoint, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [None]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

In [None]:
def show_state(ep: int, t: int):
    dff = df.loc[(df["ep"] == ep) & (df["t"] == t)]

    if dff.shape[0] != 1:
        return

    fig, ax = cartpole_util.render_cartpole_state(dff.iloc[0])

_ = interact(
    show_state,
    ep=widgets.IntSlider(min=0, max=20, step=1, value=0),
    t=widgets.IntSlider(min=0, max=500, step=1, value=0),
)

In [None]:
EP = 0

fig = cartpole_util.lineplot(df, ep=EP)
fig.show()