In [14]:
import torch
from gpytorch.utils.deprecation import bool_compat


def _default_preconditioner(x):
    return x.clone()


class GPyTorchCGSolver:

    def __init__(self, tolerance=1.e-1, max_iters=20, preconditioner=None):
        self.tolerance = tolerance
        self.max_iters = max_iters
        if preconditioner is None:
            self.preconditioner = _default_preconditioner
        else:
            self.preconditioner = preconditioner
        self.stop_updating_after = 1.e-10
        self.eps = 1.e-10

    def set_matrix_and_probes(self, A_fn, b):
        self.A = A_fn
        self.b = b
        self.x0 = torch.zeros_like(b)

    def run_mbcg_with_tracking(self):
        rhs_norm = self.b.norm(2, dim=-2, keepdim=True)
        rhs_is_zero = rhs_norm.lt(self.eps)
        rhs_norm = rhs_norm.masked_fill_(rhs_is_zero, 1)
        rhs = self.b.div(rhs_norm)

        state, out = initialize_cg(self.A, rhs, self.stop_updating_after, self.eps)
        x0, has_converged, r0, batch_shape, residual_norm = state
        (p0, gamma0, mul_storage, beta, alpha, is_zero, z0) = out
        self.initialize_trackers()
        self.update_trackers(x0, r0, gamma0, p0, k=0)
        for k in range(self.max_iters):
            Ap0 = self.A(p0)
            take_cg_step(
                Ap0, x0, r0, gamma0, p0, alpha, beta,
                z0, mul_storage, has_converged, self.eps, is_zero)
            if cond_fn(k, self.max_iters, self.tolerance,
                       r0, has_converged, residual_norm, self.stop_updating_after,
                       rhs_is_zero):
                break

            print_analysis(k, alpha, residual_norm, gamma0, beta)
            self.update_trackers(x0, r0, gamma0, p0, k)
        return x0.mul(rhs_norm)

    def update_trackers(self, x0, r0, gamma0, p0, k):
        self.Us.append(x0.clone().cpu())
        self.Rs.append(r0.clone().cpu())
        self.gammas.append(gamma0.clone().cpu())
        self.ps.append(p0.clone().cpu())
        self.k = k

    def initialize_trackers(self):
        self.Us, self.Rs, self.gammas, self.ps, self.k = [], [], [], [], -1


def linear_cg(
    matmul_closure,
    rhs,
    tolerance=None,
    eps=1e-10,
    stop_updating_after=1e-10,
    max_iter=None,
    initial_guess=None,
    preconditioner=None,
):
    rhs_norm = rhs.norm(2, dim=-2, keepdim=True)
    rhs_is_zero = rhs_norm.lt(eps)
    rhs_norm = rhs_norm.masked_fill_(rhs_is_zero, 1)
    rhs = rhs.div(rhs_norm)

    state, out = initialize_cg(matmul_closure, rhs, stop_updating_after, eps)
    x0, has_converged, r0, batch_shape, residual_norm = state
    (p0, gamma0, mul_storage, beta, alpha, is_zero, z0) = out

    for k in range(max_iter):
        Ap0 = matmul_closure(p0)
        take_cg_step(
            Ap0=Ap0,
            x0=x0,
            r0=r0,
            gamma0=gamma0,
            p0=p0,
            alpha=alpha,
            beta=beta,
            z0=z0,
            mul_storage=mul_storage,
            has_converged=has_converged,
            eps=eps,
            is_zero=is_zero,
        )

        if cond_fn(k, max_iter, tolerance, r0, has_converged, residual_norm,
                   stop_updating_after, rhs_is_zero):
            break

    x0 = x0.mul(rhs_norm)
    return x0


def initialize_cg(matmul_closure, rhs, stop_updating_after, eps):
    initial_guess = torch.zeros_like(rhs)
    preconditioner = _default_preconditioner
    eps = torch.tensor(eps, dtype=rhs.dtype, device=rhs.device)

    residual = rhs - matmul_closure(initial_guess)
    batch_shape = residual.shape[:-2]

    result = initial_guess.expand_as(residual).contiguous()

    residual_norm = residual.norm(2, dim=-2, keepdim=True)
    has_converged = torch.lt(residual_norm, stop_updating_after)

    state = (result, has_converged, residual, batch_shape, residual_norm)
    out = create_placeholders(rhs, residual, preconditioner, batch_shape)
    return state, out


