Skip to content

fasghq/ttax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TTAX: Tensor-Train toolbox on Jax

Implementation of Tensor-Train toolbox in Jax containing several routines for working with tensors in TT format in Python.

Installation

JAX and Flax are required for installation (see instructions there). Install TTAX from PyPi:

pip install ttax

Quick start

This is a quick starting guide to look at the basics of working with ttax library. Our library provides routines for Tensor-Train object – a compact (factorized) representation of a tensor.

In example below we create TT-tensor, multiply it by constant and convert it to full tensor format.

import ttax
import numpy as np
import jax

rng = jax.random.PRNGKey(42)
t = ttax.random.tensor(rng, [10, 5, 2], tt_rank=3)
print(ttax.full(2 * t))

Detailed information read here.

Structure overview

The main classes representing TT-tensors and TT-matrices are TT and TTMatrix. The base for operations on the einsum method are TTEinsum and WrappedTT.

License

MIT License

Releases

No releases published

Packages

No packages published

Languages