Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 81 additions & 26 deletions quantecon/_matrix_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,41 +273,96 @@ def solve_discrete_riccati_system(Π, As, Bs, Cs, Qs, Rs, Ns, beta,
"""
m = Qs.shape[0]
k, n = Qs.shape[1], Rs.shape[1]
# Create the Ps matrices, initialize as identity matrix
Ps = np.array([np.eye(n) for i in range(m)])
Ps1 = np.copy(Ps)

# == Set up for iteration on Riccati equations system == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."
# Precompute transposes since A(s).T and B(s).T are used often, and also all constant blocks
As_T = np.array([A.T for A in As])
Bs_T = np.array([B.T for B in Bs])
if Ns is not None:
Ns_T = np.array([N.T for N in Ns])

# == Prepare array for iteration == #
sum1, sum2 = np.empty((n, n)), np.empty((n, n))
Ps = np.array([np.eye(n) for _ in range(m)])
Ps1 = np.empty_like(Ps)

# == Main loop == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."
iteration = 0

# Preallocate possible storage for sum arrays
sum1 = np.empty((n, n))
sum2 = np.empty((n, n))

while error > tolerance:

if iteration > max_iter:
raise ValueError(fail_msg.format(max_iter))

else:
error = 0
for i in range(m):
# Initialize arrays
sum1[:, :] = 0.
sum2[:, :] = 0.
for j in range(m):
sum1 += beta * Π[i, j] * As[i].T @ Ps[j] @ As[i]
sum2 += Π[i, j] * \
(beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T) @ \
solve(Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i],
beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i])

Ps1[i][:, :] = Rs[i] + sum1 - sum2
error += np.max(np.abs(Ps1[i] - Ps[i]))

Ps[:, :, :] = Ps1[:, :, :]
iteration += 1
error = 0

for i in range(m):
sum1.fill(0.)
sum2.fill(0.)

# Preload matrices needed across the inner loop
A_T = As_T[i]
B = Bs[i]
B_T = Bs_T[i]
R = Rs[i]
Q = Qs[i]
if Ns is not None:
N = Ns[i]
N_T = Ns_T[i]
else:
N = np.zeros((B.shape[1], B.shape[0])) # (k, n)
N_T = N.T # (n, k)

# Vectorize inner loop for performance
# Stack across all j and do all at once when possible

# Precompute terms for all j
A_T_Ps = np.tensordot(A_T, Ps, axes=(1,1)) # shape (n, m, n)
# Since A_T (n, n), Ps (m, n, n) -> tensordot over (n) gives (n, m, n)
# We want A_T @ Ps[j] for all j
# Now for each j: A_T_Ps[:,j,:]
# A_T_Ps shape: (n, m, n)

B_T_Ps = np.tensordot(B_T, Ps, axes=(1,1)) # (k, m, n)

sum1_tmp = beta * np.sum(
(Π[i, :, None, None] * np.matmul(A_T[None, :, :], np.matmul(Ps, As[i]))),
axis=0
) # shape (n, n)
sum1[:,:] = sum1_tmp

# sum2 cannot be trivially vectorized using matmul across all js,
# but we can precompute and allocate intermediate arrays to batch it per i, then sum
# Each j:
# Mj = Qs[i] + beta * Bs[i].T @ Ps[j] @ Bs[i]
# vj = beta * Bs[i].T @ Ps[j] @ As[i] + Ns[i]
# outj = Π[i,j] * (beta * As[i].T @ Ps[j] @ Bs[i] + Ns[i].T) @ solve(Mj, vj)

for j in range(m):
Pj = Ps[j]
Pi_j = Π[i, j]
# Precompute blocks for reuse
A_T_Pj = A_T @ Pj # (n,n)
B_T_Pj = B_T @ Pj # (k, n)
A_T_Pj_B = A_T_Pj @ B # (n, k)
Pj_A = Pj @ As[i] # (n, n)
B_T_Pj_A = B_T @ Pj_A # (k, n)
Mj = Q + beta * B_T_Pj @ B # (k, k)
vj = beta * B_T_Pj_A + N # (k, n)
left = beta * A_T_Pj_B + N_T # (n, k)
# Solve for each column of vj, so solve(Mj, vj) returns (k, n)
solve_Mj_vj = solve(Mj, vj)
sum2 += Pi_j * left @ solve_Mj_vj

Ps1[i][:,:] = R + sum1 - sum2

error += np.max(np.abs(Ps1[i] - Ps[i]))

# In-place array swap to preserve memory/caching
Ps[:,:,:] = Ps1[:,:,:]
iteration += 1

return Ps