In [1]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
import math


Using Theano backend.


In [2]:
class Agent:
    
    def __init__(self, state_size, action_size, memory=20000, 
                 gamma=0.99, max_eps=1, min_eps=0.1, decay=0.001,
                 lr=0.00025, batch_size=32):
        
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=memory)
        
        self.max_eps = max_eps
        self.epsilon = max_eps
        self.min_eps = min_eps
        
        self.decay = decay
        self.lr = lr
        self.gamma = gamma
        self.batch_size = batch_size
        self.steps = 0
                
        self.model = Sequential()
        self.model.add(Dense(400, input_shape=(self.state_size,),
                            activation = 'relu'))
        self.model.add(Dense(200, activation='relu'))
        self.model.add(Dense(self.action_size, activation='linear'))
        

        self.model.compile(optimizer=Adam(lr=lr), loss='mse')
        
        
        
    def act(self, state, test=False):
        
        if test:
            probs = self.model.predict(state)
            return np.argmax(probs[0])         
        
        # explore
        if np.random.rand() < self.epsilon:
            return random.randrange(self.action_size)
        
        # exploit
        else:
            probs = self.model.predict(state)
            return np.argmax(probs[0])
        
    def add_memory(self, state, action, reward, next_state, done):
        
        self.memory.append((state, action, reward, next_state, done))
        
    
    def batch_run(self):
        batch_size = min(self.batch_size, len(self.memory))
        batch = random.sample(self.memory, batch_size)
        
        states = []
        targets = []
        
        for i in range(len(batch)):
            state = batch[i][0]
            action = batch[i][1]
            reward = batch[i][2]
            next_state = batch[i][3]
            done = batch[i][4]
            
            probs = self.model.predict(next_state)[0]
            target_n = reward + self.gamma * np.amax(probs)
            
            if done:
                target_n = reward
                
            target = self.model.predict(state)
            target[0][action] = target_n
            
            states.append(state[0])
            targets.append(target[0])
            
        self.model.fit(np.array(states), np.array(targets), epochs=1, verbose=0)
        
        self.steps += 1
        self.epsilon = self.min_eps + (self.max_eps - self.min_eps \
                                      ) * math.exp(-self.decay * self.steps)
            
            

In [3]:
env = gym.make('LunarLander-v2')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = Agent(state_size, action_size)

In [4]:

done = False
batch_size = 32
EPISODES = 1000
r = np.zeros(EPISODES)
for e in range(EPISODES):
    state = env.reset()
    state = np.reshape(state, [1, state_size])
    R = 0
    for time in range(150000):
        # env.render()
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        reward = reward
        R += reward
        next_state = np.reshape(next_state, [1, state_size])
        agent.add_memory(state, action, reward, next_state, done)
        state = next_state
        if done:
            print("episode: {}/{}, score: {}, last:{}, e: {}".format(e, EPISODES, R, reward, agent.epsilon))
            r[e] = R

            break
        agent.batch_run()

# r_test = np.zeros(100)
# for e in range(100):
#     state = env.reset()
#     state = np.reshape(state, [1, state_size])
#     R = 0
#     for time in range(150000):
#         tet = True
#         # env.render()
#         action = agent.act(state, test=True)
#         next_state, reward, done, _ = env.step(action)
#         reward = reward
#         R += reward
#         next_state = np.reshape(next_state, [1, state_size])
#         agent.add_memory(state, action, reward, next_state, done)
#         state = next_state
#         if done:
#             print("episode: {}/{}, score: {}, last:{}, e: {}".format(e, EPISODES, R, reward, agent.epsilon))
#             r_test[e] = R

#             break
#         agent.batch_run()

episode: 0/1000, score: -109.43545965163113, last:-100, e: 0.9022295295161482
episode: 1/1000, score: -179.3034088184037, last:-100, e: 0.8346504172830489
episode: 2/1000, score: -318.473641010156, last:-100, e: 0.763411036952465
episode: 3/1000, score: -108.47520803365379, last:-100, e: 0.6907510912682541
episode: 4/1000, score: -225.36027496852938, last:-100, e: 0.6271023611303144
episode: 5/1000, score: -102.41569296348476, last:-100, e: 0.5429224788062375
episode: 6/1000, score: -160.2921165855907, last:-100, e: 0.5023791339034208
episode: 7/1000, score: -80.78915385942241, last:-100, e: 0.42518516655190464
episode: 8/1000, score: -79.67220470596907, last:-100, e: 0.3980897940177829
episode: 9/1000, score: -63.995911600102836, last:-100, e: 0.35376147812530245
episode: 10/1000, score: -195.2680840562902, last:-100, e: 0.29178879937487634
episode: 11/1000, score: -131.50764442143088, last:-100, e: 0.24640763668626528
episode: 12/1000, score: -146.5269113757403, last:-100, e: 0.18346

