# TT-decomposition

[tt_hse16_slides](https://bayesgroup.github.io/team/arodomanov/tt_hse16_slides.pdf)

[Tensorising Neural Networks](https://arxiv.org/pdf/1509.06569.pdf)

Unfolding matrices into a tensor $A \in \mathbb{R}^{n_0\times \ldots \times n_{d-1}}$
$$
    A_k = \bigl(A_{i_{:k}, i_{k:}}\bigr)_{i \in \prod_{j=0}^{d-1} [n_j]}
        \in \mathbb{R}^{[n_0 \times \ldots \times n_{k-1}] \times [n_k \times \ldots \times n_d]}
    \,. $$
where $n_{:k} = (n_j)_{j=0}^{k-1}$ and $n_{k:} = (n_j)_{j=k}^{d-1}$ -- zero-based like numpy.

TT-format:
$$
    A_{i} = \sum_{\alpha}
        \prod_{j=0}^{d-1} G_{\alpha_j i_j \alpha_{j+1}}
    \,, $$
where $G_{\alpha_j i_j \alpha_{j+1}} \in \mathbb{R}^{r_j \times r_{j+1}}$
and $r_0 = r_d = 1$. The rank of the TT-decomposition is $r = \max_{j=0}^d r_j$.

## Tensors

In [None]:
import numpy as np

import torch
import torch.nn.functional as F

%matplotlib inline
import matplotlib.pyplot as plt

Import Tensor-Train converters

In [None]:
from ttmodule import tensor_to_tt, tt_to_tensor

from ttmodule import matrix_to_tt, tt_to_matrix

A simple, run-of-the-mill training loop.
* imports from [`cplxmodule`](https://github.com/ivannz/cplxmodule.git)

In [None]:
import tqdm
from cplxmodule.relevance import penalties
from cplxmodule.utils.stats import sparsity

def train_model(X, y, model, n_steps=20000, threshold=1.0,
                klw=1e-3, verbose=False):
    model.train()
    optim = torch.optim.Adamax(model.parameters(), lr=2e-3)

    losses, weights = [], []
    with tqdm.tqdm(range(n_steps), disable=not verbose) as bar:
        for i in bar:
            optim.zero_grad()

            y_pred = model(X)

            mse = F.mse_loss(y_pred, y)
            kl_d = sum(penalties(model))

            loss = mse + klw * kl_d
            loss.backward()

            optim.step()

            losses.append(float(loss))
            bar.set_postfix_str(f"{float(mse):.3e} {float(kl_d):.3e}")
            with torch.no_grad():
                weights.append(model.weight.clone())
        # end for
    # end with
    return model.eval(), losses, weights

def test_model(X, y, model, threshold=1.0):
    model.eval()
    with torch.no_grad():
        mse = F.mse_loss(model(X), y)
        kl_d = sum(penalties(model))

    f_sparsity = sparsity(model, threshold=threshold, hard=True)
    print(f"{f_sparsity:.1%} {mse.item():.3e} {float(kl_d):.3e}")
    return model

<br>

In [None]:
from ttmodule import TTLinear

from torch.nn import Linear
from cplxmodule.relevance import LinearARD
from cplxmodule.relevance import LinearL0ARD

Specify the problem and device

In [None]:
threshold, device_ = 3.0, "cpu"

Create a simple dataset: $(x_i, y_i)_{i=1}^n \in \mathbb{R}^{d}\times\mathbb{R}^{p}$
and $y_i = E_{:p} x_i$ with $E_{:p} = (e_j)_{j=1}^p$ the diagonal projection
matrix onto the first $p$ dimensions. We put $n\leq p$.

In [None]:
import torch.utils.data

n_features, n_output = 250, 50

X = torch.randn(10200, n_features)
y = -X[:, :n_output].clone()

dataset = torch.utils.data.TensorDataset(X.to(device_), y.to(device_))

train, test = dataset[:200], dataset[200:]

A TT-linear layer.

In [None]:
# model = TTLinear([5, 5, 5, 2], [5, 5, 2, 1], rank=5, bias=False, reassemble=True)
# model = TTLinear([25, 5, 2], [5, 5, 2], rank=1, bias=False, reassemble=True)
model = TTLinear([5, 5, 5, 1, 2], [5, 5, 2, 1, 1], rank=3, bias=False, reassemble=True)

In [None]:
# model = LinearARD(n_features, n_output, bias=False)
# model = LinearL0ARD(n_features, n_output, bias=False, group=None)

Train

In [None]:
model, losses, weights = train_model(
    *train, model, n_steps=2000, threshold=threshold,
    klw=1e0, verbose=False)

Test the model

In [None]:
test_model(*test, model, threshold=threshold)

<br>

## Simple visualization

... with not very simple setup

In [None]:
from matplotlib.gridspec import GridSpec


def canvas_setup(figsize, **kwargs):
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(1, 2, figure=fig, width_ratios=[7, 1])
    ax_main = fig.add_subplot(gs[0])
    ax_loss = fig.add_subplot(gs[1])

    with torch.no_grad():
        ax_main.imshow(abs(weights[0]).numpy(), cmap=plt.cm.bone)
        ax_loss.semilogy(losses)
    
    plt.tight_layout()
    return fig, (ax_main, ax_loss)

In [None]:
def canvas_clear(*axes):
    """Clear axis preserving its aesthetics."""
    for ax in axes:
        props = ax.properties()
        ax.clear()
        ax.update({
            k: props[k] for k in [
                "xticks", "yticks", "xlim", "ylim", "zorder", "alpha"
            ]
        })
    return axes

In [None]:
def animate_weight(n_epoch, *axes):
    ax_main, ax_loss = canvas_clear(*axes)
    
    artists = []
    with torch.no_grad():
        artists.append(ax_main.imshow(
            abs(weights[n_epoch]).numpy(),
            cmap=plt.cm.bone,
            interpolation=None
        ))
    artists.append(ax_main.set_title(f"it. {n_epoch}"))

    artists.append(
        ax_loss.semilogy(losses[:n_epoch + 1], lw=2, color="fuchsia")
    )
    artists.append(
        ax_loss.scatter([n_epoch + 1], [losses[n_epoch]],
                        s=25, color="cyan")
    )
    artists.append(
        ax_loss.axvline(n_epoch + 1, c='cyan', lw=2, alpha=0.25, zorder=-10)
    )

    return [
        artist_ for artist_ in artists
        if hasattr(artist_, "set_animated")
    ]

An interactive slider with ipywidgets

In [None]:
from ipywidgets import widgets

def int_slider(value, min, max, step):
    return widgets.IntSlider(
        value=value, min=min, max=max, step=step, continuous_update=False,
        layout=widgets.Layout(min_width='500px', display='flex'))


In [None]:
def plot_weight(n_epoch=0):
    fig, axes = canvas_setup(figsize=(16, 3))
    animate_weight(n_epoch, *axes)
    plt.show()


widgets.interact(plot_weight, n_epoch=int_slider(10, 0, len(weights)-1, 10));

<br>

In [None]:
import matplotlib.animation as animation

FFMpegWriter = animation.writers['ffmpeg_file']
class PatchedFFMpegWriter(FFMpegWriter):
    def setup(self, fig, outfile, *args, **kwargs):
        dpi = kwargs.get("dpi", getattr(self, "dpi", None))

        frame_prefix = kwargs.get(
            "frame_prefix", getattr(self, "temp_prefix", '_tmp'))

        clear_temp = kwargs.get(
            "clear_temp", getattr(self, "clear_temp", True))

        super().setup(fig, outfile, clear_temp=clear_temp,
                      frame_prefix=frame_prefix, dpi=dpi)

In [None]:
import os
import time
import tempfile

dttm = time.strftime("%Y%m%d-%H%M%S")

fig, axes = canvas_setup(figsize=(16, 3))

fps, n_frames = 15, len(weights)

schedule = [
    *range(0, 25, 1)
] + [
    *range(25, n_frames, 10)
]

outfile = os.path.join(".", f"weight-{model.__class__.__name__}-{dttm}.mp4")

# dump the intermediate frames into a temporary dir
with tempfile.TemporaryDirectory() as CACHE_PATH:
    print(f"temp dir at {CACHE_PATH}", flush=True)

    writer = PatchedFFMpegWriter(fps=fps, bitrate=-1, metadata={})
    writer.setup(fig, outfile, frame_prefix=os.path.join(
        CACHE_PATH, f"_frame_"))

    ani = animation.FuncAnimation(
        fig, animate_weight, tqdm.tqdm_notebook(schedule, unit="frm"),
        interval=1, repeat_delay=None, blit=False, fargs=axes)
    ani.save(outfile, writer=writer)
# end with
plt.close()

In [None]:
outfile = """/Users/user/Bitbox/weight-TTLinear-20190707-202453.mp4"""

In [1]:
from IPython.display import Video

Video(data=outfile, embed=True, width=768)

In [None]:
assert False

<br>

### Trunk: model grafting 

In [None]:
mod, name = module, "columns.boost00.bricks.0.body.dense03"

path = []
child, dot, name = name.partition(".")
while dot:
    mod = getattr(mod, child, None)
    if mod is None:
        break

    path.append(child)
    child, dot, name = name.partition(".")

mod = getattr(mod, child, None)

In [None]:
mod, child, path

In [None]:
getattr(mod, child)

<br>