In [5]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
from collections import deque
from pathlib import Path
import gymnasium as gym
import numpy as np
import random
import ale_py
import time
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def remake_env(render_mode=None):
    global env
    if 'env' in globals(): 
        env.reset(), env.close()
        del env
    env = gym.make('Breakout-ramDeterministic-v4', render_mode=render_mode)  # default Breakout-ramDeterministic-v4 has frameskip of 4
    return env.reset()

In [3]:
def get_checkpoint(v=-1, path='./checkpoints'):
    ''' If found returns (file_name, version). Otherwise, returns (None, 0). '''
    ls = os.listdir(path) 
    if not ls: return (None, 0)
    mx = -1
    mx_file = ''
    for f in ls:
        try: cur = int(f.split('-')[0])  # might be a 'mem-...' file
        except: continue
        if cur > mx:
            mx = cur
            mx_file = f 
        if cur == v: return f, v
    return mx_file, mx

In [6]:
class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(128, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 4)
        
    def forward(self, x):  # expect uint8 tensor as input
        x = x / 255.0
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cpu


In [10]:
online_net = QNetwork()

In [21]:
# load latest checkpoint file if there is one, set step_count, episode_count
proj_name = 'breakout-ram_03-03'

checkpoint_file, checkpoint_version = get_checkpoint()
# checkpoint_version = 12
if checkpoint_file is not None:
    checkpoint = torch.load('./checkpoints/' + checkpoint_file, weights_only=False, map_location=torch.device(device))
    online_net.load_state_dict(checkpoint['online_state_dict'])
    print('loaded', checkpoint_file)
else: 
    episode_count, step_count = 0, 0
    print('no checkpoint found')

loaded 6-breakout-ram_03-03-3000e-791651s.pth


In [22]:
def greedy(s):
    with torch.no_grad():
        s = torch.tensor(s, dtype=torch.float)
        return online_net(s).argmax().item()

In [23]:
reward_clip = (-1, 1)

In [24]:
step = 0
state, info = remake_env('human')

try:
    t0 = time.time()
    while True:
        step += 1
        
        action = greedy(state)
        state, raw_reward, terminated, truncated, info = env.step(action)
        reward = np.clip(raw_reward, *reward_clip)
    
        clear_output(wait=True)
        print(f'step {step}, took action {action}, got raw reward {raw_reward} (clipped {reward})')
        
        if terminated or truncated:
            break
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    remake_env(None)
    env.reset()
    env.close()

step 416, took action 3, got raw reward 0.0 (clipped 0.0)
