In [1]:
import math
import itertools
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory
from gym.envs.classic_control import rendering

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
class Cell:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.light = y * y
        self.therm = abs((x-5) + (y-5))
        self.toxic = 0
    
    def add_toxic(self):
        self.toxic += 1
    
class Aqua(gym.Env):
    metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 1
    }

    def __init__(self):
        self.min_position = 0
        self.max_position = 10
        self.position_y = 5
        self.position_x = 5
        self.energy = 50
        self.viewer = None
        cells = list(itertools.product(range(self.max_position), range(self.max_position)))
        self.cells = {cell: Cell(cell[0], cell[1]) for cell in cells}

        self.action_space = spaces.Discrete(5)

        self.seed()
        self.reset()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self, real_action):
        assert self.action_space.contains(real_action), "%r (%s) invalid" % (action, type(real_action))

#         if np.random.random_sample() > 0.90:
#             real_action = np.random.sample([0,1,2,3,4])
#         else:
#             real_action = action
        
        if real_action == 1:
            self.position_x += 1
        elif real_action == 2:
            self.position_y += 1
        elif real_action == 3:
            self.position_x += -1
        elif real_action == 4:
            self.position_y += -1
        self.position_x = np.clip(self.position_x, self.min_position, self.max_position-1)
        self.position_y = np.clip(self.position_y, self.min_position, self.max_position-1)
                    
        
        #штраф за верхний слой
        risk = round(np.random.random_sample()*self.position_y - 8, 1)
        if risk >= 0:
            print('instant death')
            delta_energy = -self.energy
        else:
            delta_energy = round(self.calc_delta(self.position_x, self.position_y, real_action), 1)
        self.energy = self.energy + delta_energy
        #observations
        observations_light = self.get_visible_cells(self.position_x, self.position_y, self.get_light)
        observations_temp = self.get_visible_cells(self.position_x, self.position_y, self.get_therm)
        observations_toxic = self.get_visible_cells(self.position_x, self.position_y, self.get_toxic)
        self.state = observations_light + observations_temp + observations_toxic + [self.energy, delta_energy]
        print('({},{}). light: {}, temp: {}, energy: {}, risk: {}, delta_energy: {}, toxic: {}'.
              format(self.position_x, self.position_y, observations_light[4], observations_temp[4], self.energy + delta_energy, risk, delta_energy, observations_toxic[4]))
        self.cells[(self.position_x, self.position_y)].add_toxic()
        
        reward = self.get_reward(self.energy, delta_energy)
        
        return np.array(self.state), reward, self.energy <= 0 or self.energy >= 600, {}
    
    def get_reward(self, energy, delta_energy):
        if energy <= 0:
            return -100 + delta_energy
        elif energy >= 600:
            return 100
        elif delta_energy > 5:
            return 1
        elif delta_energy < -5:
            return -1
        else:
            return 0
    
    def get_visible_cells(self, x, y, param_fn):
        return list(map(param_fn, [(x, y - 1), (x - 1, y), (x,y), (x + 1, y), (x, y + 1)]))
    
    def get_light(self, cell):
        try:
            return self.cells[cell].light
        except:
            return 0
    
    def get_therm(self, cell):
        try:
            return self.cells[cell].therm
        except:
            return 0
    
    def get_toxic(self, cell):
        try:
            return self.cells[cell].toxic
        except:
            return 0
    
    def calc_delta(self, position_x, position_y, action):
        cell = self.cells[(position_x, position_y)]
        food_koef = 0.2
        thermo_koef = 0.3
        toxic_koef = 0.2
        move_koef = 10
        move_cost = move_koef if action != 0 else 0
        metabolism = 2
        return food_koef*cell.light - metabolism - thermo_koef*cell.therm - move_cost - toxic_koef*cell.toxic
                    
    def reset(self):
        self.position_y = np.random.randint(0,9)
        self.position_x = np.random.randint(0,9)
        self.energy = 50
        cells = list(itertools.product(range(self.max_position), range(self.max_position)))
        self.cells = {cell: Cell(cell[0], cell[1]) for cell in cells}
        self.state = self.get_visible_cells(self.position_x, self.position_y, self.get_light) \
                   + self.get_visible_cells(self.position_x, self.position_y, self.get_therm) \
                   + self.get_visible_cells(self.position_x, self.position_y, self.get_toxic) \
                   + [self.energy, 0]
        return np.array(self.state)

    def render(self, mode='human'):
        screen_width = 600
        screen_height = 600
        world_width = self.max_position - self.min_position
        scale = screen_width/world_width
        agent_radius=20

        energy_to_color = self.energy/255.0 if self.energy <= 255 else 1
        color = (1-energy_to_color, energy_to_color, 0)
        
        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

        for cell in self.cells.values():
            j = rendering.Transform(translation=(cell.x*scale + scale/2, cell.y*scale + scale/2))
            self.viewer.draw_circle(scale/2, 20, color=(1-cell.light/100, 1-cell.therm/10, 1-cell.toxic/40)).add_attr(j)

        t = rendering.Transform(translation=(self.position_x*scale + scale/2, self.position_y*scale + scale/2))
        self.viewer.draw_circle(agent_radius, 30, color=color).add_attr(t)


        return self.viewer.render(return_rgb_array = mode=='rgb_array')

    def close(self):
        if self.viewer: self.viewer.close()


