In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import gym
from collections import deque
import random
import time
import yaml

In [None]:
# TO DO:
# add gradient clipping?

# TO TEST:
# enforcing norm constraint on P
# how to do exploration
# whether this works at all?

class klqr:
    # not currently doing value updates at varying rates
    # not currently doing double Q learning (what would this look like?)
    
    def __init__(self,config,sess):
        self.sess = sess
        
        self.x_dim = config['x_dim']
        self.z_dim = config['z_dim']
        self.a_dim = config['a_dim']
        self.lr = config['lr']
        self.horizon = config['horizon']
        self.gamma = config['discount_rate']

        
        ou_theta = config['ou_theta']
        ou_sigma = config['ou_sigma']
        self.config = config
        
        # Ornstein-Uhlenbeck noise for exploration -- code from Yuke Zhu
        self.noise_var = tf.Variable(tf.zeros([self.a_dim,1]))
        noise_random = tf.random_normal([self.a_dim,1], stddev=ou_sigma)
        self.noise = self.noise_var.assign_sub((ou_theta) * self.noise_var - noise_random)

        self.max_riccati_updates = config['max_riccati_updates']
        self.train_batch_size = config['train_batch_size']
        self.replay_buffer = ReplayBuffer(buffer_size=config['replay_buffer_size'])
        
        self.experience_count = 0
        
    def build_model(self):        

        with tf.variable_scope('model',reuse=tf.AUTO_REUSE):
            
            self.x_ = tf.placeholder(tf.float32,shape=[None, self.x_dim])
            self.xp_ = tf.placeholder(tf.float32,shape=[None, self.x_dim])
            self.a_ = tf.placeholder(tf.float32,shape=[None, self.a_dim])
            self.r_ = tf.placeholder(tf.float32,shape=[None])

            self.z = self.encoder(self.x_)
            self.zp = self.encoder(self.xp_)

            print('z shape:', self.z.get_shape())

            #init R

            self.R_asym = tf.get_variable('R_asym',shape=[self.a_dim,self.a_dim])
    #         self.R_asym = tf.Variable(np.random.rand(self.a_dim,self.a_dim) - 0.5)

            # working with Ra.T Ra so that inner product is norm(Rx) and not norm(R.T x)
            self.R = tf.matmul(tf.transpose(self.R_asym),self.R_asym)

            #init Q -- shape: z_dim * z_dim
            self.Q_asym = tf.get_variable('Q_asym',shape=[self.z_dim,self.z_dim])
            self.Q = tf.matmul(tf.transpose(self.Q_asym),self.Q_asym)

            #init P -- shape: z_dim * z_dim
            self.P = tf.get_variable('P_asym',shape=[self.z_dim,self.z_dim],trainable=False,initializer=tf.initializers.identity)
            self.P_asym = tf.transpose(tf.cholesky(self.P)) #this might need to have the transpose removed?

            #init B -- shape: z_dim * u_dim
            self.B = tf.get_variable('B',shape=[self.z_dim,self.a_dim])
    #         self.B = tf.Variable(np.random.rand(self.z_dim,self.u_dim) - 0.5)

            #init A -- shape: z_dim * z_dim
            self.A = tf.get_variable('A',shape=[self.z_dim,self.z_dim])
    #         self.A = tf.Variable(np.random.rand(self.z_dim,self.z_dim) - 0.5)

            #define K -- shape: u_dim * z_dim
            term1 = tf.matrix_inverse(self.R + tf.matmul(tf.matmul(tf.transpose(self.B),self.Q),self.B))
            term2 = tf.matmul(tf.matmul(tf.transpose(self.B),self.P),self.A)
            self.K = tf.matmul(term1,term2)
            self.policy_action = tf.transpose(tf.matmul(self.K,tf.transpose(self.z)))

            #make reward negative to convert to cost
            self.bootstrapped_value = -self.r_ + self.gamma*tf.square(tf.norm(tf.transpose(tf.matmul(self.P_asym,tf.transpose(self.zp))),axis=1))

            action_cost = tf.square(tf.norm(tf.transpose(tf.matmul(self.R_asym,tf.transpose(self.a_))),axis=1))#can simplify this by taking norm on other axis
            state_cost = tf.square(tf.norm(tf.transpose(tf.matmul(self.Q_asym,tf.transpose(self.z))),axis=1)) 
            self.PABK = tf.matmul(self.P_asym, self.A + tf.matmul(self.B,self.K))
            Vzp = tf.square(tf.norm(tf.transpose(tf.matmul(self.PABK,tf.transpose(self.zp))),axis=1))
            self.Qsa = action_cost + state_cost + Vzp

            self.td_loss = tf.nn.l2_loss(tf.reduce_mean(self.bootstrapped_value - self.Qsa))
            #can add regularization via P, dynamics, sparsity, etc
            self.loss = self.td_loss 
            global_step = tf.Variable(0, trainable=False, name='global_step')
            optimizer = tf.train.AdamOptimizer(self.lr)
            self.train_op = optimizer.minimize(self.loss, global_step=global_step)
    
            self.sess.run(tf.global_variables_initializer())

    
    def update_model(self):        
        #this function is mostly taken from Yuke's code
