In [None]:
# Testing CUDA matvec implementation

In [None]:
import torch
import torch.linalg as tla
import numml.sparse as sp
import time

In [None]:
# Large(r) poisson problem

N = 1024
A = sp.eye(N)*2 - sp.eye(N, k=-1) - sp.eye(N, k=1)
A.requires_grad = True
x = torch.rand(N)
x.requires_grad = True
print(repr(A))
print(repr(x))

In [None]:
# Move everything over to the GPU
# The CSR tensor has similar syntax for moving between devices

cuda = torch.device('cuda:0')
A_c = A.to(cuda).detach()
A_c.requires_grad = True
x_c = x.to(cuda).detach()
x_c.requires_grad = True

print(repr(A_c))
print(repr(x_c))

In [None]:
# Check that our CUDA matvec gives equivalent results to the CPU matvec
torch.allclose((A_c @ x_c).cpu(), A@x)

In [None]:
# Also check that the backward pass is identical
(A_c@x_c).sum().backward()
(A@x).sum().backward()

print(torch.allclose(A.grad.data, A_c.grad.data.cpu()))
print(torch.allclose(x.grad, x_c.grad.cpu()))

In [None]:
# Reset gradients
A.data.grad.zero_()
A_c.data.grad.zero_()
x.grad.zero_()
x_c.grad.zero_()

In [None]:
# Timing test

N_it = 1_000
print(f'Performing {N_it} sparse matvecs (forward pass)')

t_start = time.time()
for i in range(N_it):
    b = A@x
t_cpu = time.time() - t_start
print('CPU time:', t_cpu)

t_start = time.time()
for i in range(N_it):
    b_c = A_c@x_c
torch.cuda.synchronize()
t_cuda = time.time() - t_start
print('GPU time:', t_cuda)
print()

N_it = 100
print(f'Performing {N_it} sparse matvecs (backward pass)')

t_start = time.time()
for i in range(N_it):
    b = A@x
    b.sum().backward()
t_cpu = time.time() - t_start
print('CPU time:', t_cpu)

t_start = time.time()
for i in range(N_it):
    b_c = A_c@x_c
    b_c.sum().backward()
torch.cuda.synchronize()
t_cuda = time.time() - t_start
print('GPU time:', t_cuda)

In [None]:
grad_err = tla.norm(A.grad.data - A_c.grad.data.cpu()) / tla.norm(A.grad.data)
print('Relative error in CPU and GPU gradients (A)', grad_err.item())

grad_err = tla.norm(x.grad - x_c.grad.cpu()) / tla.norm(x.grad)
print('Relative error in CPU and GPU gradients (x)', grad_err.item())