In [3]:
env = Aqua()
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

model = Sequential()
input_shape = (1,17,)
print(input_shape)
model.add(Flatten(input_shape=input_shape))
model.add(Dense(30))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=100000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

(1, 17)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 17)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 30)                540       
_________________________________________________________________
activation_1 (Activation)    (None, 30)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 155       
_________________________________________________________________
activation_2 (Activation)    (None, 5)                 0         
Total params: 695
Trainable params: 695
Non-trainable params: 0
_________________________________________________________________
None


In [4]:
dqn.fit(env, nb_steps=1000, visualize=False, verbose=0)

(6,2). light: 9, temp: 1, energy: 46.400000000000006, risk: -7.0, delta_energy: -1.8, toxic: 0
(6,2). light: 9, temp: 1, energy: 44.2, risk: -7.2, delta_energy: -2.0, toxic: 0
(6,2). light: 9, temp: 1, energy: 41.8, risk: -7.7, delta_energy: -2.2, toxic: 0
(6,2). light: 9, temp: 1, energy: 39.2, risk: -6.7, delta_energy: -2.4, toxic: 0
(6,2). light: 9, temp: 1, energy: 36.4, risk: -7.5, delta_energy: -2.6, toxic: 0
(6,2). light: 9, temp: 1, energy: 33.400000000000006, risk: -6.8, delta_energy: -2.8, toxic: 0
(6,2). light: 9, temp: 1, energy: 30.200000000000003, risk: -7.1, delta_energy: -3.0, toxic: 0
(6,3). light: 16, temp: 0, energy: 12.200000000000003, risk: -6.4, delta_energy: -10.5, toxic: 0
(6,3). light: 16, temp: 0, energy: 21.300000000000004, risk: -5.5, delta_energy: -0.7, toxic: 0
(6,3). light: 16, temp: 0, energy: 20.200000000000006, risk: -6.2, delta_energy: -0.9, toxic: 0
(6,3). light: 16, temp: 0, energy: 18.900000000000002, risk: -7.0, delta_energy: -1.1, toxic: 0
(6,3).



