In [2]:
%matplotlib inline
import numpy as np
import scipy.linalg
import matplotlib.pylab as plt
import seaborn as sns
import time
sns.set_style('ticks')

In [4]:
class InstabilityException(Exception):
    def __init__(self, msg=None):
        super().__init__(msg)
        
class RankDegeneracyException(Exception):
    def __init__(self, msg=None):
        super().__init__(msg)

def _assert_matrix(M):
    assert len(M.shape) == 2
    
def _assert_square_matrix(M):
    _assert_matrix(M)
    assert M.shape[0] == M.shape[1]

def _assert_symmetric_matrix(M):
    _assert_square_matrix(M)
    assert np.allclose(M, M.T)
    
def svec(M):
    assert len(M.shape) == 2
    assert M.shape[0] == M.shape[1]
    diag = np.diag(M)
    off_diag = np.sqrt(2)*M[np.triu_indices(M.shape[0], k=1)]
    return np.hstack((diag, off_diag))

def smat(v):
    n = int((np.sqrt(1+8*v.shape[0])-1)/2)
    assert n*(n+1) == 2*v.shape[0]
    V = np.zeros((n, n))
    off_diag = v[n:]/np.sqrt(2.0)
    V[np.triu_indices(V.shape[0], k=1)] = off_diag
    V += V.T
    V[np.diag_indices(V.shape[0])] = v[:n]
    return V

def sym(M):
    return (M+M.T)/2

def spectral_radius(L):
    evs = np.linalg.eigvals(L)
    rho = max(np.abs(evs))
    return rho

def is_stable(L):
    return spectral_radius(L) < 1

def lambda_min(H):
    evs = np.linalg.eigvalsh(H)
    return min(evs)

def lambda_max(H):
    evs = np.linalg.eigvalsh(H)
    return max(evs)

