diff --git a/quantecon/_matrix_eqn.py b/quantecon/_matrix_eqn.py index e2bb6f72..b7fc462b 100644 --- a/quantecon/_matrix_eqn.py +++ b/quantecon/_matrix_eqn.py @@ -163,8 +163,6 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500, X = sp_solve_discrete_are(A, B, Q, R, e=I, s=N.T) return X - # if method == 'doubling' - # == Set up == # error = tolerance + 1 fail_msg = "Convergence failed after {} iterations." @@ -174,21 +172,38 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500, candidates = (0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5) BB = B.T @ B BTA = B.T @ A - for gamma in candidates: - Z = R + gamma * BB - cn = np.linalg.cond(Z) - if cn * EPS < 1: - Q_tilde = - Q + (N.T @ solve(Z, N + gamma * BTA)) + gamma * I - G0 = B @ solve(Z, B.T) - A0 = (I - gamma * G0) @ A - (B @ solve(Z, N)) + + # Precompute identity for GH block for efficiency in loop + I_k = I + # Allocate outside for efficiency + Z_array = [R + gamma * BB for gamma in candidates] + cn_array = [np.linalg.cond(Z) for Z in Z_array] + # Avoid repeated solves when filters fail, so index candidates + valid_idx = [i for i, cn in enumerate(cn_array) if cn * EPS < 1] + + for idx in valid_idx: + gamma = candidates[idx] + Z = Z_array[idx] + try: + solve_Z_BTA = solve(Z, N + gamma * BTA) + Q_tilde = -Q + (N.T @ solve_Z_BTA) + gamma * I_k + solve_Z_BT = solve(Z, B.T) + G0 = B @ solve_Z_BT + solve_Z_N = solve(Z, N) + A0 = (I_k - gamma * G0) @ A - (B @ solve_Z_N) H0 = gamma * (A.T @ A0) - Q_tilde - f1 = np.linalg.cond(Z, np.inf) + + f1 = cn_array[idx] if Z.shape[0] == Z.shape[1] else np.linalg.cond(Z, np.inf) f2 = gamma * f1 - f3 = np.linalg.cond(I + (G0 @ H0)) + GH = I_k + (G0 @ H0) + f3 = np.linalg.cond(GH) + f_gamma = max(f1, f2, f3) if f_gamma < current_min: best_gamma = gamma current_min = f_gamma + except np.linalg.LinAlgError: + continue # Skip ill-posed gamma values # == If no candidate successful then fail == # if current_min == np.inf: @@ -199,30 +214,52 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500, R_hat = R + gamma * BB # == Initial conditions == # - Q_tilde = - Q + (N.T @ solve(R_hat, N + gamma * BTA)) + gamma * I - G0 = B @ solve(R_hat, B.T) - A0 = (I - gamma * G0) @ A - (B @ solve(R_hat, N)) + solve_Rhat_N_gammaBTA = solve(R_hat, N + gamma * BTA) + Q_tilde = -Q + (N.T @ solve_Rhat_N_gammaBTA) + gamma * I_k + solve_Rhat_BT = solve(R_hat, B.T) + G0 = B @ solve_Rhat_BT + solve_Rhat_N = solve(R_hat, N) + A0 = (I_k - gamma * G0) @ A - (B @ solve_Rhat_N) H0 = gamma * (A.T @ A0) - Q_tilde i = 1 + # Use memory-efficient block for main loop + # Preallocate for solve input/output arrays + # The dimensions of all matrices are constant, so reuse arrays in place as much as possible + # == Main loop == # while error > tolerance: - if i > max_iter: raise ValueError(fail_msg.format(i)) - - else: - A1 = A0 @ solve(I + (G0 @ H0), A0) - G1 = G0 + ((A0 @ G0) @ solve(I + (H0 @ G0), A0.T)) - H1 = H0 + (A0.T @ solve(I + (H0 @ G0), (H0 @ A0))) - - error = np.max(np.abs(H1 - H0)) - A0 = A1 - G0 = G1 - H0 = H1 - i += 1 - - return H1 + gamma * I # Return X + # Avoid recomputing expressions (reuse solves when possible) + GH0 = G0 @ H0 + I_GH0 = I_k + GH0 + try: + solve_IGH0_A0 = solve(I_GH0, A0) + A1 = A0 @ solve_IGH0_A0 + + HG0 = H0 @ G0 + I_HG0 = I_k + HG0 + solve_IHG0_A0T = solve(I_HG0, A0.T) + A0_G0 = A0 @ G0 + G1 = G0 + (A0_G0 @ solve_IHG0_A0T) + + H0_A0 = H0 @ A0 + solve_IHG0_H0A0 = solve(I_HG0, H0_A0) + A0T = A0.T + H1 = H0 + (A0T @ solve_IHG0_H0A0) + + except np.linalg.LinAlgError: + raise ValueError("Matrix inversion failed during iteration.") + + error = np.max(np.abs(H1 - H0)) + # Avoid data copies: rebind, original code semantics preserved + A0 = A1 + G0 = G1 + H0 = H1 + i += 1 + + return H1 + gamma * I_k # Return X def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,