In [1]:
import gym
import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt


In [2]:
env = gym.make('CartPole-v1')

In [3]:
#param

LEARNING_RATE = 0.001

GAMMA = 0.992

EXPLORATION_MAX = 1.0 # epsilon greedy
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.95

MEMORY_SIZE = 1000000 #50000
BATCH_SIZE = 10 #20

In [4]:
class Agent(tf.keras.Model):

  def __init__(self):
    super(Agent, self).__init__(name='mon_agent')
    # Define your layers here.
    #self.dense1 = tf.keras.layers.Dense(16, activation='relu', input_shape=(4,))
    self.dense1 = tf.keras.layers.Dense(24, activation='relu')
    #self.dropout1 = tf.keras.layers.Dropout(0.5)
    self.dense2 = tf.keras.layers.Dense(24, activation='relu')
    self.dropout2 = tf.keras.layers.Dropout(0.5)
    self.dense3 = tf.keras.layers.Dense(2, activation='linear')

  def call(self, inputs):
    tmp = inputs
    tmp = self.dense1(tmp)
    #tmp = self.dropout1(tmp)
    tmp = self.dense2(tmp)
    tmp = self.dropout2(tmp)
    return self.dense3(tmp)


In [5]:
agent = Agent()
agent.compile(optimizer=tf.keras.optimizers.Adam())
agent.build(tf.TensorShape([None,4]))
agent.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
    loss="mse",
    metrics=['accuracy'])

agent.summary()

Model: "mon_agent"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  120       
_________________________________________________________________
dense_1 (Dense)              multiple                  600       
_________________________________________________________________
dropout (Dropout)            multiple                  0         
_________________________________________________________________
dense_2 (Dense)              multiple                  50        
Total params: 770
Trainable params: 770
Non-trainable params: 0
_________________________________________________________________


In [6]:
def showHistory(history) :
    #print(history.history.keys())
    plt.figure(figsize=(15.0,10.0))
    fig, axes = plt.subplots(nrows=1, ncols=2) 
    fig.set_size_inches(15.0, 7.0)         
    axes[0].plot(history.history['loss'], label="loss")
    axes[0].legend()
    axes[1].plot(history.history['accuracy'], label="accuracy")
    axes[1].legend()
    plt.show()


In [7]:
#env.close()
#agent.save_weights('CartPoleV1.tf')
#agent.load_weights('CartPoleV1.tf')

def drawAgent(agent, p=0.0, dp=0.0) :
    rx = np.arange(-0.24,0.24,0.02, np.float32)
    ry = np.arange(-2,2,0.2, np.float32)
    def agentDir(x,y) :
        [a ,b] = agent([[p, dp, x, y]])[0]
        #return '<' if b<a else '>'
        #return (b-a).numpy()
        return np.sign(b-a)
        
    pix = [
        #[ [[type(0.0), type(0.0), type(x), type(y)]] for x in rx]
        [ agentDir(x, y) for x in rx]
        for y in ry
        ]
    #print(pix)
    #return
    fig, ax = plt.subplots()
    fig.set_size_inches(15.0, 7.0)         
    im = ax.imshow(pix)
    ax.set_xticks(np.arange(len(rx)))
    ax.set_yticks(np.arange(len(ry)))
    ax.set_xticklabels([round(l,2) for l in rx])
    ax.set_yticklabels([round(l,2) for l in ry])
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")    
    for y in range(len(ry)):
        for x in range(len(rx)):
            #text = ax.text(x, y, pix[y][x], ha="center", va="center", color="w")
            text = ax.text(x, y, '<' if pix[y][x]<0 else '>' if pix[y][x]>0 else '0', ha="center", va="center", color="w")
    plt.show()

#print(pix)


In [8]:
memory = [] #{ "states":[], "masks":[], "actions":[], "values":[]}
loss = []
accuracy = []
eg = EXPLORATION_MAX
scores = []

