In [99]:
import numpy as np
from scipy.sparse.linalg import cg

def conjugate_gradient(A, b, x0, tol=1e-6, max_iter=1000):
    """
    Solves Ax = b using the Conjugate Gradient method.

    Args:
        A (np.ndarray): The symmetric, positive-definite matrix.
        b (np.ndarray): The right-hand side vector.
        x0 (np.ndarray): The initial guess for the solution.
        tol (float): The tolerance for convergence.
        max_iter (int): The maximum number of iterations.

    Returns:
        np.ndarray: The approximate solution x.
    """
    x = x0.copy()
    r = b - A @ x  # Initial residual
    p = r.copy()    # Initial search direction
    residuals = [r.copy()]
    directions = [p.copy()]
    rs_old = np.dot(r, r)

    for i in range(max_iter):
        Ap = A @ p
        alpha = rs_old / np.dot(p, Ap)
        x = x + alpha * p
        r = r - alpha * Ap
        rs_new = np.dot(r, r)
        if np.sqrt(rs_new) < tol:
            print(f"Converged in {i+1} iterations.")
            break

        beta = rs_new / rs_old
        p = r + beta * p
        residuals.append(r)
        directions.append(p)
        rs_old = rs_new
    else:
        print(f"Did not converge within {max_iter} iterations.")

    return x, directions, np.vstack(residuals)

In [None]:
def spd_from_gram(n, eps=1e-6, rng=None):
    rng = np.random.default_rng(rng)
    B = rng.normal(size=(n, n))
    A = B.T @ B
    A += eps * np.eye(n)       
    return B,A # Matriz no SPD, Matriz SPD


In [104]:
n = 3
A_tilde, A = spd_from_gram(3)
b = np.array([-1, -0.5, -1 ])
x, d, r = conjugate_gradient(A, b, [0,0,0])

Converged in 3 iterations.


In [117]:
A@x - b

array([-2.22044605e-16, -2.22044605e-16, -2.22044605e-16])

In [116]:
r@r.T

array([[2.25000000e+00, 1.11022302e-16, 6.24500451e-17],
       [1.11022302e-16, 2.96703680e-02, 2.16418887e-18],
       [6.24500451e-17, 2.16418887e-18, 1.94002299e-03]])

In [None]:
np.linalg.norm(r @ r.T - np.diag(np.diag(r @ r.T))) 

1.801701612083672e-16

In [None]:
x, d, r = conjugate_gradient(A_tilde, b, [0,0,0]) # No da output ya que no converge luego de max_iters
A @ cg(A_tilde, b, [0,0,0])[0] - b # función de scipy retorna una mal solución.

Did not converge within 1000 iterations.


array([ -5664124.44535529, -13123559.76465118,   9079693.26388094])

In [126]:
# La idea es correr el algoritmo fuera de la función, permitiendo almacenar los residuos y direcciones.
A = A_tilde; b = b; x0 = [0,0,0]; max_iter = 1000; tol=1e-6

x = x0.copy()
r = b - A @ x  # Initial residual
p = r.copy()    # Initial search direction
residuals = [r.copy()]
directions = [p.copy()]
rs_old = np.dot(r, r)

for i in range(max_iter):
    Ap = A @ p
    alpha = rs_old / np.dot(p, Ap)
    x = x + alpha * p
    r = r - alpha * Ap
    rs_new = np.dot(r, r)
    if np.sqrt(rs_new) < tol:
        print(f"Converged in {i+1} iterations.")
        break

    beta = rs_new / rs_old
    p = r + beta * p
    residuals.append(r)
    directions.append(p)
    rs_old = rs_new
else:
    print(f"Did not converge within {max_iter} iterations.")

A@x - b

Did not converge within 1000 iterations.


array([-4.74753662e+13,  4.86343332e+13,  4.58657682e+13])

In [121]:
o = np.vstack(residuals) @ np.vstack(residuals).T

In [122]:
o - np.diag(np.diag(o))

array([[ 0.00000000e+00, -1.11022302e-16,  8.83738020e+00, ...,
         2.24918939e+13,  2.25995284e+13,  2.27075686e+13],
       [-1.11022302e-16,  0.00000000e+00, -2.52781610e+00, ...,
        -8.50182150e+12, -8.54250680e+12, -8.58334542e+12],
       [ 8.83738020e+00, -2.52781610e+00,  0.00000000e+00, ...,
         2.58950230e+15,  2.60189431e+15,  2.61433302e+15],
       ...,
       [ 2.24918939e+13, -8.50182150e+12,  2.58950230e+15, ...,
         0.00000000e+00,  6.62734120e+27,  6.65902411e+27],
       [ 2.25995284e+13, -8.54250680e+12,  2.60189431e+15, ...,
         6.62734120e+27,  0.00000000e+00,  6.69089074e+27],
       [ 2.27075686e+13, -8.58334542e+12,  2.61433302e+15, ...,
         6.65902411e+27,  6.69089074e+27,  0.00000000e+00]])

In [124]:
np.linalg.norm(o - np.diag(np.diag(o)))

6.399433456017062e+29