In [None]:
import sys
import pickle
import threading
if sys.version_info.major == 2:
    from Queue import Queue
else:
    from queue import Queue

import theano
import theano.tensor as T

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from environment import Environment
from priority_buffer import PriorityBuffer

In [None]:
def random_in_range(a, b):
    if b < a:
        raise ValueError('b should not be less than a')
    return np.random.rand() * (b - a) + a

e = Environment(goal=(0.0, 0.2))
e.plot()
plt.show()

def create_state_vector(eef_x, eef_y, goal_x, goal_y):
    return np.array([
        [eef_x, eef_y, goal_x, goal_y]
    ], dtype=np.float32)

In [None]:
print(e.move(0.04, 0.00))
e.plot()
plt.show()

In [None]:
class NAFNet():
    
    def __init__(self, batch_size=256):
        # inputs: goal x, y, eef x, y
        self.batch_size = batch_size
        self.control_size = 2
        self.weight_std = 4.0
        self.gamma = 0.5
        self.hidden_size = 100
        self._setup()
        
    def _setup(self):
        X = T.fmatrix('X')

        # First fully connected layer
        W1 = theano.shared(
            value=np.random.randn(4, self.hidden_size) / self.weight_std,
            name='W1'
        )
        b1 = theano.shared(
            value=np.zeros(self.hidden_size),
            name='b1'
        )
        Y1 = T.nnet.nnet.relu(T.dot(X, W1) + b1)

        # Second fully connected layer
        W2 = theano.shared(
            value=np.random.randn(self.hidden_size, self.hidden_size) / self.weight_std,
            name='W2'
        )
        b2 = theano.shared(
            value=np.zeros(self.hidden_size),
            name='b2'
        )
        Y2 = T.nnet.nnet.relu(T.dot(Y1, W2) + b2)

        # To value function estimate
        WV = theano.shared(
            value=np.random.randn(self.hidden_size, 1) / self.weight_std,
            name='WV'
        )
        bv = theano.shared(np.float32(0.0))
        V = T.dot(Y2, WV) + bv

        # µ
        Wmu = theano.shared(
            value=np.random.randn(self.hidden_size, self.control_size) / self.weight_std,
            name='Wmu'
        )
        bmu = theano.shared(np.float32(0.0))
        mu = 0.15 * T.tanh(T.dot(Y2, Wmu) + bmu)

        # Construct L matrices
        WL_diag = theano.shared(
            value=np.random.randn(self.hidden_size, self.control_size) / self.weight_std,
            name='WL_diag'
        )
        bl_diag = theano.shared(
            value=np.zeros(self.control_size),
            name='bl_diag'
        )

        YL_diag = T.exp(T.dot(Y2, WL_diag) + bl_diag)

        l_lower_size = int((self.control_size - 1) * self.control_size / 2.0)
        WL_lower = theano.shared(
            value=np.random.randn(self.hidden_size, l_lower_size) / self.weight_std,
            name='WL_lower'
        )
        bl_lower = theano.shared(
            value=np.zeros(l_lower_size),
            name='bl_lower'
        )
        YL_lower = T.dot(Y2, WL_lower) + bl_lower

        diag_idx = np.arange(0, self.control_size, 1).astype(int)
        L = T.zeros((self.batch_size, self.control_size, self.control_size))
        L_diag = T.set_subtensor(L[:, diag_idx, diag_idx], YL_diag)
        lower_idx1 = []
        lower_idx2 = []
        for i in range(1, self.control_size):
            for j in range(i):
                lower_idx1.append(i)
                lower_idx2.append(j)
        L_lower = T.set_subtensor(L[:, lower_idx1, lower_idx2], YL_lower)
        L = L_diag + L_lower
        P = T.batched_dot(L, L.dimshuffle([0, 2, 1]))

        U = T.fmatrix('U')
        d = U - mu
        A_tmp = T.batched_dot(d, P)
        A = -T.batched_dot(d, A_tmp.reshape([self.batch_size, self.control_size, 1]))

        Q = A + V

        # Reward input to loss
        R = T.fmatrix('R')
        Q_p_out = R + self.gamma * V
        Q_p = T.fmatrix('Q_p')
        td_error_out = Q_p - Q
        
        # Loss
        # TODO Add importance sampling weights
        td_error = T.fmatrix('td_error')
        loss = T.sum(td_error * Q) / self.batch_size
        loss_sq = T.sum((Q - Q_p) ** 2) / self.batch_size

        self._f_v = theano.function([X], V, name='value', allow_input_downcast=True)
        self._f_q = theano.function([X, U], Q, name='q', allow_input_downcast=True)
        self._f_y = theano.function([X, R], Q_p_out, name='q_prime', allow_input_downcast=True)
        self._f_mu = theano.function([X], mu, name='mu', allow_input_downcast=True)
        self._td_error = theano.function([X, U, Q_p], td_error_out, name='td_error', allow_input_downcast=True)

        self.params = [W1, b1, W2, b2, WV, bv, Wmu, bmu, WL_diag, bl_diag, WL_lower, bl_lower]
        
        gradients = T.grad(loss, wrt=self.params)
        rate = np.float32(5e-7)
        self._f_train = theano.function(
            inputs=[X, U, td_error],
            outputs=loss,
            updates=[(param, param + rate * grad) for param, grad in zip(self.params, gradients)],
            allow_input_downcast=True,
            name='f_train'
        )
        
        gradients_sq = T.grad(loss_sq, wrt=self.params)
        rate_sq = np.float32(1e-8)
        self._f_train_sq = theano.function(
            inputs=[X, U, Q_p],
            outputs=[loss_sq, td_error_out],
            updates=[(param, param - rate_sq * grad) for param, grad in zip(self.params, gradients_sq)],
            allow_input_downcast=True,
            name='f_train_sq'
        )
        
    def q_function(self, state_x, control_u):
        return self._f_q(state_x, control_u)
        
    def value_function(self, state_x):
        return self._f_v(state_x)
        
    def mu(self, state_x):
        return self._f_mu(state_x)[0, :]
    
    def train(self, x, x_prime, u, r):
        """
        Returns
        =======
        loss : float
        td_error : numpy.ndarray
            temporal-difference errors
        """
        q_prime = self._f_y(x_prime, r)
        td_error = self._td_error(x, u, q_prime)
        return (
            float(self._f_train(
                x,
                u,
                td_error
            )),
            td_error[:, 0]
         )
    
    def train_squared(self, x, x_prime, u, r):
        """
        Returns
        =======
        loss : float
        td_error : numpy.ndarray
            temporal-difference errors
        """
        q_prime = self._f_y(x_prime, r)
        loss, td_error = self._f_train_sq(
                x,
                u,
                q_prime
        )
        return float(loss), td_error[:, 0]
        
