In [1]:
from gymnasium.wrappers import FrameStackObservation
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
from collections import deque
from datetime import datetime
from pathlib import Path
import gymnasium as gym
from tqdm import tqdm
from PIL import Image
import numpy as np
import random
import ale_py
import time
import cv2
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from utils import *

In [2]:
print(gym.__version__)
print(torch.__version__)
print(torch.get_num_threads())
print(torch.get_num_interop_threads())

1.0.0
2.6.0+cpu
10
10


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")

Training on device: cpu


In [4]:
def remake_env(render_mode=None):
    global env
    if 'env' in globals(): 
        env.close()
        del env
    env = gym.make('BreakoutDeterministic-v4', render_mode=render_mode)
    env = FrameStackObservation(env, 4)  # stack previous 4 frames to simulate motion and mend Markov property

In [5]:
# hyperparameters from Nature (in the same order shown in the paper, just different names)
# 'action repeat' not included because it's handled by env frameskip
# RMSProp gradient parameters not included because we're using AdamW
batch_size = 32
memory = ReplayBuffer(100_000)  # holding 1 million requires ~4*84*84*2*1e6 ~ 56 gb of memory so i'll use 100k instead - hopefully it's good
sync_freq = 10_000  # environment steps
gamma = 0.99
learn_freq = 4
learning_rate = 0.00025
eps_max = 1.0  # initial epsilon
eps_min = 0.1  # final
eps_anneal_steps = 1_000_000
learning_starts = 50_000  # uniform random policy is run for X steps before learning starts
noop_max = 30  # note: for breakout we should use action=1 (FIRE) instead of 0 so the ball releases

In [6]:
reward_clip = (-1, 1)
max_steps = 108000

In [18]:
def get_epsilon():  # epsilon schedule
    effective_steps = step_count - 14433239
    return max(eps_min, eps_min + (eps_max - eps_min) * (1 - effective_steps / eps_anneal_steps))  # linear annealing

In [8]:
online_net = QNetwork()
target_net = QNetwork()
target_net.eval()

online_net.to(device)
target_net.to(device)

def sync(): 
    target_net.load_state_dict(online_net.state_dict())
    
sync()

In [9]:
optimizer = optim.AdamW(online_net.parameters(), lr=learning_rate)  # probably fine not to load optimizer state
criterion = nn.MSELoss()

In [10]:
episode_count = 0
step_count = 0

In [13]:
# load latest checkpoint file if there is one
proj_name = 'breakout_feb_27'

