In [None]:
import torch
import torch.autograd
import numml.sparse as sp

In [None]:
# define our sparse tensor using torch's built-in COO format
i = torch.tensor([[0, 1, 1],
                  [2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32, requires_grad=True)

T_coo = torch.sparse_coo_tensor(i, v).coalesce()
T_coo

In [None]:
# Convert to our CSR class
T_csr = sp.SparseCSRTensor(A)

In [None]:
# Dense version of our tensor (for testing)
T_dense = torch.clone(T_csr.to_dense()).detach()
T_dense.requires_grad = True

print(T_dense)

In [None]:
# Sparse mat-vec

print(T_csr @ torch.arange(3).float())
print(T_dense @ torch.arange(3).float())

In [None]:
# Mat-vec grad wrt matrix
## The sparse representation will keep sparse gradients, meaning gradient information will
## be accumulated in *nonzero entries only*

print(torch.autograd.grad((T_csr @ torch.arange(3).float()).sum(), T_csr.data))
print(torch.autograd.grad((T_dense @ torch.arange(3).float()).sum(), T_dense))

In [None]:
# Mat-vec grad wrt vec

x = torch.arange(3).float()+2.
x.requires_grad = True

print(torch.autograd.grad((T_csr @ x).sum(), x))
print(torch.autograd.grad((T_dense @ x).sum(), x))