def take_cg_step(
        Ap0, x0, r0, gamma0, p0, alpha, beta, z0, mul_storage, has_converged, eps,
        is_zero):

    torch.mul(p0, Ap0, out=mul_storage)
    torch.sum(mul_storage, dim=-2, keepdim=True, out=alpha)

    torch.lt(alpha, eps, out=is_zero)
    alpha.masked_fill_(is_zero, 1)
    torch.div(gamma0, alpha, out=alpha)
    alpha.masked_fill_(is_zero, 0)
    alpha.masked_fill_(has_converged, 0)

    # residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1}
    torch.addcmul(r0, -alpha, Ap0, out=r0)

    # precon_residual{k} = M^-1 residual_{k}
    precond_residual = r0.clone()

    x0 = torch.addcmul(x0, alpha, p0, out=x0)

    # beta_{k} = (precon_residual{k}^T r_vec_{k}) / (precon_residual{k-1}^T r_vec_{k-1})
    beta.resize_as_(gamma0).copy_(gamma0)
    torch.mul(r0, precond_residual, out=mul_storage)
    torch.sum(mul_storage, -2, keepdim=True, out=gamma0)
    torch.lt(beta, eps, out=is_zero)
    beta.masked_fill_(is_zero, 1)
    torch.div(gamma0, beta, out=beta)
    beta.masked_fill_(is_zero, 0)

    # curr_conjugate_vec_{k} = precon_residual{k} + beta_{k} curr_conjugate_vec_{k-1}
    p0.mul_(beta).add_(precond_residual)


def create_placeholders(rhs, residual, preconditioner, batch_shape):
    precond_residual = preconditioner(residual)
    curr_conjugate_vec = precond_residual
    residual_inner_prod = precond_residual.mul(residual).sum(-2, keepdim=True)

    mul_storage = torch.empty_like(residual)
    alpha = torch.empty(*batch_shape, 1, rhs.size(-1),
                        dtype=residual.dtype, device=residual.device)
    beta = torch.empty_like(alpha)
    is_zero = torch.empty(*batch_shape, 1, rhs.size(-1),
                          dtype=bool_compat, device=residual.device)
    return (curr_conjugate_vec, residual_inner_prod, mul_storage, beta, alpha, is_zero,
            precond_residual)


def cond_fn(k, max_iter, tolerance, residual, has_converged, residual_norm,
            stop_updating_after, rhs_is_zero):
    torch.norm(residual, 2, dim=-2, keepdim=True, out=residual_norm)
    residual_norm.masked_fill_(rhs_is_zero, 0)
    torch.lt(residual_norm, stop_updating_after, out=has_converged)
    flag = k >= min(10, max_iter - 1) and bool(residual_norm.mean() < tolerance)
    return flag


def print_analysis(k, alpha, residual_norm, gamma0, beta):
    print('\n===================================================')
    print(f'Iter {k}')
    print(f'Residual norm mean: {torch.mean(residual_norm)}')
    print(f'Residual norm max: {torch.max(residual_norm)}')
    print(f'Residual norm: {residual_norm}')
    print('alpha')
    print(alpha)
    print(f'Alpha mean: {torch.mean(alpha)}')
    print('gamma')
    print(f'Gamma mean: {torch.mean(gamma0)}')
    print(gamma0)
    print('beta')
    print(f'Beta mean: {torch.mean(beta)}')
    print(beta)

In [15]:
import torch

