# TT decomposition in torchtt

This notebook 


Imports

In [None]:
import torch as tn
try: 
    import torchtt as tntt
except:
    # pip install
    pass

### Decomposition of a full tensor in TT format

We now create a 4d torch.tensor which we will use later

In [None]:
tens_full = tn.reshape(tn.arange(32*16*8*10, dtype = tn.float64),[32,16,8,10])

The TT approximation of a given tensor is 
$\mathsf{x}_{i_1i_2...i_d} \approx \sum\limits_{r_1,...,r_{d-1}=1}^{R_1,...,R_{d-1}} \mathsf{g}^{(1)}$

In [None]:
tens_tt = tntt.TT(tens_full)

The original tensor can be recovered using the `torchtt.TT.full()` method (also check if it equals the original full tensor):

In [None]:
tens_full_rec = tens_tt.full()
print(tn.linalg.norm(tens_full-tens_full_rec)/tn.linalg.norm(tens_full))

Using the print function, information about the newly created torchtt.TT instance can be displayed:

In [None]:
print(tens_tt)

Slicing operation can be performed on a tensor in TT format. If all the dimensions are indexed with an integer and the multiindices are valid, a torch.tensor with the corresponding value is returned. Slices can be also used, however the returned object in this case is again a torchtt.TT instance.

In [None]:
print(tens_tt[1,2,3,4])
print(tens_tt[1,1:4,2,:])

### TT rank rounding

In some cases the TT rank becomes too large and a reduction is desired. The goal is to perform a reduction of the rank while maintaining an accuracy.
The problem statement of the rounding operation is: given a tensor $\mathsf{x}$ in the TT format with the TT rank $\mathbf{R}$ and an $\epsilon>0$, find a tensor $\tilde{\mathsf{x}}$ with TT rank $\tilde{\mathbf{R}}\leq \mathbf{R}$ such that $ ||\mathsf{x}-\tilde{\mathsf{x}}||_F\leq \epsilon || \mathsf{x} ||_F$.
This is implemented using the member method of a TT object `torchtt.TT.round()`. The argument `epsilon` is passed to the function as well as the optional argument `rmax` which also restricts the rank of the rounding.

We will create a tensor of TT rank $(1,6,6,6,1)$

In [None]:
t1 = tntt.randn([10,20,30,40],[1,2,2,2,1])
t2 = tntt.randn([10,20,30,40],[1,2,2,2,1])
t3 = tntt.randn([10,20,30,40],[1,2,2,2,1])
t1, t2, t3 = t1/t1.norm(), t2/t2.norm(), t3/t3.norm()
t = t1+1e-3*t2+1e-6*t3
t_full = t.full()
print(t)