In [None]:
# %pip install -U pip wheel
# %pip install -r ../requirements.txt

In [5]:
import numpy as np
import pandas as pd
import importlib
from typing import Dict, Tuple


In [2]:
import cartpole_util
_ = importlib.reload(cartpole_util)

In [8]:
class PidAgent(cartpole_util.CartPoleAgentABC):
    def __init__(self, KP:float, KI:float, KD:float) -> None:
        self.KP = KP
        self.KI = KI
        self.KD = KD

        self.desired_state = np.array([0, 0, 0, 0])
        self.desired_mask = np.array([0.1, 0, 1, 0])

    def reset(self) -> None:
        self.integral = 0
        self.derivative = 0
        self.prev_error = 0

    def step(self, env_state: np.ndarray, env_reward: float) -> Tuple[int, Dict[str, object]]:
        error = env_state - self.desired_state

        self.integral += error
        self.derivative = error - self.prev_error
        self.prev_error = error

        pid = np.dot(self.KP * error + self.KI * self.integral + self.KD * self.derivative, self.desired_mask)

        action = 0 if pid <= 0 else 1

        return action, {"pid": pid}

In [9]:
agent = PidAgent(0.1, 0.01, 0.5)
df = cartpole_util.execute_cartpole(agent)
df

Unnamed: 0,ep,t,cart_pos,cart_vel,pole_ang,pole_vel,agent_pid,env_info
0,0,0,-0.010715,0.243391,0.006926,-0.332203,0.004024,{}
1,0,1,-0.005847,0.438414,0.000282,-0.622694,0.000339,{}
2,0,2,0.002921,0.243288,-0.012172,-0.329922,-0.002987,{}
3,0,3,0.007787,0.048342,-0.018771,-0.041103,-0.006974,{}
4,0,4,0.008754,-0.146506,-0.019593,0.245600,-0.005032,{}
...,...,...,...,...,...,...,...,...
9995,19,495,-0.049858,-0.348243,-0.002122,0.495222,-0.002384,{}
9996,19,496,-0.056823,-0.153091,0.007782,0.201871,0.000641,{}
9997,19,497,-0.059885,0.041919,0.011820,-0.088346,0.004295,{}
9998,19,498,-0.059047,0.236869,0.010053,-0.377277,0.001988,{}