(5,3). light: 16, temp: 1, energy: -2.8999999999999986, risk: -6.7, delta_energy: -10.8, toxic: 0
(5,3). light: 16, temp: 1, energy: 5.900000000000002, risk: -5.9, delta_energy: -1.0, toxic: 0
(5,3). light: 16, temp: 1, energy: 4.500000000000002, risk: -5.4, delta_energy: -1.2, toxic: 0
(5,3). light: 16, temp: 1, energy: 2.9000000000000026, risk: -5.8, delta_energy: -1.4, toxic: 0
(5,3). light: 16, temp: 1, energy: 1.1000000000000023, risk: -6.9, delta_energy: -1.6, toxic: 0
(5,3). light: 16, temp: 1, energy: -0.8999999999999977, risk: -7.8, delta_energy: -1.8, toxic: 0
(5,3). light: 16, temp: 1, energy: -3.099999999999998, risk: -5.5, delta_energy: -2.0, toxic: 0
(1,6). light: 49, temp: 2, energy: 38.599999999999994, risk: -4.9, delta_energy: -5.7, toxic: 0
(1,6). light: 49, temp: 2, energy: 52.5, risk: -7.9, delta_energy: 4.1, toxic: 0
(1,6). light: 49, temp: 2, energy: 56.199999999999996, risk: -6.1, delta_energy: 3.9, toxic: 0
(1,6). light: 49, temp: 2, energy: 59.7, risk: -3.9, de

(0,6). light: 49, temp: 3, energy: 58.5, risk: -4.8, delta_energy: -0.4, toxic: 0
(0,6). light: 49, temp: 3, energy: 57.699999999999996, risk: -6.6, delta_energy: -0.6, toxic: 0
(0,6). light: 49, temp: 3, energy: 56.7, risk: -3.0, delta_energy: -0.8, toxic: 0
(0,6). light: 49, temp: 3, energy: 55.5, risk: -7.2, delta_energy: -1.0, toxic: 0
(0,6). light: 49, temp: 3, energy: 54.099999999999994, risk: -6.4, delta_energy: -1.2, toxic: 0
(0,6). light: 49, temp: 3, energy: 52.5, risk: -3.2, delta_energy: -1.4, toxic: 0
(0,6). light: 49, temp: 3, energy: 50.699999999999996, risk: -7.7, delta_energy: -1.6, toxic: 0
(0,6). light: 49, temp: 3, energy: 48.7, risk: -6.4, delta_energy: -1.8, toxic: 0
(0,6). light: 49, temp: 3, energy: 46.5, risk: -3.5, delta_energy: -2.0, toxic: 0
(0,6). light: 49, temp: 3, energy: 44.099999999999994, risk: -6.1, delta_energy: -2.2, toxic: 0
(0,6). light: 49, temp: 3, energy: 41.5, risk: -7.0, delta_energy: -2.4, toxic: 0
(0,6). light: 49, temp: 3, energy: 38.6999

(7,6). light: 49, temp: 4, energy: 276.7, risk: -3.1, delta_energy: 2.7, toxic: 15
(7,7). light: 64, temp: 5, energy: 261.20000000000005, risk: -2.3, delta_energy: -6.4, toxic: 0
(7,7). light: 64, temp: 5, energy: 274.4, risk: -7.9, delta_energy: 3.4, toxic: 0
(7,7). light: 64, temp: 5, energy: 277.4, risk: -2.8, delta_energy: 3.2, toxic: 0
(7,7). light: 64, temp: 5, energy: 280.2, risk: -2.7, delta_energy: 3.0, toxic: 0
(7,7). light: 64, temp: 5, energy: 282.8, risk: -5.5, delta_energy: 2.8, toxic: 0
(7,7). light: 64, temp: 5, energy: 285.20000000000005, risk: -5.2, delta_energy: 2.6, toxic: 0
(7,7). light: 64, temp: 5, energy: 287.4, risk: -3.9, delta_energy: 2.4, toxic: 0
(7,7). light: 64, temp: 5, energy: 289.4, risk: -7.4, delta_energy: 2.2, toxic: 0
(7,7). light: 64, temp: 5, energy: 291.2, risk: -1.8, delta_energy: 2.0, toxic: 0
(7,7). light: 64, temp: 5, energy: 292.8, risk: -1.0, delta_energy: 1.8, toxic: 0
(7,7). light: 64, temp: 5, energy: 294.20000000000005, risk: -6.3, del

