In [11]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

Collecting ipywidgets (from -r ../requirements.txt (line 10))
  Downloading ipywidgets-8.1.1-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.9 (from ipywidgets->-r ../requirements.txt (line 10))
  Downloading widgetsnbextension-4.0.9-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab-widgets~=3.0.9 (from ipywidgets->-r ../requirements.txt (line 10))
  Downloading jupyterlab_widgets-3.0.9-py3-none-any.whl.metadata (4.1 kB)
Downloading ipywidgets-8.1.1-py3-none-any.whl (139 kB)
   ---------------------------------------- 0.0/139.4 kB ? eta -:--:--
   ---------------------------------------- 139.4/139.4 kB 4.2 MB/s eta 0:00:00
Downloading jupyterlab_widgets-3.0.9-py3-none-any.whl (214 kB)
   ---------------------------------------- 0.0/214.9 kB ? eta -:--:--
   -------------------------------------- - 204.8/214.9 kB 6.3 MB/s eta 0:00:01
   ---------------------------------------- 214.9/214.9 kB 4.4 MB/s eta 0:00:00
Downloading widgetsnbextension-4.0.9-py3-none-

In [12]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple


In [13]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [8]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP:float, KI:float, KD:float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0.1, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(self, env_state: np.ndarray, env_reward: float, cartpos_setpoint: float) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([cartpos_setpoint, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative 
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [9]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos_setpoint,cart_pos,cart_vel,pole_ang,pole_vel,agent_pid_total,agent_pid_pos,agent_pid_ang,env_info
0,0,0,0.0,0.004805,-0.222855,-0.003217,0.328225,-0.002084,0.003270,-0.002411,{}
1,0,1,0.0,0.000348,-0.027687,0.003347,0.034529,0.000005,0.000304,-0.000026,{}
2,0,2,0.0,-0.000206,0.167386,0.004038,-0.257096,0.003370,-0.002089,0.003579,{}
3,0,3,0.0,0.003142,0.362450,-0.001104,-0.548502,0.000732,-0.000194,0.000751,{}
4,0,4,0.0,0.010391,0.167344,-0.012074,-0.256168,-0.002478,0.002123,-0.002690,{}
...,...,...,...,...,...,...,...,...,...,...,...
9995,19,495,0.0,-0.012508,-0.003636,-0.002007,-0.058894,-0.000287,0.023101,-0.002597,{}
9996,19,496,0.0,-0.012580,-0.198729,-0.003185,0.233155,-0.003699,0.025310,-0.006230,{}
9997,19,497,0.0,-0.016555,-0.393806,0.001478,0.524832,-0.001115,0.023225,-0.003437,{}
9998,19,498,0.0,-0.024431,-0.198704,0.011975,0.232615,0.002035,0.020711,-0.000036,{}


In [15]:
def show_state(ep:int, t:int):
    dff = df.loc[(df['ep'] ==ep) & (df['t'] ==t)]

    if(dff.shape[0] != 1):
        return
    
    row = dff.iloc[0]
    state = np.array(row['cart_pos'], row['cart_vel'], row['pole_ang'], row['pole_vel'])
    return cartpole_util.render_cartpole_state(state, 0.0)

interact(
    show_state,
    ep=widgets.IntSlider(min=0, max=20, step=1, value=0), 
    t=widgets.IntSlider(min=0, max=500, step=1, value=0)
    )