In [1]:
use_tensorboard = False

In [11]:
env_name = "Pendulum-v0"

In [17]:
from datetime import datetime
from pathlib import Path
import tensorboardX
import gym
from evofuzzy.gymrunner import GymRunner

timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
print("Run starting at", timestamp )

if use_tensorboard:
    tensorboard_dir = f"tb_logs/gym/{env_name}"
    logdir = Path(f"{tensorboard_dir}/{timestamp}")
    logdir.mkdir(parents=True, exist_ok=True)
    tensorboard_writer = tensorboardX.SummaryWriter(str(logdir))
else:
    tensorboard_writer = None

env = gym.make(env_name)
runner = GymRunner(
    population_size=100,
    hall_of_fame_size=5,
    max_generation=10,
    mutation_prob=0.9,
    crossover_prob=0.2,
    min_tree_height=1,
    max_tree_height=3,
    min_rules=5,
    max_rules=8,
    whole_rule_prob=0.2,
    tree_height_limit=5,
)

runner.train(env, tensorboard_writer, inf_limit=10)
print(runner.best_str)
runner.play(env)

Run starting at 20210815-125740

   	        fitness         	        size        
   	------------------------	--------------------
gen	max     	avg     	min	avg  	best
0  	-793.586	-1326.03	15 	23.35	0   

1  	-725.626	-1265.74	15 	21.33	0   

2  	-730.022	-1241.85	15 	22.33	0   

3  	-737.706	-1233.24	15 	21.89	0   

4  	-761.811	-1214.55	15 	22.12	0   

5  	-769.228	-1205.02	15 	22.09	0   

6  	-730.145	-1236.06	15 	22.36	0   

7  	-769.228	-1234.86	15 	22.34	0   

8  	-730.145	-1239.48	15 	22.94	0   

9  	-730.141	-1210.79	15 	22.83	0   
IF obs_2[higher] THEN action_0[average]
IF obs_1[low] AND obs_2[high] THEN action_0[lower]
IF obs_2[average] THEN action_0[lower]
IF obs_0[higher] THEN action_0[higher]
IF obs_1[high] THEN action_0[high]
IF obs_2[higher] THEN action_0[average]
IF obs_0[high] THEN action_0[average]
Finished with reward of -730.1413351847325


In [18]:
runner.play(env)

Finished with reward of -760.9114072169135


In [23]:
env.observation_space.high

array([1., 1., 8.], dtype=float32)

In [36]:
for _ in range(20):
    print(env.step(env.action_space.sample()))

(array([ 0.44128245,  0.04087888, -0.2672821 , -0.2855685 , -1.0125861 ,
       -5.1052217 ,  1.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.43755674,  0.03877238, -0.26927763, -0.31207648, -1.2713304 ,
       -4.85952   ,  1.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.43447572,  0.03600262, -0.2796007 , -0.3380725 , -1.5302945 ,
       -4.8443975 ,  1.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.42921573,  0.03929644, -0.5251856 ,  0.04547491, -1.6736537 ,
       -2.8091803 ,  0.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.42417574,  0.04196936, -0.52576935,  0.02652681, -1.8117449 ,
       -2.758056  ,  0.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.41932917,  0.04408201, -0.5250378 ,  0.0067347 , -1.9468755 ,
       -2.7041862 ,  0.        ,  0.        ], dtype=float32), -100, True, {})
(array([ 0.41463137,  0.04550308, -0.5273431 , -0.01941235, -2.0818098 ,
       -2.70082