#         print('updating model')
        if self.replay_buffer.count() < self.train_batch_size:
            return
        
        batch           = self.replay_buffer.getBatch(self.train_batch_size)
        
        states          = np.zeros((self.train_batch_size, self.x_dim))
        rewards         = np.zeros((self.train_batch_size))
        actions         = np.zeros((self.train_batch_size, self.a_dim))
        next_states     = np.zeros((self.train_batch_size, self.x_dim))

        for k, (s0, a, r, s1, done) in enumerate(batch):
            #currently throwing away done states; should fix this
            states[k] = s0
            rewards[k] = r
            actions[k] = a
            next_states[k] = s1
            # check terminal state
#             if not done:
#                 next_states[k] = s1
#                 next_state_mask[k] = 1

        cost, _ = self.sess.run([self.loss, self.train_op],
        {
        self.x_:  states,
        self.xp_: next_states,
        self.a_:  actions,
        self.r_:  rewards
        })
    
        #possibly update target via Riccati recursion? or do standard target separation? 
    
    def update_P(self):
        print('updating P')
        self.P = tf.identity(self.Q)
        for k in range(self.max_riccati_updates):
            #do Riccati backup in tensorflow oh god why
            ABK = self.A + tf.matmul(self.B,self.K)
            APA = tf.matmul(tf.matmul(tf.transpose(ABK),self.P),ABK) #
            self.P = self.Q + tf.matmul(tf.matmul(tf.transpose(self.K),self.R),self.K) + self.gamma*APA
        
        self.P_asym = tf.transpose(tf.cholesky(self.P))
        print(sess.run(self.P))
            #TODO add a termination criterion for norm of Riccati update difference?
        
    def pi(self,x,explore=True):
        self.experience_count += 1
        x = np.reshape(x,(1,3))
        
        a,w = self.sess.run([self.policy_action,self.noise], {self.x_: x})
        
        a = a + w if explore else a
        # TODO check the dimension of the output of this
        return [a[0,0]]
        
    def store_experience(self,s,a,r,sp,done):
        # currently storing experience for every iteration
        self.replay_buffer.add(s, a, r, sp, done)
    
    def encoder(self,x,name="encoder",batch_norm=False):
        layer_sizes = self.config['encoder_layers']
        with tf.variable_scope(name,reuse=tf.AUTO_REUSE):
            inp = x
            for units in layer_sizes: 
                inp = tf.layers.dense(inputs=inp, units=units,activation=tf.nn.relu)

            z = tf.layers.dense(inputs=inp, units=self.z_dim,activation=None)

        if batch_norm:
            z = tf.layers.batch_normalization(z)

        return z

class ReplayBuffer:
    # taken from Yuke Zhu's Q learning implementation
    
    def __init__(self, buffer_size):

        self.buffer_size = buffer_size
        self.num_experiences = 0
        self.buffer = deque()

    def getBatch(self, batch_size):
        # random draw N
        return random.sample(self.buffer, batch_size)

    def size(self):
        return self.buffer_size

    def add(self, state, action, reward, next_action, done):
        new_experience = (state, action, reward, next_action, done)
        if self.num_experiences < self.buffer_size:
          self.buffer.append(new_experience)
          self.num_experiences += 1
        else:
          self.buffer.popleft()
          self.buffer.append(new_experience)

    def count(self):
        # if buffer is full, return buffer size
        # otherwise, return experience counter
        return self.num_experiences

    def erase(self):
        self.buffer = deque()
        self.num_experiences = 0