class LQR_instance(object):
    
    def __init__(self, A, B, Q, R, gamma, sigma_w):
        _assert_square_matrix(A)
        self.A_ = np.array(A)
        self.state_dim_ = A.shape[0]
        _assert_matrix(B)
        assert B.shape[0] == self.state_dim_
        self.B_ = np.array(B)
        self.input_dim_ = B.shape[1]
        _assert_symmetric_matrix(Q) #assume PD
        assert Q.shape[0] == self.state_dim_
        self.Q_ = np.array(Q)
        _assert_symmetric_matrix(R) #assume PD
        assert R.shape[0] == self.input_dim_
        self.R_ = np.array(R)
        assert gamma > 0 and gamma < 1
        self.gamma_ = gamma
        self.eta_ = gamma/(1-gamma)
        assert sigma_w >= 0
        self.sigma_w_ = sigma_w
        
    def rollout_off_policy(self, T, sigma_u, x0=None, rng=None):
        """Explore with random Gaussian noise of variance sigma_u^2,
        starting with x0 given 
        
        """
        rng = np.random if rng is None else rng
        trajectory = []
        x0 = np.zeros((self.state_dim_,)) if x0 is None else x0
        xk = np.array(x0)
        for k in range(T):
            # draw exploration input
            uk = sigma_u*rng.normal(size=(self.input_dim_,))
            
            # receive reward
            rk = xk.dot(self.Q_).dot(xk) + uk.dot(self.R_).dot(uk)
            
            # take a step
            wk = self.sigma_w_*rng.normal(size=(self.state_dim_,))
            xkp1 = self.A_.dot(xk) + self.B_.dot(uk) + wk
            
            # emit (xk, uk, rk, xkp1)
            trajectory.append((xk, uk, rk, xkp1))
            
            xk = np.copy(xkp1)
        return trajectory

    def rollout(self, K, T, rng=None):
        rng = np.random if rng is None else rng
        trajectory = []
        L = self.A_ + self.B_.dot(K)
        x0 = np.zeros((self.state_dim_,))
        xk = np.array(x0)
        M = self.Q_ + K.T.dot(self.R_.dot(K))
        for k in range(T):
            # receive reward
            rk = xk.dot(M.dot(xk))

            # take a step
            wk = self.sigma_w_*rng.normal(size=(L.shape[0],))
            xkp1 = L.dot(xk) + wk

            # emit (xk, rk, xkp1)
            trajectory.append((xk, rk, xkp1))

            xk = np.copy(xkp1)
        return trajectory
        
    def phi(self, x):
        assert x.shape == (self.state_dim_,)
        return svec(np.outer(x, x) + (self.sigma_w_**2)*self.eta_*np.eye(self.state_dim_))
        
    def solve_stationary(self, K):
        """Compute the covariance of the stationary distribution of
        x_{k+1} = (A+BK) x_{k} + w_k, w_k ~ N(0, sigma^2 I)
        
        """
        L = self.A_ + self.B_.dot(K)
        P_infty = scipy.linalg.solve_discrete_lyapunov(L, self.sigma_w_*self.sigma_w_*np.eye(self.state_dim_))
        
        assert np.allclose(
            L.dot(P_infty).dot(L.T) - P_infty + self.sigma_w_*self.sigma_w_*np.eye(self.state_dim_),
            np.zeros_like(P_infty)
        )
        
        _assert_symmetric_matrix(P_infty)
        assert lambda_min(P_infty) >= -1e-10

        return P_infty
    
    def lstd_Q(self, trajectory, K):
        """Given a trajectory, use LSTD to compute 
        Q^pi(x, u)
        
        """
        
        eta = self.eta_
        IK = np.vstack((np.eye(self.state_dim_), K))
        feature_offset = (self.sigma_w_**2)*eta*IK.dot(IK.T)
        
        def phi_Q(x, u):
            xu = np.hstack((x, u))
            return svec(np.outer(xu, xu) + feature_offset)
        
        Phi = np.array([phi_Q(xk, uk) for xk, uk, _, _ in trajectory])
        Phi_plus = np.array([phi_Q(xkp1, K.dot(xkp1)) for _, _, _, xkp1 in trajectory])
        rvec = np.array([rk for _, _, rk, _ in trajectory])
        PhiTPhi = Phi.T.dot(Phi)
        Amat = PhiTPhi - self.gamma_*Phi.T.dot(Phi_plus)
    
        # Check if Phi is full column rank
        svals_Phi = np.linalg.eigvalsh(PhiTPhi)
        if min(svals_Phi) <= 1e-8:
            raise RankDegeneracyException(
                "Phi matrix is degenerate: s_min(Phi)={}".format(min(svals_Phi)))
        
        # check if Amat is full rank
        svals = scipy.linalg.svdvals(Amat)
        if min(svals) <= 1e-8:
            raise RankDegeneracyException(
                "Amat is degenerate: s_min(Amat)={}".format(min(svals)))
        
        Bmat = Phi.T.dot(rvec)
        phat = np.linalg.lstsq(Amat, Bmat)[0]
        return Phi, Phi_plus, rvec, phat
    
    def lstd(self, trajectory):
        Phi = np.array([self.phi(xk) for xk, _, _ in trajectory])
        Phi_plus = np.array([self.phi(xkp1) for _, _, xkp1 in trajectory])
        rvec = np.array([rk for _, rk, _ in trajectory])
        Amat = Phi.T.dot(Phi) - self.gamma_*Phi.T.dot(Phi_plus)
        Bmat = Phi.T.dot(rvec)
        phat = np.linalg.lstsq(Amat, Bmat)[0]
        return Phi, Phi_plus, rvec, phat
    
    def estimate_nominal(self, trajectory):
        """estimate Ahat, Bhat from the given trajectory
        
        """
        
        # number of data points by (n+p)
        BigX = np.array([np.hstack((xk, uk)) for xk, uk, _, _ in trajectory])

        # number of data points by n
        BigY = np.array([xkp1 for _, _, _, xkp1 in trajectory])
        
        soln, _, _, _ = np.linalg.lstsq(BigX, BigY) # soln = [A, B]^T
        
        Ahat_T = soln[:self.state_dim_,:]
        Bhat_T = soln[self.state_dim_:,:]
        
        Ahat, Bhat = Ahat_T.T, Bhat_T.T

        return Ahat, Bhat
    
    def solve_nominal(self, Ahat, Bhat, discount):
        return self._solve(
            np.array(Ahat),
            np.array(Bhat),
            np.array(self.Q_),
            np.array(self.R_),
            self.gamma_ if discount else 1.0)
    
    def _solve(self, A, B, Q, R, gamma):
        assert gamma >= 0 and gamma <= 1
        
        P_star = scipy.linalg.solve_discrete_are(np.sqrt(gamma)*A, B, Q, (1.0/gamma)*R)

        assert np.allclose(
            P_star,
            gamma*A.T.dot(P_star).dot(A) -
            gamma*gamma*A.T.dot(P_star).dot(B)
                .dot(scipy.linalg.solve(R + gamma*B.T.dot(P_star).dot(B), np.eye(R.shape[0])))
                .dot(B.T.dot(P_star).dot(A)) + Q
        )

        _assert_symmetric_matrix(P_star)
        assert lambda_min(P_star) >= -1e-10
        
        neg_K_star = gamma*scipy.linalg.solve(gamma*B.T.dot(P_star).dot(B) + R, np.eye(R.shape[0])).dot(B.T.dot(P_star).dot(A))
        K_star = -neg_K_star
        
        # TODO: should check controllability and observability assumptions
        L_star = A + B.dot(K_star)
        assert is_stable(np.sqrt(gamma)*L_star)
        
        return P_star, K_star
        
    def solve(self, discount):
        """Solve the problem
        
        Returns the parameterization P of an optimal value function,
        and returns an optimal policy pi(x) = K x
        
        """
        return self._solve(
            np.array(self.A_),
            np.array(self.B_),
            np.array(self.Q_),
            np.array(self.R_),
            self.gamma_ if discount else 1.0)

    def score(self, P, discount):
        """P is a parameterization of a value function
        
        """
        _assert_symmetric_matrix(P)
        if discount:
            return self.sigma_w_*self.sigma_w_/(1-self.gamma_)*np.trace(P)
        else:
            return self.sigma_w_*self.sigma_w_*np.trace(P)
        
    def solve_value_function(self, K, discount):
        """Compute the value function associated to the policy pi(x) = Kx
        
        It is given by the
        
        """
        _assert_matrix(K)
        assert K.shape == (self.input_dim_, self.state_dim_)
        
        effective_gamma = self.gamma_ if discount else 1.0
        
        L = self.A_ + self.B_.dot(K)
        # check that sqrt(gamma)*L is stable
        if not is_stable(np.sqrt(effective_gamma)*L):
            raise InstabilityException("closed loop sqrt(eff_gamma)*(A+BK) is not stable")

        # solve (sqrt(gamma) L)^T P (sqrt(gamma) L) - P + (Q+K^T R K) = 0
        P_pi = scipy.linalg.solve_discrete_lyapunov(np.sqrt(effective_gamma)*L.T, self.Q_+K.T.dot(self.R_.dot(K)))
        
        # check that this version of scipy is consistent w/ dlyap
        assert np.allclose(
            (np.sqrt(effective_gamma)*L.T).dot(P_pi).dot(np.sqrt(effective_gamma)*L) - P_pi 
            + self.Q_ + K.T.dot(self.R_).dot(K),
            np.zeros_like(P_pi)
        )

        # check that P_pi is symmetric positive definite
        _assert_symmetric_matrix(P_pi)
        assert lambda_min(P_pi) >= -1e-10

        return P_pi
    
    def greedy_update(self, P):
        """Given a parameterization of a Q function
        Q^pi(x,u) = - (x,u)^T P (x,u) - q,
        
        compute a new feedback policy by greedy maximization
        
        """
        
        _assert_symmetric_matrix(P)
        assert P.shape[0] == (self.state_dim_ + self.input_dim_)
        
        P12 = P[:self.state_dim_, self.state_dim_:]
        P22 = P[self.state_dim_:, self.state_dim_:]
        
        return -scipy.linalg.solve(P22, P12.T, sym_pos=True)
    
    def solve_q_function(self, K):
        """Compute the Q function associated to the policy pi(x) = Kx
        
        """

        _assert_matrix(K)
        assert K.shape == (self.input_dim_, self.state_dim_)
        
        L = self.A_ + self.B_.dot(K)
        # check that L is stable
        if not is_stable(np.sqrt(self.gamma_)*L):
            raise InstabilityException("sqrt(gamma)*(A+BK) is not stable")
            
        AB = np.hstack((self.A_, self.B_))
        IK = np.vstack((np.eye(self.state_dim_), K))
        
        QR = np.vstack((
            np.hstack((self.Q_, np.zeros((self.state_dim_, self.input_dim_)))),
            np.hstack((np.zeros((self.input_dim_, self.state_dim_)), self.R_))))

        P = scipy.linalg.solve_discrete_lyapunov(np.sqrt(self.gamma_)*(IK.dot(AB)).T, QR)
        _assert_symmetric_matrix(P)
        assert lambda_min(P) >= -1e-10
        return P
    
