In [1]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
from collections import deque
from pathlib import Path
import gymnasium as gym
import numpy as np
import random
import ale_py
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [2]:
print(gym.__version__, torch.__version__, torch.get_num_threads(), torch.get_num_interop_threads(), sep='\n')

1.0.0
2.6.0+cpu
10
10


In [3]:
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("runs", exist_ok=True)

In [4]:
def remake_env(render_mode=None):
    global env
    if 'env' in globals(): 
        env.reset(), env.close()
        del env
    env = gym.make('Breakout-ramDeterministic-v4', render_mode=render_mode)  # default Breakout-ramDeterministic-v4 has frameskip of 4
    return env.reset()

In [5]:
def get_checkpoint(v=-1, path='./checkpoints'):
    ''' If found returns (file_name, version). Otherwise, returns (None, 0). '''
    ls = os.listdir(path) 
    if not ls: return (None, 0)
    mx = -1
    mx_file = ''
    for f in ls:
        try: cur = int(f.split('-')[0])  # might be a 'mem-...' file
        except: continue
        if cur > mx:
            mx = cur
            mx_file = f 
        if cur == v: return f, v
    return mx_file, mx

In [6]:
class ReplayBuffer:  # we will store frames in uint8 to save memory. the forward pass of the network can convert it to float32 if desired
    def __init__(self, maxlen):
        self.buffer = deque(maxlen=maxlen)

    def __len__(self): 
        return len(self.buffer)

    def push(self, state, action, reward, next_state, terminated):
        self.buffer.append((state, action, reward, next_state, terminated))

    def sample(self, batch_size):
        assert batch_size <= len(self), 'sample size is greater than population of buffer'
        states, actions, rewards, next_states, terminateds = zip(*random.sample(self.buffer, batch_size))  # without replacement
        return np.stack(states), np.array(actions), np.array(rewards), np.stack(next_states), np.array(terminateds)

In [7]:
class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(128, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 4)
        
    def forward(self, x):  # expect uint8 tensor as input
        x = x / 255.0
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cpu


In [9]:
online_net = QNetwork()
target_net = QNetwork()
target_net.eval()
online_net.to(device)
target_net.to(device)

def sync(): target_net.load_state_dict(online_net.state_dict())
sync()

# Hyperparameters

In [10]:
batch_size = 32
memory = ReplayBuffer(1_000_000)
sync_freq = 10_000  # environment steps
gamma = 0.99
learn_freq = 4
learning_rate = 0.00025
eps_max = 1.0  # initial epsilon
eps_min = 0.1  # final
eps_anneal_steps = 1_000_000
learning_starts = 50_000  # uniform random policy is run for X steps before learning starts
# noop_max = 30  # might be better off just not training on no-op starts
reward_clip = (-1, 1)
max_steps = 108000

optimizer = optim.AdamW(online_net.parameters(), lr=learning_rate)
criterion = nn.HuberLoss(delta=1.0)

In [11]:
def get_epsilon():  # epsilon schedule
    effective_steps = step_count
    return max(eps_min, eps_min + (eps_max - eps_min) * (1 - effective_steps / eps_anneal_steps))  # linear annealing

In [12]:
# load latest checkpoint file if there is one, set step_count, episode_count
proj_name = 'breakout-ram_03-03'

