## Install Libraries and Dependencies

#### Virtual Display

In [2]:
%%capture
!apt install python-opengl # Python binding to OpenGL and related APIs
!pip install pyglet==1.5.1 
!apt install python-opengl
!apt install ffmpeg
!apt install xvfb
!pip3 install pyvirtualdisplay

# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

#### Gym [Atari]

In [3]:
%%capture
!pip install gym[atari,accept-rom-license]==0.21.0

## Imports

In [4]:
import torch
import numpy as np
import cv2

import gym

## Gym Environment

In [5]:
ENV_NAME = "BreakoutNoFrameskip-v4"
env = gym.make(ENV_NAME)

In [6]:
OBSERVATION_SHAPE = env.observation_space.shape
ACTION_SPACE_SIZE = env.action_space.n

#### Environment Wrappers

Wrappers doing the following will be implemented:
- Downscalling the observation
- Grayscaling the observation
- Preprocessing the observation to `n` sequence of past frames
- Perform `no-op` operations at the start of an episode
- Rescaling the reward to discrete values -1, 0, 1
- Perform Frame Skipping
- Convert the observations to PyTorch Image shapes and `torch.Tensor` object
- Make atari end-of-life the end-of-episode, and reset only on game-over.


In [7]:
cv2.ocl.setUseOpenCL(False) # function

# Ref: Extended Data Table 1 | List of hyperparameters and their values
class NoOpOnReset(gym.Wrapper):
  def __init__(self, env, noop_max=30, noop_action=0):
    """Applies `noop_action` for a random number of steps in the range 
    `[1, noop_max]` at the start of an episode
    """
    super().__init__(env)
    
    self.noop_max = noop_max
    self.noop_action = noop_action

    assert env.unwrapped.get_action_meanings()[noop_action] == "NOOP", "Action meaning for noop_action doesn't match 'NOOP'"
    assert noop_max >= 1, "noop_max must be >= 1"

  def reset(self, **kwargs):
    """Do no-op action for a random number of steps in [1, noop_max]"""
    num_noop_actions = self.unwrapped.np_random.integers(1, self.noop_max + 1)

    obs = self.env.reset(**kwargs)

    for _ in range(num_noop_actions):
      obs, _, done, _ = self.env.step(self.noop_action)

      if done:
        obs = self.env.reset(**kwargs)

    return obs

  def step(self, action):
    """Perform a single `action` in the environment"""
    return self.env.step(action)  # needed?

In [8]:
class FireOnReset(gym.Wrapper):
  def __init__(self, env):
    """Take action on reset for environments that are fixed until firing."""
    super().__init__(env)
    
    assert env.unwrapped.get_action_meanings()[1] == "FIRE", "Action 1 should be 'FIRE'"
    
    # there should be at least one more action apart from NOOP and FIRE
    assert len(env.unwrapped.get_action_meanings()) >= 3

  def reset(self, **kwargs):
    self.env.reset(**kwargs)

    obs, _, done, info =  self.env.step(1)
    if done:
      self.env.reset(**kwargs)
    
    # why take action 2? (this follows from openai baselines implementation)
    obs, _, done, info = self.env.step(2)
    if done:
      self.env.reset(**kwargs)

    return obs

  def step(self, act):
    return self.env.step(act)

In [9]:
class EpisodicLifeEnv(gym.Wrapper):
  def __init__(self, env):
    """Make end-of-life == end-of-episode, but only resets on true game over"""
    super().__init__(env)

    # number of lives the agent has left
    self.lives = env.unwrapped.ale.lives()
    self.real_done = False  # true game over

    assert env.unwrapped.get_action_meanings()[0] == "NOOP", "Meaning of action 0 is not NOOP"

  def step(self, act):
    obs, rew, done, info = self.env.step(act)
    
    # if env returns done=True, then it's game over
    self.real_done = done

    # get how many the agent have left
    lives = self.env.unwrapped.ale.lives()

    # if there is a reduction in `lives` compared with previous value stored
    # in `self.lives`, that's an end-of-life transition
    if lives < self.lives and lives > 0:
      # for Qbert sometimes we stay in lives == 0 condtion for a few frames
      # so its important to keep lives > 0
      done = True

    # update the number of lives the agent has left
    self.lives = lives
    return obs, rew, done, info

  def reset(self, **kwargs):
    """Reset only when lives are exhausted.
    This way all states are still reachable even though lives are episodic,
    and the learner need not know about any of this behind-the-scenes.
    """
    if self.real_done:
      obs = self.env.reset(**kwargs)
    else:
      # take no-op action to advance from lost life state
      obs, _, _, _ = self.env.step(0)

    # update the number of lives the agent has left
    # in-case the no-op action led to end-of-life
    self.lives = self.env.unwrapped.ale.lives()
    return obs
  

