In [4]:
import os,sys

sys.path.append(os.path.join(sys.path[0], ".."))

from environment.GridEnv import GridWorld
from torchsummary import summary

import numpy as np
import os
import datetime
from collections import defaultdict
from itertools import chain
from typing import Dict, DefaultDict, List, Optional
from deap import creator, base, cma, tools
from agents.cma_model import Model

In [5]:
def eaGenerateUpdate(toolbox, ngen, halloffame=None, stats=None, save_path=None,
                     verbose=__debug__):
    logbook = tools.Logbook()
    logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])

    for gen in range(ngen):
        # Generate a new population
        population = toolbox.generate()
        # Evaluate the individuals
        fitnesses = toolbox.map(toolbox.evaluate, population)
        for ind, fit in zip(population, fitnesses):
            ind.fitness.values = fit

        if halloffame is not None:
            halloffame.update(population)
            best_weights: np.ndarray = np.array(halloffame.items[0])
            np.save(save_path, best_weights)

        # Update the strategy with the evaluated individuals
        toolbox.update(population)

        record = stats.compile(population) if stats is not None else {}
        logbook.record(gen=gen, nevals=len(population), **record)
        if verbose:
            print(logbook.stream)

    return population, logbook

In [None]:

NUM_PLAYERS: int = 1
LAMBDA: int = NUM_PLAYERS * 150
NGEN: int = 15000

SAVE_PATH: str = os.path.join('.', 'CMA', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 'weights.npy')
os.makedirs(SAVE_PATH[:SAVE_PATH.rfind(os.sep)], exist_ok=False)

def env_creator():
    return GridWorld(8)


def evaluate(individuals) -> List:
    env = env_creator()
    observation = env.reset()

    #policies: Dict = dict()

    agent: str
    weights: Optional[np.ndarray]
    #for agent, weights in zip(env.agents, individuals):
    observation_space = env.observation_space
    act_space = env.action_space
    #print("EVAL: ",individuals.shape)
    policy = Model.from_weights(observation_space, act_space.n, individuals)

    reward_sum: int=0
    done: bool = False
    while not done:
        #actions: Dict = dict()
        #for agent in env.agents:
        action_probabilities: np.ndarray = policy(observation)
        action = np.argmax(action_probabilities)
        observation, reward, done, _ = env.step(action)
        reward_sum+=reward
    #env.close()
    return [(reward_sum,)]

def custom_map_func(evaluate_func, population):
    #print(np.array(population).shape)
    #elements: np.ndarray = np.array(population).reshape(-1, NUM_PLAYERS, len(population[0]))
    #print(elements.shape)
    #  np.random.shuffle(elements)
    return chain.from_iterable(map(evaluate_func, population))


def train():
    env = env_creator()

    observation_space = env.observation_space
    act_space = env.action_space
    temp_model: Model = Model(observation_space, act_space.n)
    #print(temp_model.num_parameters, summary(temp_model.model,(50,16,)))
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)
    toolbox = base.Toolbox()
    #strategy = cma.Strategy(centroid=list(np.random.uniform(-5.0, 5.0, temp_model.num_parameters)),
    #                         sigma=np.random.uniform(0.0, 5.0, 1)[0], lambda_=LAMBDA)
    #N=16*50
    strategy=cma.Strategy(centroid=[5.0]*temp_model.num_parameters, sigma=5.0, lambda_=LAMBDA)
    toolbox.register("generate", strategy.generate, creator.Individual)
    toolbox.register("update", strategy.update)
    toolbox.register("evaluate", lambda ind: evaluate(ind))
    toolbox.register("map", custom_map_func)

    del temp_model

    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    eaGenerateUpdate(toolbox, ngen=NGEN, stats=stats, halloffame=hof, save_path=SAVE_PATH) #Removed save path (IT WAS IN CUSTOM EA UPDATE)


if __name__ == '__main__':
    train()

gen	nevals	avg	std	min	max
0  	150   	-50	0  	-50	-50
1  	150   	-50	0  	-50	-50
2  	150   	-50	0  	-50	-50
