In [1]:
%pylab inline
import tensorflow as tf
import numpy as np
import gym
from tqdm import tqdm, trange
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
np.set_printoptions(suppress=True)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

Populating the interactive namespace from numpy and matplotlib


In [2]:
# Here we create the environment we are acting in
env = gym.make("CartPole-v0")
env.observation_space, env.action_space

(Box(4,), Discrete(2))

In [3]:
# Here we create the 3 functions that make up MuZero

S_DIM = 4

# h: representation function
# s_0 = h(o_1...o_t)
x = o_0 = Input(env.observation_space.shape)
x = Dense(64)(x)
x = Activation('elu')(x)
s_0 = Dense(S_DIM, name='s_0')(x)
h = Model(o_0, s_0, name="h")
def ht(o_0):
  return h.predict(np.array(o_0)[None])[0]

# g: dynamics function (recurrent in state?) old_state+action -> state+reward
# r_k, s_k = g(s_k-1, a_k)
s_km1 = Input(S_DIM)
a_k = Input(env.action_space.n)
x = Concatenate()([s_km1, a_k])
x = Dense(64)(x)
x = Activation('elu')(x)
x = Dense(64)(x)
x = Activation('elu')(x)
s_k = Dense(S_DIM, name='s_k')(x)
r_k = Dense(1, name='r_k')(x)
g = Model([s_km1, a_k], [r_k, s_k], name="g")
g.compile('adam', 'mse')
def gt(s_km1, a_k):
  r_k, s_k = g.predict([s_km1[None], a_k[None]])
  return r_k[0], s_k[0]

# f: prediction function -- state -> policy+value
# p_k, v_k = f(s_k)
x = s_k = Input(S_DIM)
x = Dense(32)(x)
x = Activation('elu')(x)
p_k = Dense(env.action_space.n)(x)
p_k = Activation('softmax', name='p_k')(p_k)
v_k = Dense(1, name='v_k')(x)
f = Model(s_k, [p_k, v_k], name="f")
f.compile('adam', 'mse')
def ft(s_k):
  p_k, v_k = f.predict(s_k[None])
  return p_k[0], v_k[0]

In [4]:
# Here we create the MuZero function

# it's using the dynamics function for rollout search
# K is the number of rollout steps
K = 5

# represent
o_0 = Input(env.observation_space.shape, name="o_0")
s_km1 = h(o_0)

# rollout with dynamics
# p_k, v_k, r_k = mu(o_0, a_1_k)
a_all, mu_all = [], []

# run f on the first state
p_km1, v_km1 = f([s_km1])
mu_all += [p_km1, v_km1]

for k in range(K):
  a_k = Input(env.action_space.n, name="a_%d" % k)
  r_k, s_k  = g([s_km1, a_k])

  # predict
  p_k, v_k = f([s_k])

  # store
  a_all.append(a_k)
  mu_all += [r_k, p_k, v_k]
  s_km1 = s_k

