# DQN prey notebook

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

from matplotlib import cm

import pickle
import multimodal_mazes
from tqdm import tqdm

import itertools
import torch
import torch.nn as nn
import torch.optim as optim

## Ideas

## Fitness vs noise

In [None]:
# Hyperparameters 

# Task
exp_config = multimodal_mazes.load_prey_config("../prey_config.ini")

# Agent 
n_hidden_units = 8
wm_flag = np.array([0,0,0,0,0,0,0])

In [None]:
agnt = multimodal_mazes.AgentDQN(location=None, channels=exp_config['channels'], sensor_noise_scale=exp_config['sensor_noise_scale'], n_hidden_units=n_hidden_units, wm_flags=wm_flag)

# agnt.generate_predator_policy(n_train_trials=100, n_test_trials=None, exp_config=exp_config)

In [None]:
fitness, _, _, _ = multimodal_mazes.eval_predator_fitness(
    n_trials=1000,
    size=exp_config["size"],
    agnt=agnt,
    sensor_noise_scale=agnt.sensor_noise_scale,
    n_prey=exp_config["n_prey"],
    pk=exp_config["pk"],
    n_steps=exp_config["n_steps"],
    scenario=exp_config["scenario"],
    motion=exp_config["motion"],
    pc=exp_config["pc"],
    pm=exp_config["pm"],
    pe=exp_config["pe"],
)

print(fitness)

In [None]:
time, path, _, _, _ =  multimodal_mazes.predator_trial(
    size=exp_config["size"],
    agnt=agnt,
    sensor_noise_scale=agnt.sensor_noise_scale,
    n_prey=exp_config["n_prey"],
    pk=exp_config["pk"],
    n_steps=exp_config["n_steps"],
    scenario=exp_config["scenario"],
    motion=exp_config["motion"],
    pc=exp_config["pc"],
    pm=exp_config["pm"],
    pe=exp_config["pe"],)

