In [1]:
from gymnasium.wrappers import Autoreset, TimeAwareObservation, FrameStackObservation, AtariPreprocessing, ClipReward
from collections import defaultdict, deque
from IPython.display import clear_output
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np
import random
import ale_py
import wandb
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

In [2]:
env_name = 'BreakoutNoFrameskip-v4'

In [14]:
def make_env(render_mode=None):
    env = gym.make(env_name, render_mode=render_mode, max_episode_steps=10_000)
    # env = TimeAwareObservation(env)
    # env = Autoreset(env)  # TimeAwareObservation before Autoreset or else timestep tracking will be erroneous!!
    env = AtariPreprocessing(env)
    env = FrameStackObservation(env, 4)
    env = ClipReward(env, -1, 1)
    return env

In [4]:
def make_envs(n_envs=32):  # >32 might cause memory issues
    envs = gym.make_vec(
        env_name, 
        max_episode_steps=10_000,
        num_envs=n_envs, 
        vectorization_mode='async', 
        vector_kwargs={
            # 'autoreset_mode': gym.vector.AutoresetMode.DISABLED,
        },
        wrappers=[  # make_vec autoresets by default
            AtariPreprocessing, 
            lambda env: FrameStackObservation(env, 4),
            lambda env: ClipReward(env, -1, 1),
        ]
    )
    return envs

In [5]:
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
        #   nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, ...)
            nn.Conv2d(4, 32, 8, 4),      # (m, 4, 84, 84) -> (m, 32, 20, 20)
            nn.SiLU(),
            nn.Conv2d(32, 64, 4, 2),     # (m, 32, 20, 20) -> (m, 64, 9, 9)
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, 1),     # (m, 64, 9, 9) -> (m, 64, 7, 7)
            nn.SiLU(),
            nn.Flatten(),                # (m, 64, 7, 7) -> (m, 3136)
            nn.Linear(3136, 512),        # (m, 3136) -> (m, 512)
            nn.SiLU()
        )        
        self.policy_head = nn.Linear(512, 4)  # -> output (m, 4)
        self.value_head = nn.Linear(512, 1)   # -> output (m, 1)

    @staticmethod
    def preprocess(x):
        x = x.float()
        if len(x.size()) == 3:
            x = x.unsqueeze(0)  # add the batch dim
        with torch.no_grad():
            x = x / 255.0  # should we scale to [0, 1]
        x.requires_grad_()
        return x

    def value_only(self, x):
        x = self.preprocess(x)
        x = self.backbone(x)
        value = self.value_head(x)
        return value
        
    def forward(self, x):
        x = self.preprocess(x)
        x = self.backbone(x)
        logits = self.policy_head(x)
        value = self.value_head(x)
        
        return logits, value

    def sample(self, state, stochastic=True):
        logits, value = self.forward(state)
        dist = Categorical(logits=logits)
        if stochastic:
            action = dist.sample()
        else:
            action = torch.argmax(logits, dim=-1)
        log_prob = dist.log_prob(action)
        
        return action, log_prob, value

In [16]:
env = make_env()
network = ActorCritic()

In [17]:
path = './saved/breakout-51853600s.pth'
data = torch.load(path, weights_only=False, map_location='cpu')
network.load_state_dict(data['network_state_dict'])

<All keys matched successfully>

In [24]:
best_reward = float('-inf')
best_frames = []
episode_rewards = 0
episode_frames = []

num_iterations = 100000

state = env.reset()[0]

for i in range(num_iterations):
    if i % 200 == 0:
        print('i', i, ' best', best_reward, ' cur', episode_rewards)
    state = torch.as_tensor(state)
    action, _, _ = network.sample(state, stochastic=False)
    state, reward, terminated, truncated, _ = env.step(action)
    frame = env.env.env.env.env.env.env.ale.getScreenRGB()
    episode_rewards += reward
    episode_frames.append(frame)
    if terminated or truncated:
        state = env.reset()[0]
        if episode_rewards > best_reward:
            best_reward = episode_rewards
            best_frames = episode_frames
        episode_rewards = 0
        episode_frames = []

i 0  best -inf  cur 0
i 200  best -inf  cur 7.0
i 400  best -inf  cur 24.0
i 600  best -inf  cur 53.0
i 800  best -inf  cur 72.0
i 1000  best -inf  cur 83.0
i 1200  best -inf  cur 83.0
i 1400  best -inf  cur 83.0
i 1600  best -inf  cur 83.0
i 1800  best -inf  cur 83.0
i 2000  best -inf  cur 83.0
i 2200  best -inf  cur 83.0
i 2400  best -inf  cur 83.0
i 2600  best 83.0  cur 3.0
i 2800  best 83.0  cur 13.0
i 3000  best 83.0  cur 25.0
i 3200  best 83.0  cur 58.0
i 3400  best 83.0  cur 80.0
i 3600  best 83.0  cur 83.0
i 3800  best 84.0  cur 2.0
i 4000  best 84.0  cur 9.0
i 4200  best 84.0  cur 19.0
i 4400  best 84.0  cur 39.0
i 4600  best 84.0  cur 54.0
i 4800  best 84.0  cur 0.0
i 5000  best 84.0  cur 8.0
i 5200  best 84.0  cur 28.0
i 5400  best 84.0  cur 55.0
i 5600  best 84.0  cur 74.0
i 5800  best 84.0  cur 83.0
i 6000  best 84.0  cur 83.0
i 6200  best 84.0  cur 83.0
i 6400  best 84.0  cur 83.0
i 6600  best 84.0  cur 83.0
i 6800  best 84.0  cur 83.0
i 7000  best 84.0  cur 83.0
i 7200  

KeyboardInterrupt: 

In [25]:
import imageio

if True:
    imageio.mimsave(f"{best_reward}_episode.gif", best_frames, fps=20)
    print(f"Saved best episode with reward {best_reward}")

Saved best episode with reward 90.0


In [26]:
best_frames.__len__()

2500