# TT-decomposition

[tt_hse16_slides](https://bayesgroup.github.io/team/arodomanov/tt_hse16_slides.pdf)

[Tensorising Neural Networks](https://arxiv.org/pdf/1509.06569.pdf)

Unfolding matrices into a tensor $
    A \in \mathbb{R}^{n_0\times \ldots \times n_{d-1}}
$

$$
A_k = \bigl(A_{i_{:k}, i_{k:}}\bigr)_{i \in \prod_{j=0}^{d-1} [n_j]}
    \in \mathbb{R}^{
        [n_0 \times \ldots \times n_{k-1}]
        \times [n_k \times \ldots \times n_d]
    }
    \,. $$

where $n_{:k} = (n_j)_{j=0}^{k-1}$ and $n_{k:} = (n_j)_{j=k}^{d-1}$ -- zero-based like numpy.

TT-format:

$$
A_{i} = \sum_{\alpha}
    \prod_{j=0}^{d-1} G_{\alpha_j i_j \alpha_{j+1}}
    \,, $$

where $
    G_{\cdot i_j \cdot} \in \mathbb{R}^{r_j \times r_{j+1}}
$ and $r_0 = r_d = 1$. The rank of the TT-decomposition is $r = \max_{j=0}^d r_j$.

## Tensors

In [None]:
import numpy as np

import torch
import torch.nn.functional as F

%matplotlib inline
import matplotlib.pyplot as plt

Import Tensor-Train converters

In [None]:
from ttmodule import tensor_to_tt, tt_to_tensor

from ttmodule import matrix_to_tt, tt_to_matrix

A simple, run-of-the-mill training loop.
* imports from [`cplxmodule`](https://github.com/ivannz/cplxmodule.git)

In [None]:
import tqdm
from cplxmodule.relevance import penalties
from cplxmodule.utils.stats import sparsity

def train_model(X, y, model, n_steps=20000, threshold=1.0,
                klw=1e-3, verbose=False):
    model.train()
    optim = torch.optim.Adamax(model.parameters(), lr=2e-3)

    losses, weights = [], []
    with tqdm.tqdm(range(n_steps), disable=not verbose) as bar:
        for i in bar:
            optim.zero_grad()

            y_pred = model(X)

            mse = F.mse_loss(y_pred, y)
            kl_d = sum(penalties(model))

            loss = mse + klw * kl_d
            loss.backward()

            optim.step()

            losses.append(float(loss))
            bar.set_postfix_str(f"{float(mse):.3e} {float(kl_d):.3e}")
            with torch.no_grad():
                weights.append(model.weight.clone())
        # end for
    # end with
    return model.eval(), losses, weights

def test_model(X, y, model, threshold=1.0):
    model.eval()
    with torch.no_grad():
        mse = F.mse_loss(model(X), y)
        kl_d = sum(penalties(model))

    f_sparsity = sparsity(model, threshold=threshold, hard=True)
    print(f"{f_sparsity:.1%} {mse.item():.3e} {float(kl_d):.3e}")
    return model

<br>

In [None]:
from ttmodule import TTLinear

from torch.nn import Linear
from cplxmodule.relevance import LinearARD
from cplxmodule.relevance import LinearL0ARD

Specify the problem and device

In [None]:
threshold, device_ = 3.0, "cpu"

Create a simple dataset: $
    (x_i, y_i)_{i=1}^n \in \mathbb{R}^{d}\times\mathbb{R}^{p}
$ and $y_i = E_{:p} x_i$ with $E_{:p} = (e_j)_{j=1}^p$ the diagonal
projection matrix onto the first $p$ dimensions. We put $n\leq p$.

In [None]:
import torch.utils.data

n_features, n_output = 250, 50

X = torch.randn(10200, n_features)
y = -X[:, :n_output].clone()

dataset = torch.utils.data.TensorDataset(X.to(device_), y.to(device_))

train, test = dataset[:200], dataset[200:]

## A TT-linear layer

In [None]:
models = {}

A useful way of thinking about the TT-format of tensors is the following.
If we assume the thelixcographic order of index traversl of the tensor $A$
(`C`-order, or row-major) then

$$
A_\mathbf{i}
    = \prod_{k=1}^d G^{(k)}_{i_k}
    = \sum_\mathbf{\alpha}
        \prod_{k=1}^d e_{\alpha_{k-1}}^\top G^{(k)}_{i_k} e_{\alpha_k}
    = \sum_\mathbf{\alpha}
        \prod_{k=1}^d g^k_{\alpha_{k-1} i_k \alpha_k}
    \,, \\
\mathop{vec} A
    = \sum_\mathbf{\alpha}
        g^1_{\alpha_{0} \alpha_1}
        \otimes g^2_{\alpha_{1} \alpha_2}
        \otimes \cdots
        \otimes g^d_{\alpha_{d-1} \alpha_d}
    \,, $$

with $
    \mathbf{i} = (i_k)_{k=1}^d
$ running from $1$ to $
    [n_1\times \ldots \times n_d]
$,
$\alpha$ running over $\prod_{k=0}^d [r_k]$, $\otimes$ being the Krnoecker product
and `vec` taken in the lexicographic (row-major) order. The cores are $
    G^{(k)}_{i_k} \in \mathbb{R}^{r_{k-1} \times r_k}
$
$i_k \in [n_k]$, and their `vec`-versions -- $
    g^k_{\alpha_{k-1} \alpha_k} \in \mathbb{R}^{n_k}
$
for $\alpha_{k-1} \in [n_{k-1}]$ and $\alpha_k \in [n_k]$.


In the case of a matrix TT-decomposition with shapes $(n_k)_{k=1}^d$
and $(m_k)_{k=1}^d$ we have:

$$
A = \sum_\mathbf{\alpha}
    B^1_{\alpha_{0} \alpha_1}
    \otimes \cdots
    \otimes B^d_{\alpha_{d-1} \alpha_d}
    \,, $$

with $
    B^k_{\alpha_{k-1} \alpha_k} \in \mathbb{R}^{n_k\times m_k}
$ and
$
    B^k_{\alpha_{k-1} \alpha_k p q} = G^{(k)}_{\alpha_{k-1} [p q] \alpha_k}
$, since each $i_k = [p q]$ is in fact a flattened index of the row-major
flattened dimension $n_k\times m_k$.

The matrix dimension factorization determines the block heirarchy of
the matrix and thus is crucial to the properties and success of a linear
layer with the weight in TT-format. If the linear layer in upstream,
i.e. close to the inputs of the network, then the factorization and
the induced heirarcy has semantic ties to the input features. In the
mid-stream layers any particular heirarchy has less rationale, albeit
it seems that the general-to-particular dimension factorization order
is still preferable.

#### Detailed deep factorization

In [None]:
models["detailed-deep-lo"] = TTLinear(
    [5, 5, 5, 2], [5, 5, 2, 1], rank=1, bias=False, reassemble=True)

models["detailed-deep-hi"] = TTLinear(
    [5, 5, 5, 2], [5, 5, 2, 1], rank=5, bias=False, reassemble=True)

#### Detailed shallow factorization

In [None]:
models["detailed-shallow-lo"] = TTLinear(
    [25, 10], [25, 2], rank=1, bias=False, reassemble=True)

models["detailed-shallow-hi"] = TTLinear(
    [25, 10], [25, 2], rank=5, bias=False, reassemble=True)

In [None]:
# models["detailed-lo"] = TTLinear(
#     [25, 5, 2], [5, 5, 2], rank=1, bias=False, reassemble=True)

# models["detailed-lo"] = TTLinear(
#     [5, 5, 5, 1, 2], [5, 5, 2, 1, 1], rank=3, bias=False, reassemble=True)

In [None]:
models["dotted"] = TTLinear(
#     [25, 10, 1], [5, 5, 2], rank=1, bias=False, reassemble=True)
    [25, 5, 2], [5, 5, 2], rank=1, bias=False, reassemble=True)

#### Coarse deep factorization

This one, with inverted hierarchy fails

In [None]:
models["coarse-deep-lo"] = TTLinear(
    [2, 5, 5, 5], [1, 2, 5, 5], rank=1, bias=False, reassemble=True)

models["coarse-deep-hi"] = TTLinear(
    [2, 5, 5, 5], [1, 2, 5, 5], rank=5, bias=False, reassemble=True)

#### Coarse shallow factorization

In [None]:
models["coarse-shallow-lo"] = TTLinear(
    [10, 25], [5, 10], rank=1, bias=False, reassemble=True)

models["coarse-shallow-hi"] = TTLinear(
    [10, 25], [5, 10], rank=5, bias=False, reassemble=True)

In [None]:
models["striped"] = TTLinear(
    [5, 25, 2], [5, 5, 2], rank=1, bias=False, reassemble=True)

In [None]:
# model = LinearARD(n_features, n_output, bias=False)
# model = LinearL0ARD(n_features, n_output, bias=False, group=None)

In [None]:
models["blocked"] = TTLinear(
    [5, 25, 1, 2], [5, 5, 2, 1], rank=3, bias=False, reassemble=True)

Train

In [None]:
models["test"] = TTLinear(
    [10, 25], [10, 5], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [5, 5, 10], [2, 5, 5], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [25, 10, 1], [2, 5, 5], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [25, 10], [2, 25], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [2, 5, 25], [2, 1, 25], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [2, 5, 25], [2, 5, 5], rank=4, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [5, 1, 25, 2], [1, 5, 2, 5], rank=3, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [25, 5, 2], [1, 25, 2], rank=1, bias=False, reassemble=True)

In [None]:
models["test"] = TTLinear(
    [5, 25, 1, 2], [5, 5, 2, 1], rank=2, bias=False, reassemble=True)

In [None]:
model, losses, weights = train_model(
    *train, models["test"], n_steps=2000,
    threshold=threshold, klw=1e0, verbose=True)

Test the model

In [None]:
test_model(*test, model, threshold=threshold)

In [None]:
for core in model.cores:
    plt.imshow(abs(core.detach()).numpy()[0, ..., 0].T,
               cmap=plt.cm.bone, interpolation=None)

    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    plt.show()

<br>

## Simple visualization

... with not very simple setup

In [None]:
from matplotlib.gridspec import GridSpec


def canvas_setup(figsize, **kwargs):
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(1, 2, figure=fig, width_ratios=[7, 1])
    ax_main = fig.add_subplot(gs[0])
    ax_loss = fig.add_subplot(gs[1])

    with torch.no_grad():
        ax_main.imshow(abs(weights[0]).numpy(), cmap=plt.cm.bone)
        ax_loss.semilogy(losses)
    
    plt.tight_layout()
    return fig, (ax_main, ax_loss)

In [None]:
def canvas_clear(*axes):
    """Clear axis preserving its aesthetics."""
    for ax in axes:
        props = ax.properties()
        ax.clear()
        ax.update({
            k: props[k] for k in [
                "xticks", "yticks", "xlim", "ylim", "zorder", "alpha"
            ]
        })
    return axes

In [None]:
def animate_weight(n_epoch, *axes):
    ax_main, ax_loss = canvas_clear(*axes)
    
    artists = []
    with torch.no_grad():
        artists.append(ax_main.imshow(
            abs(weights[n_epoch]).numpy(),
            cmap=plt.cm.bone,
            interpolation=None
        ))
    artists.append(ax_main.set_title(f"it. {n_epoch}"))

    artists.append(
        ax_loss.semilogy(losses[:n_epoch + 1], lw=2, color="fuchsia")
    )
    artists.append(
        ax_loss.scatter([n_epoch + 1], [losses[n_epoch]],
                        s=25, color="cyan")
    )
    artists.append(
        ax_loss.axvline(n_epoch + 1, c='cyan', lw=2, alpha=0.25, zorder=-10)
    )

    return [
        artist_ for artist_ in artists
        if hasattr(artist_, "set_animated")
    ]

An interactive slider with ipywidgets

In [None]:
from ipywidgets import widgets

def int_slider(value, min, max, step):
    return widgets.IntSlider(
        value=value, min=min, max=max, step=step, continuous_update=False,
        layout=widgets.Layout(min_width='500px', display='flex'))


In [None]:
def plot_weight(n_epoch=0):
    fig, axes = canvas_setup(figsize=(16, 3))
    animate_weight(n_epoch, *axes)
    plt.show()


widgets.interact(plot_weight, n_epoch=int_slider(1000, 0, len(weights)-1, 10));

<br>

In [None]:
import matplotlib.animation as animation

try:
    FFMpegWriter = animation.writers['ffmpeg_file']
    class PatchedFFMpegWriter(FFMpegWriter):
        def setup(self, fig, outfile, *args, **kwargs):
            dpi = kwargs.get("dpi", getattr(self, "dpi", None))

            frame_prefix = kwargs.get(
                "frame_prefix", getattr(self, "temp_prefix", '_tmp'))

            clear_temp = kwargs.get(
                "clear_temp", getattr(self, "clear_temp", True))

            super().setup(fig, outfile, clear_temp=clear_temp,
                          frame_prefix=frame_prefix, dpi=dpi)

except:
    class PatchedFFMpegWriter(animation.AbstractMovieWriter):
        pass

In [None]:
import os
import time
import tempfile

dttm = time.strftime("%Y%m%d-%H%M%S")

fig, axes = canvas_setup(figsize=(16, 3))

fps, n_frames = 15, len(weights)

schedule = [
    *range(0, 25, 1)
] + [
    *range(25, n_frames, 10)
]

shape_tag = model.extra_repr()
outfile = os.path.join(".", f"weight-{model.__class__.__name__}{shape_tag}-{dttm}.mp4")

# dump the intermediate frames into a temporary dir
with tempfile.TemporaryDirectory() as tmp:
    print(f"temp dir at {tmp}", flush=True)

    writer = PatchedFFMpegWriter(fps=fps, bitrate=-1, metadata={})
    writer.setup(fig, outfile, frame_prefix=os.path.join(tmp, f"_frame_"))

    ani = animation.FuncAnimation(
        fig, animate_weight, tqdm.tqdm_notebook(schedule, unit="frm"),
        interval=1, repeat_delay=None, blit=False, fargs=axes)
    ani.save(outfile, writer=writer)

plt.close()

In [None]:
from IPython.display import Video

print(outfile)
Video(data=outfile, embed=True, width=768)

In [None]:
assert False

## Matrix-vector product in TT-format

Suppose the TT representation of a matrix $W\in \mathbb{R}^{n\times m}$
with shapes $(n_k)_{k=1}^d$ and $(m_k)_{k=1}^d$ is given by $ \prod_{k=1}^d
G^{(k)}_{i_k j_k}$ with $
    G^{(k)}_{i_k j_k} \in \mathbb{R}^{r_{k-1}\times r_k}
$ with $r_0 = r_d = 1$. Then for index $
    \alpha \in \prod_{k=1}^{d-1} [r_k]
$ with $\alpha_0 = \alpha_d = 1$ we have:

$$
y_j = e_j^\top W^\top x
    = \sum_\alpha \sum_i 
          \prod_{k=1}^d g_{\alpha_{k-1} i_k j_k \alpha_k} x_i
    = \sum_{\alpha_0, \alpha_{1:}} \sum_{i_{2:}} 
          \prod_{k=2}^d g_{\alpha_{k-1} i_k j_k \alpha_k}
          \sum_{i_1} g_{\alpha_0 i_1 j_1 \alpha_1} x_{i_1 i_{2:}}
    = \sum_{\alpha_{1:}} \sum_{i_{2:}} 
          \prod_{k=2}^d g_{\alpha_{k-1} i_k j_k \alpha_k}
          \sum_{\alpha_0, i_1} g_{\alpha_0 i_1 j_1 \alpha_1} x_{i_1 i_{2:} \alpha_0}
    \,,\\
\dots
    = \sum_{\alpha_{1:}} \sum_{i_{2:}} 
          \prod_{k=2}^d g_{\alpha_{k-1} i_k j_k \alpha_k} z_{i_{2:} j_1 \alpha_1}
    = \sum_{\alpha_{2:}} \sum_{i_{3:}} 
          \prod_{k=3}^d g_{\alpha_{k-1} i_k j_k \alpha_k} z_{i_{3:} j_{:3} \alpha_2}
    \,. $$

<br>

## Tensor Rings

See [Tensor Ring Decomposition](https://arxiv.org/abs/1606.05535). Essentially the same idea but
with $t_0 = r_d \geq 1$. Tensors in TT-format are a special case of TR-format:

$$
A_\mathbf{i}
    = \mathop{Tr} \prod_{k=1}^d G^{(k)}_{i_k}
    = \sum_{\mathbf{\alpha}\colon \alpha_0=\alpha_d}
        \prod_{k=1}^d e_{\alpha_{k-1}}^\top G^{(k)}_{i_k} e_{\alpha_k}
    = \sum_{\mathbf{\alpha}\colon \alpha_0=\alpha_d}
        \prod_{k=1}^d g^k_{\alpha_{k-1} i_k \alpha_k}
    \,, $$

where $
    G^{(k)}_{i_k} \in \mathbb{R}^{r_j \times r_{j+1}}
$ and $r_0 = r_d$ and $
    \alpha \in \prod_{k=0}^d [n_k]
$.


In [None]:
# ranks = [2, 3, 4, 5, 5]
# shapes = [2, 3, 7, 4, 5], [3, 7, 7, 5, 2]

ranks = [3, 2, 1, 5]
shape = [2, 3, 7, 5], [3, 7, 1, 2]

cores = [torch.randn(r0, n, m, r1, dtype=torch.double)
         for r0, n, m, r1 in zip(ranks[-1:] + ranks[:-1], *shape, ranks)]

In [None]:
def tr_to_tensor_zero(*cores):
    # chip off the first core and contract the rest
    rest = tt_to_tensor(*cores[1:], squeeze=False)

    # contract with tensor_dot (reshape + einsum("i...j, j...i->...") was slower)
    return torch.tensordot(cores[0], rest, dims=[[0, -1], [-1, 0]])

In [None]:
from ttmodule.tensor import tr_to_tensor

res = %timeit -o -n 100 -r 25 tr_to_tensor_zero(*cores)

timing = [res]
for k in range(len(cores)):

    res = %timeit -o -n 100 -r 25 tr_to_tensor(*cores, k=k)
    timing.append(res)
    print(f">>> ({k}) {ranks[k]} {cores[k].shape}")

$$
    W_{ij} = \mathop{tr}
        \prod_{k=1}^d G^{(k)}_{i_k j_k}
    \,, \\
    y_j = \sum_i W_{ij} x_i
        = \sum_i \mathop{tr} \prod_{k=1}^d G^{(k)}_{i_k j_k} x_i
        = \mathop{tr} \sum_i \prod_{k=1}^d G^{(k)}_{i_k j_k} x_i
        = \mathop{tr} \sum_{i_{1:}} \sum_{i_1} \prod_{k=1}^d G^{(k)}_{i_k j_k} x_i
    \,, \\
    y_j = \mathop{tr} \sum_{i_{1:}}
         \prod_{k=2}^d G^{(k)}_{i_k j_k} \sum_{i_1} G^{(1)}_{i_1 j_1} x_i
    \,. $$

In [None]:
from ttmodule.matrix import tr_to_matrix

weight = tr_to_matrix(shape, *cores, k=2)

With $\alpha \in \prod_{k=1}^d [r_k]$ and $\alpha_0 = \alpha_d$
and broadcasting $x_{i \alpha_d} = x_{i}$
$$
    y = W^\top x
        = \bigl( \sum_i \sum_\alpha
          \prod_{k=1}^d g_{\alpha_{k-1} i_k j_k \alpha_k} x_i \bigr)_j
        = \bigl(\sum_\alpha \sum_{i_{:d}} 
          \prod_{k=1}^{d-1} g_{\alpha_{k-1} i_k j_k \alpha_k}
               \sum_{i_d} g_{\alpha_{d-1} i_d j_d \alpha_d} x_i \bigr)_j
    \,, \\
    y = W^\top x
        = \bigl(\sum_{\alpha_d \alpha_1} \sum_{\alpha_{2:d}} \sum_{i_{2:}} 
          \prod_{k=2}^d g_{\alpha_{k-1} i_k j_k \alpha_k}
               \sum_{i_1} g_{\alpha_d i_1 j_1 \alpha_1} x_i \bigr)_j
        = \bigl(\sum_{\alpha_d \alpha_1} \sum_{i_{2:}} 
          Z_{\alpha_1 i_{2:} j_{2:} \alpha_d}
               \sum_{i_1} g_{\alpha_d i_1 j_1 \alpha_1} x_i \bigr)_j
    \,, \\
    y_j = e_j^\top W^\top x
        = \sum_{\alpha_d} \sum_i Z_{\alpha_d i j \alpha_d} x_i
    \,. $$

In [None]:
def ttmv(shape, input, *cores):
    *head, tail = input.shape
    data = input.view(-1, *shape[0], 1)
    for core in cores:
        data = torch.tensordot(data, core, dims=[[1, -1], [1, 0]])

    return data.reshape(*head, -1)

In [None]:
input = torch.randn(100, np.prod(shape[0])).double()

reference = sum([
    ttmv(shape, input, cores[ 0][[a], ...],
         *cores[1:-1], cores[-1][..., [a]])
    for a in range(ranks[-1])
])

assert torch.allclose(reference, torch.mm(input, weight))

In [None]:
from ttmodule.matrix import invert


def tr_vec(shape, input, *cores, k=0):
    k = (len(cores) + k) if k < 0 else k
    assert 0 <= k < len(cores)

    *head, tail = input.shape
    data = input.view(-1, *shape[0])

    shuffle = list(range(1, data.dim()))
    shuffle = 0, *shuffle[k:], *shuffle[:k]
    data = data.permute(shuffle).unsqueeze(-1)

    cores, output = cores[k:] + cores[:k], 0
    for a in range(cores[0].shape[0]):
        cyc = cores[ 0][[a], ...], *cores[1:-1], cores[-1][..., [a]]

        interm = data.clone()
        for core in cyc:
            interm = torch.tensordot(interm, core, dims=[[1, -1], [1, 0]])
        output += interm

    return output.squeeze(-1).permute(invert(*shuffle)).reshape(*head, -1)

In [None]:
for k in range(len(cores)):
    assert torch.allclose(tr_vec(shape, input, *cores, k=k), reference)

In [None]:
*head, tail = input.shape
data = input.view(-1, *shape[0], 1)
for core in cores:
    data = torch.tensordot(data, core, dims=[[1, -1], [1, 0]])
data = data.sum(dim=-1).reshape(*head, -1)

assert not torch.allclose(data, reference)

<br>

### Transposed shape for TTLinear

In [None]:
# ranks = [2, 3, 4, 5, 5]
# shapes = [2, 3, 7, 4, 5], [3, 7, 7, 5, 2]

ranks = [1, 3, 2, 5, 1]
shape = [2, 3, 7, 5], [3, 7, 1, 2]

cores = [torch.randn(r0, n, m, r1, dtype=torch.double)
         for r0, n, m, r1 in zip(ranks[:-1], *shape, ranks[1:])]

In [None]:
shape_t = shape[1], shape[0]
cores_t = [core.permute(0, 2, 1, 3) for core in cores]

In [None]:
def ttmv_t(shape, input, *cores):
    *head, tail = input.shape
    data = input.view(-1, *shape[1], 1)
    for core in cores:
        data = torch.tensordot(data, core, dims=[[1, -1], [2, 0]])

    return data.reshape(*head, -1)

In [None]:
assert torch.allclose(ttmv_t(shape_t, input, *cores_t),
                      ttmv(shape, input, *cores))

In [None]:
assert torch.allclose(tt_to_matrix(shape_t, *cores_t).t(),
                      tt_to_matrix(shape, *cores))