checkpoint_file, checkpoint_version = get_checkpoint()
# checkpoint_version = 12
if True:
    checkpoint = torch.load('./checkpoints/' + checkpoint_file, weights_only=False, map_location=torch.device(device))
    online_net.load_state_dict(checkpoint['online_state_dict'])
    target_net.load_state_dict(checkpoint['target_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step_count = checkpoint.get('step_count')
    episode_count = checkpoint.get('episode_count')
    memory = ReplayBuffer(100_000)
    try:
        memory = torch.load('./checkpoints/mem-' + checkpoint_file, weights_only=False)
    except:
        pass
    print('loaded')
# else:
#     checkpoint_file = str(checkpoint_version) + '-' + proj_name

prev_mem_name = '--sentinel--'
def save_checkpoint():
    global prev_mem_name
    name = f'./checkpoints/{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    mem_name = f'./checkpoints/mem-{checkpoint_version}-{proj_name}-{episode_count}e-{step_count}s.pth'
    if os.path.exists(prev_mem_name):
        os.remove(prev_mem_name)
    prev_mem_name = mem_name
    checkpoint = {
        'online_state_dict': online_net.state_dict(),
        'target_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step_count': step_count,
        'episode_count': episode_count
    }
    torch.save(checkpoint, name)
    torch.save(memory, mem_name)  # don't save here to save space on kaggle
    print('saved', name)

loaded


In [14]:
def greedy(preprocessed_state):
    return online_net(preprocessed_state).argmax().item()

def epsilon_greedy(preprocessed_state):
    if np.random.random() < get_epsilon():
        return env.action_space.sample()
    return greedy(preprocessed_state)

In [15]:
def do_action(action, update_memory=True):
    global state
    next_state, reward, terminated, truncated, info = env.step(action)
    next_state = preprocess_state(next_state)
    if update_memory:
        memory.push(state, action, reward, next_state, terminated)
    state = next_state
    reward = np.clip(reward, *reward_clip)  # wrong position for clip!!
    return reward, terminated, truncated

In [16]:
print(f'checkpoint {checkpoint_file}')
print(f'checkpoint version {checkpoint_version}')
print(f'episode_count {episode_count}')
print(f'step_count {step_count}')

checkpoint 61-breakout_feb_27-51126e-14386023s.pth
checkpoint version 61
episode_count 51126
step_count 14386023


In [17]:
train_episodes = 100_000

writer = SummaryWriter(log_dir='./runs/1st-run')
return_history = []
remake_env(None)
t0 = time.time()
start_episode, start_step = episode_count, step_count
try:
    while len(memory) < learning_starts:
    # while len(memory) < 50:
        state, info = env.reset()
        state = preprocess_state(state)
        for step in range(max_steps):
            reward, terminated, truncated = do_action(env.action_space.sample())
            if len(memory) % 100 == 0:
                clear_output(wait=True)
                print(f'collecting initial training samples {len(memory)}/{learning_starts}')
            if terminated or truncated:
                break
                    
    for episode in range(train_episodes):
        episode_return = 0
        state, info = env.reset()
        episode_count += 1
    
        for noop_step in range(np.random.randint(0, noop_max)):  # no-op start
            state, reward, terminated, truncated, info = env.step(1)
            if terminated or truncated:
                print('noop start ended episode? probably should not be happening')
                with open('_warnings.txt', 'w') as f:
                    f.write(f'noop start ended episode @ step {step_count}')
                state, info = env.reset()
                break
        state = preprocess_state(state)  # make sure to preprocess before running do_action

        for step in range(max_steps):  # step
            action = epsilon_greedy(state)
            reward, terminated, truncated = do_action(action)
            episode_return += reward
            step_count += 1
    
            if step_count % sync_freq == 0:  # update target net
                sync()
            
            if step_count % learn_freq == 0:  # update online net
                states, actions, rewards, next_states, terminateds = memory.sample(batch_size)
                # note: states and next_states should be kept uint8
                # states = torch.tensor(states, dtype=torch.float, device=device)                           # (m, 4, 84, 84), m = batch_size
                actions = torch.tensor(actions, dtype=torch.long, device=device).reshape(-1, 1)           # (m, 1)
                rewards = torch.tensor(rewards, dtype=torch.float, device=device).reshape(-1, 1)          # (m, 1)
                # next_states = torch.tensor(next_states, dtype=torch.float, device=device)                 # (m, 4, 84, 84)
                terminateds = torch.tensor(terminateds, dtype=torch.float, device=device).reshape(-1, 1)  # (m, 1)
                
                pred = online_net(states).gather(1, actions)  # predicted Q-values of the selected action
                
                with torch.no_grad():
                    y = rewards + gamma * target_net(next_states).max(axis=1, keepdim=True).values * (1 - terminateds)
                
                loss = criterion(pred, y)  # don't need to detach but y.requires_grad is False
                
                optimizer.zero_grad()
                loss.backward()
                max_grad = max(p.grad.abs().max().item() for p in online_net.parameters() if p.grad is not None)
                total_norm = torch.nn.utils.clip_grad_norm_(online_net.parameters(), float('inf'))
                writer.add_scalar('loss', loss.item(), step_count)
                writer.add_scalar('max_grad', max_grad, step_count)
                writer.add_scalar('total_norm', total_norm, step_count)
                optimizer.step()
            
            if terminated or truncated:
                break

        writer.add_scalar('episode_steps', step, step_count)
        writer.add_scalar('episode_return', episode_return, step_count)
        writer.add_scalar('epsilon', get_epsilon(), step_count)
        return_history.append(episode_return)
    
        if episode % 1000 == 0 and episode != 0:
            checkpoint_version += 1
            save_checkpoint()
            writer.flush()
    
        if episode % 10 == 0 or episode == train_episodes-1:
            tt = time.time() - t0
            et = episode_count - start_episode 
            st = step_count - start_step 
            # plt.plot(return_history)
            
            clear_output(wait=True)
            print(f'episode {et}\t({et/tt:.1f}/s)  [total {episode_count}]')
            print(f'step {st}\t({st/tt:.0f}/s)  [total {step_count}]')
            print(f'time {tt:.2f} s')
            print('---')
            print(f'avg. return: {np.mean(return_history[-100:]):.5f}  (last 100 episodes)')
            print(f'epsilon {get_epsilon():.5f}')
            # plt.show()
            
except KeyboardInterrupt:
    print('\033[31m Interrupt signal detected \033[32m training state is saved as checkpoint \033[0m')
finally:
    checkpoint_version += 1
    save_checkpoint()
    writer.close()
    
    
                

episode 191	(0.1/s)  [total 51317]
step 47216	(34/s)  [total 14433239]
time 1404.47 s
---
avg. return: 1.80000  (last 100 episodes)
epsilon 0.10000
[31m Interrupt signal detected [32m training state is saved as checkpoint [0m
saved ./checkpoints/62-breakout_feb_27-51322e-14434276s.pth


In [17]:
# Benchmarking Summary
# 424 env steps/s, ::2 downsample + normalization + torch.tensor
# 46 passes/s, with AdamW

In [18]:
# n = QNetwork()
# t0 = time.time()
# for k in range(1001):
#     pred = n(torch.rand(1, 4, 84, 84))
#     if k % 50 == 0:
#         t = time.time() - t0
#         clear_output(wait=True)
#         print(f'{k} forward passes completed in {t:.2f} s  ({k/t :.2f}/s)')

# 326 forwards/s        

In [19]:
# n = QNetwork()
# x = preprocess_state(state)
# optimizer = torch.optim.AdamW(n.parameters(), lr=0.001)
# t0 = time.time()
# for k in range(1001):
#     l = (n(torch.rand(1, 4, 84, 84)) + k).sum()
#     optimizer.zero_grad()
#     l.backward(retain_graph=True)
#     optimizer.step()
#     if k % 50 == 0:
#         t = time.time() - t0
#         clear_output(wait=True)
#         print(f'{k} backward passes completed in {t:.2f} s  ({k/t :.2f}/s)')

# 146 backwards/s, no forward pass, no optimizer step
# 46 backwards/s, with forward pass and AdamW

In [20]:
# benchmark env time vs network inference time
# remake_env(None)
# steps = 0
# t0 = time.time()
# for episode in range(100):
#     state, info = env.reset()
#     while True:
#         steps += 1
#         # action = env.action_space.sample()
#         action = 1
#         state, reward, done, truncated, info = env.step(action)
#         _ = preprocess_state(state)
#         if done or truncated:
#             break
#         if steps % 1000 == 0:
#             t = time.time() - t0
#             clear_output(wait=True)
#             print(f'{episode} episodes, {steps} steps ({steps/t :.2f}/s)')
# t = time.time() - t0
# print(f'{steps} steps completed in {t:.2f}  ({steps/t :.2f}/s)')

# 676 steps/s, no preprocessing
# 117 steps/s, np.dot graysacle, no cv2 resize
# 111 steps/s, np.dot grayscale and cv2 resize
# 566 steps/s, cv2 grayscale and cv2 resize (!)
# 396 steps/s, cv2 grayscale and cv2 resize + normalization + torch.tensor
# 424 steps/s, ::2 downsample + normalization + torch.tensor