Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion quantecon/_lqcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def stationary_values(self, method='doubling'):
Q, R, A, B, N, C = self.Q, self.R, self.A, self.B, self.N, self.C

# === solve Riccati equation, obtain P === #
A0, B0 = np.sqrt(self.beta) * A, np.sqrt(self.beta) * B
sqrt_beta = np.sqrt(self.beta)
A0, B0 = sqrt_beta * A, sqrt_beta * B
P = solve_discrete_riccati(A0, B0, R, Q, N, method=method)

# == Compute F == #
Expand Down
65 changes: 45 additions & 20 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,24 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5)
BB = B.T @ B
BTA = B.T @ A

best_gamma = None
gamma_data_cache = {}

for gamma in candidates:
Z = R + gamma * BB
cn = np.linalg.cond(Z)
if cn * EPS < 1:
Q_tilde = - Q + (N.T @ solve(Z, N + gamma * BTA)) + gamma * I
G0 = B @ solve(Z, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(Z, N))
# Cache solutions to avoid repeated solves for same Z
solve_cache = {}
N_gBTA = N + gamma * BTA
sol_Z_NgBTA = solve(Z, N_gBTA)
sol_Z_BT = solve(Z, B.T)
sol_Z_N = solve(Z, N)

Q_tilde = - Q + (N.T @ sol_Z_NgBTA) + gamma * I
G0 = B @ sol_Z_BT
A0 = (I - gamma * G0) @ A - (B @ sol_Z_N)
H0 = gamma * (A.T @ A0) - Q_tilde
f1 = np.linalg.cond(Z, np.inf)
f2 = gamma * f1
Expand All @@ -189,38 +200,46 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
if f_gamma < current_min:
best_gamma = gamma
current_min = f_gamma
gamma_data_cache['R_hat'] = Z
gamma_data_cache['Q_tilde'] = Q_tilde
gamma_data_cache['G0'] = G0
gamma_data_cache['A0'] = A0
gamma_data_cache['H0'] = H0

# == If no candidate successful then fail == #
if current_min == np.inf:
msg = "Unable to initialize routine due to ill conditioned arguments"
raise ValueError(msg)

gamma = best_gamma
R_hat = R + gamma * BB

# == Initial conditions == #
Q_tilde = - Q + (N.T @ solve(R_hat, N + gamma * BTA)) + gamma * I
G0 = B @ solve(R_hat, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(R_hat, N))
H0 = gamma * (A.T @ A0) - Q_tilde
R_hat = gamma_data_cache['R_hat']
Q_tilde = gamma_data_cache['Q_tilde']
G0 = gamma_data_cache['G0']
A0 = gamma_data_cache['A0']
H0 = gamma_data_cache['H0']
i = 1

# == Main loop == #
while error > tolerance:

if i > max_iter:
raise ValueError(fail_msg.format(i))
# Precompute solves for in-loop performance
I_GH = I + (G0 @ H0)
solve_IGH_A0 = solve(I_GH, A0)
A1 = A0 @ solve_IGH_A0

else:
A1 = A0 @ solve(I + (G0 @ H0), A0)
G1 = G0 + ((A0 @ G0) @ solve(I + (H0 @ G0), A0.T))
H1 = H0 + (A0.T @ solve(I + (H0 @ G0), (H0 @ A0)))
I_HG = I + (H0 @ G0)
solve_IHG_A0T = solve(I_HG, A0.T)
G1 = G0 + ((A0 @ G0) @ solve_IHG_A0T)

solve_IHG_HA0 = solve(I_HG, (H0 @ A0))
H1 = H0 + (A0.T @ solve_IHG_HA0)

error = np.max(np.abs(H1 - H0))
A0 = A1
G0 = G1
H0 = H1
i += 1
error = np.max(np.abs(H1 - H0))
A0 = A1
G0 = G1
H0 = H1
i += 1

return H1 + gamma * I # Return X

Expand Down Expand Up @@ -311,3 +330,9 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
iteration += 1

return Ps

def _solve_cached(Z, cache, arr):
key = id(Z), arr.tobytes()
if key not in cache:
cache[key] = solve(Z, arr)
return cache[key]