In [1]:
!pip install gym[atari,accept-rom-license]
!pip install moviepy
!pip install pysdl2
!pip install pyvirtualdisplay

Collecting gym[accept-rom-license,atari]
  Using cached gym-0.26.2-py3-none-any.whl
Collecting gym-notices>=0.0.4
  Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)
Collecting ale-py~=0.8.0
  Using cached ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (1.6 MB)
Collecting autorom[accept-rom-license]~=0.4.2
  Using cached AutoROM-0.4.2-py3-none-any.whl (16 kB)
Collecting AutoROM.accept-rom-license
  Using cached AutoROM.accept_rom_license-0.4.2-py3-none-any.whl
Installing collected packages: gym-notices, gym, ale-py, AutoROM.accept-rom-license, autorom
Successfully installed AutoROM.accept-rom-license-0.4.2 ale-py-0.8.0 autorom-0.4.2 gym-0.26.2 gym-notices-0.0.8
Collecting moviepy
  Using cached moviepy-1.0.3-py3-none-any.whl
Collecting proglog<=1.0.0
  Using cached proglog-0.1.10-py3-none-any.whl (6.1 kB)
Collecting decorator<5.0,>=4.0.2
  Using cached decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Collecting imageio-ffmpeg>=0.2.0
  Using cached imageio_

In [2]:
!pip install opencv-python

Collecting opencv-python
  Using cached opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.9 MB)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.6.0.66


In [3]:
import gym
import random 
import torch
import numpy as np
from torch import nn
from torch.distributions import Categorical

from gym import wrappers
import matplotlib.pyplot as plt
from IPython import display
from tqdm.notebook import tqdm

env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array")
# record the game as as an mp4 file
env = wrappers.RecordVideo(env, 'video/pong-base')