In [None]:
for epi in range(1,1000) :
    print(epi, " => ", eg)
    state = env.reset()
    
    step = 0
    #for step in range(1000):
    while True :
        step += 1
        if eg<0.03 :
            env.render()

        if (random.random() > eg) :
            Qs = agent([state])[0];
            action = np.argmax(Qs);
            #Q = Qs[action]
            #print("State: ", state, " Qs: ", Qs.numpy(), " Action: ", action, " Q: ", Q.numpy())
        else :
            #Qs = agent([state])[0];
            action = random.randrange(2)
            #Q = Qs[action]
            #print("State: ", state, " Action: ", action, " Q: ", Q.numpy())

        state_next, reward, done, info = env.step(action)
        if done : reward = -1.0
            
        mem = { "s": state, "a": action, "r": reward, "d": done, "sn": state_next, }        
        memory.append(mem)
        
        state = state_next

        if done : 
            print('-' if abs(state[0])>2.4 else 'a' if abs(state[2])>0.2094 else '?')
            break
        else :
            print('.', end='')
        
        state = state_next
        
    if len(memory) >= BATCH_SIZE :
        if len(memory) > MEMORY_SIZE :
            memory = random.sample(memory, MEMORY_SIZE)

        batch = random.sample(memory, BATCH_SIZE)
        bx = [ b["s"] for b in batch ]
        def applyAt(a, index, value) :
            a[index] = value
            return a
        
        by0 = [ agent([b['s']])[0].numpy() for b in batch ]
        by = [ applyAt(
                agent([b['s']])[0].numpy(),
                b["a"],
                b['r'] if b['d'] else b['r'] + GAMMA * np.max(agent([b['sn']])[0]) 
            ) for b in batch ]

        #print("!!batch: ", batch[:4], "\r\n!!bx: ",bx[:4], "\r\n!!by0: ", by0[:4], "\r\n!!by:", by[:4], "\r\n")
        
        #print("bs: ", bs[:4], " bm: ",bm[:4], "by: ", by[:4], "\r\n")
        history = agent.fit(
            np.array(bx),
            np.array(by),
            verbose=0)
        loss.append(history.history['loss'])
        accuracy.append(history.history['accuracy'])
        
        #break
    #drawAgent(agent)
    scores.append(step)
    print("step : ", step, " (", np.min(scores), " / ", np.average(scores[max(0,len(scores)-100):]), " / ", np.max(scores), ")")
    #print(memory)
    eg = max(EXPLORATION_MIN, eg * EXPLORATION_DECAY)
    #print("\r\neG: ", eg, "; ", c0,"-",c1,"/",c0+c1)
    #print("\r\neG: ", eg)
    
    #plt.figure(figsize=(15.0,7.0))
    #fig, axes = plt.subplots(nrows=1, ncols=2) 
    #fig.set_size_inches(15.0, 7.0)         
    #axes[0].plot(loss, label="loss")
    #axes[0].legend()
    #axes[1].plot(accuracy, label="accuracy")
    #axes[1].legend()
    #plt.show()
    #loss = []
    #accuracy = []

    #showHistory(history)
    
    #print(agent([s0])[0].numpy(), " <= ", o0)
    #break
    
    
    

1  =>  1.0
.........a
step :  10  ( 10  /  10.0  /  10 )
2  =>  0.95
...........................a
step :  28  ( 10  /  19.0  /  28 )
3  =>  0.9025
..................a
step :  19  ( 10  /  19.0  /  28 )
4  =>  0.8573749999999999
...........a
step :  12  ( 10  /  17.25  /  28 )
5  =>  0.8145062499999999
.................a
step :  18  ( 10  /  17.4  /  28 )
6  =>  0.7737809374999999
.....................a
step :  22  ( 10  /  18.166666666666668  /  28 )
7  =>  0.7350918906249998
...........a
step :  12  ( 10  /  17.285714285714285  /  28 )
8  =>  0.6983372960937497
..............a
step :  15  ( 10  /  17.0  /  28 )
9  =>  0.6634204312890623
...........a
step :  12  ( 10  /  16.444444444444443  /  28 )
10  =>  0.6302494097246091
............a
step :  13  ( 10  /  16.1  /  28 )
11  =>  0.5987369392383786
.............a
step :  14  ( 10  /  15.909090909090908  /  28 )
12  =>  0.5688000922764596
...................................a
step :  36  ( 10  /  17.583333333333332  /  36 )
13  =>  0.54

.......a
step :  8  ( 8  /  11.09  /  36 )
101  =>  0.01
.......a
step :  8  ( 8  /  11.07  /  36 )
102  =>  0.01
.........a
step :  10  ( 8  /  10.89  /  36 )
103  =>  0.01
........a
step :  9  ( 8  /  10.79  /  36 )
104  =>  0.01
........a
step :  9  ( 8  /  10.76  /  36 )
105  =>  0.01
..........a
step :  11  ( 8  /  10.69  /  36 )
106  =>  0.01
.......a
step :  8  ( 8  /  10.55  /  36 )
107  =>  0.01
........a
step :  9  ( 8  /  10.52  /  36 )
108  =>  0.01
.......a
step :  8  ( 8  /  10.45  /  36 )
109  =>  0.01
..........a
step :  11  ( 8  /  10.44  /  36 )
110  =>  0.01
.........a
step :  10  ( 8  /  10.41  /  36 )
111  =>  0.01
........a
step :  9  ( 8  /  10.36  /  36 )
112  =>  0.01
.......a
step :  8  ( 8  /  10.08  /  36 )
113  =>  0.01
........a
step :  9  ( 8  /  10.06  /  36 )
114  =>  0.01
.........a
step :  10  ( 8  /  10.03  /  36 )
115  =>  0.01
........a
step :  9  ( 8  /  10.0  /  36 )
116  =>  0.01
.........a
step :  10  ( 8  /  9.98  /  36 )
117  =>  0.01
.......