In [None]:
with open('config.yml','r') as ymlfile:
    config = yaml.load(ymlfile)
    
tf.reset_default_graph()
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))

In [4]:
# simulates the agent acting in env, yielding every N steps
# (decouples episode reseting mechanics from the training alg)
def experience_generator(agent, env, N):
    s = env.reset()
    n_steps = 0
    n_eps = 0
    last_cum_rew = 0
    cum_rew = 0
    while True:
        n_steps += 1
        a = agent.pi(s)
        sp, r, done,_ = env.step(a)
        cum_rew += r
        if done:
            n_eps += 1
            last_cum_rew = cum_rew
            cum_rew = 0
            s = env.reset()
        else:
            agent.store_experience(s, a, r, sp, done)
            s = sp

        if n_steps % N == 0:
            yield (n_steps, n_eps, last_cum_rew)



def train_agent(agent, env,
                max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint
                n_transitions_between_updates=100,
                n_optim_steps_per_update=100,
                n_optim_steps_per_p_update=5,
                ):

    # run an episode, and feed data to model
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"

    exp_gen = experience_generator(agent, env, n_transitions_between_updates)

    while True:
        iters_so_far += 1
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        elif max_seconds and time.time() - tstart >= max_seconds:
            break

        print("********** Iteration %i ************"%iters_so_far)

        # gather experience
        episodes_so_far, timesteps_so_far, last_cum_rew = exp_gen.__next__()

        # optimize the model from collected data:
        for i in range(n_optim_steps_per_update):
            agent.update_model()

            if (i+1) % n_optim_steps_per_p_update == 0:
                agent.update_P()

        print("\tLast Episode Reward: %d"%last_cum_rew)
        # add other logging stuff here
        # add saving checkpoints here


In [5]:
env = gym.make('Pendulum-v0')
agent = klqr(config,sess)
agent.build_model()
train_agent(agent,env,max_timesteps=100)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
z shape: (?, 6)
********** Iteration 1 ************
updating P
[[ 9.9125219e+05  4.7940310e+06  1.9061870e+06 -1.3809962e+05
  -2.1658365e+06  1.5345481e+06]
 [ 4.7940310e+06  2.3185652e+07  9.2190070e+06 -6.6788562e+05
  -1.0474756e+07  7.4216310e+06]
 [ 1.9061868e+06  9.2190060e+06  3.6656378e+06 -2.6556341e+05
  -4.1649408e+06  2.9509655e+06]
 [-1.3810022e+05 -6.6788731e+05 -2.6556438e+05  1.9244092e+04
   3.0173681e+05 -2.1378564e+05]
 [-2.1658362e+06 -1.0474759e+07 -4.1649408e+06  3.0173594e+05
   4.7322635e+06 -3.3529265e+06]
 [ 1.5345478e+06  7.4216310e+06  2.9509652e+06 -2.1378511e+05
  -3.3529260e+06  2.3756352e+06]]