(7,8). light: 81, temp: 6, energy: 415.5, risk: -1.7, delta_energy: 6.5, toxic: 0
(7,8). light: 81, temp: 6, energy: 421.6, risk: -1.4, delta_energy: 6.3, toxic: 0
(7,9). light: 0, temp: 0, energy: 420.09999999999997, risk: -7.0, delta_energy: 2.4, toxic: 0
(7,9). light: 0, temp: 0, energy: 442.09999999999997, risk: -3.3, delta_energy: 12.2, toxic: 0
(7,9). light: 0, temp: 0, energy: 433.9, risk: -5.9, delta_energy: 2.0, toxic: 0
(7,9). light: 0, temp: 0, energy: 435.5, risk: -3.4, delta_energy: 1.8, toxic: 0
(7,9). light: 0, temp: 0, energy: 456.90000000000003, risk: -5.3, delta_energy: 11.6, toxic: 0
(7,9). light: 0, temp: 0, energy: 448.09999999999997, risk: -7.4, delta_energy: 1.4, toxic: 0
(7,9). light: 0, temp: 0, energy: 449.09999999999997, risk: -7.1, delta_energy: 1.2, toxic: 0
(6,9). light: 0, temp: 0, energy: 453.29999999999995, risk: -5.5, delta_energy: 2.7, toxic: 0
(6,8). light: 81, temp: 5, energy: 449.8, risk: -4.8, delta_energy: -0.4, toxic: 1
(6,7). light: 64, temp: 4

(1,2). light: 9, temp: 6, energy: 1.1000000000000014, risk: -7.3, delta_energy: -13.3, toxic: 1
(1,1). light: 4, temp: 7, energy: -13.999999999999996, risk: -7.2, delta_energy: -14.2, toxic: 1
(2,1). light: 4, temp: 6, energy: -27.599999999999998, risk: -7.3, delta_energy: -13.9, toxic: 0
(7,0). light: 1, temp: 2, energy: 24.200000000000003, risk: -8.0, delta_energy: -12.9, toxic: 0
(6,0). light: 1, temp: 3, energy: 10.700000000000003, risk: -8.0, delta_energy: -13.2, toxic: 0
(6,0). light: 1, temp: 3, energy: -2.8999999999999986, risk: -8.0, delta_energy: -13.4, toxic: 0
(6,0). light: 1, temp: 3, energy: -16.699999999999996, risk: -8.0, delta_energy: -13.6, toxic: 0
(6,2). light: 9, temp: 1, energy: 26.400000000000002, risk: -7.3, delta_energy: -11.8, toxic: 0
(6,2). light: 9, temp: 1, energy: 34.2, risk: -7.1, delta_energy: -2.0, toxic: 0
(5,2). light: 9, temp: 2, energy: 12.000000000000002, risk: -7.8, delta_energy: -12.1, toxic: 0
(5,2). light: 9, temp: 2, energy: 19.5, risk: -6.8,

(3,6). light: 49, temp: 0, energy: 39.8, risk: -6.8, delta_energy: -5.1, toxic: 0
(3,5). light: 36, temp: 1, energy: 29.699999999999996, risk: -4.1, delta_energy: -7.6, toxic: 1
(4,5). light: 36, temp: 0, energy: 22.699999999999996, risk: -4.5, delta_energy: -7.3, toxic: 0
(5,5). light: 36, temp: 1, energy: 15.999999999999996, risk: -3.3, delta_energy: -7.0, toxic: 0
(5,4). light: 25, temp: 0, energy: 4.799999999999997, risk: -6.7, delta_energy: -9.1, toxic: 1
(6,4). light: 25, temp: 1, energy: -3.7000000000000046, risk: -4.5, delta_energy: -8.8, toxic: 0
(5,4). light: 25, temp: 0, energy: -13.500000000000005, risk: -4.5, delta_energy: -9.3, toxic: 1
(4,2). light: 9, temp: 3, energy: 25.200000000000003, risk: -6.0, delta_energy: -12.4, toxic: 0
(4,2). light: 9, temp: 3, energy: 32.4, risk: -7.8, delta_energy: -2.6, toxic: 0
(3,2). light: 9, temp: 4, energy: 9.600000000000001, risk: -6.0, delta_energy: -12.7, toxic: 0
(2,2). light: 9, temp: 5, energy: -3.6999999999999993, risk: -7.2, de

