In [1]:
from xai import *
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt

In [2]:
try:
    dqn = DQN.load("dqn-model-v2.pt")
except FileNotFoundError:
    print("Creating new agent...")
    dqn = DQN(autoencoder_path="asteroids-l32.pt",translate=True, rotate=True, device="cpu") 

In [3]:
ae = AutoEncoder.load("asteroids-l32.pt")
ae.explainers = {}
ae.encoder.explainers = {}
ae.decoder.explainers = {}
ae.save("asteroids-l32.pt")

In [6]:
dqn.train(
    total_time_steps=1_000_000,
    replay_buffer_size=1000,
    learning_rate = 1e-3,
    learning_starts = 500,
    batch_size = 16,
    tau = 1.0,
    gamma = 0.99,
    train_frequency = 32,
    frame_skip=4,
    gradient_steps = 1,
    episode_save_freq= 3,
    target_update_frequency = 2000,
    final_exploration_rate_progress = 0.3,
    initial_exploration_rate = 0.7,
    final_exploration_rate = 0.05,
    verbose = True,
    save_path="dqn-model-v2.pt",
    q_value_head_background_path="states.npy"
)

Filling replay buffer: 100%|██████████| 500/500 [00:08<00:00, 58.65it/s]
time_step=209806, episode=243, total_reward=1, exploration=0.25:  21%|██        | 209807/1000000 [1:01:12<3:50:31, 57.13it/s] 


KeyboardInterrupt: 

In [None]:
plt.plot(range(len(dqn.rewards_per_episode)), dqn.rewards_per_episode, label="Total reward")
plt.plot(range(len(dqn.exploration_rate_per_episode)), dqn.exploration_rate_per_episode, label="Exploration rate")
plt.legend()

In [15]:
Recorder("test.mp4").write(dqn.rollout(0.0, frame_skips=4).take(3000).map(lambda step: step.observation.numpy(False)))

KeyboardInterrupt: 

In [3]:
with Window("Asteroids", 60, 4.0) as window:
    obs_background = dqn._autoencoder.encoder(torch.load("observations.pt")).output()
    state_background = torch.from_numpy(np.load("states.npy"))

    q_explainer = dqn._policy.explainer("permutation", state_background[:50])
    obs_explainer = dqn._autoencoder.decoder.explainer("permutation", obs_background[:50])

    for step in dqn.rollout(0.0, frame_skips=4).take(3000).monitor("Frame:", expected_length=3000):
        obs = step.observation.translated().rotated()
        reconstruction = dqn._autoencoder(obs.numpy(True)).output().numpy(force=True)
        reconstruction = (reconstruction*255).astype(np.uint8)

        q_explanation = q_explainer.explain(step.next_state).item().flatten().shap_values
        obs_explanation = obs_explainer.explain(step.next_state[0,:32]).item().flatten().shap_values
        weights = np.abs(obs_explanation)/(np.sum(np.abs(obs_explanation), axis=0))
        shap_values = q_explanation.reshape((5,4,32)).sum(1)@weights.T
        shap_values = shap_values.reshape((5,210,160,3))
        shap_values = shap_values.sum((0,3))

        norm = np.max(np.abs(shap_values))
        im = np.zeros((210,160,3), dtype=np.uint8)
        black = np.zeros_like(shap_values)
        red = np.where(shap_values > 0, shap_values/norm, black)
        blue = np.where(shap_values < 0, -shap_values/norm, black)
        im[:,:,0] = (red*255).astype(np.uint8)
        im[:,:,2] = (blue*255).astype(np.uint8)
        #print(q_values)
        window(np.hstack([obs.numpy(False), reconstruction, im]))

CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
Frame::   0%|          | 0/3000 [00:00<?, ?it/s]


AttributeError: 'AutoEncoderFeedForward' object has no attribute 'output'

In [4]:
videos = {}

obs_background = dqn._autoencoder.encoder(torch.load("observations.pt")).output()
state_background = torch.from_numpy(np.load("states.npy"))

for step in dqn.rollout(0.7, frame_skips=4).take(240).monitor("Frame:", expected_length=3000):
    obs = step.observation.numpy(True)
    videos.setdefault("Original", []).append(step.observation.numpy(False))
    videos.setdefault("Affine", []).append(step.observation.translated().rotated().numpy(False))

    shap_values = step.explain_eap(
        "permutation",
        decoder_background=obs_background[:50],
        q_background=state_background[:50]
    ).shap_values
    shap_sum = shap_values.sum(0)

    norm = max([np.max(np.abs(shap_values)), np.max(np.abs(shap_sum))])

    for i,action_explanation in enumerate(shap_values):
        im = np.zeros((210,160,3), dtype=np.uint8)
        black = np.zeros_like(action_explanation)
        red = np.where(action_explanation > 0, action_explanation/norm, black)
        blue = np.where(action_explanation < 0, -action_explanation/norm, black)
        im[:,:,0] = (red*255).astype(np.uint8)
        im[:,:,2] = (blue*255).astype(np.uint8)
        videos.setdefault(i, []).append(im)

    im = np.zeros((210,160,3), dtype=np.uint8)
    black = np.zeros_like(shap_sum)
    red = np.where(shap_sum > 0, shap_sum/norm, black)
    blue = np.where(shap_sum < 0, -shap_sum/norm, black)
    im[:,:,0] = (red*255).astype(np.uint8)
    im[:,:,2] = (blue*255).astype(np.uint8)
    videos.setdefault("Sum", []).append(im)

        

            

Frame::   8%|▊         | 240/3000 [20:59<4:01:24,  5.25s/it]


In [5]:
id_to_action = {
    0: "Noop.mp4",
    1: "Up.mp4",
    2: "Left.mp4",
    3: "Right.mp4",
    4: "Fire.mp4"
}

for i,name in id_to_action.items():
    with Recorder(f"Videos/DQN-EAP-SHAP/{name}", fps=24, scale=4) as recorder:
        for frame in videos[i]:
            recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Original.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Original"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Affine.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Affine"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Sum.mp4", fps=24, scale=4) as recorder:
    for frame in videos["Sum"]:
        recorder(frame)

with Recorder("Videos/DQN-EAP-SHAP/Combined.mp4", fps=24, scale=3) as recorder:
    for frames in np.moveaxis(np.array(list(videos.values())),1,0):
        recorder(np.hstack(frames))

In [7]:
with Window("Asteroids", 60, 4.0) as window:
    for i,step in dqn.rollout(0.7, frame_skips=4).take(3000).enumerate():
        window(step.observation.numpy(False))

WindowClosed: 