def lqr_tests():
    A = np.array([[1, 0], [1, 1]])
    B = np.array([[0, 0], [1, 0]])
    Q = 1e-3 * np.eye(2)
    R = np.eye(2)
    gamma = 0.9
    
    K = np.array([
        [0, -0.5],
        [0.5, 1]])
    
    assert is_stable(np.sqrt(gamma)*(A+B.dot(K)))
    
    n, p = 2, 2
    n_params_q = int((n+p)*(n+p+1)/2) 
    
    prob = LQR_instance(A, B, Q, R, gamma, sigma_w=0)
    P0 = prob.solve_q_function(K)
    x0 = np.random.normal(size=(n,))
    traj = prob.rollout_off_policy(T=n_params_q, sigma_u=1, x0=x0)
    _, _, _, phat = prob.lstd_Q(traj, K)
    Phat0 = smat(phat)
    
    assert np.allclose(P0, Phat0)
    
    _, Kstar = prob.solve(discount=True)
    Jstar = np.trace(prob.solve_value_function(Kstar, discount=True))
    
    # check policy iteration
    Jcur = np.trace(prob.solve_value_function(K, discount=True))
    Kcur = K
    Pcur = P0
    while 1:
        # update
        Knext = prob.greedy_update(Pcur)
        
        # score next
        Pnext = prob.solve_q_function(Knext)
        Jnext = np.trace(prob.solve_value_function(Knext, discount=True))
        
        assert Jnext <= Jcur
        
        if (Jcur - Jnext)/Jcur <= 1e-5:
            break
            
        Jcur, Kcur, Pcur = Jnext, Knext, Pnext
        
    # check final score
    assert (Jcur - Jstar)/Jstar <= 1e-5