(8,6). light: 49, temp: 5, energy: 76.9, risk: -3.8, delta_energy: 3.2, toxic: 0
(8,6). light: 49, temp: 5, energy: 79.7, risk: -5.6, delta_energy: 3.0, toxic: 0
(8,6). light: 49, temp: 5, energy: 82.3, risk: -6.0, delta_energy: 2.8, toxic: 0
(8,6). light: 49, temp: 5, energy: 84.69999999999999, risk: -2.2, delta_energy: 2.6, toxic: 0
(7,6). light: 49, temp: 4, energy: 70.69999999999999, risk: -4.3, delta_energy: -5.7, toxic: 0
(7,5). light: 36, temp: 3, energy: 55.999999999999986, risk: -7.9, delta_energy: -10.2, toxic: 1
(7,5). light: 36, temp: 3, energy: 65.39999999999998, risk: -5.7, delta_energy: -0.4, toxic: 1
(8,5). light: 36, temp: 4, energy: 47.19999999999999, risk: -4.1, delta_energy: -9.3, toxic: 8
(8,5). light: 36, temp: 4, energy: 57.499999999999986, risk: -6.3, delta_energy: 0.5, toxic: 8
(9,5). light: 36, temp: 5, energy: 40.59999999999998, risk: -3.1, delta_energy: -8.2, toxic: 0
(9,5). light: 36, temp: 5, energy: 51.999999999999986, risk: -3.1, delta_energy: 1.6, toxic

(5,4). light: 25, temp: 0, energy: 16.7, risk: -4.1, delta_energy: -0.7, toxic: 0
(5,4). light: 25, temp: 0, energy: 15.6, risk: -5.0, delta_energy: -0.9, toxic: 0
(5,4). light: 25, temp: 0, energy: 14.3, risk: -6.0, delta_energy: -1.1, toxic: 0
(6,4). light: 25, temp: 1, energy: -2.200000000000001, risk: -7.0, delta_energy: -8.8, toxic: 0
(6,4). light: 25, temp: 1, energy: 8.6, risk: -4.0, delta_energy: 1.0, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.200000000000001, risk: -6.7, delta_energy: 0.8, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.6, risk: -5.4, delta_energy: 0.6, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.8, risk: -7.7, delta_energy: 0.4, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.799999999999999, risk: -5.4, delta_energy: 0.2, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.6, risk: -5.6, delta_energy: 0.0, toxic: 0
(6,4). light: 25, temp: 1, energy: 9.200000000000001, risk: -7.2, delta_energy: -0.2, toxic: 0
(6,4). light: 25, temp: 1, energy: 8.6, risk: -5.0, del

