In [252]:
import torch
from torch import Tensor
import typing
from torch.nn.functional import normalize
from torch import tensordot as dot
from torch.linalg import vector_norm as norm

We first make an implementation of Gram-Schmidt

In [312]:
def gram_schmidt(x: Tensor, B: Tensor):
    """
    Tensor x: shape (1, M) <- Representing NxC vectors in R^M
    Tensor B: shape (R, M) <- Representing NxC vectors in R^M
    
    h0 = B' * x (pointwise dot products) -> (C, R)
    """
    
    h0 = dot(x, B, dims=[[1], [1]])             # dim (1, R)
    x0 = x - dot(h0, B, dims=[[1], [0]])        # dim (1, M) - ((1, R) x (R, M)) = # dim (1, M)

    h1 = dot(x0, B, dims=[[1], [1]])            # dim (1, R)
    x1 = x0 - dot(h1, B, dims=[[1], [0]])       # dim (1, M) - ((1, R) x (R, M)) = # dim (1, M) 
    
    h = h0 + h1                                             # dim (1, R)
    b = norm(x1, dim=1, keepdim=True)   # dim (1, 1)        
    
    x1 = x1 / b
    
    return x1, h, b

In [313]:
def arnoldi(op, b, steps, callback=None):
    """
    nn.Module op: takes (C, M) input, outputs of shape (C, M).
    Tensor b: (1, M)
    """
    assert steps < b.shape[1], f"Hey, steps must be at most {b.shape[1]}"
    
    q = b / norm(b, dim=1)  # dim (C, M)
    Q = torch.zeros(steps+1, b.shape[1]) # dim (C, M)
    Q[0, :] = torch.ones(q.shape[1])
    Q = Q * q
    for m in range(steps):
        # Next element in krylov space
        x = op(q)
        
        # Orthogonalise against Q and update
        (q, h, beta) = gram_schmidt(x, Q[:(m+1), :])
        
        # Create H
        # Hij = dot(q_i, (Aq)_;j) = q_i' * Aq_j
        # Behöver endast beräkna q_m' * Aq_j  och q_j * Aqm
        AQ = op(Q[:(m+1), :]) # -> (C, M)
        H = dot(Q[:(m+1), :], AQ, dims=[[1],[1]]) # (C, C) # unnecessary double counting but ok
        
        # Do callback
        stop = False
        if callback is not None:
            stop = callback(Q[:(m+1),:], q, H, beta, m+1)
        if stop:
            break
        mask = torch.zeros_like(Q)
        mask[m+1, :] = torch.ones_like(q)
        Q = Q + mask * q
    return Q, q, H, beta



def gmres(op, b, steps, callback=None, verbose=False):
    """
    GMRES implementation in pytorch
    """
    Q, q, H, beta = arnoldi(op, b, steps, callback=callback)
    
    normb = norm(b, dim=1)
    m = steps
    Q = Q[:m, :]
    
    em = torch.zeros(m, m)
    em[m-1, m-1] = 1

    e1 = torch.zeros(1, m)
    e1[0, 0] = 1

    HmTHm = dot(H, H, dims=[[0], [0]]) + em * beta**2
    HmTbe = dot(normb*e1, H, dims=[[1], [0]])

    z = torch.linalg.solve(HmTHm[None,:,:], HmTbe)
    x = dot(z, Q, dims=[[1],[0]])
    return x

In [315]:
n = 10
A = 0.1*torch.rand(n,n) + torch.eye(n)
x = torch.ones(n)[None,:]
b.requires_grad = True

op = lambda x: torch.tensordot(x, A, dims=[[1],[1]]) # (C, M) x (M, M) -> (C, M)
y = gmres(op, b, steps=5, verbose=True)

print(dot(y,A,dims=[[1],[1]]))

tensor([[3.5018e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00]],
       grad_fn=<ReshapeAliasBackward0>)


Let's try GMRES on our integral operator

Below is an in-place version of gmres which is faster, but not differentiable.

In [None]:

def gmres(op, b, steps, tol=0., callback=None, verbose=False):
    """
    GMRES implementation in pytorch
    """
    
    normb = torch.linalg.vector_norm(b, dim=1)
    x = torch.zeros_like(b)
    
    def gmres_callback(Q, q, H, beta, m):
        em = torch.zeros(1, m)
        em[0, m-1] = 1
        
        e1 = torch.zeros(1, m+1)
        e1[0, 0] = 1
        
        Hm = torch.zeros(m+1, m)
        Hm[:m, :m] = H
        Hm[m,   :] = em * beta
        
        be1 = normb  * e1
        
        HmTHm = torch.tensordot(Hm, Hm , dims=[[0], [0]])
        HmTbe = torch.tensordot(be1, Hm, dims=[[1], [0]])
        
        z = torch.linalg.solve(HmTHm[None,:,:], HmTbe)
        x[:] = torch.tensordot(z, Q, dims=[[1],[0]])
        
        res = torch.linalg.vector_norm((torch.tensordot(z, Hm, dims=[[1], [1]]) - be1)[0])
        stop = False
        
        if verbose:
            print(f"residual at step {m}: {res:.2e}")
        if res < tol:
            stop = True
        if callback is not None:
            stop = stop or callback(Q, q, H, beta, m)
        return stop
    
    arnoldi(op, b, steps, callback=gmres_callback)
    return x