# Introduction

Tensor decoupling methods allow use to *decouple* multivariate, vector-valued functions.

More specifically, it allows us to transform a function $\mathbf{f} : \mathbb{R}^{m} \to \mathbb{R}^{n}$ into the following form:
$$\mathbf{W}\mathbf{g}(\mathbf{V}^T\mathbf{x})$$

with $\mathbf{W} \in \mathbb{R}^{n \times r}, \mathbf{V} \in \mathbb{R}^{m \times r}, \mathbf{g}: \mathbb{R}^r \to \mathbb{R}^r$ and $\mathbf{x} \in \mathbb{R}^m$  is the input.

We can see that it consists in two linear transformations with a non linear element-wise transformation in between. We can also easily see that this is equivalent to a 2-layer neural network. https://en.wikipedia.org/wiki/Neural_network_(machine_learning)

Let's show a simple example on how to obtain such a representation using `untangle`.

In [1]:
%%capture
! uv pip install -e ..
import jax, jax.numpy as jnp

Let's begin by defining a simple polynomial function.

In [2]:
def function(x):
    x1, x2 = x
    return jnp.array([
        x1**2 + x2 + 2,
        x2**3 - x1 + 1,
    ])

# example of input/output
x = jnp.array([-1.0, 0.5])
function(x)

Array([3.5  , 2.125], dtype=float32)

Now, we need to **collect information** about this function.

We can use the utility function `collect_information`, which simply generates random points and returns them alongside with the function outputs, and the stacked jacobian tensor.

In [3]:
from untangle.utils import collect_information

N = 100
X, Y, J = collect_information(function, N, 2, range=(-2, 2))

X.shape, Y.shape, J.shape

((100, 2), (100, 2), (2, 2, 100))

We can now run the decoupling algorithm, but first, we need to decide which **rank** to use as part of the CP decomposition.

Here it's not so expensive to just try many ranks until one yields a very low reconstruction error.

In [4]:
from untangle.utils import search_rank

rank = search_rank(J, linesearch=True)
rank

3

We can finally obtain our decoupled representation.

In [7]:
from untangle.algorithms.basic import decoupling_basic, inference

W, V, coefs = decoupling_basic(X, Y, J, rank, degree=3, linesearch=True)
decoupling = inference(W, V, coefs)

Lets' try it.

In [10]:
function(x), decoupling(x)

(Array([3.5  , 2.125], dtype=float32),
 Array([3.1674182, 2.1443737], dtype=float32))

In [12]:
x = jnp.array([2, -2])
function(x), decoupling(x)

(Array([ 4, -9], dtype=int32),
 Array([  3.3900006, -10.250131 ], dtype=float32))

Not too bad, but there exists better decoupling methods.