# Prey notebook

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

import pickle
import multimodal_mazes
from tqdm import tqdm

## Ideas 

* Odour - constant but noisy. 
* Sound - reliable but infrequent.
* Rather than resetting the env to zero every step, you could decay it. E.g. env[:,:,:-1] *= 0.8 + blur. 
* Analysis: n_prey caught, speed, costs (e.g. movement vs food). 
* Evolve prey against different algorithms. Then, evolve predators against these prey.   

## Prey trial

In [None]:
# Hyperparameters 
size = 11
n_channels = 2 
n_prey = 10
n_steps = 50
n_trials = 1000
pk = 5 # the width of the prey's Gaussian signal (in rc)

# Prey signal 
from scipy import signal
k1d = signal.gaussian(pk, std=1)
k2d = np.outer(k1d, k1d)  

In [None]:
results = []
for n in range(n_trials):

    # Create environment with track (1. and walls 0.)
    env = np.zeros((size, size, n_channels + 1))
    env[(pk//2):-(pk//2), (pk//2):-(pk//2), -1] = 1 

    # Define an agent 
    agnt = multimodal_mazes.AgentRuleBased(location=[size//2,size//2], channels=[1,1], policy='Linear fusion')
    agnt.sensor_noise_scale = 0.2

    # Define prey 
    rcs = np.stack(np.argwhere(env[:,:,-1])) 
    prey_rcs = np.random.choice(range(len(rcs)), size=n_prey, replace=False)    
    preys = []
    for n in range(n_prey):
        preys.append(multimodal_mazes.AgentRandom(location=rcs[prey_rcs[n]], channels=[0,0])) 
        preys[n].state = 1 # free (1) or caught (0)
        preys[n].cues = n % 2 # channel for emitting cues 

    # Sensation-action loop
    path = [] 
    prey_counter = np.copy(n_prey)
    for time in range(n_steps):

        env[:,:,:-1] *= 0.0 # reset channels  

        # Prey 
        for prey in preys:
            if prey.state == 1: 
                if (prey.location == agnt.location).all(): # caught 
                    prey.state = 0 
                    prey_counter -= 1 

                else: # free  
                    r,c = prey.location
                    env[r-(pk//2):r+(pk//2)+1, c-(pk//2):c+(pk//2)+1, prey.cues] += np.copy(k2d)

        # If all prey have been caught
        if prey_counter == 0:
            break
        
        # Update env
        for ch in range(n_channels):
            env[:,:,ch] *= env[:,:,-1]

        # Predator
        agnt.sense(env)
        agnt.policy()
        agnt.act(env)

        path.append(list(agnt.location))

    results.append([preys[n].state for n in range(n_prey)])

results = np.array(results)
print(results.sum() / results.size)

In [None]:
# Plotting
path = np.array(path)
from matplotlib import colors
prey_markers = ['P', 'X']

plt.imshow(1 - env[:, :, -1], cmap="binary", alpha=0.25)

# Plot path
cmap = colors.LinearSegmentedColormap.from_list(
    "", ["xkcd:teal blue", "xkcd:off white", "xkcd:coral"], N=n_steps
)
for t in range(len(path) - 1):
    plt.plot([path[t, 1], path[t + 1, 1]], [path[t, 0], path[t + 1, 0]], c=cmap(t), zorder=0)
    plt.scatter(path[t + 1, 1], path[t + 1, 0], s=30, color=cmap(t), zorder=1)

plt.axis("off")

# Add prey 
for prey in preys:
    plt.scatter(prey.location[1], prey.location[0], color='k', alpha=0.5, marker=prey_markers[prey.cues], zorder=2)

# Adjust axes 
plt.xlim([(pk//2) - 1, size - pk//2])
plt.ylim([(pk//2) - 1, size - pk//2]) 