# Notebook

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

import pickle
import multimodal_mazes

In [None]:
# Load config
config_file = '../neat_config.ini'
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                        neat.DefaultSpeciesSet, neat.DefaultStagnation,
                        config_file)

In [None]:
# Load test 
x = np.load('../Results/test.npy')

top_agent = np.where(x['fitness'] == x['fitness'].max())

with open('../Results/test.pickle', 'rb') as file:
    genomes = pickle.load(file)

genome_id, genome, channels = genomes[top_agent[0][0]]

print(x[top_agent[0][0]])
print(genome.size())

In [None]:
# Plotting

# Fitness 
# multimodal_mazes.plot_fitness_over_generations(x, plot_species=True)

# Architecture
node_names = {-1: 'Ch0 L', -2: 'Ch1 L', -3 : 'Ch0 R', -4 : 'Ch1 R', 
              -5: 'Ch0 U', -6: 'Ch1 U', -7 : 'Ch0 D', -8 : 'Ch1 D',
              0 : 'Act L', 1 : 'Act R', 2 : 'Act U', 3 : 'Act D'}
multimodal_mazes.plot_architecture(genome, config, node_names=node_names)

# Filtered architecture 
genome = multimodal_mazes.prune_architecture(genome, config)
multimodal_mazes.plot_architecture(genome, config, node_names=node_names)

# Maze path 
maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(48)

In [None]:
# WIP: architecture metrics 

top_agents_metrics = [] 
for g in np.where(x['fitness'] > 0.95)[0]: 
    _, genome, channels = genomes[g]
    genome = multimodal_mazes.prune_architecture(genome, config)
    fitness = multimodal_mazes.eval_fitness(genome, config, channels, maze)
    if fitness > 0.95: 
        arch_metrics = multimodal_mazes.architecture_metrics(genome, config, channels)
        top_agents_metrics.append(list(arch_metrics.values()))
        
top_agents_metrics = np.array(top_agents_metrics)

# Plotting 
f, (a0, a1) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [0.8, 0.2]}, figsize=(10, 5))
a0.plot(top_agents_metrics[:,:-2].T, c='k', linewidth=1.5, alpha=0.25);
a0.set_xticks(np.arange(len(arch_metrics) -2), list(arch_metrics.keys())[:-2], rotation='vertical')
a0.set_ylabel('Number')

a1.plot(top_agents_metrics[:,-2:].T, c='k', linewidth=1.5, alpha=0.25);
a1.set_xticks([0,1], list(arch_metrics.keys())[-2:], rotation='vertical')
a1.set_ylabel('Ratio')
a1.set_xlim([-0.25, 1.25])
a1.set_ylim([-0.025, 1.025])