checkpoint_file, checkpoint_version = get_checkpoint()
# checkpoint_version = 12
if checkpoint_file is not None:
    checkpoint = torch.load('./checkpoints/' + checkpoint_file, weights_only=False, map_location=torch.device(device))
    online_net.load_state_dict(checkpoint['online_state_dict'])
    target_net.load_state_dict(checkpoint['target_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    episode_count, step_count = checkpoint.get('episode_count'), checkpoint.get('step_count')
    try: memory = torch.load('./checkpoints/mem-' + checkpoint_file, weights_only=False)
    except FileNotFoundError: pass
    print('loaded', checkpoint_file)
else: 
    episode_count, step_count = 0, 0
    print('no checkpoint found')
    
prev_mem_name = '--sentinel--'
def save_checkpoint():
    global prev_mem_name
    name = f'./checkpoints/{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    mem_name = f'./checkpoints/mem-{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    checkpoint = {
        'online_state_dict': online_net.state_dict(),
        'target_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step_count': step_count,
        'episode_count': episode_count
    }
    torch.save(checkpoint, name)
    torch.save(memory, mem_name)
    if os.path.exists(prev_mem_name): os.remove(prev_mem_name)
    prev_mem_name = mem_name
    print('saved', name)

no checkpoint found


In [17]:
def greedy(s):
    with torch.no_grad():
        s = torch.tensor(s, dtype=torch.float, device=device)
        return online_net(s).argmax().item()

def epsilon_greedy(s):
    return env.action_space.sample() if np.random.random() < get_epsilon() else greedy(s)

In [14]:
def do_action(action, update_memory=True):
    global state
    next_state, reward, terminated, truncated, info = env.step(action)
    if update_memory:
        memory.push(state, action, reward, next_state, terminated)
    state = next_state
    reward = np.clip(reward, *reward_clip)
    return reward, terminated, truncated

In [15]:
print(f'checkpoint {checkpoint_file}')
print(f'checkpoint version {checkpoint_version}')
print(f'episode_count {episode_count}')
print(f'step_count {step_count}')

checkpoint None
checkpoint version 0
episode_count 0
step_count 0


In [19]:
train_episodes = 1_000

writer = SummaryWriter(log_dir='./runs')
return_history = deque(maxlen=100)
remake_env(None)
start_episode, start_step = episode_count, step_count

try:
    t0 = time.time()
    while len(memory) < learning_starts:
        state, info = env.reset()
        for step in range(max_steps):
            reward, terminated, truncated = do_action(env.action_space.sample())
            if len(memory) % 100 == 0:
                clear_output(wait=True)
                print(f'collecting initial training samples {len(memory)}/{learning_starts} ({time.time() - t0:.2f} s)')
            if terminated or truncated:
                break
                
    t0 = time.time()
    for episode in range(train_episodes):
        state, info = env.reset()
        episode_return = 0
        episode_count += 1
        
        for step in range(max_steps):
            action = epsilon_greedy(state)
            reward, truncated, terminated = do_action(action)
            episode_return += reward
            step_count += 1
            
            if step_count % sync_freq == 0:  # update target network
                sync()
                
            if step_count % learn_freq == 0:  # update online net
                states, actions, rewards, next_states, terminateds = memory.sample(batch_size)
                states = torch.tensor(states, dtype=torch.float, device=device)            # (m, 128)
                next_states = torch.tensor(next_states, dtype=torch.float, device=device)  # (m, 128)
                actions = torch.tensor(actions, dtype=torch.long, device=device).reshape(-1, 1)           # (m, 1)
                rewards = torch.tensor(rewards, dtype=torch.float, device=device).reshape(-1, 1)          # (m, 1)
                terminateds = torch.tensor(terminateds, dtype=torch.float, device=device).reshape(-1, 1)  # (m, 1)
                
                pred = online_net(states).gather(1, actions)  # predicted Q-values of the selected action
                
                with torch.no_grad():
                    y = rewards + gamma * target_net(next_states).max(axis=1, keepdim=True).values * (1 - terminateds)
                
                loss = criterion(pred, y)  # don't need to detach but y.requires_grad is False
                
                optimizer.zero_grad()
                loss.backward()
                total_norm = torch.nn.utils.get_total_norm(online_net.parameters())
                writer.add_scalar('loss', loss.item(), step_count)
                writer.add_scalar('total_norm', total_norm, step_count)
                optimizer.step()
            
            if terminated or truncated:
                break

        writer.add_scalar('episode_steps', step, step_count)
        writer.add_scalar('episode_return', episode_return, step_count)
        writer.add_scalar('epsilon', get_epsilon(), step_count)
        return_history.append(episode_return)

        if episode % 500 == 0 and episode != 0:
            checkpoint_version += 1
            save_checkpoint()
            writer.flush()

        if episode % 10 == 0 or episode == train_episodes-1:
            tt = time.time() - t0
            et = episode_count - start_episode 
            st = step_count - start_step 
            # plt.plot(return_history)
            
            clear_output(wait=True)
            print(f'episode {et}\t({et/tt:.1f}/s)  [total {episode_count}]')
            print(f'step {st}\t({st/tt:.0f}/s)  [total {step_count}]')
            print(f'time {tt:.2f} s')
            print('---')
            print(f'avg. return: {np.mean(return_history):.5f}  (last 100 episodes)')
            print(f'epsilon {get_epsilon():.5f}')
            # plt.show()

except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    if episode_count - start_episode > 20:
        checkpoint_version += 1
        save_checkpoint()
    writer.close()
    env.reset()
    env.close()

episode 1000	(0.5/s)  [total 3000]
step 366484	(182/s)  [total 791651]
time 2014.70 s
---
avg. return: 9.30000  (last 500 episodes)
epsilon 0.28751
saved ./checkpoints/6-breakout-ram_03-03-3000e-791651s.pth
