In [1]:
import os

import torch
import tntorch as tn

import tensorkrowch as tk
from tensorkrowch.decompositions import tt_rss

from math import sqrt

# AKLT MPS

In [2]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1)
aklt_core

tensor([[[ 0.0000,  0.8165],
         [-0.5774,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.5774],
         [-0.8165,  0.0000]]])

In [3]:
aklt_core.norm() ** 2

tensor(2.0000)

In [4]:
# Model hyperparameters
n_features = 100
phys_dim = aklt_core.shape[1]
bond_dim = aklt_core.shape[0]

fn_name = 'aklt_mps'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
boundary_conditions = [0, 0]
scale = sqrt(phys_dim)

aklt_cores = [aklt_core.to(device) * scale for _ in range(n_features)]
# aklt_cores = [aklt_core.to(device) for _ in range(n_features)]
aklt_cores[0] = aklt_cores[0][boundary_conditions[0], :, :]
aklt_cores[-1] = aklt_cores[-1][:, :, boundary_conditions[1]]

In [6]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [7]:
def numel_mps(model):
    n = 0
    for node in model.mats_env:
        n += node.tensor.numel()
    return n

print('Nº params. mps:', numel_mps(mps))

Nº params. mps: 1200


# TT-RSS

In [23]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [8]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'aklt_mps')
os.makedirs(results_dir, exist_ok=True)

In [9]:
def embedding(x):
    x = tk.embeddings.basis(x, dim=phys_dim).float()
    return x

In [10]:
def fn_aklt(input):
    output = mps(embedding(input)).unsqueeze(-1)
    return output

In [11]:
std = 1e-10

def fn_aklt(input):
    output = mps(embedding(input)).unsqueeze(-1)
    output = output + torch.randn_like(output) * std
    return output

In [31]:
samples_size = 200
samples = torch.tensor([]).int()

for i in range(n_features):
    mps.unset_data_nodes()
    mps.reset()
    mps.in_features = torch.arange(i + 1).tolist()
    
    new_feature = torch.arange(phys_dim).view(-1, 1)
    new_feature = new_feature.repeat(samples_size, 1)
    
    if i > 0:
        aux_samples = samples.repeat(1, phys_dim)
        aux_samples = aux_samples.reshape(samples_size * phys_dim, i)
    else:
        aux_samples = samples
    
    aux_samples = torch.cat([aux_samples, new_feature], dim=1)
    
    density = mps(embedding(aux_samples.to(device)),
                  marginalize_output=True,
                  renormalize=True,
                  )
    
    if i == (n_features - 1):
        density = torch.outer(density, density)
    
    distr = density.diagonal().reshape(samples_size, phys_dim)
    distr = distr / distr.sum(dim=-1, keepdim=True)
    distr = distr.cumsum(dim=-1)
    
    probs = torch.rand(samples_size, 1).to(device)
    new_samples = phys_dim - (probs < distr).sum(dim=-1)
    
    if i > 0:
        samples = torch.cat([samples,
                             new_samples.cpu().int().reshape(-1, 1)], dim=1)
    else:
        samples = new_samples.cpu().int().reshape(-1, 1)

mps.reset()
mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# torch.save(samples, os.path.join(results_dir, f'dataset_{n_features}.pt'))

In [32]:
samples[0]

tensor([0, 2, 1, 0, 1, 1, 2, 1, 1, 0, 2, 0, 1, 2, 1, 1, 0, 1, 2, 0, 1, 1, 2, 0,
        1, 1, 2, 0, 2, 1, 1, 0, 1, 2, 0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 0, 1, 2, 0,
        2, 1, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 2, 1, 1, 1, 0, 1, 1, 1, 1, 2, 0, 1,
        2, 1, 1, 0, 1, 2, 1, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 1, 1, 0, 2, 0, 1, 2,
        0, 1, 1, 2], dtype=torch.int32)

In [33]:
fn_aklt(samples[:20].to(device))

tensor([[ 1.3422e+08],
        [ 8.3886e+06],
        [ 2.0971e+06],
        [ 2.6843e+08],
        [ 2.0971e+06],
        [ 4.1943e+06],
        [ 6.7109e+07],
        [ 1.0486e+06],
        [ 3.3554e+07],
        [-2.6843e+08],
        [ 8.3886e+06],
        [-3.3554e+07],
        [-8.3886e+06],
        [-8.3886e+06],
        [-1.3422e+08],
        [-1.3422e+08],
        [ 1.3422e+08],
        [ 1.3422e+08],
        [ 2.1475e+09],
        [-3.3554e+07]], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [35]:
sketch_size = 50

domain = [torch.arange(phys_dim) for _ in range(n_features)]

# Load dataset
dataset = torch.load(os.path.join(results_dir, f'dataset_{n_features}.pt'),
                     weights_only=False)
sketch_samples = dataset[:sketch_size]

cores_rss, info = tt_rss(function=fn_aklt,
                         embedding=embedding,
                         sketch_samples=sketch_samples,
                         domain=domain,
                         rank=bond_dim,
                         cum_percentage=1 - 1e-5,
                         batch_size=500,
                         device=device,
                         verbose=False,
                         return_info=True)

mps.reset()

# Save cores
# torch.save(cores_rss,
#            os.path.join(results_dir, f'cores_rss_{info["total_time"]:.2f}.pt'))

In [36]:
info

{'total_time': 16.279383659362793,
 'val_eps': tensor(2.5839e-06, device='cuda:0')}

In [12]:
cores_rss = torch.load(os.path.join(results_dir, 'cores_rss_16.28.pt'),
                       weights_only=False)
mps_rss = tk.models.MPS(tensors=[c.to(device) for c in cores_rss])
mps_rss.canonicalize(renormalize=True)

In [13]:
mps_norm = mps.norm(log_scale=True)
mps_rss_norm = mps_rss.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_rss_norm.item()}')

