In [1]:
import os

import torch
import tntorch as tn

import tensorkrowch as tk
from tensorkrowch.decompositions import tt_rss

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_features = 100
phys_dim = 2
bond_dim = 10

mps = tk.models.MPS(n_features=n_features,
                    phys_dim=phys_dim,
                    bond_dim=bond_dim,
                    init_method='unit',
                    device=device)

# Save cores
results_dir = os.path.join('..', '..', 'results', '1_performance', 'random_mps')
os.makedirs(results_dir, exist_ok=True)

# torch.save(mps.tensors, os.path.join(results_dir, 'cores.pt'))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# TT-RSS

In [12]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'random_mps')
os.makedirs(results_dir, exist_ok=True)

In [13]:
# Load cores
cores = torch.load(os.path.join(results_dir, 'cores.pt'),
                   weights_only=False)
mps = tk.models.MPS(tensors=cores)

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [14]:
sketch_size = 50

domain = [torch.arange(2) for _ in range(n_features)]

# Create dataset to compare results and check accuracy
sketch_samples = torch.randint(low=0, high=phys_dim, size=(sketch_size, n_features))

def embedding(x): return tk.embeddings.basis(x, dim=phys_dim).float()
def fun(x): return mps(embedding(x)).unsqueeze(-1)

cores_rss, info = tt_rss(function=fun,
                         embedding=embedding,
                         sketch_samples=sketch_samples,
                         domain=domain,
                         rank=bond_dim,
                         cum_percentage=1 - 1e-5,
                         batch_size=500,
                         device=device,
                         verbose=False,
                         return_info=True)

mps.reset()

# Save cores
# torch.save(cores_rss,
#            os.path.join(results_dir, f'cores_rss_{info["total_time"]:.2f}.pt'))

In [15]:
info

{'total_time': 27.98005485534668,
 'val_eps': tensor(1.8876e-06, device='cuda:0')}

In [16]:
cores_rss = torch.load(os.path.join(results_dir, 'cores_rss_27.98.pt'),
                       weights_only=False)
mps_rss = tk.models.MPS(tensors=[c.to(device) for c in cores_rss])
mps_rss.canonicalize(renormalize=True)

In [17]:
mps_norm = mps.norm(log_scale=True)
mps_rss_norm = mps_rss.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_rss_norm.item()}')

mps.reset()
mps_rss.reset()

MPS: 33.506038665771484
MPS RSS: 33.50614929199219


In [18]:
for node1, node2 in zip(mps.mats_env, mps_rss.mats_env):
    node1['input'] ^ node2['input']

In [19]:
log_scale = 0

# Contract mps with mps_rss
stack = tk.stack(mps.mats_env)
stack_rss = tk.stack(mps_rss.mats_env)
stack ^ stack_rss

mats_results = tk.unbind(stack @ stack_rss)

mats_results[0] = mps.left_node @ (mps_rss.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_rss.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

33.506103515625


In [23]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(1.0000, device='cuda:0', grad_fn=<ExpBackward0>)

# TT-CROSS

In [24]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'random_mps')
os.makedirs(results_dir, exist_ok=True)

In [26]:
# Load cores
cores = torch.load(os.path.join(results_dir, 'cores.pt'),
                   weights_only=False)
mps = tk.models.MPS(tensors=cores)

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [27]:
domain = [torch.arange(2, device=device) for _ in range(n_features)]

def embedding(x): return tk.embeddings.basis(x.int(), dim=phys_dim).float()
def fun(x): return mps(embedding(x))

tt_cross, info = tn.cross(function=fun,
                          domain=domain,
                          device=device,
                          function_arg='matrix',
                          rmax=bond_dim,
                          # max_iter=5,
                          eps=1e-3,
                          verbose=True,
                          return_info=True)

cores_cross = tt_cross.cores
cores_cross[0] = cores_cross[0][0]
cores_cross[-1] = cores_cross[-1][..., 0]

mps.reset()

# Save cores
# torch.save(cores_cross,
#            os.path.join(results_dir, f'cores_cross_{info["total_time"]:.2f}.pt'))

cross device is cuda
Functions that require cross-approximation can be accelerated with the optional maxvolpy package, which can be installed by 'pip install maxvolpy'. More info is available at https://bitbucket.org/muxas/maxvolpy.
Cross-approximation over a 100D domain containing 1.26765e+30 grid points:
iter: 0  | eps: 1.000e+00 | time:  13.6642 | largest rank:   1
iter: 1  | eps: 1.000e+00 | time: 143.9003 | largest rank:   4
iter: 2  | eps: 1.106e+00 | time: 237.5115 | largest rank:   7
iter: 3  | eps: 3.988e-06 | time: 333.9764 | largest rank:  10 <- converged: eps < 0.001
Did 63114 function evaluations, which took 5.613s (1.124e+04 evals/s)



In [28]:
info

{'nsamples': 63114,
 'eval_time': 5.612637519836426,
 'val_epss': [tensor(1., device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.0000, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.1060, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(3.9881e-06, device='cuda:0', grad_fn=<DivBackward0>)],
 'min': 0,
 'argmin': None,
 'lsets': [array([[0]]),
  array([[0, 0],
         [0, 1]]),
  array([[0, 0, 0],
         [0, 0, 1],
         [0, 1, 0],
         [0, 1, 1]]),
  array([[0, 0, 0, 0],
         [0, 0, 0, 1],
         [0, 0, 1, 0],
         [0, 0, 1, 1],
         [0, 1, 0, 0],
         [0, 1, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 1]]),
  array([[0, 0, 1, 0, 1],
         [0, 0, 0, 1, 1],
         [0, 1, 1, 1, 1],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0],
         [0, 0, 1, 1, 1],
         [0, 1, 0, 1, 1],
         [0, 1, 1, 0, 1],
         [0, 1, 1, 0, 0]]),
  array([[0, 1, 0, 1, 1, 1],
         [0, 1, 0, 0, 0, 1],
         [0, 1, 1, 0,

In [30]:
cores_cross = torch.load(os.path.join(results_dir, 'cores_cross_333.98.pt'),
                         weights_only=False)
mps_cross = tk.models.MPS(tensors=[c.to(device) for c in cores_cross])
mps_cross.canonicalize(renormalize=True)

In [31]:
mps_norm = mps.norm(log_scale=True)
mps_cross_norm = mps_cross.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_cross_norm.item()}')

mps.reset()
mps_cross.reset()

MPS: 33.506038665771484
MPS RSS: 33.50614547729492


In [32]:
for node1, node2 in zip(mps.mats_env, mps_cross.mats_env):
    node1['input'] ^ node2['input']

In [33]:
log_scale = 0

# Contract mps with mps_cross
stack = tk.stack(mps.mats_env)
stack_cross = tk.stack(mps_cross.mats_env)
stack ^ stack_cross

mats_results = tk.unbind(stack @ stack_cross)

mats_results[0] = mps.left_node @ (mps_cross.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_cross.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

33.506099700927734


In [34]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(1.0000, device='cuda:0', grad_fn=<ExpBackward0>)