In [5]:
lqr_tests()


In [6]:
A = np.array([
    [1.01, 0.01, 0],
    [0.01, 1.01, 0.01],
    [0, 0.01, 1.01]
])
B = np.eye(3)
Q = 1e-3 * np.eye(3)
R = np.eye(3)

gamma = 0.98
prob = LQR_instance(A, B, Q, R, gamma, 1.0)
P_star, K_star = prob.solve(discount=True)
P_avg_star, K_avg_star = prob.solve(discount=False)
J_star, J_star_avg = prob.score(P_star, discount=True), prob.score(P_avg_star, discount=False)
print("J_star (gamma):", J_star, "J_star (avg):", J_star_avg)
print("rho(A+BK)", spectral_radius(A+B.dot(K_star)))

# check how good optimal discount soln is for avg cost
P_hat = prob.solve_value_function(K_star, discount=False)
J_hat = prob.score(P_hat, discount=False)
assert J_hat >= J_star_avg
print("(J_hat - J_star_avg)/J_star_avg", (J_hat - J_star_avg)/J_star_avg)

# arbitrary starting point
K0 = -np.array([
    [0.41, 0.01, 0],
    [0.01, 0.41, 0.01],
    [0, 0.01, 0.41]
])
print("rho(A+BK0)", spectral_radius(A+B.dot(K0)))

# naive score
P_K0 = prob.solve_value_function(K0, discount=True)
J_K0 = prob.score(P_K0, discount=True)
print("J_{K0}",J_K0, "(J_{K0} - J_star)/J_star", (J_K0 - J_star)/J_star)
print(A+B.dot(K0))

P_K0 = prob.solve_value_function(K0, discount=False)
J_K0 = prob.score(P_K0, discount=False)
print("J_{K0}",J_K0, "(J_{K0} - J_star_avg)/J_star_avg", (J_K0 - J_star_avg)/J_star_avg)

J_star (gamma): 5.15859259193 J_star (avg): 0.137287165978
rho(A+BK) 0.979018198485
(J_hat - J_star_avg)/J_star_avg 0.0694714351996
rho(A+BK0) 0.6
J_{K0} 39.2228059333 (J_{K0} - J_star)/J_star 6.60339283133
[[ 0.6  0.   0. ]
 [ 0.   0.6  0. ]
 [ 0.   0.   0.6]]
J_{K0} 0.79328125 (J_{K0} - J_star_avg)/J_star_avg 4.77826226034


In [7]:
def lspi(prob, traj, K_init):
    def check_weighted_stable(K):
        if spectral_radius(np.sqrt(prob.gamma_)*(prob.A_ + prob.B_.dot(K))) >= 1:
            raise InstabilityException()
    
    K_cur = np.array(K_init)
    check_weighted_stable(K_cur)
    cur_cost = prob.score(prob.solve_value_function(K_cur, discount=True), discount=True)
    while 1:
        #print("rho(gamma*CLP)", spectral_radius(np.sqrt(gamma)*(A+B.dot(K_cur))))
        _, _, _, qhat = prob.lstd_Q(traj, K_cur)
        K_propose = prob.greedy_update(smat(qhat))
        check_weighted_stable(K_propose)
        
        # compute new
        new_cost = prob.score(prob.solve_value_function(K_propose, discount=True), discount=True)
        #print("cur_cost", cur_cost, "new_cost", new_cost)

        if np.abs(cur_cost-new_cost)/np.abs(cur_cost) <= 1e-5:
            #print("exiting")
            break

        K_cur = K_propose
        cur_cost = new_cost
        
    return K_cur, cur_cost

In [8]:
def collect_samples(prob, NT, T):
    assert T <= NT
    samples = []
    while len(samples) < NT:
        remaining = NT - len(samples)
        samples.extend(prob.rollout_off_policy(min(T, remaining), 1.0))
    return samples

In [9]:
T = 20
#N_values = [20, 40, 60, 80, 100]
#NT_values = [N*T for N in N_values]
NT_values = [250, 500, 750, 1000, 2000, 3000, 4000, 5000]
n_trials = 100
all_trajectories = []
for _ in range(n_trials):
    all_trajectories.append(collect_samples(prob, max(NT_values), T))

In [10]:
# sanity
Q_pizero = prob.solve_q_function(K0)
error_lspi = np.zeros((n_trials, len(NT_values)))
for idx, traj in enumerate(all_trajectories):
    for jdx, NT_value in enumerate(NT_values):
        
        iter_start = time.time()
        
        sample = traj[:NT_value]
        _, _, _, qhat = prob.lstd_Q(sample, K0)
        Qhat = smat(qhat)
        error_lspi[idx, jdx] = np.linalg.norm(Q_pizero - Qhat, ord='fro')/np.linalg.norm(Q_pizero, ord='fro')
        print("done with pair (idx, jdx)", (idx, jdx), "in {} seconds".format(time.time() - iter_start))

