In [None]:
import torch
import torch.autograd
import numpy as np
import matplotlib.pyplot as plt
import numml.sparse as sp

In [None]:
# Our example triangular system, create CSR representation and enable gradient

lower = True # switch this to False to test upper triangular

A_csr = sp.eye(5)*2 - sp.eye(5,k=(-1 if lower else 1))
A_csr.requires_grad = True
print(A_csr)
print(A_csr.to_dense())

In [None]:
# Generate right-hand-side
b = torch.arange(5).float() + 1.
b

In [None]:
# Solve Ax=b for x
x = A_csr.solve_triangular(upper=(not lower), unit=False, b=b)
x

In [None]:
# Assert that we have zero residual (A times x is equal to b)
torch.all((A_csr@x).detach() == b)

In [None]:
# Interesting example:
# optimize the entries of A such that A^{-1} b = b

optimizer = torch.optim.Adam([A_csr.data], lr=0.01)
lh = []

for i in range(1_000):
    optimizer.zero_grad()
    
    x = A_csr.solve_triangular(upper=(not lower), unit=False, b=b)
    loss = torch.sum((x - b) ** 2)
    loss.backward()
    
    optimizer.step()
    
    lh.append(loss.item())
    if i % 100 == 0:
        print(i, loss.item())

In [None]:
plt.semilogy(lh)
plt.grid()
plt.title('Loss history')

In [None]:
A_csr.to_dense()

In [None]:
A_csr.solve_triangular(upper=(not lower), unit=False, b=b)