# DQN notebook

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

import pickle
import multimodal_mazes
from tqdm import tqdm

import itertools
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
wm_flags = np.array(list(itertools.product([0,1], repeat=7)))[0]
self = multimodal_mazes.AgentDQN(location=[5,5], channels=[1,1], sensor_noise_scale=0.05, n_hidden_units=4, wm_flags=wm_flags)

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=20000, noise_scale=0.0, gaps=0)

In [3]:
self.generate_policy(maze,n_steps=10)

In [None]:
epsilon = 0.0
n = -1

# Reset agent
prev_input = torch.zeros(self.n_input_units)
hidden = torch.zeros(self.n_hidden_units)
prev_output = torch.zeros(self.n_output_units)

self.location = np.copy(maze.start_locations[n])
self.outputs = torch.zeros(self.n_output_units)

n_steps = 10

# Trial
for time in range(n_steps):
    # Sense
    self.sense(maze.mazes[n])

    # Epsilon-greedy action selection
    if torch.rand(1) < epsilon:
        action = torch.randint(
            low=0, high=self.n_output_units, size=(1,)
        ).item()
    else:
        with torch.no_grad():
            q_values, _, _, _ = self.forward(
                prev_input, hidden, prev_output
            )
            action = torch.argmax(q_values).item()

    # Predicted Q-value
    q_values, prev_input, hidden, prev_output = self.forward(
        prev_input, hidden, prev_output
    )
    predicted = q_values[action]

    # Act
    self.outputs *= 0.0
    self.outputs[action] = 1.0
    self.act(maze.mazes[n])

    print(self.location)

    if np.array_equal(self.location, maze.goal_locations[n]):
        break

In [None]:
plt.imshow(self.input_to_hidden.weight.detach())
plt.colorbar()