Implementation of Tensor-Train toolbox in Jax containing several routines for working with tensors in TT format in Python.
JAX and Flax are required for installation (see instructions there). Install TTAX from PyPi:
pip install ttax
This is a quick starting guide to look at the basics of working with ttax library. Our library provides routines for Tensor-Train object – a compact (factorized) representation of a tensor.
In example below we create TT-tensor, multiply it by constant and convert it to full tensor format.
import ttax
import numpy as np
import jax
rng = jax.random.PRNGKey(42)
t = ttax.random.tensor(rng, [10, 5, 2], tt_rank=3)
print(ttax.full(2 * t))
Detailed information read here.
The main classes representing TT-tensors and TT-matrices are TT
and TTMatrix
.
The base for operations on the einsum
method are TTEinsum
and WrappedTT
.
MIT License