In [1]:
import numpy as np
import numpy.typing as npt
import gym

from typing import List, Tuple, Literal, Any, Optional, cast, Callable, Union
import plotly.graph_objects as go
from utils.agent import Agent
from tqdm.autonotebook import tqdm
from utils.algorithm import AlgorithmInterface
from utils.preprocess import PreprocessInterface
import torch
from collections import deque
from torchvision import transforms
import math
from torch import nn
from gym.wrappers import FrameStack
from gym.spaces import Box
import sys
from torchvision import transforms as T
from copy import deepcopy
from utils.common import Step, Episode, TransitionGeneric
from gym.utils.play import play


pygame 2.1.2 (SDL 2.0.16, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


  logger.warn("failed to set matplotlib backend, plotting will not work: %s" % str(e))


In [2]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7fdfbcad4110>

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
env = gym.make("BoxingNoFrameskip-v4")
env.seed(RANDOM_SEED)
TOTAL_ACTIONS = env.action_space.n


A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]


In [5]:
o = env.reset()
(o.shape, o.dtype)

((210, 160, 3), dtype('uint8'))

In [6]:
env.observation_space.low

array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]], dtype=uint8)

In [7]:
env.reward_range

(-inf, inf)

In [8]:
TOTAL_ACTIONS

6

In [9]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env: gym.Env, skip: int):
        assert skip >= 0

        """Return only every `skip`-th frame"""
        super().__init__(env)

        self.env = env
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        done = False
        obs = None
        info = None

        for _ in range(self._skip + 1):
            # Accumulate reward and repeat the same action
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info


class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

        self.obs_shape = env.observation_space.shape[:2]
        self.observation_space = Box(
            low=0, high=255, shape=(1,) + self.obs_shape, dtype=np.uint8)

        self.transform = T.Grayscale()

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        observation = self.transform(observation)
        assert observation.shape == (1,) + self.obs_shape
        return observation


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, _shape: Union[int, Tuple[int, int]]):
        super().__init__(env)

        self.env = env

        if isinstance(_shape, int):
            shape = (_shape, _shape)
        else:
            shape = _shape

        self.obs_shape = self.observation_space.shape[0:1] + shape

        # obs_low = self.observation_space.low
        # obs_high = self.observation_space.high

        self.observation_space = Box(
            low=0, high=255, shape=self.obs_shape, dtype=np.uint8)

        self.transforms = T.Compose(
            [T.Resize(shape)]
        )

    def observation(self, observation):
        observation = self.transforms(observation)
        assert observation.shape == self.observation_space.shape
        return observation


In [10]:
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, 84)
env = FrameStack(env, num_stack=4)
env


<FrameStack<ResizeObservation<GrayScaleObservation<SkipFrame<TimeLimit<AtariEnv<PongNoFrameskip-v4>>>>>>>

In [11]:
# shape is (210, 160, 3)
Observation = npt.NDArray[np.uint8]
Action = int

# shape is (4, 210, 160, 3)
State = torch.Tensor
Reward = int

Transition = TransitionGeneric[State, Action]

In [12]:
o = env.reset()
(o.shape,o.dtype, o)

((4, 1, 84, 84),
 torch.float32,
 <gym.wrappers.frame_stack.LazyFrames at 0x7f800868f4a0>)

In [13]:
EVALUATION_TIMES = 30
rwds: List[float] = []

obs: List[Any] = []
for _ in range(1):

    end = False

    rwd = 0
    while not end:
        (o, r, end, _) = env.step(env.action_space.sample())
        obs.append(o)
        # i += 1
        # env.render()
        if r != 0:
            print(r)
        rwd += r
        if end:
            print(f"at end: {r}")

        # if end:
        #     rwds.append(np.sum([r if r is not None else 0 for (_,
        #                                                        _, r) in cast(Episode, episode)]))
    rwds.append(rwd)


-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
at end: -1.0


In [22]:
obs[0].frame()

AttributeError: 'LazyFrames' object has no attribute 'frame'

In [15]:
rwds

[-21.0]

In [16]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(x=[i + 1 for i in range(len(rwds))],
               y = [r for r in rwds])
)
# fig.update_yaxes(type="log")
# fig.update_layout(yaxis_type="log")
fig.show()

