In [1]:
from GridEnv_cma_modified import GridWorld
#from custom_eaGenerateUpdate import eaGenerateUpdate
from torchsummary import summary

import numpy as np
import os
import datetime
from collections import defaultdict
from itertools import chain
from typing import Dict, DefaultDict, List, Optional
from deap import creator, base, cma, tools
#from deap.algorithms import eaGenerateUpdate
from cma_model import Model

pygame 2.1.2 (SDL 2.0.16, Python 3.8.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def eaGenerateUpdate(toolbox, ngen, halloffame=None, stats=None, save_path=None,
                     verbose=__debug__):
    logbook = tools.Logbook()
    logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])

    for gen in range(ngen):
        # Generate a new population
        population = toolbox.generate()
        # Evaluate the individuals
        fitnesses = toolbox.map(toolbox.evaluate, population)
        for ind, fit in zip(population, fitnesses):
            ind.fitness.values = fit

        if halloffame is not None:
            halloffame.update(population)
            best_weights: np.ndarray = np.array(halloffame.items[0])
            np.save(save_path, best_weights)

        # Update the strategy with the evaluated individuals
        toolbox.update(population)

        record = stats.compile(population) if stats is not None else {}
        logbook.record(gen=gen, nevals=len(population), **record)
        if verbose:
            print(logbook.stream)

    return population, logbook

In [3]:

NUM_PLAYERS: int = 1
LAMBDA: int = NUM_PLAYERS * 100
NGEN: int = 1500

SAVE_PATH: str = os.path.join('.', 'CMA', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 'weights.npy')
os.makedirs(SAVE_PATH[:SAVE_PATH.rfind(os.sep)], exist_ok=False)

def env_creator():
    return GridWorld(8)


def evaluate(individuals) -> List:
    env = env_creator()
    observation = env.reset()

    #policies: Dict = dict()

    agent: str
    weights: Optional[np.ndarray]
    #for agent, weights in zip(env.agents, individuals):
    observation_space = env.observation_space
    act_space = env.action_space
    #print("EVAL: ",individuals.shape)
    policy = Model.from_weights(observation_space, act_space.n, individuals)

    reward_sum: int=0
    done: bool = False
    while not done:
        #actions: Dict = dict()
        #for agent in env.agents:
        action_probabilities: np.ndarray = policy(observation)
        action = np.argmax(action_probabilities)
        observation, reward, done, _ = env.step(action)
        reward_sum+=reward
    #env.close()
    return [(reward_sum,)]

def custom_map_func(evaluate_func, population):
    #print(np.array(population).shape)
    #elements: np.ndarray = np.array(population).reshape(-1, NUM_PLAYERS, len(population[0]))
    #print(elements.shape)
    #  np.random.shuffle(elements)
    return chain.from_iterable(map(evaluate_func, population))


def train():
    env = env_creator()

    observation_space = env.observation_space
    act_space = env.action_space
    temp_model: Model = Model(observation_space, act_space.n)
    #print(temp_model.num_parameters, summary(temp_model.model,(50,16,)))
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)
    toolbox = base.Toolbox()
    #strategy = cma.Strategy(centroid=list(np.random.uniform(-5.0, 5.0, temp_model.num_parameters)),
    #                         sigma=np.random.uniform(0.0, 5.0, 1)[0], lambda_=LAMBDA)
    #N=16*50
    strategy=cma.Strategy(centroid=[5.0]*temp_model.num_parameters, sigma=5.0, lambda_=LAMBDA)
    toolbox.register("generate", strategy.generate, creator.Individual)
    toolbox.register("update", strategy.update)
    toolbox.register("evaluate", lambda ind: evaluate(ind))
    toolbox.register("map", custom_map_func)

    del temp_model

    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    eaGenerateUpdate(toolbox, ngen=NGEN, stats=stats, halloffame=hof, save_path=SAVE_PATH) #Removed save path (IT WAS IN CUSTOM EA UPDATE)


if __name__ == '__main__':
    train()

