In [1]:
%matplotlib inline

import os
import numpy as np

import gym
import neat
import cv2

import utils.PyPlotReporter
import utils.PyPlotReporter
import utils.CSVReporter
from utils.EnvEvaluator import EnvEvaluator
from utils.Atari import CONTROLLER_TO_ACTION,CONTROLLER_TO_ACTION_SHORT, CONTROLLER_TO_ACTION_FORCE

In [2]:
resize = (16,21)

class CartPoleEnvEvaluator(EnvEvaluator):
    
    def make_net(self, genome, config): 
        return neat.nn.FeedForwardNetwork.create(genome, config)
        
    def activate_net(self, net, observation): 
        g1 = np.all((observation == [110, 156, 66]), axis=-1)
        g2 = np.all((observation == [53, 95, 24]), axis=-1)
        grass = np.logical_or(g1, g2).astype(np.float32)
        grass = cv2.resize(grass,dsize=resize, interpolation=cv2.INTER_AREA)
        ship = np.all((observation == [232,232,74]), axis=-1).astype(np.float32)
        ship = cv2.resize(ship,dsize=resize, interpolation=cv2.INTER_AREA)
        
        input_ =  np.concatenate((grass.flatten(),  ship.flatten()))
        out = tuple(np.array(net.activate(input_)) > 0.5)
        action = CONTROLLER_TO_ACTION_FORCE[out]
        
        return action 

evaluator = CartPoleEnvEvaluator("Riverraid-v0", 5000, n_workers=8, n_batches=1, seed=[1111,2222])

In [9]:

config_path = "./configs/neatatari.cfg" 

config = neat.Config(
    neat.DefaultGenome,
    neat.DefaultReproduction,
    neat.DefaultSpeciesSet,
    neat.DefaultStagnation,
    config_path,
)


pop = neat.Population(config)
pop.add_reporter(utils.CSVReporter.CSVReporter("saves/neat_prog.csv"))
pop.add_reporter(neat.Checkpointer(1, None, "saves/cp_neat_rr"))
pop.add_reporter(utils.PyPlotReporter.PyPlotReporter())



In [None]:
gnome = pop.run(evaluator.eval_all_genomes, 10)

In [None]:
from IPython.display import clear_output
import pickle
import gzip

location = "saves2/"

for i in range(5):
    clear_output(wait=True)
    pop = neat.Population(config)
    pop.add_reporter(utils.CSVReporter.CSVReporter(location + f"/lc_{i}.csv"))
    
    try: 
        gnome = pop.run(evaluator.eval_all_genomes, 100000)
    except neat.CompleteExtinctionException as e:
        gnome = pop.best_genome
    
    create_video(evaluator, gnome, config, fps=60, fout=location+f"vis_{int(gnome.fitness)}_{i}.mp4")
    with gzip.open(location+f"cp_{i}.cp", 'w', compresslevel=5) as f:
          pickle.dump((gnome, config), f, protocol=pickle.HIGHEST_PROTOCOL)
    

## Visualization

In [15]:
evaluator.show(gnome, config, delay=0.005, random=False, i_seed=0)

In [8]:
evaluator.create_video(evaluator,gnome, config, fps=99, fout="f.mp4")