done with pair (idx, jdx) (0, 0) in 0.0306549072265625 seconds
done with pair (idx, jdx) (0, 1) in 0.06129789352416992 seconds
done with pair (idx, jdx) (0, 2) in 0.06448221206665039 seconds
done with pair (idx, jdx) (0, 3) in 0.06575703620910645 seconds
done with pair (idx, jdx) (0, 4) in 0.15710997581481934 seconds
done with pair (idx, jdx) (0, 5) in 0.19704484939575195 seconds
done with pair (idx, jdx) (0, 6) in 0.2632269859313965 seconds
done with pair (idx, jdx) (0, 7) in 0.32811594009399414 seconds
done with pair (idx, jdx) (1, 0) in 0.01672816276550293 seconds
done with pair (idx, jdx) (1, 1) in 0.03158402442932129 seconds
done with pair (idx, jdx) (1, 2) in 0.04832601547241211 seconds
done with pair (idx, jdx) (1, 3) in 0.06038808822631836 seconds
done with pair (idx, jdx) (1, 4) in 0.1443328857421875 seconds
done with pair (idx, jdx) (1, 5) in 0.18512701988220215 seconds
done with pair (idx, jdx) (1, 6) in 0.2656741142272949 seconds
done with pair (idx, jdx) (1, 7) in 0.354438

