# Linear solvers in the TT format

This tutorial addresses solving multilinear systems $\mathsf{Ax}=\mathsf{b}$ in the TT format.

Imports:

In [None]:
import torch as tn
import datetime
try: 
    import torchtt as tntt
except:
    print('Installing torchTT...')
    %pip install git+https://github.com/ion-g-ion/torchTT
    import torchtt as tntt

### Small example

A random tensor operator $\mathsf{A}$ is created in the TT format. We create a random right-hand side $\mathsf{b} = \mathsf{Ax}$, where $\mathsf{x}$ is a random tensor in the TT format. This way the solution of $\mathsf{Ax}=\mathsf{b}$ is known and we can compare it as a reference.
This works only for small random tensors.

In [None]:
A = tntt.random([(4, 4),(5, 5),(6, 6)],[1,2,3,1]) 
A = A @ A.t() + 10 * tntt.eye([4, 5, 6])
x = tntt.random([4, 5, 6],[1, 2, 3, 1])
b = A @ x

Solve the multilinear system $\mathsf{Ax}=\mathsf{b}$ using the method torchtt.solvers.amen_solve().


In [None]:
xs = tntt.solvers.amen_solve(A,b, x0 = b, eps = 1e-7)

The relative residual norm and the relative error of the solution are reported:

In [None]:
print(xs)
print('Relative residual error ',(A @ xs - b).norm() / b.norm())
print('Relative error of the solution  ',(xs - x).norm() / x.norm())

### Finite differences

We now solve the problem $\Delta u = 1$ in $[0,1]^d$ with $ u = 0 $ on the entire boundary using finite differences.
First, set the size of the problem (n is the mode size and d is the number of dimensions):

In [None]:
dtype = tn.float64 
n =  256
d = 8

Create the finite differences matrix corresponding to the problem. The operator is constructed directly in the TT format as it follows

In [None]:
L1d = -2 * tn.eye(n, dtype = dtype) + tn.diag(tn.ones(n-1,dtype = dtype),-1) + tn.diag(tn.ones(n - 1,dtype = dtype),1)
L1d[0, 1] = 0
L1d[-1,-2] = 0
L1d /= (n - 1)
L1d = tntt.TT(L1d, [(n,n)])

L_tt = tntt.zeros([(n,n)] * d)
for i in range(1, d - 1):
    L_tt = L_tt + tntt.eye([n] * i) ** L1d ** tntt.eye([n] * (d - 1 - i))

L_tt = L_tt + L1d ** tntt.eye([n] * (d - 1)) +  tntt.eye([n] * (d - 1)) ** L1d
L_tt = L_tt.round(1e-14)

The right hand site of the finite difference system is also computed in the TT format

In [None]:
b1d = tn.ones(n, dtype=dtype)
b1d[0] = 0
b1d[-1] = 0
b1d = tntt.TT(b1d)
b_tt = b1d
for i in range(d-1):
    b_tt = b_tt ** b1d

Solve the system

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve(L_tt, b_tt, x0 = b_tt, nswp=20, eps=1e-8, verbose=True)
time = datetime.datetime.now() - time
print('Relative residual: ',(L_tt @ x - b_tt).norm() / b_tt.norm())
print('Solver time: ',time)

Display the structure of the TT

In [None]:
print(x)

Try one more time on the GPU (if available).

In [None]:
if tn.cuda.is_available():
    cuda_dev = 'cuda:0'
    time = datetime.datetime.now()
    x = tntt.solvers.amen_solve(L_tt.to(cuda_dev), b_tt.to(cuda_dev) ,x0 = b_tt.to(cuda_dev), nswp = 20, eps = 1e-8, verbose = True, preconditioner=None)
    time = datetime.datetime.now() - time
    x = x.cpu()
    print('Relative residual: ',(L_tt @ x-b_tt).norm()/b_tt.norm())
    print('Solver time: ',time)
else:
    print('GPU not available...')

The banded structure of the matrix can be considered

Not banded solver without preconditioning

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve(L_tt, b_tt, x0=b_tt, nswp=20, eps=1e-8, verbose=False, preconditioner=None)
time = datetime.datetime.now() - time
print('Relative residual: ',(L_tt @ x - b_tt).norm() / b_tt.norm())
print('Solver time: ', time)

Banded solver without preconditioning

In [None]:
bands_L_tt = [1] * len(L_tt.cores)

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve(L_tt, b_tt, x0=b_tt, nswp=20, eps=1e-8, verbose=False, bandsMatrices=bands_L_tt, preconditioner=None)
time = datetime.datetime.now() - time
print('Relative residual: ', (L_tt @ x - b_tt).norm() / b_tt.norm())
print('Solver time: ', time)

