Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy.linalg import solve
from scipy.linalg import solve_discrete_lyapunov as sp_solve_discrete_lyapunov
from scipy.linalg import solve_discrete_are as sp_solve_discrete_are
from numba import njit


EPS = np.finfo(float).eps
Expand Down Expand Up @@ -273,16 +274,27 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
"""
m = Qs.shape[0]
k, n = Qs.shape[1], Rs.shape[1]

# Validate input dimensions to match original behavior
# This will trigger IndexError if dimensions are incompatible
# Access array elements to trigger same IndexError as original
for i in range(m):
_ = Qs[i]
_ = Rs[i]
_ = As[i]
_ = Bs[i]
_ = Ns[i]
for j in range(m):
_ = Π[i, j]

# Create the Ps matrices, initialize as identity matrix
Ps = np.array([np.eye(n) for i in range(m)])
Ps1 = np.copy(Ps)

# == Set up for iteration on Riccati equations system == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."

# == Prepare array for iteration == #
sum1, sum2 = np.empty((n, n)), np.empty((n, n))
fail_msg = "Convergence failed after {} iterations."
error = tolerance + 1

# == Main loop == #
iteration = 0
Expand All @@ -291,23 +303,58 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
if iteration > max_iter:
raise ValueError(fail_msg.format(max_iter))

else:
error = 0
for i in range(m):
# Initialize arrays
sum1[:, :] = 0.
sum2[:, :] = 0.
for j in range(m):
sum1 += beta * Π[i, j] * As[i].T @ Ps[j] @ As[i]
sum2 += Π[i, j] * \
(beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T) @ \
solve(Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i],
beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i])

Ps1[i][:, :] = Rs[i] + sum1 - sum2
error += np.max(np.abs(Ps1[i] - Ps[i]))

Ps[:, :, :] = Ps1[:, :, :]
iteration += 1
# Use the numba-compiled step function for each iteration.
error = _iterate_riccati(Ps, Ps1, Π, As, Bs, Qs, Rs, Ns, beta)

Ps[:, :, :] = Ps1[:, :, :]
iteration += 1

return Ps


@njit(cache=True)
def _iterate_riccati(Ps: np.ndarray, Ps1: np.ndarray, Π: np.ndarray,
As: np.ndarray, Bs: np.ndarray, Qs: np.ndarray,
Rs: np.ndarray, Ns: np.ndarray, beta: float) -> float:
"""
Helper function to iterate over the Riccati equations system.
Returns the total error after updating Ps1.
"""
m = Qs.shape[0]
n = Rs.shape[1]
error = 0.0

for i in range(m):
# Initialize arrays
sum1 = np.zeros((n, n), dtype=np.float64)
sum2 = np.zeros((n, n), dtype=np.float64)
for j in range(m):
# Using explicit variable to ensure proper float multiplication
beta_Pi = beta * Π[i, j]

# Calculate common terms
AsT = As[i].T
Psj = Ps[j]
Bs_i = Bs[i]
Qs_i = Qs[i]
Rs_i = Rs[i]
Ns_i = Ns[i]

A_P_B = Bs_i.T @ Psj @ Bs_i
A_P_A = AsT @ Psj @ As[i]
A_P_B2 = Bs_i.T @ Psj @ As[i]

# == SUM1 COMPONENT ==
sum1 += beta_Pi * A_P_A

# == SUM2 COMPONENT ==
mat_for_solve = Qs_i + beta * A_P_B
vec_for_solve = beta * A_P_B2 + Ns_i

# Use solve for the linear system (numba supported since 0.57)
solve_result = solve(mat_for_solve, vec_for_solve)
sum2 += Π[i, j] * (beta * AsT @ Psj @ Bs_i + Ns_i.T) @ solve_result

Ps1[i][:, :] = Rs[i] + sum1 - sum2
error += np.max(np.abs(Ps1[i] - Ps[i]))
return error