def run_experiment(n=50, t=5, max_iter=100, tolerance=1e-5, seed=42):
    torch.manual_seed(seed)
    M = torch.randn(n, n)
    A = M.t().mm(M) + 1e-3 * torch.eye(n)
    y = torch.randn(n, 1)
    if t > 1:
        rand_vectors = torch.randn(n, t - 1)
        rand_vectors = rand_vectors / rand_vectors.norm(p=2, dim=0, keepdim=True)
        b = torch.cat([y, rand_vectors], dim=1)
    else:
        b = y
    A_fn = lambda x: A.mm(x)
    cg_solver = GPyTorchCGSolver(tolerance=tolerance, max_iters=max_iter)
    cg_solver.set_matrix_and_probes(A_fn, b)
    x_cg = cg_solver.run_mbcg_with_tracking()
    x_ref = torch.linalg.solve(A, y)
    relative_error = (x_cg[:, 0:1] - x_ref).norm(p=2) / x_ref.norm(p=2)
    print(f"\nRelative error: {relative_error.item():.5f}")

run_experiment()


Iter 0
Residual norm mean: 1.0314686298370361
Residual norm max: 1.2431721687316895
Residual norm: tensor([[1.2432, 0.6902, 1.1852, 0.9194, 1.1194]])
alpha
tensor([[0.0259, 0.0161, 0.0297, 0.0181, 0.0210]])
Alpha mean: 0.022153671830892563
gamma
Gamma mean: 1.1049745082855225
tensor([[1.5455, 0.4763, 1.4048, 0.8453, 1.2530]])
beta
Beta mean: 1.1049745082855225
tensor([[1.5455, 0.4763, 1.4048, 0.8453, 1.2530]])

Iter 1
Residual norm mean: 1.0119779109954834
Residual norm max: 1.203175663948059
Residual norm: tensor([[1.0799, 0.7183, 1.0830, 1.2032, 0.9755]])
alpha
tensor([[0.0208, 0.0222, 0.0173, 0.0289, 0.0212]])
Alpha mean: 0.022067328914999962
gamma
Gamma mean: 1.0508644580841064
tensor([[1.1662, 0.5159, 1.1730, 1.4476, 0.9516]])
beta
Beta mean: 1.0289448499679565
tensor([[0.7546, 1.0830, 0.8350, 1.7127, 0.7594]])

