In [39]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

In [40]:
import numpy as np
import pandas as pd
import importlib

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from typing import Dict, Tuple

In [41]:
import cpagent
import cprender
import cpenvs

_ = importlib.reload(cpagent)
_ = importlib.reload(cprender)
_ = importlib.reload(cpenvs)

In [42]:
from numpy import ndarray


class PidAgent(cpagent.CartPoleAgentABC):
    def __init__(self, KP: float, KI: float, KD: float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_mask = np.array([0, 0, 1, 0, 0.1])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def after_step(self, old_env_state: ndarray, new_env_state: ndarray, action: int, env_reward: float) -> None:
        pass

    def step(
        self, env_state: np.ndarray
    ) -> Tuple[int, Dict[str, object]]:
        desired_state = np.array([0, 0, 0, 0, 0])

        error = env_state - desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pids = self.KP * error + self.KI * self.integral + self.KD * self.derivative
        pid = np.dot(pids, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action # {"pid_total": pid, "pid_pos": pids[0], "pid_ang": pids[2]}

In [43]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cpagent.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos,cart_vel,pole_ang,pole_vel,pos_deviation,reward
0,0,0,-0.044338,0.018650,0.041078,-0.027629,-0.044338,
1,0,1,-0.043965,0.213160,0.040525,-0.307074,-0.043965,1.0
2,0,2,-0.039702,0.407682,0.034383,-0.586706,-0.039702,1.0
3,0,3,-0.031549,0.602306,0.022649,-0.868363,-0.031549,1.0
4,0,4,-0.019502,0.406883,0.005282,-0.568646,-0.019502,1.0
...,...,...,...,...,...,...,...,...
10015,19,496,0.012719,-0.347941,-0.013654,0.503214,0.012719,1.0
10016,19,497,0.005761,-0.152629,-0.003589,0.206259,0.005761,1.0
10017,19,498,0.002708,0.042544,0.000536,-0.087554,0.002708,1.0
10018,19,499,0.003559,0.237659,-0.001215,-0.380067,0.003559,1.0


In [44]:
def show_state(ep: int, t: int):
    dff = df.loc[(df["ep"] == ep) & (df["t"] == t)]

    if dff.shape[0] != 1:
        return

    fig, ax = cprender.render_cartpole_state(dff.iloc[0])

_ = interact(
    show_state,
    ep=widgets.IntSlider(min=0, max=20, step=1, value=0),
    t=widgets.IntSlider(min=0, max=500, step=1, value=0),
)

interactive(children=(IntSlider(value=0, description='ep', max=20), IntSlider(value=0, description='t', max=50…

In [45]:
EP = 0

fig = cprender.lineplot(df, ep=EP)
fig.show()