gen	nevals	avg    	std    	min  	max 
0  	100   	-488.17	250.972	-1185	-100
1  	100   	-552.23	234.407	-1100	0   
2  	100   	-532.63	265.664	-1297	-100
3  	100   	-531.92	248.963	-1200	-100
4  	100   	-505.2 	237.097	-1300	-100
5  	100   	-483.73	295.552	-1294	-100
6  	100   	-521.22	249.118	-1200	0   
7  	100   	-572.83	283.774	-1300	-100
8  	100   	-524.82	242.321	-1100	-100
9  	100   	-467.72	256.029	-1400	0   
10 	100   	-525.24	266.686	-1290	0   
11 	100   	-535.55	243.404	-1249	-100
12 	100   	-551.71	266.266	-1194	-100
13 	100   	-528.03	292.801	-1299	0   
14 	100   	-570.19	295.638	-1279	0   
15 	100   	-516.47	275.072	-1392	0   
16 	100   	-513.54	244.68 	-1179	-100
17 	100   	-543.41	248.062	-1200	-101
18 	100   	-546.2 	267.249	-1199	-103
19 	100   	-554.64	249.905	-1379	-199
20 	100   	-535.47	244.951	-1199	-100
21 	100   	-531.73	279.95 	-1297	-100
22 	100   	-501.31	250.222	-1100	0   
23 	100   	-555.26	251.442	-1297	-100
24 	100   	-552.34	260.182	-1200	-100
25 	100   	-

215	100   	-485.2 	232.234	-1194	-100
216	100   	-508.78	237.152	-1100	-100
217	100   	-478.32	278.91 	-1300	0   
218	100   	-588   	267.548	-1200	-100
219	100   	-500.92	268.153	-1200	-100
220	100   	-518.03	285.91 	-1400	-100
221	100   	-544.25	252.226	-1097	-100
222	100   	-487.23	256.508	-1200	-100
223	100   	-529.92	258.249	-1200	-100
224	100   	-509.82	251.326	-1379	-100
225	100   	-464.22	223.743	-1085	-100
226	100   	-486.32	262.795	-1300	-100
227	100   	-509.85	250.435	-1200	-100
228	100   	-510.6 	273.519	-1300	-100
229	100   	-506.18	242.513	-1400	-121
230	100   	-542.69	250.557	-1285	-100
231	100   	-506.95	242.131	-1085	-100
232	100   	-519.11	260.244	-1200	0   
233	100   	-523.35	231.456	-1100	-100
234	100   	-514.48	256.983	-1279	-100
235	100   	-478.83	264.114	-1385	-100
236	100   	-543.1 	250.699	-1190	-100
237	100   	-530.12	289.107	-1299	0   
238	100   	-545.41	279.794	-1297	-100
239	100   	-523.24	255.199	-1379	-101
240	100   	-520.59	258.288	-1297	0   
241	100   	-

431	100   	-448.27	255.544	-1200	-100
432	100   	-462.38	266.848	-1394	-100
433	100   	-515.67	236.143	-1200	-100
434	100   	-477.91	226.963	-1079	-106
435	100   	-489.75	242.547	-1000	0   
436	100   	-465.08	238.568	-1297	-101
437	100   	-505.76	258.155	-1299	0   
438	100   	-429   	243.359	-1199	-100
439	100   	-451.23	241.392	-1046	-100
440	100   	-466.97	245.528	-1285	-100
441	100   	-447.68	226.099	-1100	-100
442	100   	-472.84	259.003	-1190	0   
443	100   	-502.34	285.429	-1399	-100
444	100   	-480.1 	237.624	-1297	-100
445	100   	-451.66	233.236	-1185	-100
446	100   	-472.22	252.982	-1300	0   
447	100   	-455.12	229.403	-1000	0   
448	100   	-513.65	264.968	-1277	-150
449	100   	-467.03	248.637	-1285	-100
450	100   	-492.85	257.296	-1300	-100
451	100   	-486.86	222.23 	-1050	-100
452	100   	-451.19	217.242	-1200	-100
453	100   	-453.37	227.573	-997 	0   
454	100   	-444.3 	210.278	-894 	0   
455	100   	-423.18	229.547	-1048	-100
456	100   	-473.37	233.717	-1100	-100
457	100   	-

647	100   	-490.77	223.889	-1050	-100
648	100   	-482.13	249.895	-1285	0   
649	100   	-466.2 	263.348	-1200	-100
650	100   	-465.4 	216.357	-1042	-100
651	100   	-437.4 	208.141	-1099	-150
652	100   	-422.13	209.468	-900 	0   
653	100   	-443.68	208.043	-1200	-100
654	100   	-461.42	253.167	-1242	0   
655	100   	-444.57	237.011	-1003	-100
656	100   	-460.45	232.581	-1182	-100
657	100   	-443.77	232.368	-1185	-100
658	100   	-446.38	229.862	-1300	-100
659	100   	-429.92	227.379	-1063	0   
660	100   	-432.74	216.217	-1000	-100
661	100   	-421.07	228.254	-1100	0   
662	100   	-483.09	242.957	-1200	-107
663	100   	-441.39	228.609	-1200	-100
664	100   	-450.65	201.127	-1142	-100
665	100   	-410.68	228.343	-1199	-100
666	100   	-466.08	222.827	-1300	-100
667	100   	-452.94	238.733	-1394	-101
668	100   	-431.16	217.667	-1279	-100
669	100   	-435.46	250.645	-1099	-100
670	100   	-468.4 	253.202	-1299	-101
671	100   	-420.67	234.472	-1248	-100
672	100   	-457.76	220.415	-1050	-100
673	100   	-

