# Notebook

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

import pickle
import multimodal_mazes

In [None]:
# Load config
config_file = '../neat_config.ini'
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                        neat.DefaultSpeciesSet, neat.DefaultStagnation,
                        config_file)

In [None]:
# Load test
x = np.load('../Results/test.npy')

top_agent = np.where(x['fitness'] == x['fitness'].max())

with open('../Results/test.pickle', 'rb') as file:
    genome_id, genome, channels = pickle.load(file)

print(x[top_agent])
print(genome.size())


In [None]:
# Plotting

# Fitness 
multimodal_mazes.plot_fitness_over_generations(x, plot_species=True)

# Architecture
node_names = {-1: 'Ch0 L', -2: 'Ch1 L', -3 : 'Ch0 R', -4 : 'Ch1 R', 
              -5: 'Ch0 U', -6: 'Ch1 U', -7 : 'Ch0 D', -8 : 'Ch1 D',
              0 : 'Act L', 1 : 'Act R', 2 : 'Act U', 3 : 'Act D'}
multimodal_mazes.plot_architecture(genome, config, node_names=node_names)

# Maze path 
maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(48)

In [None]:
n_steps = 10
fitness, times, paths = [], [], []
# For each maze
for mz_n, mz in enumerate(maze.mazes):
    # Run trial
    time, path = multimodal_mazes.maze_trial(
        mz,
        maze.start_locations[mz_n],
        maze.goal_locations[mz_n],
        channels,
        n_steps,
        agnt=None,
        genome=genome, 
        config=config,
    )

    if mz_n == 15: 
        multimodal_mazes.plot_path(path, mz, maze.goal_locations[mz_n], n_steps)

    # Record normalised fitness
    times.append(
        1
        - (
            (time - maze.fastest_solutions[mz_n])
            / (n_steps - 1 - maze.fastest_solutions[mz_n])
        )
    )

    paths.append(
        (maze.d_maps[mz_n].max() - maze.d_maps[mz_n][path[-1][0], path[-1][1]])
        / maze.d_maps[mz_n].max()
    )

# Fitness
fitness = (np.array(times) + np.array(paths)) * 0.5

print(fitness)
print(fitness.mean())