In [1]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple


In [3]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [4]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP:float, KI:float, KD:float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0.1, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(self, env_state: np.ndarray, env_reward: float, cartpos_setpoint: float) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([cartpos_setpoint, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative 
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [5]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos_setpoint,cart_pos,cart_vel,pole_ang,pole_vel,agent_pid_total,agent_pid_pos,agent_pid_ang,env_info
0,0,0,0.0,-0.046289,-0.180237,-0.046584,0.236064,-0.030747,-0.028409,-0.027906,{}
1,0,1,0.0,-0.049893,-0.374663,-0.041863,0.513696,-0.006542,-0.005415,-0.006000,{}
2,0,2,0.0,-0.057387,-0.569171,-0.031589,0.792899,-0.003990,-0.008219,-0.003168,{}
3,0,3,0.0,-0.068770,-0.763846,-0.015731,1.075479,-0.000828,-0.011487,0.000320,{}
4,0,4,0.0,-0.084047,-0.568520,0.005778,0.777901,0.003015,-0.015258,0.004541,{}
...,...,...,...,...,...,...,...,...,...,...,...
9995,19,495,0.0,-0.042201,0.003070,0.010334,-0.045856,0.004243,0.009259,0.003317,{}
9996,19,496,0.0,-0.042139,0.198042,0.009417,-0.335261,0.002007,0.010405,0.000967,{}
9997,19,497,0.0,-0.038179,0.002787,0.002712,-0.039623,-0.000747,0.011941,-0.001941,{}
9998,19,498,0.0,-0.038123,-0.192374,0.001919,0.253914,-0.004088,0.013905,-0.005478,{}


In [6]:
@interact(ep=(0,20,1), t=(0,500,1))
def show_state(ep:int, t:int):
    dff = df.loc[(df['ep'] ==ep) & (df['t'] ==t)]

    if(dff.shape[0] != 1):
        return
    
    row = dff.iloc[0]
    global fig
    fig = cartpole_util._render_cartpole_state(row['cart_pos'], row['cart_vel'], row['pole_ang'], row['pole_vel'], 0.0)
    fig.show()

# interact(
#     show_state,
#     ep=widgets.IntSlider(min=0, max=20, step=1, value=0), 
#     t=widgets.IntSlider(min=0, max=500, step=1, value=0)
#     )

interactive(children=(IntSlider(value=10, description='ep', max=20), IntSlider(value=250, description='t', max…