In [1]:
%load_ext lab_black

In [2]:
import random
from typing import List

# 부드러운 애니메이션을 위해

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation

mpl.rc("animation", html="jshtml")

## Set ENviroment

In [3]:
class Environment:
    # initiazlize environment
    def __init__(self):
        self.step_left = 10

    def get_observation(self) -> List[float]:
        # interact with the environment
        return [0.0, 0.0, 0.0]

    def get_actions(self) -> List[int]:
        # set of actions it can execute
        return [0, 1]

    def is_done(self) -> bool:
        return self.step_left == 0

    def action(self, action: int) -> float:
        if self.is_done():
            raise Exception("Game is over")
        self.step_left -= 1
        return random.random()

In [4]:
class Agent:
    def __init__(self):
        self.total_reward = 0.0

    def step(self, env: Environment):
        current_obs = env.get_observation()
        actions = env.get_actions()
        reward = env.action(random.choice(actions))
        self.total_reward += reward

In [5]:
if __name__ == "__main__":
    env = Environment()
    agent = Agent()
    while not env.is_done():
        agent.step(env)
    print("Total Reward got: %.4f" % agent.total_reward)

Total Reward got: 4.6821


## The CartPole session

In [6]:
import gym

In [7]:
e = gym.make("CartPole-v1")

In [8]:
obs = e.reset()

In [9]:
obs

array([-0.01965713,  0.01632802,  0.03129084,  0.0032428 ], dtype=float32)

In [10]:
# where 0 means pushing the platform to the left and 1 means to the right
e.action_space

Discrete(2)

In [11]:
e.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [12]:
# analize tuple
# new observation
# reward
# the done flag
# extra information

e.step(0)

(array([-0.01933057, -0.17922838,  0.0313557 ,  0.30563182], dtype=float32),
 1.0,
 False,
 {})

In [13]:
e.action_space.sample()

1

In [14]:
e.action_space.sample()

1

In [15]:
e.observation_space.sample()

array([-2.8465040e+00,  2.5458939e+38, -8.7421993e-03, -9.7522957e+37],
      dtype=float32)

## Random CartPole agent

In [17]:
if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    total_reward = 0.0
    total_steps = 0
    obs = env.reset()

    while True:
        action = env.action_space.sample()
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        total_steps += 1
        if done:
            break

    print("Episode done in %d steps, total reward %.2f" % (total_steps, total_reward))

Episode done in 17 steps, total reward 17.00


## Gym Wrappers

In [18]:
from typing import TypeVar

In [19]:
Action = TypeVar("Action")

In [20]:
class RandomActionWrapper(gym.ActionWrapper):
    def __init__(self, env, epsilon=0.1):
        super(RandomActionWrapper, self).__init__(env)
        self.epsilon = epsilon

    def action(self, action: Action) -> Action:
        if random.random() < self.epsilon:
            print("Random!!")
            return self.env.action_space.sample()
        return action

In [23]:
try:
    import pyvirtualdisplay

    display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
except ImportError:
    pass

In [21]:
def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return (patch,)


def plot_animation(frames, repeat=False, interval=40):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis("off")
    anim = animation.FuncAnimation(
        fig,
        update_scene,
        fargs=(frames, patch),
        frames=len(frames),
        repeat=repeat,
        interval=interval,
    )
    plt.close()
    return anim

In [24]:
if __name__ == "__main__":
    env = RandomActionWrapper(gym.make("CartPole-v1"))
    obs = env.reset()
    total_reward = 0.0
    frames = []
    while True:
        obs, reward, done, _ = env.step(0)
        img = env.render(mode="rgb_array")
        frames.append(img)
        total_reward += reward
        if done:
            break
    print("Reward got: %.2f" % total_reward)

Reward got: 9.00


ALSA lib confmisc.c:767:(parse_card) cannot find card '0'
ALSA lib conf.c:4732:(_snd_config_evaluate) function snd_func_card_driver returned error: No such file or directory
ALSA lib confmisc.c:392:(snd_func_concat) error evaluating strings
ALSA lib conf.c:4732:(_snd_config_evaluate) function snd_func_concat returned error: No such file or directory
ALSA lib confmisc.c:1246:(snd_func_refer) error evaluating name
ALSA lib conf.c:4732:(_snd_config_evaluate) function snd_func_refer returned error: No such file or directory
ALSA lib conf.c:5220:(snd_config_expand) Evaluate error: No such file or directory
ALSA lib pcm.c:2642:(snd_pcm_open_noupdate) Unknown PCM default


In [25]:
plot_animation(frames)