In [1]:
import operator
import gym
from NN import NN 

In [2]:
env = gym.make('CartPole-v1')

In [3]:
class Agent:
    def __init__(self, env, w = None):
        self.action_size = env.action_space.n
        self.observation_size = env.observation_space.shape[0]
        self.fitness = 0
        if w:
            self.NN = NN.from_weights(w)
        else:
            self.NN = NN.from_params(self.observation_size, [2], self.action_size)

    def get_action(self, observation):
        return self.NN.predict(observation)
    
    def set_fitness(self, fitness):
        self.fitness = fitness
        
    def get_mutated_weights(self, rate):
        return self.NN.mutate(rate)
        

In [19]:
total_run = 15
agent_number = 50
total_generation = 10
agents = [Agent(env) for i in range(agent_number)]
for generation in range(total_generation):
    print("Generation: " + str(generation))
    print("testing...")
    
    for agent in agents:
        fitness = 0
        for run in range(total_run):
            state = env.reset()
            for t in range(1000):
                action = agent.get_action(state)
                state, reward, done, info = env.step(action)
                fitness += reward
                if done:
                    break
        agent.set_fitness(fitness/total_run)
        
    agents.sort(key=operator.attrgetter('fitness'), reverse=True)
    new_agents = []
    print("Creating new agents")
    for agent in agents[:max(agent_number//10,1)]:
        print(agent.fitness)
        new_agents.append(agent)
        for i in range(agent_number-1):
            new_agent = Agent(env, agent.get_mutated_weights(0.01))
            new_agents.append(new_agent)
    agents = new_agents
    
print("End training")
agents.sort(key=operator.attrgetter('fitness'), reverse=True)

Generation: 0
testing...
Creating new agents
196.8
35.46666666666667
11.066666666666666
9.866666666666667
9.866666666666667
Generation: 1
testing...
Creating new agents
206.4
206.26666666666668
205.93333333333334
204.86666666666667
204.73333333333332
Generation: 2
testing...
Creating new agents
313.0
213.33333333333334
212.4
211.66666666666666
210.0
Generation: 3
testing...
Creating new agents
478.8
336.6666666666667
336.53333333333336
329.53333333333336
327.73333333333335
Generation: 4
testing...
Creating new agents
500.0
500.0
496.0
490.3333333333333
486.6666666666667
Generation: 5
testing...
Creating new agents
500.0
500.0
500.0
500.0
500.0
Generation: 6
testing...
Creating new agents
500.0
500.0
500.0
500.0
500.0
Generation: 7
testing...
Creating new agents
500.0
500.0
500.0
500.0
500.0
Generation: 8
testing...
Creating new agents
500.0
500.0
500.0
500.0
500.0
Generation: 9
testing...
Creating new agents
500.0
500.0
500.0
500.0
500.0
End training


In [20]:
evaluation_score = 0
evaluation_runs = 100
for run in range(evaluation_runs):
    fitness = 0
    state = env.reset()
    for t in range(1000):
        action = agents[0].get_action(state)
        state, reward, done, info = env.step(action)
        fitness += reward
        # env.render()
        if done:
            # print(t)
            evaluation_score+=fitness
            break
print(evaluation_score/evaluation_runs)


500.0


In [24]:
state = env.reset()
for t in range(1000):
    action = agents[0].get_action(state)
    state, reward, done, info = env.step(action)
    env.render()
    if done:
        print(t)
        break

499


In [7]:
env.close()