Not banded solver with central preconditioning

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve(L_tt, b_tt, x0=b_tt, nswp=20, eps=1e-8, verbose=False, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ',(L_tt @ x - b_tt).norm() / b_tt.norm())
print('Solver time: ', time)

Banded solver with central preconditioning

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve(L_tt, b_tt, x0=b_tt, nswp=20, eps=1e-8, verbose=False, bandsMatrices=bands_L_tt, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ', (L_tt @ x - b_tt).norm() / b_tt.norm())
print('Solver time: ', time)

### Sum of matrices

Solving system $(A + B + C + D + I)x = b$

In [None]:
N = [32, 16, 16]
MN = [(32, 32), (16, 16), (16, 16)]
R = [1, 5, 5, 1]

In [None]:
A = tntt.random(MN, R, dtype=tn.float64)
A = A @ A.t() + 100 * tntt.eye(N)
B = tntt.random(MN, R, dtype=tn.float64)
B = B @ B.t() + 100 * tntt.eye(N)
C = tntt.random(MN, R, dtype=tn.float64)
C = C @ C.t() + 100 * tntt.eye(N)
D = tntt.random(MN, R, dtype=tn.float64)
D = D @ D.t() + 100 * tntt.eye(N)
I = 100 * tntt.eye(N)
b = tntt.random(N, R)

Solving for full matrix

In [None]:
time = datetime.datetime.now()
Mat = (A + B + C + D).round(1e-8)
x = tntt.solvers.amen_solve(Mat, b, eps=1e-8, verbose=False, local_iterations=40, resets=2)
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for summands

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([A, B, C, D], b, eps=1e-8, verbose=False, local_iterations=40, resets=2)
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for full matrix with central preconditioner

In [None]:
time = datetime.datetime.now()
Mat = (A + B + C + D).round(1e-8)
x = tntt.solvers.amen_solve(Mat, b, eps=1e-8, verbose=False, local_iterations=40, resets=2, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for summands with central preconditioner

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([A, B, C, D], b, eps=1e-8, verbose=False, local_iterations=40, resets=2, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ', time)

Solving for full matrix with right preconditioner

In [None]:
time = datetime.datetime.now()
Mat = (A + B + C + D).round(1e-8)
x = tntt.solvers.amen_solve(Mat, b, eps=1e-8, verbose=False, local_iterations=40, resets=2, preconditioner='r')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for summands with right preconditioner

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([A, B, C, D], b, eps=1e-8, verbose=False, local_iterations=40, resets=2, preconditioner='r')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Banded structure can be taken into account

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([A, B, C, D], b, eps=1e-8, verbose=False, local_iterations=40, resets=2, preconditioner='c', bandsMatrices=[[31, 15, 15], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]])
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ', time)

### Sum of matrix products

Solving system $(AA^T + BB^T + CC^T + DD^T + I )x = b$

In [None]:
N = [32, 32, 16]
MN = [(32, 32), (32, 32), (16, 16)]
R = [1, 5, 5, 1]

In [None]:
A = tntt.random(MN, R, dtype=tn.float64)
B = tntt.random(MN, R, dtype=tn.float64)
C = tntt.random(MN, R, dtype=tn.float64)
D = tntt.random(MN, R, dtype=tn.float64)
I = 100 * tntt.eye(N)
b = tntt.random(N, R)

Solving for full matrix

In [None]:
time = datetime.datetime.now()
Mat = (A @ A.t() + B @ B.t() + C @ C.t() + D @ D.t() + I).round(1e-8)
x = tntt.solvers.amen_solve([(Mat,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2)
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for products

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([(A, A.t(),), (B, B.t()), (C, C.t()), (D, D.t()), (I,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2)
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for full matrix with central preconditioner

In [None]:
time = datetime.datetime.now()
Mat = (A @ A.t() + B @ B.t() + C @ C.t() + D @ D.t() + I).round(1e-8)
x = tntt.solvers.amen_solve([(Mat,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for products with central preconditioner

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([(A, A.t(),), (B, B.t()), (C, C.t()), (D, D.t()), (I,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2, preconditioner='c')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for full matrix with right preconditioner

In [None]:
time = datetime.datetime.now()
Mat = (A @ A.t() + B @ B.t() + C @ C.t() + D @ D.t() + I).round(1e-8)
x = tntt.solvers.amen_solve([(Mat,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2, preconditioner='r')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)

Solving for products with right preconditioner

In [None]:
time = datetime.datetime.now()
x = tntt.solvers.amen_solve([(A, A.t(),), (B, B.t()), (C, C.t()), (D, D.t()), (I,)], b, eps=1e-8, verbose=False, local_iterations=100, resets=2, preconditioner='r')
time = datetime.datetime.now() - time
print('Relative residual: ',(Mat @ x - b).norm() / b.norm())
print('Solver time: ',time)