863	100   	-481.03	257.675	-1094	-100
864	100   	-441.4 	242.873	-1194	-100
865	100   	-427.51	247.048	-1200	-100
866	100   	-420.52	221.2  	-1000	0   
867	100   	-522.15	274.551	-1299	0   
868	100   	-494.64	223.792	-1094	-100
869	100   	-464.1 	250.522	-1181	-100
870	100   	-432.16	226.378	-1100	0   
871	100   	-493.33	256.962	-1200	-100
872	100   	-437.27	251.856	-985 	0   
873	100   	-485.19	231.352	-1000	-100
874	100   	-490.36	250.517	-1285	-100
875	100   	-468.9 	227.65 	-1090	-100
876	100   	-460.3 	253.855	-1300	-100
877	100   	-463.85	276.432	-1400	-100
878	100   	-494.39	241.966	-1381	0   
879	100   	-459.25	234.852	-1279	0   
880	100   	-454.15	222.879	-1100	-100
881	100   	-500.54	267.935	-1397	0   
882	100   	-469.93	210.413	-1094	-100
883	100   	-440.55	218.606	-1093	-100
884	100   	-398.81	283.707	-994 	1452
885	100   	-458.42	234.751	-1148	0   
886	100   	-434.94	235.566	-1193	0   
887	100   	-460.17	241.639	-1100	-100
888	100   	-474.64	254.876	-1250	-100
889	100   	-

1077	100   	-440.75	229.114	-1176	-100
1078	100   	-436.73	232.414	-1100	-101
1079	100   	-413.08	203.732	-950 	-100
1080	100   	-439.34	200.104	-1100	-100
1081	100   	-433.03	222.845	-1148	-100
1082	100   	-431.66	292.538	-1148	1109
1083	100   	-391.17	188.141	-1048	-100
1084	100   	-408.71	213.194	-1100	0   
1085	100   	-407.91	208.896	-950 	-101
1086	100   	-445.37	232.667	-961 	0   
1087	100   	-389.47	212.574	-950 	0   
1088	100   	-400.98	197.636	-1099	-100
1089	100   	-426.53	220.281	-1000	-100
1090	100   	-399.98	213.872	-979 	-100
1091	100   	-393   	204.026	-1050	-100
1092	100   	-430.18	210.416	-950 	-100
1093	100   	-456.18	238.542	-1200	0   
1094	100   	-513.62	235.377	-1048	-100
1095	100   	-451.74	233.953	-1250	-150
1096	100   	-436.09	223.85 	-1100	-119
1097	100   	-385.69	199.558	-1099	-100
1098	100   	-423   	198.184	-1034	-100
1099	100   	-447.13	210.175	-1000	-101
1100	100   	-476.52	245.286	-1085	-103
1101	100   	-465.07	245.301	-1200	-100
1102	100   	-419.1 	216.7

1288	100   	-453.03	219.167	-1150	-100
1289	100   	-401.91	210.677	-1148	-100
1290	100   	-391.83	217.822	-1300	-100
1291	100   	-429.86	229.9  	-1050	-127
1292	100   	-401.93	197.498	-1248	-150
1293	100   	-407.67	201.23 	-1281	-101
1294	100   	-403.35	203.55 	-1100	-150
1295	100   	-422.62	216.225	-1200	-100
1296	100   	-426.27	233.761	-1100	-100
1297	100   	-407.7 	209.939	-942 	-100
1298	100   	-437.06	204.703	-1050	-100
1299	100   	-402.61	204.221	-1042	-101
1300	100   	-407.28	214.238	-1001	-100
1301	100   	-422.01	240.437	-1200	-100
1302	100   	-434.17	214.511	-1245	-100
1303	100   	-409.94	201.463	-1200	-100
1304	100   	-437.74	205.617	-1300	-150
1305	100   	-419.89	211.597	-1146	-100
1306	100   	-430.77	198.736	-900 	-100
1307	100   	-395.09	189.099	-952 	-100
1308	100   	-416.23	196.878	-900 	0   
1309	100   	-374.28	164.947	-800 	-100
1310	100   	-407.62	220.476	-1097	-100
1311	100   	-417.61	198.41 	-850 	-100
1312	100   	-413.41	236.928	-1248	-150
1313	100   	-407.39	212.0

1499	100   	-413.95	213.666	-1099	-100