# put in the first observation and actions
#   need policy from search
#   need values from sum of rewards + last state value (real state?)
#   need rewards
#a_all = Concatenate()(a_all)
mu = Model([o_0, a_all], mu_all)
mu.compile('adam', 'mse')
mu.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
o_0 (InputLayer)                [(None, 4)]          0                                            
__________________________________________________________________________________________________
h (Model)                       (None, 4)            580         o_0[0][0]                        
__________________________________________________________________________________________________
a_0 (InputLayer)                [(None, 2)]          0                                            
__________________________________________________________________________________________________
g (Model)                       [(None, 1), (None, 4 4933        h[1][0]                          
                                                                 a_0[0][0]                    

In [5]:
def to_one_hot(x,n):
  ret = np.zeros([n])
  ret[x] = 1.0
  return ret

# enumerate the whole action space
import itertools
aopts = list(itertools.product([0,1], repeat=K))
aoptss = np.array([[to_one_hot(x, 2) for x in aa] for aa in aopts])
aoptss = aoptss.swapaxes(0,1)
aoptss = [aoptss[x] for x in range(5)]

# TODO: this is naive search, replace with MCTS
def search(o_0):
  # concatenate the current state with every possible action
  o_0s = np.repeat(np.array(o_0)[None], len(aopts), axis=0)
  ret = mu.predict([o_0s]+aoptss)
  v_s = ret[-1]
  
  # group the value with the action rollout that caused it
  v = [(v_s[i][0], aopts[i]) for i in range(len(v_s))]
  
  av = [0,0]
  for vk, ak in v:
    av[ak[0]] += vk
    
  policy = np.exp(av)/sum(np.exp(av))
  return policy
  
  #return sorted(v, reverse=True)

env.reset()
search(env.state), ft(ht(env.state))

(array([0.51951649, 0.48048351]),
 (array([0.5006866 , 0.49931344], dtype=float32),
  array([0.00405711], dtype=float32)))

In [6]:
gamma = 0.95

def bstack(bb):
  ret = [[x] for x in bb[0]]
  for i in range(1, len(bb)):
    for j in range(len(bb[i])):
      ret[j].append(bb[i][j])
  return [np.array(x) for x in ret]

env.reset()
sc = 0
scs = []
vs = []
rs = []
# "epochs"
for _ in range(20):
  # accumulate a batch
  X,Y = [],[]
  for _ in range(16):
    x = [np.copy(env.state)]
    # rollout of 5 steps
    y = []
    for _ in range(K):
      _, v_0 = ft(ht(env.state))
      p_0 = search(env.state)
      a_1 = np.random.choice([0,1], p=p_0)
      _, r_1, done, _ = env.step(a_1)
      sc += 1
      
      y += [p_0, None, r_1]
      
      # append the real actions taken
      x.append(to_one_hot(a_1, 2))
    
    _, v_k = ft(ht(env.state))
    p_k = search(env.state)
    y += [p_k, v_k]
    
    # fix values
    for i in range(K):
      y[-4 - i*3] = y[-3 - i*3] + gamma * y[-1 - i*3]
      
    vs += y[1::3][0:5]
    rs += y[2::3]
        
    X.append(x)
    Y.append(y)
    if done:
      env.reset()
      scs.append(sc)
      sc = 0
      
  ll = mu.fit(bstack(X), bstack(Y), verbose=1)
  loss = ll.history['loss']
  print(loss)
plot(vs)
plot(rs)



[42.358123779296875]
[43.68019104003906]
[41.97210693359375]
[38.11468505859375]


KeyboardInterrupt: 

In [9]:
# can act?
env.reset()
for sn in range(100):
  p_0 = search(env.state)
  _, v_0 = ft(ht(env.state))
  a_1 = np.random.choice([0,1], p=p_0)
  print(p_0, v_0, a_1)
  env.render()
  _,r,done,_ = env.step(a_1)
  if done:
    print("DONE", sn)
    break

[0.51052331 0.48947669] [0.11920644] 0
[0.51062385 0.48937615] [0.23886201] 1
[0.51053848 0.48946152] [0.12409505] 0
[0.51063529 0.48936471] [0.2439914] 1
[0.51055707 0.48944293] [0.13041276] 0
[0.51064786 0.48935214] [0.25046456] 1
[0.51058009 0.48941991] [0.13830557] 1
[0.51037237 0.48962763] [0.02507184] 1
[0.51000986 0.48999014] [-0.0803359] 0
[0.51044351 0.48955649] [0.04015018] 0
[0.51068066 0.48931934] [0.1639403] 1
[0.51047768 0.48952232] [0.04766529] 1
[0.51009116 0.48990884] [-0.06143287] 0
[0.51052275 0.48947725] [0.0574681] 0
[0.51074462 0.48925538] [0.17879365] 0
[0.51078319 0.48921681] [0.29218638] 0
[0.51065427 0.48934573] [0.38931048] 1
[0.51074408 0.48925592] [0.2893029] 0
[0.51061515 0.48938485] [0.3878355] 1
[0.5107167 0.4892833] [0.29189497] 0
[0.51058155 0.48941845] [0.39080775] 0
[0.51039156 0.48960844] [0.46772557] 0
[0.51022316 0.48977684] [0.5239818] 1
[0.51032265 0.48967735] [0.46657318] 0
[0.51016626 0.48983374] [0.52139705] 1
[0.51026057 0.48973943] [0.47126