done with pair (idx, jdx) (16, 4) in 0.15043997764587402 seconds
done with pair (idx, jdx) (16, 5) in 0.19939184188842773 seconds
done with pair (idx, jdx) (16, 6) in 0.26087379455566406 seconds
done with pair (idx, jdx) (16, 7) in 0.31549692153930664 seconds
done with pair (idx, jdx) (17, 0) in 0.016561269760131836 seconds
done with pair (idx, jdx) (17, 1) in 0.0336308479309082 seconds
done with pair (idx, jdx) (17, 2) in 0.05041384696960449 seconds
done with pair (idx, jdx) (17, 3) in 0.06103396415710449 seconds
done with pair (idx, jdx) (17, 4) in 0.15442991256713867 seconds
done with pair (idx, jdx) (17, 5) in 0.19541287422180176 seconds
done with pair (idx, jdx) (17, 6) in 0.28143906593322754 seconds
done with pair (idx, jdx) (17, 7) in 0.33109378814697266 seconds
done with pair (idx, jdx) (18, 0) in 0.021274805068969727 seconds
done with pair (idx, jdx) (18, 1) in 0.03287005424499512 seconds
done with pair (idx, jdx) (18, 2) in 0.06126284599304199 seconds
done with pair (idx, jdx

done with pair (idx, jdx) (32, 4) in 0.16286587715148926 seconds
done with pair (idx, jdx) (32, 5) in 0.21483206748962402 seconds
done with pair (idx, jdx) (32, 6) in 0.28977513313293457 seconds
done with pair (idx, jdx) (32, 7) in 0.358644962310791 seconds
done with pair (idx, jdx) (33, 0) in 0.02533888816833496 seconds
done with pair (idx, jdx) (33, 1) in 0.047119855880737305 seconds
done with pair (idx, jdx) (33, 2) in 0.062283992767333984 seconds
done with pair (idx, jdx) (33, 3) in 0.09398126602172852 seconds
done with pair (idx, jdx) (33, 4) in 0.1580359935760498 seconds
done with pair (idx, jdx) (33, 5) in 0.21433281898498535 seconds
done with pair (idx, jdx) (33, 6) in 0.28837108612060547 seconds
done with pair (idx, jdx) (33, 7) in 0.3384699821472168 seconds
done with pair (idx, jdx) (34, 0) in 0.026566743850708008 seconds
done with pair (idx, jdx) (34, 1) in 0.044792890548706055 seconds
done with pair (idx, jdx) (34, 2) in 0.05509305000305176 seconds
done with pair (idx, jdx)

done with pair (idx, jdx) (48, 3) in 0.08830904960632324 seconds
done with pair (idx, jdx) (48, 4) in 0.15325212478637695 seconds
done with pair (idx, jdx) (48, 5) in 0.21746277809143066 seconds
done with pair (idx, jdx) (48, 6) in 0.29014015197753906 seconds
done with pair (idx, jdx) (48, 7) in 0.3339877128601074 seconds
done with pair (idx, jdx) (49, 0) in 0.023767948150634766 seconds
done with pair (idx, jdx) (49, 1) in 0.0415189266204834 seconds
done with pair (idx, jdx) (49, 2) in 0.056771039962768555 seconds
done with pair (idx, jdx) (49, 3) in 0.1108250617980957 seconds
done with pair (idx, jdx) (49, 4) in 0.15156888961791992 seconds
done with pair (idx, jdx) (49, 5) in 0.221113920211792 seconds
done with pair (idx, jdx) (49, 6) in 0.31053733825683594 seconds
done with pair (idx, jdx) (49, 7) in 0.36801886558532715 seconds
done with pair (idx, jdx) (50, 0) in 0.018580913543701172 seconds
done with pair (idx, jdx) (50, 1) in 0.036277055740356445 seconds
done with pair (idx, jdx) 

done with pair (idx, jdx) (64, 4) in 0.15848994255065918 seconds
done with pair (idx, jdx) (64, 5) in 0.23051810264587402 seconds
done with pair (idx, jdx) (64, 6) in 0.303774356842041 seconds
done with pair (idx, jdx) (64, 7) in 0.364063024520874 seconds
done with pair (idx, jdx) (65, 0) in 0.017594099044799805 seconds
done with pair (idx, jdx) (65, 1) in 0.03495001792907715 seconds
done with pair (idx, jdx) (65, 2) in 0.05652284622192383 seconds
done with pair (idx, jdx) (65, 3) in 0.06540107727050781 seconds
done with pair (idx, jdx) (65, 4) in 0.15231609344482422 seconds
done with pair (idx, jdx) (65, 5) in 0.23190593719482422 seconds
done with pair (idx, jdx) (65, 6) in 0.2913179397583008 seconds
done with pair (idx, jdx) (65, 7) in 0.3464350700378418 seconds
done with pair (idx, jdx) (66, 0) in 0.020073890686035156 seconds
done with pair (idx, jdx) (66, 1) in 0.03882098197937012 seconds
done with pair (idx, jdx) (66, 2) in 0.05349993705749512 seconds
done with pair (idx, jdx) (66

done with pair (idx, jdx) (80, 4) in 0.15532922744750977 seconds
done with pair (idx, jdx) (80, 5) in 0.18437719345092773 seconds
done with pair (idx, jdx) (80, 6) in 0.2666590213775635 seconds
done with pair (idx, jdx) (80, 7) in 0.34415483474731445 seconds
done with pair (idx, jdx) (81, 0) in 0.02144598960876465 seconds
done with pair (idx, jdx) (81, 1) in 0.03860878944396973 seconds
done with pair (idx, jdx) (81, 2) in 0.1257641315460205 seconds
done with pair (idx, jdx) (81, 3) in 0.2958638668060303 seconds
done with pair (idx, jdx) (81, 4) in 0.281400203704834 seconds
done with pair (idx, jdx) (81, 5) in 0.21455788612365723 seconds
done with pair (idx, jdx) (81, 6) in 0.2797279357910156 seconds
done with pair (idx, jdx) (81, 7) in 0.32172298431396484 seconds
done with pair (idx, jdx) (82, 0) in 0.019855976104736328 seconds
done with pair (idx, jdx) (82, 1) in 0.04334688186645508 seconds
done with pair (idx, jdx) (82, 2) in 0.06693482398986816 seconds
done with pair (idx, jdx) (82,

done with pair (idx, jdx) (96, 4) in 0.14833903312683105 seconds
done with pair (idx, jdx) (96, 5) in 0.376507043838501 seconds
done with pair (idx, jdx) (96, 6) in 0.32263970375061035 seconds
done with pair (idx, jdx) (96, 7) in 0.3507039546966553 seconds
done with pair (idx, jdx) (97, 0) in 0.018340110778808594 seconds
done with pair (idx, jdx) (97, 1) in 0.046536922454833984 seconds
done with pair (idx, jdx) (97, 2) in 0.0594792366027832 seconds
done with pair (idx, jdx) (97, 3) in 0.07272696495056152 seconds
done with pair (idx, jdx) (97, 4) in 0.1474301815032959 seconds
done with pair (idx, jdx) (97, 5) in 0.19936203956604004 seconds
done with pair (idx, jdx) (97, 6) in 0.3195960521697998 seconds
done with pair (idx, jdx) (97, 7) in 0.3633451461791992 seconds
done with pair (idx, jdx) (98, 0) in 0.019035816192626953 seconds
done with pair (idx, jdx) (98, 1) in 0.03539013862609863 seconds
done with pair (idx, jdx) (98, 2) in 0.07105588912963867 seconds
done with pair (idx, jdx) (98

In [11]:
np.median(error_lspi, axis=0)

array([ 0.34243106,  0.22264834,  0.17188926,  0.14500774,  0.10292338,
        0.08007318,  0.07379356,  0.06435294])

In [12]:
data_lspi_dis = np.zeros((n_trials, len(NT_values)))
data_lspi_avg = np.zeros((n_trials, len(NT_values)))

loop_start = time.time()
for idx, traj in enumerate(all_trajectories):
    for jdx, NT_value in enumerate(NT_values):
        
        iter_start = time.time()
        sample = traj[:NT_value]

        try:
            K_lspi, cost_lspi_dis = lspi(prob, sample, K0)
            # LSPI succeeded, so we can at least score the discounted
            data_lspi_dis[idx, jdx] = cost_lspi_dis
            
            # sanity
            assert np.allclose(
                prob.score(prob.solve_value_function(K_lspi, discount=True), discount=True),
                cost_lspi_dis)
            
            # now we try to score K_lspi on the avg cost
            try:
                Phat_lspi_avg = prob.solve_value_function(K_lspi, discount=False)
                data_lspi_avg[idx, jdx] = prob.score(Phat_lspi_avg, discount=False)
            except InstabilityException:
                # the discounted controller did not work for the true system
                data_lspi_avg[idx, jdx] = np.inf
            
        except InstabilityException:
            # if LSPI fails, both of them will score inf
            data_lspi_dis[idx, jdx] = np.inf
            data_lspi_avg[idx, jdx] = np.inf
            
        print("done with pair (idx, jdx)", (idx, jdx), "in {} seconds".format(time.time() - iter_start))
        
print("loop took", time.time() - loop_start, "seconds")

done with pair (idx, jdx) (0, 0) in 0.07462787628173828 seconds
done with pair (idx, jdx) (0, 1) in 0.1343832015991211 seconds
done with pair (idx, jdx) (0, 2) in 0.2755279541015625 seconds
done with pair (idx, jdx) (0, 3) in 0.2826120853424072 seconds
done with pair (idx, jdx) (0, 4) in 1.174226999282837 seconds
done with pair (idx, jdx) (0, 5) in 1.5503978729248047 seconds
done with pair (idx, jdx) (0, 6) in 2.006042957305908 seconds
done with pair (idx, jdx) (0, 7) in 2.921441078186035 seconds
done with pair (idx, jdx) (1, 0) in 0.11652874946594238 seconds
done with pair (idx, jdx) (1, 1) in 0.309312105178833 seconds
done with pair (idx, jdx) (1, 2) in 0.3037397861480713 seconds
done with pair (idx, jdx) (1, 3) in 0.39354825019836426 seconds
done with pair (idx, jdx) (1, 4) in 2.029719829559326 seconds
done with pair (idx, jdx) (1, 5) in 0.9931039810180664 seconds
done with pair (idx, jdx) (1, 6) in 1.348289966583252 seconds
done with pair (idx, jdx) (1, 7) in 1.7423748970031738 sec

done with pair (idx, jdx) (16, 1) in 0.1353001594543457 seconds
done with pair (idx, jdx) (16, 2) in 0.5179383754730225 seconds
done with pair (idx, jdx) (16, 3) in 0.35563015937805176 seconds
done with pair (idx, jdx) (16, 4) in 1.0788240432739258 seconds
done with pair (idx, jdx) (16, 5) in 1.5793049335479736 seconds
done with pair (idx, jdx) (16, 6) in 1.7345898151397705 seconds
done with pair (idx, jdx) (16, 7) in 2.2125539779663086 seconds
done with pair (idx, jdx) (17, 0) in 0.16867685317993164 seconds
done with pair (idx, jdx) (17, 1) in 0.10425925254821777 seconds
done with pair (idx, jdx) (17, 2) in 0.2506539821624756 seconds
done with pair (idx, jdx) (17, 3) in 0.5781512260437012 seconds
done with pair (idx, jdx) (17, 4) in 0.9025001525878906 seconds
done with pair (idx, jdx) (17, 5) in 1.3259727954864502 seconds
done with pair (idx, jdx) (17, 6) in 1.8083269596099854 seconds
done with pair (idx, jdx) (17, 7) in 2.640531063079834 seconds
done with pair (idx, jdx) (18, 0) in 0

done with pair (idx, jdx) (32, 1) in 0.2980830669403076 seconds
done with pair (idx, jdx) (32, 2) in 0.34938597679138184 seconds
done with pair (idx, jdx) (32, 3) in 0.5382080078125 seconds
done with pair (idx, jdx) (32, 4) in 1.038193941116333 seconds
done with pair (idx, jdx) (32, 5) in 1.9205617904663086 seconds
done with pair (idx, jdx) (32, 6) in 2.1292548179626465 seconds
done with pair (idx, jdx) (32, 7) in 2.583268165588379 seconds
done with pair (idx, jdx) (33, 0) in 0.1442270278930664 seconds
done with pair (idx, jdx) (33, 1) in 0.17511200904846191 seconds
done with pair (idx, jdx) (33, 2) in 0.2126331329345703 seconds
done with pair (idx, jdx) (33, 3) in 0.3476848602294922 seconds
done with pair (idx, jdx) (33, 4) in 0.8510959148406982 seconds
done with pair (idx, jdx) (33, 5) in 1.2073681354522705 seconds
done with pair (idx, jdx) (33, 6) in 2.355823040008545 seconds
done with pair (idx, jdx) (33, 7) in 2.5176799297332764 seconds
done with pair (idx, jdx) (34, 0) in 0.08552

done with pair (idx, jdx) (48, 2) in 0.27971363067626953 seconds
done with pair (idx, jdx) (48, 3) in 0.3761889934539795 seconds
done with pair (idx, jdx) (48, 4) in 1.213641881942749 seconds
done with pair (idx, jdx) (48, 5) in 1.6348769664764404 seconds
done with pair (idx, jdx) (48, 6) in 1.8130712509155273 seconds
done with pair (idx, jdx) (48, 7) in 2.261863946914673 seconds
done with pair (idx, jdx) (49, 0) in 0.15156078338623047 seconds
done with pair (idx, jdx) (49, 1) in 0.17154288291931152 seconds
done with pair (idx, jdx) (49, 2) in 0.3892519474029541 seconds
done with pair (idx, jdx) (49, 3) in 0.5130088329315186 seconds
done with pair (idx, jdx) (49, 4) in 1.0803098678588867 seconds
done with pair (idx, jdx) (49, 5) in 1.6824710369110107 seconds
done with pair (idx, jdx) (49, 6) in 2.138415813446045 seconds
done with pair (idx, jdx) (49, 7) in 2.551625967025757 seconds
done with pair (idx, jdx) (50, 0) in 0.05859684944152832 seconds
done with pair (idx, jdx) (50, 1) in 0.1

done with pair (idx, jdx) (64, 2) in 0.21251201629638672 seconds
done with pair (idx, jdx) (64, 3) in 0.343533992767334 seconds
done with pair (idx, jdx) (64, 4) in 1.0256240367889404 seconds
done with pair (idx, jdx) (64, 5) in 1.6200149059295654 seconds
done with pair (idx, jdx) (64, 6) in 1.9096901416778564 seconds
done with pair (idx, jdx) (64, 7) in 2.286055088043213 seconds
done with pair (idx, jdx) (65, 0) in 0.09453821182250977 seconds
done with pair (idx, jdx) (65, 1) in 0.6340720653533936 seconds
done with pair (idx, jdx) (65, 2) in 0.2675449848175049 seconds
done with pair (idx, jdx) (65, 3) in 0.7628672122955322 seconds
done with pair (idx, jdx) (65, 4) in 1.0749807357788086 seconds
done with pair (idx, jdx) (65, 5) in 1.6015350818634033 seconds
done with pair (idx, jdx) (65, 6) in 2.0545520782470703 seconds
done with pair (idx, jdx) (65, 7) in 2.4872488975524902 seconds
done with pair (idx, jdx) (66, 0) in 0.13820195198059082 seconds
done with pair (idx, jdx) (66, 1) in 0.

done with pair (idx, jdx) (80, 2) in 0.47800397872924805 seconds
done with pair (idx, jdx) (80, 3) in 0.5702097415924072 seconds
done with pair (idx, jdx) (80, 4) in 0.9153618812561035 seconds
done with pair (idx, jdx) (80, 5) in 1.1853971481323242 seconds
done with pair (idx, jdx) (80, 6) in 2.0374832153320312 seconds
done with pair (idx, jdx) (80, 7) in 2.525002956390381 seconds
done with pair (idx, jdx) (81, 0) in 0.09316086769104004 seconds
done with pair (idx, jdx) (81, 1) in 0.3373749256134033 seconds
done with pair (idx, jdx) (81, 2) in 0.43072009086608887 seconds
done with pair (idx, jdx) (81, 3) in 0.546299934387207 seconds
done with pair (idx, jdx) (81, 4) in 0.9654641151428223 seconds
done with pair (idx, jdx) (81, 5) in 1.7057530879974365 seconds
done with pair (idx, jdx) (81, 6) in 2.0887861251831055 seconds
done with pair (idx, jdx) (81, 7) in 2.5578341484069824 seconds
done with pair (idx, jdx) (82, 0) in 0.0719139575958252 seconds
done with pair (idx, jdx) (82, 1) in 0.

done with pair (idx, jdx) (96, 2) in 0.4064030647277832 seconds
done with pair (idx, jdx) (96, 3) in 0.5749359130859375 seconds
done with pair (idx, jdx) (96, 4) in 1.0750458240509033 seconds
done with pair (idx, jdx) (96, 5) in 1.6311500072479248 seconds
done with pair (idx, jdx) (96, 6) in 2.172438859939575 seconds
done with pair (idx, jdx) (96, 7) in 2.5694808959960938 seconds
done with pair (idx, jdx) (97, 0) in 0.05854511260986328 seconds
done with pair (idx, jdx) (97, 1) in 0.2944800853729248 seconds
done with pair (idx, jdx) (97, 2) in 0.5824460983276367 seconds
done with pair (idx, jdx) (97, 3) in 0.5814371109008789 seconds
done with pair (idx, jdx) (97, 4) in 1.232879877090454 seconds
done with pair (idx, jdx) (97, 5) in 1.6198270320892334 seconds
done with pair (idx, jdx) (97, 6) in 1.807631015777588 seconds
done with pair (idx, jdx) (97, 7) in 2.307307004928589 seconds
done with pair (idx, jdx) (98, 0) in 0.08379793167114258 seconds
done with pair (idx, jdx) (98, 1) in 0.294