.......a
step :  8  ( 8  /  9.37  /  36 )
244  =>  0.01
.........a
step :  10  ( 8  /  9.38  /  36 )
245  =>  0.01
.......a
step :  8  ( 8  /  9.36  /  36 )
246  =>  0.01
.........a
step :  10  ( 8  /  9.36  /  36 )
247  =>  0.01
.........a
step :  10  ( 8  /  9.38  /  36 )
248  =>  0.01
.........a
step :  10  ( 8  /  9.38  /  36 )
249  =>  0.01
.......a
step :  8  ( 8  /  9.37  /  36 )
250  =>  0.01
........a
step :  9  ( 8  /  9.37  /  36 )
251  =>  0.01
.........a
step :  10  ( 8  /  9.37  /  36 )
252  =>  0.01
........a
step :  9  ( 8  /  9.36  /  36 )
253  =>  0.01
.........a
step :  10  ( 8  /  9.36  /  36 )
254  =>  0.01
.........a
step :  10  ( 8  /  9.37  /  36 )
255  =>  0.01
.........a
step :  10  ( 8  /  9.38  /  36 )
256  =>  0.01
.........a
step :  10  ( 8  /  9.38  /  36 )
257  =>  0.01
.........a
step :  10  ( 8  /  9.4  /  36 )
258  =>  0.01
.......a
step :  8  ( 8  /  9.38  /  36 )
259  =>  0.01
.........a
step :  10  ( 8  /  9.4  /  36 )
260  =>  0.01
.........a
step

.........a
step :  10  ( 8  /  9.48  /  36 )
386  =>  0.01
.........a
step :  10  ( 8  /  9.48  /  36 )
387  =>  0.01
.........a
step :  10  ( 8  /  9.49  /  36 )
388  =>  0.01
........a
step :  9  ( 8  /  9.48  /  36 )
389  =>  0.01
........a
step :  9  ( 8  /  9.48  /  36 )
390  =>  0.01
.........a
step :  10  ( 8  /  9.49  /  36 )
391  =>  0.01
........a
step :  9  ( 8  /  9.48  /  36 )
392  =>  0.01
........a
step :  9  ( 8  /  9.46  /  36 )
393  =>  0.01
.........a
step :  10  ( 8  /  9.46  /  36 )
394  =>  0.01
........a
step :  9  ( 8  /  9.46  /  36 )
395  =>  0.01
........a
step :  9  ( 8  /  9.46  /  36 )
396  =>  0.01
........a
step :  9  ( 8  /  9.46  /  36 )
397  =>  0.01
........a
step :  9  ( 8  /  9.45  /  36 )
398  =>  0.01
.........a
step :  10  ( 8  /  9.45  /  36 )
399  =>  0.01
........a
step :  9  ( 8  /  9.44  /  36 )
400  =>  0.01
........a
step :  9  ( 8  /  9.44  /  36 )
401  =>  0.01
.......a
step :  8  ( 8  /  9.42  /  36 )
402  =>  0.01
........a
step :  9 

In [None]:
#eg = 0.05
env.close()

In [None]:
a = 0
b = '<' if a<0 else '>' if a>0 else '0'
print(b)

In [None]:
state = env.reset() #[2:]
while True :
    env.render()

    Qs = agent([state])[0];
    action = np.argmax(Qs);
    Q = Qs[action]
    print(state, " ", Qs.numpy(), " ", action)
    drawAgent(agent, state[0], state[1])
    
    state, rewards, done, info = env.step(action)
    input()
    if done :
        break

In [None]:
memory = [] #{ "states":[], "masks":[], "actions":[], "values":[]}
loss = []
accuracy = []

state = env.reset()
for step in range(1000):
    env.render()
    #print(state)
    if (random.random() > eg) :
        c0 +=1
        Qs = agent([state])[0];
        action = np.argmax(Qs);
        Q = Qs[action]
        #print("State: ", state, " Qs: ", Qs.numpy(), " Action: ", action, " Q: ", Q.numpy())
    else :
        c1 +=1
        Qs = agent([state])[0];
        action = random.randrange(2)
        Q = Qs[action]
        #print("State: ", state, " Action: ", action, " Q: ", Q.numpy())

    mask = np.zeros(2)
    mask[action] = 1
    mem = { "s": state, "a": action, "m": mask}
    print(state, " ", Qs.numpy(), " ", action)
    drawAgent(agent)
    state, rewards, done, info = env.step(action)

    if done : rewards = 0.0
    mem["r"] = rewards
    mem["sp"] = state
    mem["d"] = done
    memory.append(mem)
    
    if done : 
        print(memory)
        break


batch = memory #random.sample(memory, BATCH_SIZE)
bs = [ b["s"] for b in batch ]
bm = [ b["m"] for b in batch ]
by = [ b["m"] * ( b['r'] if b['d'] else b['r'] + a * np.max(agent([b['sp']])[0]) )
       for b in batch ]

print (bs, bm, by)
#history = coach.fit(
#    [np.array(bs), np.array(bm)],
#    np.array(by),
#    verbose=0)
#loss.append(history.history['loss'])
#accuracy.append(history.history['accuracy'])

drawAgent(agent)
