In [1]:
from functools import partial

import numpy as np
import ale_py
import mlflow
import gymnasium as gym
import torch
from tqdm import tqdm

from preprocess import PreprocessWrapper
from params import load_params
from dqn_eval import dqn_agent, play

In [2]:
params = load_params('dqn_train.toml', profile='double-dqn-tuned-adam')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)
gym.register_envs(ale_py)
env = gym.make(params.gym_env_id, render_mode="rgb_array", frameskip=1, repeat_action_probability=0)
env = PreprocessWrapper(env, params.skip_frames, device, processed_only=False)
num_actions = env.action_space.n

device: cpu


A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


In [3]:
def eval_model(uri, n=30):
    q0 = mlflow.pytorch.load_model(uri, map_location=device)
    q0.eval()
    agent = partial(dqn_agent, q0=q0, num_actions=num_actions, eps=params.eps_eval)
    scores = []
    for _ in tqdm(range(n)):
        scores.append(play(env, agent, params))
    return np.mean(scores), np.std(scores)

In [4]:
uris = [
    'runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_10000',
    'runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_15000',
    'runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_22000',
]
for uri in uris:
    print(uri)
    print(eval_model(uri))

runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_10000


100%|██████████| 30/30 [01:08<00:00,  2.28s/it]


(35.96666666666667, 3.459126415087421)
runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_15000


100%|██████████| 30/30 [01:18<00:00,  2.62s/it]


(266.6333333333333, 34.07979199206213)
runs:/8d3c31e65a3240eda1e0e92890057f46/q0_episode_22000


100%|██████████| 30/30 [01:12<00:00,  2.41s/it]

(337.76666666666665, 86.6805181238681)



