Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
X = sp_solve_discrete_are(A, B, Q, R, e=I, s=N.T)
return X

# if method == 'doubling'

# == Set up == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."
Expand All @@ -174,21 +172,38 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5)
BB = B.T @ B
BTA = B.T @ A
for gamma in candidates:
Z = R + gamma * BB
cn = np.linalg.cond(Z)
if cn * EPS < 1:
Q_tilde = - Q + (N.T @ solve(Z, N + gamma * BTA)) + gamma * I
G0 = B @ solve(Z, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(Z, N))

# Precompute identity for GH block for efficiency in loop
I_k = I
# Allocate outside for efficiency
Z_array = [R + gamma * BB for gamma in candidates]
cn_array = [np.linalg.cond(Z) for Z in Z_array]
# Avoid repeated solves when filters fail, so index candidates
valid_idx = [i for i, cn in enumerate(cn_array) if cn * EPS < 1]

for idx in valid_idx:
gamma = candidates[idx]
Z = Z_array[idx]
try:
solve_Z_BTA = solve(Z, N + gamma * BTA)
Q_tilde = -Q + (N.T @ solve_Z_BTA) + gamma * I_k
solve_Z_BT = solve(Z, B.T)
G0 = B @ solve_Z_BT
solve_Z_N = solve(Z, N)
A0 = (I_k - gamma * G0) @ A - (B @ solve_Z_N)
H0 = gamma * (A.T @ A0) - Q_tilde
f1 = np.linalg.cond(Z, np.inf)

f1 = cn_array[idx] if Z.shape[0] == Z.shape[1] else np.linalg.cond(Z, np.inf)
f2 = gamma * f1
f3 = np.linalg.cond(I + (G0 @ H0))
GH = I_k + (G0 @ H0)
f3 = np.linalg.cond(GH)

f_gamma = max(f1, f2, f3)
if f_gamma < current_min:
best_gamma = gamma
current_min = f_gamma
except np.linalg.LinAlgError:
continue # Skip ill-posed gamma values

# == If no candidate successful then fail == #
if current_min == np.inf:
Expand All @@ -199,30 +214,52 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
R_hat = R + gamma * BB

# == Initial conditions == #
Q_tilde = - Q + (N.T @ solve(R_hat, N + gamma * BTA)) + gamma * I
G0 = B @ solve(R_hat, B.T)
A0 = (I - gamma * G0) @ A - (B @ solve(R_hat, N))
solve_Rhat_N_gammaBTA = solve(R_hat, N + gamma * BTA)
Q_tilde = -Q + (N.T @ solve_Rhat_N_gammaBTA) + gamma * I_k
solve_Rhat_BT = solve(R_hat, B.T)
G0 = B @ solve_Rhat_BT
solve_Rhat_N = solve(R_hat, N)
A0 = (I_k - gamma * G0) @ A - (B @ solve_Rhat_N)
H0 = gamma * (A.T @ A0) - Q_tilde
i = 1

# Use memory-efficient block for main loop
# Preallocate for solve input/output arrays
# The dimensions of all matrices are constant, so reuse arrays in place as much as possible

# == Main loop == #
while error > tolerance:

if i > max_iter:
raise ValueError(fail_msg.format(i))

else:
A1 = A0 @ solve(I + (G0 @ H0), A0)
G1 = G0 + ((A0 @ G0) @ solve(I + (H0 @ G0), A0.T))
H1 = H0 + (A0.T @ solve(I + (H0 @ G0), (H0 @ A0)))

error = np.max(np.abs(H1 - H0))
A0 = A1
G0 = G1
H0 = H1
i += 1

return H1 + gamma * I # Return X
# Avoid recomputing expressions (reuse solves when possible)
GH0 = G0 @ H0
I_GH0 = I_k + GH0
try:
solve_IGH0_A0 = solve(I_GH0, A0)
A1 = A0 @ solve_IGH0_A0

HG0 = H0 @ G0
I_HG0 = I_k + HG0
solve_IHG0_A0T = solve(I_HG0, A0.T)
A0_G0 = A0 @ G0
G1 = G0 + (A0_G0 @ solve_IHG0_A0T)

H0_A0 = H0 @ A0
solve_IHG0_H0A0 = solve(I_HG0, H0_A0)
A0T = A0.T
H1 = H0 + (A0T @ solve_IHG0_H0A0)

except np.linalg.LinAlgError:
raise ValueError("Matrix inversion failed during iteration.")

error = np.max(np.abs(H1 - H0))
# Avoid data copies: rebind, original code semantics preserved
A0 = A1
G0 = G1
H0 = H1
i += 1

return H1 + gamma * I_k # Return X


def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
Expand Down