In [18]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
from collections import deque
from pathlib import Path
import gymnasium as gym
import numpy as np
import random
import ale_py
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [83]:
print(gym.__version__, torch.__version__, torch.get_num_threads(), torch.get_num_interop_threads(), sep='\n')

1.0.0
2.6.0+cpu
10
10


In [84]:
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("runs", exist_ok=True)

In [62]:
def remake_env(render_mode=None):
    global env
    if 'env' in globals(): 
        env.reset(), env.close()
        del env
    env = gym.make('Breakout-ramDeterministic-v4', render_mode=render_mode)  # default Breakout-ramDeterministic-v4 has frameskip of 4
    return env.reset()

In [74]:
class ReplayBuffer:  # we will store frames in uint8 to save memory. the forward pass of the network can convert it to float32 if desired
    def __init__(self, maxlen):
        self.buffer = deque(maxlen=maxlen)

    def __len__(self): 
        return len(self.buffer)

    def push(self, state, action, reward, next_state, terminated):
        self.buffer.append((state, action, reward, next_state, terminated))

    def sample(self, batch_size):
        assert batch_size <= len(self), 'sample size is greater than population of buffer'
        states, actions, rewards, next_states, terminateds = zip(*random.sample(self.buffer, batch_size))  # without replacement
        return np.stack(states), np.array(actions), np.array(rewards), np.stack(next_states), np.array(terminateds)

In [75]:
class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(128, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 4)
        
    def forward(self, x):  # expect uint8 tensor as input
        x = x / 255.0
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cpu


In [81]:
online_net = QNetwork()
target_net = QNetwork()
target_net.eval()
online_net.to(device)
target_net.to(device)

def sync(): target_net.load_state_dict(online_net.state_dict())
sync()

In [90]:
batch_size = 32
memory = ReplayBuffer(1_000_000)
sync_freq = 10_000  # environment steps
gamma = 0.99
learn_freq = 4
learning_rate = 0.00025
eps_max = 1.0  # initial epsilon
eps_min = 0.1  # final
eps_anneal_steps = 1_000_000
learning_starts = 50_000  # uniform random policy is run for X steps before learning starts
noop_max = 30
reward_clip = (-1, 1)
max_steps = 108000

optimizer = optim.AdamW(online_net.parameters(), lr=learning_rate)
criterion = nn.HuberLoss()

In [None]:
# load latest checkpoint file if there is one, set step_count, episode_count
proj_name = 'breakout-ram_03-03'

checkpoint_file, checkpoint_version = get_checkpoint()
# checkpoint_version = 12
if checkpoint_file is not None:
    checkpoint = torch.load('./checkpoints/' + checkpoint_file, weights_only=False, map_location=torch.device(device))
    online_net.load_state_dict(checkpoint['online_state_dict'])
    target_net.load_state_dict(checkpoint['target_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    episode_count, step_count = checkpoint.get('episode_count'), checkpoint.get('step_count')
    try: memory = torch.load('./checkpoints/mem-' + checkpoint_file, weights_only=False)
    except FileNotFoundError: pass
    print('loaded', checkpoint_file)
else: 
    episode_count, step_count = 0, 0
    print('no checkpoint found')
    
prev_mem_name = '--sentinel--'
def save_checkpoint():
    global prev_mem_name
    name = f'./checkpoints/{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    mem_name = f'./checkpoints/mem-{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    if os.path.exists(prev_mem_name): os.remove(prev_mem_name)
    prev_mem_name = mem_name
    checkpoint = {
        'online_state_dict': online_net.state_dict(),
        'target_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step_count': step_count,
        'episode_count': episode_count
    }
    torch.save(checkpoint, name)
    torch.save(memory, mem_name)
    print('saved', name)

In [64]:
state, info = remake_env(None)

In [70]:
action = 2
next_state, reward, terminated, truncated, info = env.step(action)
memory.push(state, action, reward, next_state, terminated)
memory.push(state, action, reward, next_state, terminated)
state = next_state

In [71]:
memory.sample(2)[0]

array([[ 63,  63,  63,  63,  63,  63, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 192, 192, 192, 192, 192, 192, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 240,   0,   0,
        255,   0,   0, 240,   0,   5,   0,   0,   6,   0,  70, 182, 134,
        198,  22,  38,  54,  70,  66,   2, 158,   0,   4,   0,   0,   0,
          0,   0,   0, 241,   0, 242,   0, 242,  25, 241,   5, 242,   4,
          0, 255,   0, 224,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   8,   0, 255, 255, 255, 255, 255, 255, 255,
          0,   0,   5,   0,   0, 186, 214, 117, 246, 219, 242],
       [ 63,  63,  63,  63,  63,  63, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 192, 192, 192, 192, 192, 192, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 24