episode: 99/1000, score: -16.668113993256455, last:1.1235997048036956, e: 0.1
episode: 100/1000, score: -18.535001611888237, last:0.9712754558269836, e: 0.1
episode: 101/1000, score: 32.212090190182934, last:-2.4749394063547556, e: 0.1
episode: 102/1000, score: 226.20306900933045, last:100, e: 0.1
episode: 103/1000, score: 190.12932974318323, last:100, e: 0.1
episode: 104/1000, score: 53.6767239091362, last:-0.2436015768607642, e: 0.1
episode: 105/1000, score: 271.9614963263852, last:100, e: 0.1
episode: 106/1000, score: 269.45281176618107, last:100, e: 0.1
episode: 107/1000, score: 163.75990420189154, last:0.0006138536360147384, e: 0.1
episode: 108/1000, score: 260.4279572003445, last:100, e: 0.1
episode: 109/1000, score: 282.42668241909234, last:100, e: 0.1
episode: 110/1000, score: 253.60893047108718, last:100, e: 0.1
episode: 111/1000, score: 203.89929523782385, last:100, e: 0.1
episode: 112/1000, score: 222.80125611012477, last:100, e: 0.1
episode: 113/1000, score: 261.15671211077

episode: 228/1000, score: 241.49089515458297, last:100, e: 0.1
episode: 229/1000, score: 261.92372565184274, last:100, e: 0.1
episode: 230/1000, score: 254.82816177492282, last:100, e: 0.1
episode: 231/1000, score: 269.3388592861276, last:100, e: 0.1
episode: 232/1000, score: 303.7519163942471, last:100, e: 0.1
episode: 233/1000, score: 275.23699710981884, last:100, e: 0.1
episode: 234/1000, score: 268.55977861784226, last:100, e: 0.1
episode: 235/1000, score: 278.9508484282879, last:100, e: 0.1
episode: 236/1000, score: 268.88741950537474, last:100, e: 0.1
episode: 237/1000, score: 260.4703128908401, last:100, e: 0.1
episode: 238/1000, score: 241.7418730917867, last:100, e: 0.1
episode: 239/1000, score: 283.25572288465764, last:100, e: 0.1
episode: 240/1000, score: 250.58876367176828, last:100, e: 0.1
episode: 241/1000, score: 280.6154157255032, last:100, e: 0.1
episode: 242/1000, score: 252.4775800053367, last:100, e: 0.1
episode: 243/1000, score: 272.1504856563656, last:100, e: 0.1


episode: 358/1000, score: 248.52343400119628, last:100, e: 0.1
episode: 359/1000, score: 275.7108409934349, last:100, e: 0.1
episode: 360/1000, score: 224.7750098505976, last:100, e: 0.1
episode: 361/1000, score: 282.99979366198784, last:100, e: 0.1
episode: 362/1000, score: -118.34601491504608, last:-100, e: 0.1
episode: 363/1000, score: 257.08462328740245, last:100, e: 0.1
episode: 364/1000, score: 225.56604904716164, last:100, e: 0.1
episode: 365/1000, score: 276.27507641896443, last:100, e: 0.1
episode: 366/1000, score: 266.7688745186133, last:100, e: 0.1
episode: 367/1000, score: 259.79161750270333, last:100, e: 0.1
episode: 368/1000, score: 264.63369147886544, last:100, e: 0.1
episode: 369/1000, score: 249.9975545666593, last:100, e: 0.1
episode: 370/1000, score: 56.821065276684436, last:-100, e: 0.1
episode: 371/1000, score: 48.758893521387904, last:-100, e: 0.1
episode: 372/1000, score: 13.56451190497107, last:-100, e: 0.1
episode: 373/1000, score: 249.27379374175075, last:100,

episode: 489/1000, score: 266.58739537448105, last:100, e: 0.1
episode: 490/1000, score: 297.35700962653283, last:100, e: 0.1
episode: 491/1000, score: 262.79136629050004, last:100, e: 0.1
episode: 492/1000, score: 241.8288484720227, last:100, e: 0.1
episode: 493/1000, score: 250.87554409372564, last:100, e: 0.1
episode: 494/1000, score: 263.29607633926685, last:100, e: 0.1
episode: 495/1000, score: 267.40052274795573, last:100, e: 0.1
episode: 496/1000, score: 289.8480968317384, last:100, e: 0.1
episode: 497/1000, score: 260.5705598300072, last:100, e: 0.1
episode: 498/1000, score: 256.9324490308036, last:100, e: 0.1
episode: 499/1000, score: 264.9999762780731, last:100, e: 0.1
episode: 500/1000, score: 271.251844313835, last:100, e: 0.1
episode: 501/1000, score: 259.6879870855364, last:100, e: 0.1
episode: 502/1000, score: 275.1910272836735, last:100, e: 0.1
episode: 503/1000, score: 242.7705075923672, last:100, e: 0.1
episode: 504/1000, score: 49.018037361291704, last:-100, e: 0.1
e