In [10]:
class FrameSkip(gym.Wrapper): # MaxAndSkipEnv
  def __init__(self, env, skip=4):
    """Return only every `skip`-th frame"""
    super().__init__(env)
    self.skip = skip
    self.obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)

    assert skip >= 1, "Number of frames to skip should be at least 1"

  def step(self, act):
    """Repeat action for `skip` frames, sum reward, and max over last observations."""
    total_reward = 0.0
    for i in range(self.skip):
      obs, rew, done, info = self.env.step(act)
      if i == self.skip - 2: self.obs_buffer[0] = obs
      if i == self.skip - 1: self.obs_buffer[1] = obs

      total_reward += rew

      if done:
        # note that the observervation after done does not matter
        # so it doesn't matter if we have refilled the obs_buffer or not
        break

    max_frame = self.obs_buffer.max(axis=0)
    return max_frame, total_reward, done, info

  def reset(self, **kwargs):
    return self.env.reset(**kwargs)

  def reset(self):
    return self.env.reset()


In [11]:
class ClipReward(gym.RewardWrapper):
  def __init__(self, env):
    super().__init__(env)

  def reward(self, rew):
    """Bin reward to {+1, 0, -1} by its sign.
    Return -1 if rew < 0, 0 if rew == 0, +1 if rew > 0
    """
    return np.sign(rew)

In [12]:
from gym import spaces

class WarpFrame(gym.ObservationWrapper):
  def __init__(self, env):
    """Convert frames to grayscale and downscale to 84x84.
    Expects inputs to be of shape height x width x num_channels.
    """
    super().__init__(env)
    self.width = 84
    self.height = 84
    self.observation_space = spaces.Box(low=0, high=255, 
                                        shape=(self.height, self.width, 1), 
                                        dtype=np.uint8)

  def observation(self, frame):
    """Convert frame to grayscale and downsscale to 84x84
    Return observation with shape 84x84x1
    """
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
    return frame[:, :, None]

In [24]:
from collections import deque

class StackFrame(gym.Wrapper):  #### TODO: Review
  def _init__(self, env, k):
    """Stack k last frames.
    Returns lazy array, which is much more memory efficient.
    Expects inputs to be of shape num_channels x height x width.
    """
    super().__init__(env)

    self.k = k
    self.frames = deque([], maxlen=k)
    frame_shape = env.observation_space.shape
    self.observation_space = spaces.Box(low=0, high=255, 
                                        shape=(frame_shape[0]*k, frame_shape[1], frame_shape[2]),
                                        dtype=np.uint8)

  def step(self, act):
    obs, rew, done, info = self.env.step(act)
    self.frames.append(obs)
    return self.get_obs(), rew, done, info

  def reset(self, **kwargs): ###### watch **kwargs
    obs = self.env.reset(**kwargs)

    for _ in range(self.k):
      self.frames.append(obs)
    
    return self.get_obs()

  def get_obs(self):
    assert len(self.frames) == self.k, "Length of frame stack array less than k"
    return LazyFrames(list(self.frames))
    

In [23]:
class ScaleFrameToFloat(gym.ObservationWrapper):
  def __init__(self, env):
    super.__init__(env)

    self.observation_space = spaces.Box(low=0, high=1, 
                                        shape=env.observation_space.shape, 
                                        dtype=np.float32)

  def observation(self, obs):
    return np.array(obs).astype(np.float32)/255.0

In [None]:
class LazyFrames(object):
  def __init__(self, frames):
    self.frames = frames

  def __array__(self, dtype=None):
    out = np.concatenate(self.frames, axis=0)
    if dtype is not None:
      out = out.astype(dtype)
    
    return out

  def __len__(self):
    return len(self.frames)

  def __getitem__(self, i):
    return self.frames[i]
    

In [None]:
class PyTorchImageFrame(gym.ObservationWrapper):
  def __init__(self, frames):
    """Image shape to num_channels x height x width"""
    super.__init__(env)
    frame_shape = env.observation_space.shape
    self.observation_shape = spaces.Box(low=0, high=1, 
                                        shape=((frame_shape[-1],) + frame_shape[:-1]),
                                        dtype=np.float32)

  def observation(self, obs):
    return np.moveaxis(obs,2,0)
