In [10]:
from xai import *
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import gc
import cv2

In [17]:
device = "cpu"

try:
    dqn = DQN.load("git-ignore/dqn-model.pt", device=device)
except FileNotFoundError:
    print("Creating new agent...")
    dqn = DQN(autoencoder_path="git-ignore/asteroids-autoencoder-l32.pt", translate=True, rotate=True, device=device) 

Creating new agent...


In [None]:
dqn.train(
    total_time_steps=1_000_000,
    max_episodes=201,
    replay_buffer_size=int(5e6),
    learning_rate = 1e-4,
    learning_starts = 3*6500,
    batch_size = 64,
    tau = 1.0,
    gamma = 0.99,
    train_frequency = 64,
    frame_skip=4,
    gradient_steps = 1,
    episode_save_freq= 25,
    target_update_frequency = 2000,
    final_exploration_rate_progress = 1.0,
    initial_exploration_rate = 0.2898721500000001,
    final_exploration_rate = 0.05,
    verbose = True,
    save_path="git-ignore/dqn-model.pt",
    q_value_head_background_path="git-ignore/states.npy"
)

In [None]:
fig = plt.figure(dpi=250)
plt.title("Rewards per episode")
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode, label="Total reward")
plt.xlabel("Episode")
plt.ylabel("Reward")

avg_rewards = []

for i in range(len(dqn.rewards_per_episode)):
    avg = []
    for i,reward in enumerate(dqn.rewards_per_episode[max(0,i-4):min(i+5, len(dqn.rewards_per_episode))]):
        avg.append(reward)
    avg_rewards.append(sum(avg)/len(avg))

plt.plot(range(len(avg_rewards)), avg_rewards, label="Moving average reward")
plt.legend()
plt.grid()

In [None]:
fig = plt.figure(dpi=250)
plt.yscale("log")
plt.title("Rewards versus exploration correlation")
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode, label="Total reward")
plt.xlabel("Episode")
plt.plot(range(len(dqn.exploration_rate_per_episode)), dqn.exploration_rate_per_episode, label="Exploration rate")

avg_rewards = []

for i in range(len(dqn.rewards_per_episode)):
    avg = []
    for i,reward in enumerate(dqn.rewards_per_episode[max(0,i-4):min(i+5, len(dqn.rewards_per_episode))]):
        avg.append(reward)
    avg_rewards.append(sum(avg)/len(avg))

plt.plot(range(len(avg_rewards)), avg_rewards, label="Moving average reward")
plt.legend()
plt.grid()

In [None]:
fig = plt.figure(dpi=250)
plt.title("Exploration rate per episode")
plt.xlabel("Episode")
plt.ylabel("Exploration rate")
plt.plot(range(len(dqn.exploration_rate_per_episode)), dqn.exploration_rate_per_episode, label="Exploration rate")
plt.legend()
plt.grid()

In [None]:
with Window("Asteroids", 60, 4.0) as window:
    for step in dqn.rollout(0.3, 4).take(50000):
        window(step.observation.numpy(False))

In [None]:
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode, label="Total reward")
plt.plot(range(len(dqn.exploration_rate_per_episode)), dqn.exploration_rate_per_episode, label="Exploration rate")
plt.legend()

In [None]:
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode)

In [None]:
with Window("Asteroids", 60, 4.0) as window:
    for i,step in dqn.rollout(0.7, frame_skips=4).take(3000).enumerate():
        window(step.observation.numpy(False))