episode: 620/1000, score: 283.4538631935261, last:100, e: 0.1
episode: 621/1000, score: 263.3204064755304, last:100, e: 0.1
episode: 622/1000, score: 268.25319805955894, last:100, e: 0.1
episode: 623/1000, score: 292.87766760092666, last:100, e: 0.1
episode: 624/1000, score: 279.04057638202005, last:100, e: 0.1
episode: 625/1000, score: 247.5288138045395, last:100, e: 0.1
episode: 626/1000, score: 269.56783874143855, last:100, e: 0.1
episode: 627/1000, score: 306.92679918874677, last:100, e: 0.1
episode: 628/1000, score: 297.91008966267043, last:100, e: 0.1
episode: 629/1000, score: 300.8162798872948, last:100, e: 0.1
episode: 630/1000, score: 246.23537859092588, last:100, e: 0.1
episode: 631/1000, score: 292.85150400171466, last:100, e: 0.1
episode: 632/1000, score: 268.54303313802376, last:100, e: 0.1
episode: 633/1000, score: 271.8877154242513, last:100, e: 0.1
episode: 634/1000, score: 281.94399560737617, last:100, e: 0.1
episode: 635/1000, score: 292.13902574803683, last:100, e: 0

episode: 751/1000, score: 241.6756193444879, last:100, e: 0.1
episode: 752/1000, score: 277.3634059092767, last:100, e: 0.1
episode: 753/1000, score: 25.034349011264908, last:-100, e: 0.1
episode: 754/1000, score: 279.9035287224658, last:100, e: 0.1
episode: 755/1000, score: 76.56939216441404, last:-100, e: 0.1
episode: 756/1000, score: 279.55383110976004, last:100, e: 0.1
episode: 757/1000, score: 46.374591807060824, last:-100, e: 0.1
episode: 758/1000, score: 264.90814484388164, last:100, e: 0.1
episode: 759/1000, score: 276.0895614044245, last:100, e: 0.1
episode: 760/1000, score: 253.83412370012317, last:100, e: 0.1
episode: 761/1000, score: 209.3293469108196, last:100, e: 0.1
episode: 762/1000, score: 249.89214981448998, last:100, e: 0.1
episode: 763/1000, score: 233.14085702924294, last:100, e: 0.1
episode: 764/1000, score: 247.07019822012109, last:100, e: 0.1
episode: 765/1000, score: 259.85846397497687, last:100, e: 0.1
episode: 766/1000, score: 297.97565808799027, last:100, e:

episode: 882/1000, score: 15.85709430455411, last:-100, e: 0.1
episode: 883/1000, score: 267.2868941941721, last:100, e: 0.1
episode: 884/1000, score: 275.17557237024306, last:100, e: 0.1
episode: 885/1000, score: -156.47118437605047, last:-100, e: 0.1
episode: 886/1000, score: 260.3046062910802, last:100, e: 0.1
episode: 887/1000, score: 238.4288224865286, last:100, e: 0.1
episode: 888/1000, score: 234.2488417790342, last:100, e: 0.1
episode: 889/1000, score: 250.1519499768753, last:100, e: 0.1
episode: 890/1000, score: 255.01745241114125, last:100, e: 0.1
episode: 891/1000, score: 266.90987701811946, last:100, e: 0.1
episode: 892/1000, score: 275.8918566190823, last:100, e: 0.1
episode: 893/1000, score: 245.38552377129483, last:100, e: 0.1
episode: 894/1000, score: 282.3107459011814, last:100, e: 0.1
episode: 895/1000, score: 240.57285340600043, last:100, e: 0.1
episode: 896/1000, score: 231.1054427255899, last:100, e: 0.1
episode: 897/1000, score: 75.52961877828804, last:-100, e: 0.

In [6]:
r

array([-109.43545965, -179.30340882, -318.47364101, -108.47520803,
       -225.36027497, -102.41569296, -160.29211659,  -80.78915386,
        -79.67220471,  -63.9959116 , -195.26808406, -131.50764442,
       -146.52691138,  -85.38092427, -176.93020471,  -87.92573669,
        -78.20476273, -290.86418074,  -95.39883771,  -88.64937885,
       -247.54277325,  -78.57047334, -169.08622226, -144.35842573,
       -174.97758812,  -60.57101783,  -76.42520779,  -80.96883655,
        -63.47261053,  -23.77088049,  -45.13229257,  -32.51619086,
        -53.23337603,  -51.64518829,  -76.15377394,  -23.92336646,
        -56.23359142,  -32.16894184,  -53.44043613,  -23.92864396,
        -62.09540256,  -34.16774185,    1.78410776,  -40.6032581 ,
        -48.85473058, -165.66250242,  -91.97503186,  -45.23688472,
        -55.94549219,  -54.82633024,  -24.47434222,  -33.26584069,
        -28.38063888,  -66.27365907,  -58.19834695,  -52.38815385,
        -68.47946155,  -62.30036234,  -14.0647069 ,   -9.24766

In [5]:
import matplotlib as plt