**Oбучение с подкреплением**

Евгений Борисов borisov.e@solarl.ru

In [1]:
# https://towardsdatascience.com/reinforcement-learning-w-keras-openai-dqns-1eed3a5338c

In [2]:
import gym
import numpy as np
import random
from collections import deque

from keras.models import Sequential
from keras.layers import Dense
# from keras.layers import Dropout
from keras.optimizers import Adam


Using TensorFlow backend.


In [4]:
env = gym.make('CartPole-v1')

In [5]:
# s = env.reset()
# env.render()
# for t in range(150):
#     a = env.action_space.sample()
#     sn,r,d,_ = env.step(a)
#     env.render()
#     print('%d: %.1f %d'%(t,r,d) )
#     s=sn
#     #if d: break
# env.close()  

In [6]:
class DQN:
    def __init__(self, env):
        self.env     = env
        self.memory  = deque(maxlen=2000)
        
        self.gamma = 0.85
        self.epsilon = 1.0
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.995
        self.learning_rate = 0.002
        self.tau = .125

        self.model        = self.create_model()
        self.target_model = self.create_model()

    def create_model(self):
        model   = Sequential()
        state_shape  = self.env.observation_space.shape
        model.add(Dense(32, input_shape=state_shape, activation='relu'))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(32, input_shape=state_shape, activation='relu'))
        model.add(Dense(self.env.action_space.n))
        model.compile(loss='mean_squared_error', optimizer='rmsprop')#Adam(lr=self.learning_rate))
        return model

    def act(self, state):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon_min, self.epsilon)
        if np.random.random() < self.epsilon: return self.env.action_space.sample()
        return np.argmax(self.model.predict(state)[0])

    def remember(self, state, action, reward, new_state, done):
        self.memory.append([state, action, reward, new_state, done])

    def replay(self):
        batch_size = 32
        if len(self.memory) < batch_size: return

        samples = random.sample(self.memory, batch_size)
        for sample in samples:
            state, action, reward, new_state, done = sample
            target = self.target_model.predict(state)
            if done:
                target[0][action] = reward
            else:
                Q_future = max(self.target_model.predict(new_state)[0])
                target[0][action] = reward + Q_future * self.gamma
            self.model.fit(state, target, epochs=1, verbose=0)

    def target_train(self):
        weights = self.model.get_weights()
        target_weights = self.target_model.get_weights()
        for i in range(len(target_weights)):
            target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
        self.target_model.set_weights(target_weights)

    def save_model(self, fn):
        self.model.save(fn)


In [7]:
# gamma   = 0.9
# epsilon = .95
trials  = 1500
trial_len = 400

In [8]:
def test_model(env,model,tt=1000):
    s = env.reset()[np.newaxis,:]
    for t in range(tt):    
        a = np.argmax(model.predict(s)[0])   
        s,r,d,_ = env.step(a)
        s=s[np.newaxis,:]
        if d: break
    return t

In [9]:
dqn_agent = DQN(env=env)

Instructions for updating:
Colocations handled automatically by placer.


In [10]:
%%time

# updateTargetNetwork = 1000
steps = []

for trial in range(trials):
    cur_state = env.reset()[np.newaxis,:] #.reshape(1,2)
    for step in range(trial_len):
        action = dqn_agent.act(cur_state)
        new_state, reward, done, _ = env.step(action)

        # reward = reward if not done else -20
        new_state = new_state[np.newaxis,:] # .reshape(1,2)
        dqn_agent.remember(cur_state, action, reward, new_state, done)

        dqn_agent.replay()       # internally iterates default (prediction) model
        dqn_agent.target_train() # iterates target model

        cur_state = new_state
        if done: break
        
    print('%d/%d : %d steps'%(trial+1,trials,step))   

    if trial%10==0:
        tst = test_model(env,dqn_agent.model)
        print('test:',tst)
        if tst>250: 
            dqn_agent.save_model('CartPole-%d.model'%(tst))
            break # Completed


1/1500 : 19 steps
test: 8
Instructions for updating:
Use tf.cast instead.
2/1500 : 40 steps
3/1500 : 41 steps
4/1500 : 81 steps
5/1500 : 65 steps
6/1500 : 72 steps
7/1500 : 72 steps
8/1500 : 58 steps
9/1500 : 118 steps
10/1500 : 64 steps
11/1500 : 63 steps
test: 77
12/1500 : 61 steps
13/1500 : 138 steps
14/1500 : 76 steps
15/1500 : 117 steps
16/1500 : 80 steps
17/1500 : 96 steps
18/1500 : 56 steps
19/1500 : 86 steps
20/1500 : 79 steps
21/1500 : 14 steps
test: 70
22/1500 : 58 steps
23/1500 : 128 steps
24/1500 : 116 steps
25/1500 : 218 steps
26/1500 : 77 steps
27/1500 : 64 steps
28/1500 : 106 steps
29/1500 : 74 steps
30/1500 : 15 steps
31/1500 : 101 steps
test: 163
32/1500 : 133 steps
33/1500 : 91 steps
34/1500 : 121 steps
35/1500 : 196 steps
36/1500 : 84 steps
37/1500 : 179 steps
38/1500 : 128 steps
39/1500 : 128 steps
40/1500 : 118 steps
41/1500 : 140 steps
test: 14
42/1500 : 103 steps
43/1500 : 132 steps
44/1500 : 12 steps
45/1500 : 132 steps
46/1500 : 134 steps
47/1500 : 23 steps
48/