mps.reset()
mps_rss.reset()

MPS: 54.584007263183594
MPS RSS: 54.58405685424805


In [14]:
for node1, node2 in zip(mps.mats_env, mps_rss.mats_env):
    node1['input'] ^ node2['input']

In [15]:
log_scale = 0

# Contract mps with mps_rss
stack = tk.stack(mps.mats_env)
stack_rss = tk.stack(mps_rss.mats_env)
stack ^ stack_rss

mats_results = tk.unbind(stack @ stack_rss)

mats_results[0] = mps.left_node @ (mps_rss.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_rss.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

54.58402633666992


In [16]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(1.0000, device='cuda:0', grad_fn=<ExpBackward0>)

# TT-CROSS

In [24]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [17]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'aklt_mps')
os.makedirs(results_dir, exist_ok=True)

In [18]:
def embedding(x):
    x = tk.embeddings.basis(x.int(), dim=phys_dim).float()
    return x

In [19]:
def fn_aklt(input):
    output = mps(embedding(input)).unsqueeze(-1)
    return output

In [68]:
domain = [torch.arange(phys_dim, device=device) for _ in range(n_features)]

tt_cross, info = tn.cross(function=fn_aklt,
                          domain=domain,
                          device=device,
                          function_arg='matrix',
                          rmax=2,
                          max_iter=5,
                          eps=1e-3,
                          verbose=True,
                          return_info=True)

cores_cross = tt_cross.cores
cores_cross[0] = cores_cross[0][0]
cores_cross[-1] = cores_cross[-1][..., 0]

mps.reset()

# Save cores
# torch.save(cores_cross,
#            os.path.join(results_dir, f'cores_cross_{info["total_time"]:.2f}.pt'))

cross device is cuda
Functions that require cross-approximation can be accelerated with the optional maxvolpy package, which can be installed by 'pip install maxvolpy'. More info is available at https://bitbucket.org/muxas/maxvolpy.
Cross-approximation over a 100D domain containing 5.15378e+47 grid points:
iter: 0 | eps: nan | time:  11.0036 | largest rank:   1
iter: 1 | eps: nan | time: 116.3506 | largest rank:   2
iter: 2 | eps: nan | time: 234.8218 | largest rank:   2
iter: 3 | eps: nan | time: 325.9593 | largest rank:   2
iter: 4 | eps: nan | time: 582.8404 | largest rank:   2 <- max_iter was reached: 5
Did 10077 function evaluations, which took 5.039s (2000 evals/s)



In [69]:
info

{'nsamples': 10077,
 'eval_time': 5.038981199264526,
 'val_epss': [tensor(nan, device='cuda:0'),
  tensor(nan, device='cuda:0'),
  tensor(nan, device='cuda:0'),
  tensor(nan, device='cuda:0'),
  tensor(nan, device='cuda:0')],
 'min': 0,
 'argmin': None,
 'lsets': [array([[0]]),
  array([[0, 0],
         [0, 1]]),
  array([[0, 0, 0],
         [0, 0, 1]]),
  array([[0, 0, 0, 0],
         [0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]),
  array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1

In [25]:
cores_cross = torch.load(os.path.join(results_dir, 'cores_cross_582.84.pt'),
                         weights_only=False)
mps_cross = tk.models.MPS(tensors=[c.to(device) for c in cores_cross])
mps_cross.canonicalize(renormalize=True)

In [26]:
mps_norm = mps.norm(log_scale=True)
mps_cross_norm = mps_cross.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_cross_norm.item()}')

mps.reset()
mps_cross.reset()

MPS: 54.584007263183594
MPS RSS: -inf


In [27]:
for node1, node2 in zip(mps.mats_env, mps_cross.mats_env):
    node1['input'] ^ node2['input']

In [28]:
log_scale = 0

# Contract mps with mps_cross
stack = tk.stack(mps.mats_env)
stack_cross = tk.stack(mps_cross.mats_env)
stack ^ stack_cross

mats_results = tk.unbind(stack @ stack_cross)

mats_results[0] = mps.left_node @ (mps_cross.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_cross.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

-inf


In [29]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(0., device='cuda:0', grad_fn=<ExpBackward0>)