batch_size = 256
nn = NAFNet(batch_size=batch_size)
x = np.random.randn(batch_size, 4).astype(np.float32)
xp = np.random.randn(batch_size, 4).astype(np.float32)
u = np.random.randn(batch_size, 2).astype(np.float32)
r = np.random.randn(batch_size, 1).astype(np.float32)
err, errs = nn.train_squared(x, xp, u, r)
err

In [None]:
def test_score(environment, policy, goal_x=0.0, goal_y=0.2):
    np.random.seed(0)
    n_tries = 64
    n_steps = 4
    score_total = 0.0
    for i in range(n_tries):
        environment.reset()
        environment.goal_x = goal_x
        environment.goal_y = goal_y
        score_run = 0.0
        for _ in range(n_steps):
            action = nn.mu(create_state_vector(
                environment.eef_x,
                environment.eef_y,
                environment.goal_x,
                environment.goal_y
            ))
            score_total += environment.move(*action)
    return score_total

e = Environment()
test_score(e, nn)

In [None]:
#with open('params.pkl', 'rb') as fh:
#    param_values = pickle.load(fh)
#[a.set_value(b) for a, b in zip(params, param_values)]
#param_values = [p.get_value() for p in params]
#with open('params.pkl', 'wb') as f:
#    pickle.dump(param_values, f, protocol=2)

In [None]:
def plot_q(nn, batch_size=256):
    x = np.array([[0.1, 0.25, -0.1, 0.15] for _ in range(batch_size)], dtype=np.float32)
    control_range = np.linspace(-0.2, 0.2, 16)
    u = np.array([[a, b] for a in control_range for b in control_range], dtype=np.float32)
    q = nn.q_function(x, u)

    xs, ys = np.meshgrid(control_range, control_range)
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    zs = np.zeros((16, 16))
    for i in range(16):
        for j in range(16):
            zs[i, j] = q[i * 16 + j]
    ax.plot_surface(xs, ys, zs, cmap='viridis', shade=True)
    plt.xlabel('$u_1$')
    plt.ylabel('$u_2$')
    ax.set_zlabel('$Q(\mathbf{s}, \mathbf{u})$')
    plt.axis('equal')
    #plt.savefig('naf_q_over_u_surface_random.pdf')
    plt.show()
    
plot_q(nn)