updating P
[[ 8.5861788e+05  4.1848048e+06  1.6668281e+06 -1.1383765e+05
  -1.8909962e+06  1.3476669e+06]
 [ 4.1848062e+06  2.0396358e+07  8.1239710e+

[[ 5.11512250e+05  2.57765300e+06  1.03521462e+06 -6.09846758e+04
  -1.16233812e+06  8.49988125e+05]
 [ 2.57765325e+06  1.29896150e+07  5.21678050e+06 -3.07305781e+05
  -5.85738900e+06  4.28336950e+06]
 [ 1.03521488e+06  5.21678000e+06  2.09512488e+06 -1.23418969e+05
  -2.35239575e+06  1.72025038e+06]
 [-6.09845391e+04 -3.07305312e+05 -1.23418633e+05  7.27548828e+03
   1.38573297e+05 -1.01332555e+05]
 [-1.16233838e+06 -5.85738900e+06 -2.35239600e+06  1.38573703e+05
   2.64126725e+06 -1.93149338e+06]
 [ 8.49988438e+05  4.28337100e+06  1.72025075e+06 -1.01332594e+05
  -1.93149362e+06  1.41245762e+06]]
updating P
[[ 5.11079188e+05  2.57563375e+06  1.03442106e+06 -6.09327266e+04
  -1.16141638e+06  8.49357312e+05]
 [ 2.57563375e+06  1.29802630e+07  5.21311100e+06 -3.07064344e+05
  -5.85311550e+06  4.28046200e+06]
 [ 1.03442106e+06  5.21311100e+06  2.09368575e+06 -1.23324359e+05
  -2.35071900e+06  1.71911100e+06]
 [-6.09327266e+04 -3.07064062e+05 -1.23324086e+05  7.26931494e+03
   1.38463203

[[ 5.0940844e+05  2.5679530e+06  1.0314028e+06 -6.0835613e+04
  -1.1578650e+06  8.4695506e+05]
 [ 2.5679535e+06  1.2945284e+07  5.1993970e+06 -3.0666319e+05
  -5.8368975e+06  4.2695840e+06]
 [ 1.0314031e+06  5.1993985e+06  2.0883132e+06 -1.2317091e+05
  -2.3443568e+06  1.7148532e+06]
 [-6.0835633e+04 -3.0666325e+05 -1.2317096e+05  7.2697930e+03
   1.3827186e+05 -1.0114053e+05]
 [-1.1578650e+06 -5.8368975e+06 -2.3443565e+06  1.3827180e+05
   2.6318015e+06 -1.9251120e+06]
 [ 8.4695512e+05  4.2695845e+06  1.7148531e+06 -1.0114039e+05
  -1.9251121e+06  1.4081865e+06]]
updating P
[[ 5.09530562e+05  2.56851075e+06  1.03161950e+06 -6.08407578e+04
  -1.15812525e+06  8.47128438e+05]
 [ 2.56851025e+06  1.29478000e+07  5.20037150e+06 -3.06681625e+05
  -5.83807650e+06  4.27036150e+06]
 [ 1.03161925e+06  5.20037150e+06  2.08868900e+06 -1.23177266e+05
  -2.34481325e+06  1.71515325e+06]
 [-6.08408633e+04 -3.06682250e+05 -1.23177531e+05  7.26924121e+03
   1.38281500e+05 -1.01145594e+05]
 [-1.15812500e

[[ 5.13086188e+05  2.58486975e+06  1.03800219e+06 -6.09500234e+04
  -1.16574175e+06  8.52264500e+05]
 [ 2.58486950e+06  1.30223810e+07  5.22938000e+06 -3.07047031e+05
  -5.87291700e+06  4.29365250e+06]
 [ 1.03800181e+06  5.22938050e+06  2.09996075e+06 -1.23301516e+05
  -2.35838050e+06  1.72419638e+06]
 [-6.09502070e+04 -3.07047375e+05 -1.23301969e+05  7.24502197e+03
   1.38474781e+05 -1.01235180e+05]
 [-1.16574162e+06 -5.87291700e+06 -2.35838050e+06  1.38474359e+05
   2.64860900e+06 -1.93637888e+06]
 [ 8.52264312e+05  4.29365300e+06  1.72419588e+06 -1.01235000e+05
  -1.93637888e+06  1.41567638e+06]]
updating P
[[ 5.13949469e+05  2.58883350e+06  1.03955044e+06 -6.09792461e+04
  -1.16758675e+06  8.53506562e+05]
 [ 2.58883425e+06  1.30404130e+07  5.23640350e+06 -3.07148719e+05
  -5.88133800e+06  4.29927200e+06]
 [ 1.03955056e+06  5.23640300e+06  2.10269325e+06 -1.23337406e+05
  -2.36166350e+06  1.72638075e+06]
 [-6.09789023e+04 -3.07147156e+05 -1.23337000e+05  7.23956250e+03
   1.38526562

[[ 5.13207562e+05  2.58555975e+06  1.03827688e+06 -6.09537070e+04
  -1.16604425e+06  8.52511062e+05]
 [ 2.58555950e+06  1.30262540e+07  5.23092300e+06 -3.07074469e+05
  -5.87461900e+06  4.29502700e+06]
 [ 1.03827656e+06  5.23092250e+06  2.10057500e+06 -1.23312547e+05
  -2.35905800e+06  1.72474325e+06]
 [-6.09538008e+04 -3.07074781e+05 -1.23312812e+05  7.24407617e+03
   1.38485953e+05 -1.01246391e+05]
 [-1.16604412e+06 -5.87461900e+06 -2.35905800e+06  1.38485891e+05
   2.64935625e+06 -1.93698338e+06]
 [ 8.52510938e+05  4.29502550e+06  1.72474312e+06 -1.01246531e+05
  -1.93698312e+06  1.41616088e+06]]
updating P
[[ 5.13379031e+05  2.58636750e+06  1.03859425e+06 -6.09615000e+04
  -1.16641612e+06  8.52768375e+05]
 [ 2.58636700e+06  1.30300410e+07  5.23240900e+06 -3.07107344e+05
  -5.87636350e+06  4.29622900e+06]
 [ 1.03859425e+06  5.23240850e+06  2.10115775e+06 -1.23325031e+05
  -2.35974325e+06  1.72521475e+06]
 [-6.09617656e+04 -3.07107500e+05 -1.23325070e+05  7.24387939e+03
   1.38501844

[[ 5.13555719e+05  2.58737800e+06  1.03899575e+06 -6.10217422e+04
  -1.16684675e+06  8.53116688e+05]
 [ 2.58737800e+06  1.30357420e+07  5.23467700e+06 -3.07424594e+05
  -5.87880950e+06  4.29818450e+06]
 [ 1.03899544e+06  5.23467650e+06  2.10205975e+06 -1.23451703e+05
  -2.36071550e+06  1.72599300e+06]
 [-6.10215938e+04 -3.07424344e+05 -1.23451680e+05  7.25535449e+03
   1.38641734e+05 -1.01362555e+05]
 [-1.16684675e+06 -5.87881050e+06 -2.36071600e+06  1.38641641e+05
   2.65120700e+06 -1.93837888e+06]
 [ 8.53116750e+05  4.29818600e+06  1.72599350e+06 -1.01362688e+05
  -1.93837912e+06  1.41721250e+06]]
updating P
[[ 5.13587031e+05  2.58754125e+06  1.03906081e+06 -6.10300391e+04
  -1.16691838e+06  8.53170688e+05]
 [ 2.58754075e+06  1.30365890e+07  5.23501350e+06 -3.07467562e+05
  -5.87918250e+06  4.29846500e+06]
 [ 1.03906044e+06  5.23501450e+06  2.10219425e+06 -1.23468750e+05
  -2.36086375e+06  1.72610500e+06]
 [-6.10299688e+04 -3.07467094e+05 -1.23469016e+05  7.25675391e+03
   1.38660547

[[ 5.32374062e+05  2.67372325e+06  1.07261475e+06 -6.18930273e+04
  -1.20699038e+06  8.80083500e+05]
 [ 2.67372375e+06  1.34282490e+07  5.38699850e+06 -3.10832438e+05
  -6.06186800e+06  4.42005200e+06]
 [ 1.07261462e+06  5.38699850e+06  2.16110250e+06 -1.24697477e+05
  -2.43183500e+06  1.77318812e+06]
 [-6.18932383e+04 -3.10831719e+05 -1.24697352e+05  7.20018311e+03
   1.40318094e+05 -1.02311266e+05]
 [-1.20699038e+06 -6.06186750e+06 -2.43183500e+06  1.40318344e+05
   2.73649100e+06 -1.99532838e+06]
 [ 8.80083188e+05  4.42005250e+06  1.77318825e+06 -1.02311062e+05
  -1.99532862e+06  1.45490962e+06]]
updating P
[[ 5.33451000e+05  2.67863650e+06  1.07451938e+06 -6.19420195e+04
  -1.20928025e+06  8.81610375e+05]
 [ 2.67863650e+06  1.34504400e+07  5.39556800e+06 -3.11018062e+05
  -6.07224500e+06  4.42690400e+06]
 [ 1.07451950e+06  5.39556800e+06  2.16440675e+06 -1.24764523e+05
  -2.43584825e+06  1.77582725e+06]
 [-6.19422031e+04 -3.11019250e+05 -1.24765156e+05  7.19739893e+03
   1.40411391

KeyboardInterrupt: 