A.L.E: Arcade Learning Environment (version 0.8.0+919230b)
[Powered by Stella]
  logger.warn(


In [4]:
env.unwrapped.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']

In [5]:
from collections import deque
import cv2
cv2.ocl.setUseOpenCL(False)


class LazyFrames(object):
    def __init__(self, frames):
        """This object ensures that common frames between the observations are only stored once.
        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
        buffers.
        This object should only be converted to numpy array before being passed to the model.
        You'd not believe how complex the previous solution was."""
        self._frames = frames
        self._out = None

    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=2)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack k last frames.
        Returns lazy array, which is much more memory efficient.
        See Also
        --------
        baselines.common.atari_wrappers.LazyFrames
        """
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)

    def reset(self):
        ob, info = self.env.reset()
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob(), info

    def step(self, action):
        ob, reward, done, truncated, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, truncated, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return LazyFrames(list(self.frames))


class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env):
        """Warp frames to 84x84 as done in the Nature paper and later work."""
        gym.ObservationWrapper.__init__(self, env)
        self.width = 84
        self.height = 84
        self.observation_space = gym.spaces.Box(low=0, high=255,
            shape=(self.height, self.width, 1), dtype=np.uint8)

    def observation(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]

class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env=None):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        super(EpisodicLifeEnv, self).__init__(env)
        self.lives = 0
        self.was_real_done = True
        self.was_real_reset = False

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives and lives > 0:
            # for Qbert somtimes we stay in lives == 0 condtion for a few frames
            # so its important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, truncated, info

    def reset(self):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_done:
            obs, info = self.env.reset()
            self.was_real_reset = True
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _, info = self.env.step(0)
            self.was_real_reset = False
        self.lives = self.env.unwrapped.ale.lives()
        return obs, info
    
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, truncated, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break

        max_frame = np.max(np.stack(self._obs_buffer), axis=0)

        return max_frame, total_reward, done, truncated, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs, info = self.env.reset()
        self._obs_buffer.append(obs)
        return obs, info

    
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env=None, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        super(NoopResetEnv, self).__init__(env)
        self.noop_max = noop_max
        self.override_num_noops = None
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset()
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = np.random.randint(1, self.noop_max + 1)
        assert noops > 0
        obs = None
        info = None
        for _ in range(noops):
            obs, _, done, _, info = self.env.step(0)
            if done:
                obs, info = self.env.reset()
        return obs, info

In [6]:
def make_env(env, stack_frames=True, episodic_life=True):
    if episodic_life:
        env = EpisodicLifeEnv(env)

    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    #if 'FIRE' in env.unwrapped.get_action_meanings():
    #    env = FireResetEnv(env)

    env = WarpFrame(env)
    if stack_frames:
        env = FrameStack(env, 4)
  
    return env


In [7]:
env = make_env(env)

In [8]:
def render(env):
    img.set_data(env.render())
    display.display(plt.gcf())
    display.clear_output(wait=True)

In [55]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

class ActorCritic(nn.Module):
 
    def __init__(self, in_channels=4, n_actions=14):
        """
        Initialize Actor Network
        Args:
            in_channels (int): number of input channels
            n_actions (int): number of outputs
        """
        super(ActorCritic, self).__init__()
        self.actor_features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 256),
            nn.ReLU()
        )
        self.critic_features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 256),
            nn.ReLU()
        )
        self.actor_head = nn.Linear(256, n_actions)
        self.critic_head = nn.Linear(256, 1)
        
    def forward(self, x):
        x = x.float() / 255
        x_a = self.actor_features(x)
        x_c = self.critic_features(x)
        return F.softmax(self.actor_head(x_a), dim=-1), self.critic_head(x_c)

In [56]:
from collections import namedtuple
import random

SavedAction = namedtuple('SavedAction', 
                        ('log_prob', 'value'))


In [65]:
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorCriticPolicy:
    
    GAMMA = 0.99
    EPS = 1e-7
    RENDER = False

    
    def __init__(self, lr=1e-4):
        self.n_actions = 4
        self.steps_done = 0
        self.mean_reward = None
        self.model = ActorCritic(n_actions=self.n_actions).to(device)
        self.saved_actions = []
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.postfix = {}
        
    def get_state(self, obs):
        state = np.array(obs)
        state = state.transpose((2, 0, 1))
        state = torch.from_numpy(state)
        return state.unsqueeze(0).to(device)

    def __call__(self, observation):
        state = self.get_state(observation)
        probs, value = self.model(state)
        
        m = Categorical(probs)
        action = m.sample()
        
        self.saved_actions.append(SavedAction(m.log_prob(action), value))
        self.steps_done += 1
        
        return action.item()
  
        
    def init_game(self, observation):
        self.state = self.get_state(observation)
        self.total_reward = 0.0
        self.rewards = []
    
    def update(self, obs, reward, terminated, truncated, info, pbar):
        self.total_reward += reward
        self.rewards.append(reward)
        if not terminated:
            self.next_state = self.get_state(obs)
        else:
            if self.mean_reward is None:
                self.mean_reward = self.total_reward
            else:
                self.mean_reward = self.mean_reward * 0.95 + self.total_reward * (1.0 - 0.95)
            self.postfix['total_reward'] = self.total_reward
            self.postfix['mean_reward'] = self.mean_reward
            self.postfix['steps'] = self.steps_done
            pbar.set_postfix(self.postfix)
            self.next_state = None
            
            self.finish_episode()    
            
        if self.steps_done % 100_000 == 0:
            self.save(f'model_{self.steps_done}.pt')

    def finish_episode(self):
        R = 0
        policy_losses = []
        value_losses = []
        returns = []
        
        for r in self.rewards[::-1]:
            if r != 0:
                # Game boundary (Pong specific) !
                R = 0
            R = r + ActorCriticPolicy.GAMMA * R
            returns.insert(0, R)
            
        returns = torch.tensor(returns)
        #returns = (returns - returns.mean()) / (returns.std() + ActorCriticPolicy.EPS)
        
        for (log_prob, value), R in zip(self.saved_actions, returns):
            advantage = R - value.item()

            # calculate actor (policy) loss
            policy_losses.append(-log_prob * advantage)

            # calculate critic (value) loss using L1 smooth loss
            value_losses.append(F.smooth_l1_loss(value, torch.tensor([R]).unsqueeze(0).to(device)))
            
        self.optimizer.zero_grad()
        # sum up all the values of policy_losses and value_losses
        policy_loss = torch.stack(policy_losses).sum()
        value_loss = torch.stack(value_losses).sum()
        loss = policy_loss + value_loss
        
        self.postfix['policy_loss'] = policy_loss.item()
        self.postfix['value_loss'] = value_loss.item()
        self.postfix['loss'] = loss.item()

        # perform backprop
        loss.backward()
        for param in self.model.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        # reset rewards and action buffer
        del self.rewards[:]
        del self.saved_actions[:]

        
    def load(self, PATH):
        checkpoint = torch.load(PATH)
        print(checkpoint.keys())
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.steps_done = checkpoint['steps_done']
        if "mean_reward" in checkpoint:
            self.mean_reward = checkpoint['mean_reward']
        
    def save(self, PATH):
        state = {
                    'steps_done': self.steps_done,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'mean_reward': self.mean_reward
        }
        print(state.keys())
        torch.save(state, PATH)

In [66]:
policy = ActorCriticPolicy(lr=1e-5)
#policy.load("ac_model.pt")

In [None]:
observation, info = env.reset()
policy.init_game(observation)

plt.ion()
plt.axis('off')
img = plt.imshow(env.render())

with tqdm(total=10000) as pbar:
    while True:
        try:
            action = policy(observation)
            observation, reward, terminated, truncated, info = env.step(action)
            #render(env)
            policy.update(observation, reward, terminated, truncated, info, pbar)

            if terminated or truncated:
                pbar.update()
                observation, info = env.reset()
                policy.init_game(observation)
                
        except KeyboardInterrupt:
            break
env.close()

  0%|          | 0/10000 [00:00<?, ?it/s]

dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dic

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [19]:
policy.save("ac_model.pt")

dict_keys(['steps_done', 'model_state_dict', 'optimizer_state_dict', 'mean_reward'])