In [None]:
def plot_v(nn, batch_size=256, x=0.0, y=0.2, filename=None):
    xs = np.linspace(-0.20, 0.20, 16)
    ys = np.linspace(0.10, 0.30, 16)
    x = np.array([[a, b, x, y] for a in xs for b in ys], dtype=np.float32)
    v = nn.value_function(x)

    xs, ys = np.meshgrid(xs, ys)
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    zs = np.zeros((16, 16))
    for i in range(16):
        for j in range(16):
            zs[i, j] = v[i * 16 + j]
    ax.plot_surface(xs, ys, zs, cmap='viridis', shade=True)
    plt.xlabel('$x$')
    plt.ylabel('$y$')
    ax.set_zlabel('$V(\mathbf{s})$')
    plt.axis('equal')
    if filename:
        plt.savefig(filename)
    plt.show()
    
plot_v(nn)

In [None]:
def plot_pi(nn, batch_size=256, goal_x=0.0, goal_y=0.2, filename=None):
    xs = np.linspace(-0.2, 0.2, 16)
    ys = np.linspace(0.10, 0.30, 16)
    for x in xs:
        for y in ys:
            state = create_state_vector(x, y, goal_x, goal_y)
            mu = nn.mu(state)
            dx, dy = mu
            #d = np.linalg.norm([dx, dy])
            dx /= 6
            dy /= 6
            plt.arrow(x, y, dx, dy, color='k')
    plt.plot(goal_x, goal_y, 'ro')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('$\mathbf{\mu(s)}$')
    plt.xlim((-0.2, 0.2))
    plt.ylim((0.10, 0.30))
    if filename:
        plt.savefig(filename)
    plt.show()
    
plot_pi(nn)

In [None]:
replay_buffer = PriorityBuffer(2 ** 19)
best_params = None
best_score = -np.inf
scores = []
epsilon = 0.1

In [None]:
batch_queue = Queue(maxsize=64)
batch_queue_stop = False
exists_batch_workers = False

def batch_queue_filler():
    while not batch_queue_stop:
        X = np.zeros((batch_size, 4), dtype=np.float32)
        Xp = np.zeros((batch_size, 4), dtype=np.float32)
        U = np.zeros((batch_size, 2), dtype=np.float32)
        R = np.zeros((batch_size, 1), dtype=np.float32)
        exp_nodes = []
        for j in range(batch_size):
            exp_nodes.append(replay_buffer.sample())
            state, state_prim, c_x, c_y, reward = exp_nodes[-1].data
            X[j, :] = state
            Xp[j, :] = state_prim
            U[j, :] = [c_x, c_y]
            R[j, :] = reward
        batch_queue.put((exp_nodes, X, Xp, U, R))
    print('batch worker stopping')

In [None]:
def reset_env(e):
    e.reset()
    e.goal_x = 0.0
    e.goal_y = 0.20

n_iterations = 1048

n = 1
for n in range(n, n_iterations + 1):
    # gather a batch from current policy
    reset_env(e)
    for i in range(batch_size):
        if np.sqrt((e.goal_x - e.eef_x) ** 2 + (e.goal_y - e.eef_y) ** 2) < 0.01:
            reset_env(e)
        noise_x = np.random.randn() / 20.0
        noise_y = np.random.randn() / 20.0
        state = create_state_vector(e.eef_x, e.eef_y, e.goal_x, e.goal_y)
        control = nn.mu(state)
        c_x, c_y = control
        c_x += noise_x
        c_y += noise_y
        reward = e.move(c_x, c_y)
        state_prim = create_state_vector(e.eef_x, e.eef_y, e.goal_x, e.goal_y)
        exp_node = replay_buffer.add((state, state_prim, c_x, c_y, reward))
        exp_node.set_value(100.0)
        if reward == -1:
            reset_env(e)
            
    if not exists_batch_workers:
        [threading.Thread(target=batch_queue_filler).start() for _ in range(8)]
        print('starting batch workers')
        exists_batch_workers = True
            
    # train
    alpha = 2.0
    for i in range(512):
        exp_nodes, X, Xp, U, R = batch_queue.get()
        err, td_errors = nn.train_squared(X, Xp, U, R)
        for exp_node, td_error in zip(exp_nodes, td_errors):
            exp_node.set_value((abs(td_error) + epsilon) ** alpha)
    score = test_score(e, nn)
    if score > best_score:
        best_params = [p.get_value() for p in nn.params]
    scores.append(score)
    print(replay_buffer)
    print('approximate queue size:', batch_queue.qsize())
    plt.plot(scores)
    plt.show()
    plot_pi(nn, goal_x=e.goal_x, goal_y=e.goal_y, filename='pi_plots/pi{:05d}.pdf'.format(n))
    plot_v(nn, x=e.goal_x, y=e.goal_y, filename='v_plots/v{:05d}.pdf'.format(n))

In [None]:
eef_x = -0.1
eef_y = 0.25
goal_x = 0.0
goal_y = 0.2
dx, dy = nn.mu(create_state_vector(eef_x, eef_y, goal_x, goal_y))
dx, dy