Iter 2
Residual norm mean: 0.9281632304191589
Residual norm max: 1.100314974784851
Residual norm: tensor([[0.9796, 0.7229, 0.9545, 1.1003, 0.8834]])
alpha
tensor([[0.0226

In [3]:
import torch

# A simple kernel function for demonstration.
def simple_kernel(x, y):
    # x: (1, d), y: (n, d)
    diff = x - y  # shape: (n, d) via broadcasting
    dist_sq = torch.sum(diff ** 2, dim=1)
    return simple_kernel.outputscale ** 2 * torch.exp(-0.5 * dist_sq)

# Set an attribute for the kernel function (our "scale" parameter)
simple_kernel.outputscale = 2.0

def _pivoted_cholesky(x, kernel_fn, outputscale, rank=3):
    n = x.shape[0]
    L = torch.zeros((rank, n), device=x.device, dtype=x.dtype)
    d = (outputscale ** 2) * torch.ones(n, device=x.device, dtype=x.dtype)
    pi = torch.arange(n, device=x.device)
    
    for m in range(rank):
        pivot_relative = torch.argmax(d[m:]).item()
        i = m + pivot_relative
        
        # Swap the pivot indices
        temp = pi[m].clone()
        pi[m] = pi[i]
        pi[i] = temp
        
        # Set L[m, pi[m]] = sqrt(d[pi[m]])
        L[m, pi[m]] = torch.sqrt(d[pi[m]])
        
        # Evaluate the kernel between the pivot and remaining points.
        a = kernel_fn(x[pi[m]].unsqueeze(0), x[pi[m+1:]])
        a = a.squeeze(0)  # shape: (n-m-1,)
        
        if m > 0:
            correction = torch.sum(L[:m, pi[m]].unsqueeze(1) * L[:m, pi[m+1:]], dim=0)
        else:
            correction = 0.0

        l = (a - correction) / L[m, pi[m]]
        L[m, pi[m+1:]] = l
        
        d[pi[m+1:]] = d[pi[m+1:]] - (l ** 2)
    
    return L

def psd_safe_cholesky(M):
    jitter = 1e-6
    while True:
        try:
            return torch.linalg.cholesky(M)
        except RuntimeError:
            M = M + jitter * torch.eye(M.shape[0], device=M.device, dtype=M.dtype)
            jitter *= 10

def build_cholesky(X, kernel, noise, rank):
    # Compute L using the pivoted Cholesky factorization (L shape: (rank, n))
    L = _pivoted_cholesky(X, kernel, kernel.outputscale, rank=rank)
    
    noise_inv2 = noise ** -2
    noise_inv4 = noise ** -4
    
    # Compute M = I + noise^{-2} (L @ L.T) of shape (rank, rank)
    M = torch.eye(rank, device=L.device, dtype=L.dtype) + noise_inv2 * (L @ L.T)
    M_cho_factor = psd_safe_cholesky(M)
    
    def precond_inv(v):
        # Ensure v is (n, 1)
        if v.ndim == 1:
            v = v.unsqueeze(1)
        z = torch.cholesky_solve(L @ v, M_cho_factor, upper=False)
        result = noise_inv2 * v - noise_inv4 * (L.T @ z)
        return result.squeeze(-1) if result.ndim == 2 and result.shape[1] == 1 else result
    
    # For verification, define A = noise^2 I_n + L.T @ L, with shape (n, n)
    n = X.shape[0]
    A = noise ** 2 * torch.eye(n, device=X.device, dtype=X.dtype) + (L.T @ L)
    
    return precond_inv, A, L

# ---- Demonstration ----

# Generate a random input matrix X of size (n, d)
n, d = 10, 3
X = torch.randn(n, d)

# Choose a noise scale and rank for the pivoted Cholesky.
noise = 0.1
rank = 3

# Build the preconditioner and obtain A.
precond_inv, A, L = build_cholesky(X, simple_kernel, noise, rank)

# We now want to verify that precond_inv is indeed A^{-1}. 
# One way is to apply it to A and check if we get the identity matrix.

# We'll compute the product: precond_inv(A @ e_i) for each standard basis vector e_i.
I_approx = torch.zeros_like(A)
for i in range(n):
    e = torch.zeros(n, device=A.device)
    e[i] = 1.0
    Ae = A @ e
    inv_Ae = precond_inv(Ae)
    I_approx[:, i] = inv_Ae

# Compute the difference from the true identity matrix.
I_true = torch.eye(n, device=A.device, dtype=A.dtype)
error_norm = torch.norm(I_approx - I_true)

print("Approximated I from precond_inv * A:")
print(I_approx)
print("\nError norm ||I_approx - I||:", error_norm.item())

Approximated I from precond_inv * A:
tensor([[ 1.0000e+00, -9.5367e-06, -1.2666e-07,  5.7742e-08, -1.7136e-07,
         -1.0431e-07, -1.0710e-08, -1.4901e-08,  9.5367e-07,  1.9073e-06],
        [ 0.0000e+00,  1.0000e+00, -1.7881e-07,  0.0000e+00,  9.5367e-07,
         -4.7684e-07,  2.9802e-08,  0.0000e+00, -1.5259e-05, -7.6294e-06],
        [ 1.8626e-08, -4.7684e-07,  1.0001e+00,  1.1921e-07,  1.5259e-05,
          2.2888e-05,  1.4901e-08,  0.0000e+00,  3.0547e-07,  2.8610e-06],
        [-2.1420e-07,  3.8147e-06,  3.5763e-07,  1.0000e+00,  0.0000e+00,
          1.1921e-07,  1.8626e-09,  0.0000e+00,  9.5367e-07,  0.0000e+00],
        [-2.2352e-08,  4.7684e-07,  3.0518e-05,  0.0000e+00,  1.0000e+00,
          1.5259e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9073e-06],
        [-1.4901e-08, -2.3842e-07,  1.5259e-05,  1.1921e-07,  7.6294e-06,
          1.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07,  4.7684e-07],
        [-2.5611e-09,  2.9802e-08,  1.4901e-08,  0.0000e+00,  0.0000e