(6,2). light: 9, temp: 1, energy: 26.4, risk: -6.4, delta_energy: -2.6, toxic: 0
(6,2). light: 9, temp: 1, energy: 23.4, risk: -6.4, delta_energy: -2.8, toxic: 0
(6,2). light: 9, temp: 1, energy: 20.2, risk: -6.1, delta_energy: -3.0, toxic: 0
(6,2). light: 9, temp: 1, energy: 16.8, risk: -7.8, delta_energy: -3.2, toxic: 0
(6,2). light: 9, temp: 1, energy: 13.200000000000001, risk: -7.4, delta_energy: -3.4, toxic: 0
(6,2). light: 9, temp: 1, energy: 9.400000000000002, risk: -7.0, delta_energy: -3.6, toxic: 0
(6,2). light: 9, temp: 1, energy: 5.400000000000003, risk: -8.0, delta_energy: -3.8, toxic: 0
(7,2). light: 9, temp: 0, energy: -13.799999999999997, risk: -6.1, delta_energy: -11.5, toxic: 0
(1,4). light: 25, temp: 4, energy: 49.400000000000006, risk: -5.3, delta_energy: -0.3, toxic: 0
(1,4). light: 25, temp: 4, energy: 48.7, risk: -4.1, delta_energy: -0.5, toxic: 0
(1,4). light: 25, temp: 4, energy: 47.8, risk: -7.5, delta_energy: -0.7, toxic: 0
(1,4). light: 25, temp: 4, energy: 4

(1,6). light: 49, temp: 2, energy: 58.599999999999994, risk: -4.5, delta_energy: 4.3, toxic: 0
(1,6). light: 49, temp: 2, energy: 62.5, risk: -3.6, delta_energy: 4.1, toxic: 0
(1,6). light: 49, temp: 2, energy: 66.2, risk: -3.6, delta_energy: 3.9, toxic: 0
(1,6). light: 49, temp: 2, energy: 69.7, risk: -3.5, delta_energy: 3.7, toxic: 0
(1,6). light: 49, temp: 2, energy: 73.0, risk: -3.4, delta_energy: 3.5, toxic: 0
(1,6). light: 49, temp: 2, energy: 76.1, risk: -7.4, delta_energy: 3.3, toxic: 0
(1,6). light: 49, temp: 2, energy: 78.99999999999999, risk: -4.4, delta_energy: 3.1, toxic: 0
(1,6). light: 49, temp: 2, energy: 81.7, risk: -4.2, delta_energy: 2.9, toxic: 0
(1,6). light: 49, temp: 2, energy: 84.2, risk: -7.2, delta_energy: 2.7, toxic: 0
(1,6). light: 49, temp: 2, energy: 86.5, risk: -7.3, delta_energy: 2.5, toxic: 0
(1,6). light: 49, temp: 2, energy: 88.6, risk: -3.4, delta_energy: 2.3, toxic: 0
(1,6). light: 49, temp: 2, energy: 90.49999999999999, risk: -6.6, delta_energy: 2.

<keras.callbacks.History at 0x7effef9fc518>

In [5]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
(5,3). light: 16, temp: 1, energy: 48.400000000000006, risk: -6.2, delta_energy: -0.8, toxic: 0
(5,3). light: 16, temp: 1, energy: 47.2, risk: -5.8, delta_energy: -1.0, toxic: 0
(5,3). light: 16, temp: 1, energy: 45.8, risk: -6.7, delta_energy: -1.2, toxic: 0
(5,3). light: 16, temp: 1, energy: 44.2, risk: -5.6, delta_energy: -1.4, toxic: 0
(5,3). light: 16, temp: 1, energy: 42.4, risk: -6.7, delta_energy: -1.6, toxic: 0
(5,3). light: 16, temp: 1, energy: 40.400000000000006, risk: -6.9, delta_energy: -1.8, toxic: 0
(5,3). light: 16, temp: 1, energy: 38.2, risk: -7.1, delta_energy: -2.0, toxic: 0
(5,3). light: 16, temp: 1, energy: 35.8, risk: -7.7, delta_energy: -2.2, toxic: 0
(5,3). light: 16, temp: 1, energy: 33.2, risk: -5.9, delta_energy: -2.4, toxic: 0
(5,3). light: 16, temp: 1, energy: 30.4, risk: -7.6, delta_energy: -2.6, toxic: 0
(5,3). light: 16, temp: 1, energy: 27.4, risk: -7.9, delta_energy: -2.8, toxic: 0
(5,3). light: 16, temp: 1, energy: 24.2, ri

<keras.callbacks.History at 0x7effef9ed358>