In [None]:
from xai import *
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

In [None]:
device = "cuda"

try:
    dqn = DQN.load("dqn-model.pt", device=device)
except FileNotFoundError:
    print("Creating new agent...")
    dqn = DQN(autoencoder_path="asteroids-autoencoder-l32.pt", translate=True, rotate=True, device=device) 

In [None]:
dqn.train(
    total_time_steps=1_000_000,
    replay_buffer_size=int(5e6),
    learning_rate = 1e-4,
    learning_starts = 6500,
    batch_size = 64,
    tau = 1.0,
    gamma = 0.99,
    train_frequency = 64,
    frame_skip=4,
    gradient_steps = 1,
    episode_save_freq= 10,
    target_update_frequency = 2000,
    final_exploration_rate_progress = 0.3,
    initial_exploration_rate = 0.6,
    final_exploration_rate = 0.05,
    verbose = True,
    save_path="dqn-model.pt",
    q_value_head_background_path="states.npy"
)

In [None]:
with Window("Asteroids", 60, 4.0) as window:
    for step in dqn.rollout(0.4, 4).take(50000):
        window(step.observation.numpy(False))

In [None]:
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode, label="Total reward")
plt.plot(range(len(dqn.exploration_rate_per_episode)), dqn.exploration_rate_per_episode, label="Exploration rate")
plt.legend()

In [None]:
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode)

In [None]:
with Window("Asteroids", 60, 4.0) as window:
    obs_background = dqn._autoencoder.encoder(torch.load("observations.pt")).output()
    state_background = torch.from_numpy(np.load("states.npy"))

    for step in dqn.rollout(0.6, frame_skips=4).take(3000).monitor("Frame:", expected_length=3000):
        obs = step.observation.translated().rotated()
        reconstruction = dqn._autoencoder(obs.numpy(True)).output().numpy(force=True)
        reconstruction = (reconstruction*255).astype(np.uint8)

        shap_values = step.explain_eap(
            algorithm="permutation",
            decoder_background=obs_background[:5],
            q_background=state_background[:5]
        ).shap_values.sum(0)

        norm = np.max(np.abs(shap_values))
        im = np.zeros((210,160,3), dtype=np.uint8)
        black = np.zeros_like(shap_values)
        red = np.where(shap_values > 0, shap_values/norm, black)
        blue = np.where(shap_values < 0, -shap_values/norm, black)
        im[:,:,0] = (red*255).astype(np.uint8)
        im[:,:,2] = (blue*255).astype(np.uint8)
        #print(q_values)
        window(np.hstack([obs.numpy(False), reconstruction, im]))

In [None]:
videos = {}

obs_background = dqn._autoencoder.encoder(torch.load("observations.pt")).output()
state_background = torch.from_numpy(np.load("states.npy"))

for step in dqn.rollout(0.7, frame_skips=4).take(240).monitor("Frame:", expected_length=3000):
    obs = step.observation.numpy(True)
    videos.setdefault("Original", []).append(step.observation.numpy(False))
    videos.setdefault("Affine", []).append(step.observation.translated().rotated().numpy(False))

    shap_values = step.explain_eap(
        "permutation",
        decoder_background=obs_background[:50],
        q_background=state_background[:50]
    ).shap_values
    shap_sum = shap_values.sum(0)

    norm = max([np.max(np.abs(shap_values)), np.max(np.abs(shap_sum))])

    for i,action_explanation in enumerate(shap_values):
        im = np.zeros((210,160,3), dtype=np.uint8)
        black = np.zeros_like(action_explanation)
        red = np.where(action_explanation > 0, action_explanation/norm, black)
        blue = np.where(action_explanation < 0, -action_explanation/norm, black)
        im[:,:,0] = (red*255).astype(np.uint8)
        im[:,:,2] = (blue*255).astype(np.uint8)
        videos.setdefault(i, []).append(im)

    im = np.zeros((210,160,3), dtype=np.uint8)
    black = np.zeros_like(shap_sum)
    red = np.where(shap_sum > 0, shap_sum/norm, black)
    blue = np.where(shap_sum < 0, -shap_sum/norm, black)
    im[:,:,0] = (red*255).astype(np.uint8)
    im[:,:,2] = (blue*255).astype(np.uint8)
    videos.setdefault("Sum", []).append(im)

        

            

In [None]:
id_to_action = {
    0: "Noop.mp4",
    1: "Up.mp4",
    2: "Left.mp4",
    3: "Right.mp4",
    4: "Fire.mp4"
}

for i,name in id_to_action.items():
    with Recorder(f"Videos/DQN-EAP-SHAP/{name}", fps=24, scale=4) as recorder:
        for frame in videos[i]:
            recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Original.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Original"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Affine.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Affine"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Sum.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Sum"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Combined.mp4", fps=24, scale=3) as recorder:
    for frames in np.moveaxis(np.array(list(videos.values())),1,0):
        recorder(np.hstack(frames))

In [None]:
with Window("Asteroids", 60, 4.0) as window:
    for i,step in dqn.rollout(0.7, frame_skips=4).take(3000).enumerate():
        window(step.observation.numpy(False))