386/1500 : 61 steps
387/1500 : 78 steps
388/1500 : 15 steps
389/1500 : 75 steps
390/1500 : 70 steps
391/1500 : 68 steps
test: 57
392/1500 : 68 steps
393/1500 : 19 steps
394/1500 : 9 steps
395/1500 : 65 steps
396/1500 : 12 steps
397/1500 : 84 steps
398/1500 : 21 steps
399/1500 : 19 steps
400/1500 : 68 steps
401/1500 : 75 steps
test: 16
402/1500 : 76 steps
403/1500 : 58 steps
404/1500 : 61 steps
405/1500 : 14 steps
406/1500 : 102 steps
407/1500 : 70 steps
408/1500 : 23 steps
409/1500 : 22 steps
410/1500 : 11 steps
411/1500 : 124 steps
test: 12
412/1500 : 105 steps
413/1500 : 83 steps
414/1500 : 123 steps
415/1500 : 360 steps
416/1500 : 165 steps
417/1500 : 163 steps
418/1500 : 166 steps
419/1500 : 104 steps
420/1500 : 145 steps
421/1500 : 179 steps
test: 114
422/1500 : 59 steps
423/1500 : 35 steps
424/1500 : 38 steps
425/1500 : 71 steps
426/1500 : 9 steps
427/1500 : 25 steps
428/1500 : 15 steps
429/1500 : 9 steps
430/1500 : 10 steps
431/1500 : 11 steps
test: 93
432/1500 : 19 steps
433/15

780/1500 : 9 steps
781/1500 : 7 steps
test: 9
782/1500 : 8 steps
783/1500 : 8 steps
784/1500 : 9 steps
785/1500 : 10 steps
786/1500 : 8 steps
787/1500 : 8 steps
788/1500 : 8 steps
789/1500 : 12 steps
790/1500 : 9 steps
791/1500 : 10 steps
test: 9
792/1500 : 9 steps
793/1500 : 9 steps
794/1500 : 9 steps
795/1500 : 8 steps
796/1500 : 12 steps
797/1500 : 9 steps
798/1500 : 10 steps
799/1500 : 10 steps
800/1500 : 9 steps
801/1500 : 9 steps
test: 9
802/1500 : 12 steps
803/1500 : 8 steps
804/1500 : 14 steps
805/1500 : 12 steps
806/1500 : 11 steps
807/1500 : 7 steps
808/1500 : 9 steps
809/1500 : 11 steps
810/1500 : 8 steps
811/1500 : 11 steps
test: 7
812/1500 : 16 steps
813/1500 : 9 steps
814/1500 : 21 steps
815/1500 : 11 steps
816/1500 : 13 steps
817/1500 : 29 steps
818/1500 : 29 steps
819/1500 : 12 steps
820/1500 : 21 steps
821/1500 : 13 steps
test: 10
822/1500 : 17 steps
823/1500 : 23 steps
824/1500 : 23 steps
825/1500 : 14 steps
826/1500 : 11 steps
827/1500 : 9 steps
828/1500 : 11 steps
8

1165/1500 : 33 steps
1166/1500 : 30 steps
1167/1500 : 23 steps
1168/1500 : 24 steps
1169/1500 : 18 steps
1170/1500 : 17 steps
1171/1500 : 21 steps
test: 15
1172/1500 : 19 steps
1173/1500 : 25 steps
1174/1500 : 50 steps
1175/1500 : 21 steps
1176/1500 : 61 steps
1177/1500 : 18 steps
1178/1500 : 15 steps
1179/1500 : 26 steps
1180/1500 : 62 steps
1181/1500 : 59 steps
test: 16
1182/1500 : 65 steps
1183/1500 : 14 steps
1184/1500 : 29 steps
1185/1500 : 19 steps
1186/1500 : 23 steps
1187/1500 : 41 steps
1188/1500 : 20 steps
1189/1500 : 23 steps
1190/1500 : 26 steps
1191/1500 : 26 steps
test: 17
1192/1500 : 11 steps
1193/1500 : 17 steps
1194/1500 : 13 steps
1195/1500 : 28 steps
1196/1500 : 30 steps
1197/1500 : 21 steps
1198/1500 : 13 steps
1199/1500 : 56 steps
1200/1500 : 17 steps
1201/1500 : 12 steps
test: 19
1202/1500 : 12 steps
1203/1500 : 64 steps
1204/1500 : 46 steps
1205/1500 : 28 steps
1206/1500 : 80 steps
1207/1500 : 68 steps
1208/1500 : 65 steps
1209/1500 : 17 steps
1210/1500 : 53 step

In [19]:
# from keras.models import load_model 
# model = load_model('CartPole-288.model')
# dqn_agent.model = model

In [22]:
s = env.reset()[np.newaxis,:]
env.render()
#while True:
for t in range(350):    
    a = np.argmax( dqn_agent.model.predict(s)[0] )   
    s,r,d,_ = env.step(a)
    s=s[np.newaxis,:]
    env.render()
    if d: 
        print(t) 
        break

298


In [23]:
env.close()    