In [None]:
def generate_predator_policy(self, n_train_trials, n_test_trials, exp_config):
        """
        Uses deep Q-learning to optimise model weights.
        Arguments:
            n_train_trials:
            n_steps: number of simulation steps.
            n_test_trials:
                Used to record the agent's fitness 100 times throughout training.
            exp_config:
        Updates:
            self.parameters.
            self.training_fitness (if n_test_trials is provided).
        """
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        criterion = nn.MSELoss()
        gamma = 0.9
        epsilons = np.repeat(
            np.linspace(start=0.95, stop=0.25, num=10), repeats=n_train_trials // 10
        )

        self.gradient_norms = []
        self.training_fitness = []

        for a in range(n_train_trials):

            # Record fitness
            if (a % (n_train_trials // 100) == 0) & (n_test_trials != None):
                with torch.no_grad():
                    fitness, _, _, _ = multimodal_mazes.eval_predator_fitness(
                        n_trials=n_test_trials,
                        size=exp_config["size"],
                        agnt=self,
                        sensor_noise_scale=self.sensor_noise_scale,
                        n_prey=exp_config["n_prey"],
                        pk=exp_config["pk"],
                        n_steps=exp_config["n_steps"],
                        scenario=exp_config["scenario"],
                        motion=exp_config["motion"],
                        pc=exp_config["pc"],
                        pm=exp_config["pm"],
                        pe=exp_config["pe"],
                    )
                    self.training_fitness.append(fitness)
                    print(fitness)

            pk_hw = (
                exp_config["pk"] // 2
            )  # half width of prey's Gaussian signal (in rc)

            # Reset agent
            prev_input = torch.zeros(self.n_input_units)
            hidden = torch.zeros(self.n_hidden_units)
            prev_output = torch.zeros(self.n_output_units)

            self.location = np.array(
                [pk_hw + (exp_config["size"] // 2), pk_hw + (exp_config["size"] // 2)]
            )
            self.outputs = torch.zeros(self.n_output_units)

            loss = 0.0

            # Create environment with track (1.) and walls (0.)
            env = np.zeros(
                (exp_config["size"], exp_config["size"], len(self.channels) + 1)
            )
            env[:, :, -1] = 1.0
            env = np.pad(env, pad_width=((pk_hw, pk_hw), (pk_hw, pk_hw), (0, 0)))

            # Define prey
            k1d = signal.windows.gaussian(exp_config["pk"], std=1)
            k2d = np.outer(k1d, k1d)
            k2d_noise = np.copy(k2d)

            rcs = np.stack(np.argwhere(env[:, :, -1]))
            prey_rcs = np.random.choice(
                range(len(rcs)), size=exp_config["n_prey"], replace=False
            )
            preys = []
            for n in range(exp_config["n_prey"]):
                preys.append(
                    multimodal_mazes.AgentRandom(
                        location=rcs[prey_rcs[n]],
                        channels=[0, 0],
                        motion=exp_config["motion"],
                    )
                )
                preys[n].state = 1  # free (1) or caught (0)
                preys[n].path = [list(preys[n].location)]

                if exp_config["scenario"] == "Foraging":
                    preys[n].cues = n % 2  # channel for emitting cues

            # Trial
            prey_counter = np.copy(exp_config["n_prey"])
            for time in range(exp_config["n_steps"]):

                env[:, :, :-1] *= exp_config["pc"]  # reset channels

                # Prey
                for prey in preys:
                    if prey.state == 1:
                        if (prey.location == self.location).all():  # caught
                            prey.state = 0
                            prey_counter -= 1

                        else:  # free
                            # Movement
                            if exp_config["scenario"] == "Hunting":
                                if np.random.rand() < exp_config["pm"]:
                                    prey.policy()
                                    prey.act(env)
                            prey.path.append(list(prey.location))

                            # Emit cues
                            r, c = prey.location
                            if exp_config["scenario"] == "Foraging":
                                env[
                                    r - pk_hw : r + pk_hw + 1,
                                    c - pk_hw : c + pk_hw + 1,
                                    prey.cues,
                                ] += np.copy(k2d)

                            elif exp_config["scenario"] == "Hunting":
                                for ch in range(len(prey.channels)):
                                    if np.random.rand() < exp_config["pe"]:
                                        env[
                                            r - pk_hw : r + pk_hw + 1,
                                            c - pk_hw : c + pk_hw + 1,
                                            ch,
                                        ] += np.copy(
                                            k2d
                                        )  # emit cues
                                    else:
                                        np.random.shuffle(k2d_noise.reshape(-1))
                                        env[
                                            r - pk_hw : r + pk_hw + 1,
                                            c - pk_hw : c + pk_hw + 1,
                                            ch,
                                        ] += np.copy(
                                            k2d_noise
                                        )  # emit noise

                # Apply edges
                for ch in range(len(self.channels)):
                    env[:, :, ch] *= env[:, :, -1]

                # If all prey have been caught
                if prey_counter == 0:
                    break

                # Predator
                # Sense
                self.sense(env)

                # Epsilon-greedy action selection
                if torch.rand(1) < epsilons[a]:
                    action = torch.randint(
                        low=0, high=self.n_output_units, size=(1,)
                    ).item()
                else:
                    with torch.no_grad():
                        q_values, _, _, _ = self.forward(
                            prev_input, hidden, prev_output
                        )
                        action = torch.argmax(q_values).item()

                # Predicted Q-value
                q_values, prev_input, hidden, prev_output = self.forward(
                    prev_input.detach(), hidden.detach(), prev_output.detach()
                )
                predicted = q_values[action]

                # Act
                self.outputs *= 0.0
                self.outputs[action] = 1.0
                self.act(env)

                # Reward
                reward = np.sum(self.channel_inputs) / self.channel_inputs.size

                # Target Q-value
                self.sense(env)
                with torch.no_grad():
                    next_q_values, _, _, _ = self.forward(
                        prev_input, hidden, prev_output
                    )
                    target = reward + (gamma * torch.max(next_q_values)) - 0.1

                # Loss
                loss = loss + criterion(predicted, target)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.parameters(), 10)

            # Check for exploding gradients
            # with torch.no_grad():
            #     total_norm = 0
            #     for p in self.parameters():
            #         param_norm = p.grad.data.norm(2)
            #         total_norm += param_norm.item() ** 2
            #     total_norm = total_norm ** (1.0 / 2)
            #     self.gradient_norms.append(total_norm)

            optimizer.step()