In [1]:
import os
from math import log, sqrt

import torch

import tensorkrowch as tk
from tensorkrowch.decompositions import tt_rss

In [2]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1)
aklt_core

tensor([[[ 0.0000,  0.8165],
         [-0.5774,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.5774],
         [-0.8165,  0.0000]]])

In [3]:
aklt_core.norm() ** 2

tensor(2.0000)

In [4]:
# Model hyperparameters
n_features = 100
phys_dim = aklt_core.shape[1]
bond_dim = aklt_core.shape[0]

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [5]:
boundary_conditions = [0, 0]
scale = sqrt(phys_dim)

aklt_cores = [aklt_core.to(device) * scale for _ in range(n_features)]
# aklt_cores = [aklt_core.to(device) for _ in range(n_features)]
aklt_cores[0] = aklt_cores[0][boundary_conditions[0], :, :]
aklt_cores[-1] = aklt_cores[-1][:, :, boundary_conditions[1]]

In [6]:
def embedding(x):
    x = tk.embeddings.discretize(x, base=phys_dim, level=1).squeeze(-1).int()
    x = tk.embeddings.basis(x, dim=phys_dim).float() # batch x n_features x dim
    return x

In [7]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

# Disable parameters
mps.parameterize(set_param=False, override=True)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

norm: tensor(nan, device='cuda:0')
log-scale norm: tensor(54.5840, device='cuda:0')


In [8]:
def numel_mps(model):
    n = 0
    for node in model.mats_env:
        n += node.tensor.numel()
    return n

print('Nº params. mps:', numel_mps(mps))

Nº params. mps: 1200


In [9]:
fn_name = 'aklt_mps'

def fn_aklt(input):
    output = mps(embedding(input)).view(-1, 1)
    return output

In [167]:
fn_name = 'aklt_mps_noisy'
std = 1e-5

def fn_aklt(input):
    output = mps(embedding(input)).view(-1, 1)
    output = output + torch.randn_like(output) * std
    return output

## Random dataset

In [164]:
# dataset samples
samples_size = 100
samples = torch.rand(size=(samples_size, n_features))

In [165]:
fn_aklt(samples.to(device)).flatten()

tensor([ 1.1116e-05, -1.6783e-07, -1.0755e-05, -1.0175e-05, -6.0110e-06,
         7.8913e-06, -1.0995e-05,  6.4625e-06,  2.2697e-07, -6.8937e-06,
        -1.0534e-05, -8.8278e-06,  1.4955e-05, -1.0746e-05, -2.1080e-05,
        -1.5788e-05, -8.7712e-06,  7.5415e-06, -6.0559e-06,  1.5122e-05,
        -2.5964e-06, -8.1284e-06,  4.7058e-06, -1.4241e-05,  1.9640e-06,
         7.7230e-06, -8.8588e-06,  7.4468e-06,  1.3166e-05, -1.3281e-06,
         9.4061e-07, -8.9485e-06, -1.3359e-05, -7.8703e-06,  1.3613e-05,
         4.7160e-06,  6.4751e-06, -6.5520e-06,  1.6969e-05,  1.0420e-05,
        -7.6493e-06, -8.4733e-06,  1.8233e-05,  1.6111e-05, -1.2984e-05,
        -8.3915e-06, -9.2881e-06,  1.2876e-05,  1.1645e-05,  1.0034e-05,
        -1.8640e-05, -1.3328e-06, -5.0122e-06,  1.6861e-06,  1.2176e-05,
         3.8775e-06,  1.7179e-05, -7.4841e-06, -6.2458e-06,  3.0296e-06,
         9.8133e-06,  1.0920e-06,  1.1016e-05, -6.4182e-06, -9.9560e-06,
         1.0950e-06,  1.8989e-07, -6.4316e-06,  1.9

In [166]:
cores = tt_rss(function=fn_aklt,
               embedding=embedding,
               sketch_samples=samples,
               rank=5*bond_dim,
               cum_percentage=0.999,
               batch_size=500,
               device=device,
               verbose=True)

# Save cores
# torch.save(cores, f'cores/{fn_name}.pt')



|| Site: 1 / 100 ||
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 1:
-------
tensor([[1.],
        [0.],
        [0.]])
* Final D_k: 1
* S_k out dim: 100


|| Site: 2 / 100 ||
* Max D_k: min(3, 10)
* T_k out dim: 100


KeyboardInterrupt: 

In [24]:
# Load cores
# cores = torch.load(f'cores/{fn_name}.pt')

In [61]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps2.parameters():
    p.requires_grad_(False)

In [62]:
torch.tensor(mps.bond_dim)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2])

In [63]:
torch.tensor(mps2.bond_dim)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1])

In [64]:
mps.reset()
mps.norm(), mps.norm(log_scale=True)

(tensor(nan, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [65]:
mps2.reset()
mps2.norm(), mps2.norm(log_scale=True)

(tensor(0., device='cuda:0', grad_fn=<SqrtBackward0>),
 tensor(-inf, device='cuda:0', grad_fn=<DivBackward0>))

In [66]:
# dataset samples
samples_size = 100
samples = torch.rand(size=(samples_size, n_features))

In [67]:
mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [68]:
results = torch.stack([mps_results, mps2_results], dim=1)
(results[:, 0] - results[:, 1]).pow(2).sum().sqrt()

tensor(0., device='cuda:0')

In [69]:
results[:20]

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')

## Generated dataset

In [10]:
def aux_embedding(x):
    x = tk.embeddings.basis(x, dim=phys_dim).float()
    return x

In [12]:
samples_size = 500
samples = torch.tensor([]).int()

for i in range(n_features):
    mps.unset_data_nodes()
    mps.reset()
    mps.in_features = torch.arange(i + 1).tolist()
    
    new_feature = torch.arange(phys_dim).view(-1, 1)
    new_feature = new_feature.repeat(samples_size, 1)
    
    if i > 0:
        aux_samples = samples.repeat(1, phys_dim)
        aux_samples = aux_samples.reshape(samples_size * phys_dim, i)
    else:
        aux_samples = samples
    
    aux_samples = torch.cat([aux_samples, new_feature], dim=1)
    
    density = mps(aux_embedding(aux_samples.to(device)),
                  marginalize_output=True,
                  renormalize=True,
                  )
    
    if i == (n_features - 1):
        density = torch.outer(density, density)
    
    distr = density.diagonal().reshape(samples_size, phys_dim)
    distr = distr / distr.sum(dim=-1, keepdim=True)
    distr = distr.cumsum(dim=-1)
    
    probs = torch.rand(samples_size, 1).to(device)
    new_samples = phys_dim - (probs < distr).sum(dim=-1)
    
    if i > 0:
        samples = torch.cat([samples,
                             new_samples.cpu().int().reshape(-1, 1)], dim=1)
    else:
        samples = new_samples.cpu().int().reshape(-1, 1)

samples = samples / phys_dim

mps.reset()
mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

KeyboardInterrupt: 

In [220]:
samples[0]

tensor([0.3333, 0.3333, 0.0000, 0.6667, 0.3333, 0.0000, 0.3333, 0.3333, 0.3333,
        0.6667, 0.0000, 0.3333, 0.3333, 0.6667, 0.0000, 0.3333, 0.6667, 0.0000,
        0.3333, 0.3333, 0.3333, 0.6667, 0.3333, 0.3333, 0.0000, 0.3333, 0.3333,
        0.3333, 0.3333, 0.6667, 0.0000, 0.6667, 0.3333, 0.3333, 0.3333, 0.3333,
        0.3333, 0.3333, 0.3333, 0.0000, 0.3333, 0.6667, 0.3333, 0.0000, 0.3333,
        0.3333, 0.3333, 0.3333, 0.6667, 0.3333, 0.0000, 0.6667, 0.3333, 0.3333,
        0.0000, 0.3333, 0.3333, 0.3333, 0.3333, 0.6667, 0.0000, 0.3333, 0.3333,
        0.3333, 0.3333, 0.3333, 0.6667, 0.0000, 0.3333, 0.3333, 0.6667, 0.3333,
        0.0000, 0.3333, 0.6667, 0.3333, 0.0000, 0.3333, 0.6667, 0.0000, 0.6667,
        0.3333, 0.0000, 0.6667, 0.0000, 0.3333, 0.3333, 0.3333, 0.3333, 0.6667,
        0.3333, 0.3333, 0.3333, 0.3333, 0.0000, 0.6667, 0.0000, 0.6667, 0.0000,
        0.6667])

In [221]:
fn_aklt(samples[:20].to(device))

tensor([[ 2.0971e+06],
        [ 3.3554e+07],
        [-2.0971e+06],
        [ 1.6777e+07],
        [-1.3422e+08],
        [-6.7109e+07],
        [ 5.2429e+05],
        [ 3.3554e+07],
        [ 6.7109e+07],
        [-2.0971e+06],
        [ 1.0737e+09],
        [ 6.7109e+07],
        [ 8.3886e+06],
        [-4.1943e+06],
        [ 6.7109e+07],
        [ 6.5536e+04],
        [-4.1943e+06],
        [-1.6777e+07],
        [-1.6777e+07],
        [-3.3554e+07]], device='cuda:0')

In [172]:
cores = tt_rss(function=fn_aklt,
               embedding=embedding,
               sketch_samples=samples[:(samples_size // 2)],
               domain=torch.arange(phys_dim).float() / phys_dim,
               rank=5*bond_dim,
               cum_percentage=0.9,
               batch_size=500,
               device=device,
               verbose=True)

# Save cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

# torch.save(cores, os.path.join(results_dir, f'{fn_name}.pt'))



|| Site: 1 / 100 ||
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 1:
-------
tensor([[-1.0000e+00,  7.0109e-10],
        [-7.0109e-10, -1.0000e+00],
        [ 1.3795e-15,  6.1224e-16]])
* Final D_k: 2
* S_k out dim: 2


|| Site: 2 / 100 ||
* Max D_k: min(6, 10)
* T_k out dim: 100

Core 2:
-------
tensor([[[ 4.9574e-10,  2.2346e-15],
         [-7.0711e-01, -2.6964e-07],
         [ 3.3391e-07, -8.9443e-01]],

        [[ 7.0711e-01,  2.9137e-07],
         [-1.7139e-07,  4.4721e-01],
         [-4.7056e-16,  6.2707e-10]]])
* Final D_k: 2
* S_k out dim: 4


|| Site: 3 / 100 ||
* Max D_k: min(6, 10)
* T_k out dim: 100

Core 3:
-------
tensor([[[ 3.3545e-07],
         [ 5.3452e-01],
         [-6.3371e-10]],

        [[-8.4515e-01],
         [ 2.0175e-07],
         [ 1.0158e-16]]])
* Final D_k: 1
* S_k out dim: 8


|| Site: 4 / 100 ||
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 4:
-------
tensor([[[-4.2103e-16, -9.1160e-10],
         [-1.8027e-07, -5.9161e-01],
         [ 8.2639e-01, -2.5

In [222]:
# Load cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

cores = torch.load(os.path.join(results_dir, f'{fn_name}.pt'),
                   weights_only=False)

In [223]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

# Disable parameters
mps2.parameterize(set_param=False, override=True)

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [224]:
torch.tensor(mps.bond_dim)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2])

In [225]:
torch.tensor(mps2.bond_dim)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2])

In [226]:
mps.reset()
mps.norm(), mps.norm(log_scale=True)

(tensor(nan, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [227]:
mps2.reset()
mps2.norm(), mps2.norm(log_scale=True)

(tensor(inf, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [228]:
mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [229]:
results = torch.stack([mps_results, mps2_results], dim=1)
((results[:, 0] - results[:, 1]) / results[:, 0]).pow(2).sum().sqrt()

tensor(2.3132e-05, device='cuda:0')

In [230]:
results[:20]

tensor([[ 2.0971e+06,  2.0971e+06],
        [ 3.3554e+07,  3.3554e+07],
        [-2.0971e+06, -2.0971e+06],
        [ 1.6777e+07,  1.6777e+07],
        [-1.3422e+08, -1.3422e+08],
        [-6.7109e+07, -6.7109e+07],
        [ 5.2429e+05,  5.2429e+05],
        [ 3.3554e+07,  3.3554e+07],
        [ 6.7109e+07,  6.7109e+07],
        [-2.0971e+06, -2.0971e+06],
        [ 1.0737e+09,  1.0737e+09],
        [ 6.7109e+07,  6.7109e+07],
        [ 8.3886e+06,  8.3886e+06],
        [-4.1943e+06, -4.1943e+06],
        [ 6.7109e+07,  6.7109e+07],
        [ 6.5536e+04,  6.5536e+04],
        [-4.1943e+06, -4.1943e+06],
        [-1.6777e+07, -1.6777e+07],
        [-1.6777e+07, -1.6777e+07],
        [-3.3554e+07, -3.3554e+07]], device='cuda:0')

In [231]:
((results[:, 0] - results[:, 1]).abs() / results[:, 0]).max()

tensor(5.2452e-06, device='cuda:0')

In [232]:
# dataset samples
samples_size = 20
samples = torch.rand(size=(samples_size, n_features))

mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [233]:
results = torch.stack([mps_results, mps2_results], dim=1)
(results[:, 0] - results[:, 1]).pow(2).sum().sqrt()

tensor(0., device='cuda:0')

In [234]:
results

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')

In [235]:
(results[:, 0] - results[:, 1]).max()

tensor(0., device='cuda:0')

### Fidelity

In [236]:
mps.reset()
mps.unset_data_nodes()
mps2.reset()
mps2.unset_data_nodes()

mps_norm = mps.norm(log_scale=True)
mps2_norm = mps.norm(log_scale=True)

mps.reset()
mps2.reset()

for node1, node2 in zip(mps.mats_env, mps2.mats_env):
    node1['input'] ^ node2['input']

In [237]:
log_scale = 0

# Contract mps with mps_rss
stack = tk.stack(mps.mats_env)
stack_rss = tk.stack(mps2.mats_env)
stack ^ stack_rss

mats_results = tk.unbind(stack @ stack_rss)

mats_results[0] = mps.left_node @ (mps2.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps2.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

54.58403015136719


In [238]:
(2*approx_mps_norm - mps_norm - mps2_norm).exp()

tensor(1.0000, device='cuda:0')

### Check if tensors are equal

In [34]:
mps.canonicalize(oc=0)
mps.canonicalize(oc=n_features - 1)

In [35]:
mps2.canonicalize(oc=0)
mps2.canonicalize(oc=n_features - 1)

In [36]:
for node, node2 in zip(mps.mats_env, mps2.mats_env):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[-1.0000e+00,  1.6182e-07],
         [-9.3363e-08, -1.0000e+00],
         [ 0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 1.0000e+00,  4.7185e-09],
         [-1.0485e-07,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 1.2621e-09,  1.2176e-07],
         [-9.1366e-07, -7.0711e-01],
         [ 8.9443e-01, -7.3831e-07]],

        [[ 3.6841e-07,  7.0711e-01],
         [-4.4721e-01,  5.5046e-07],
         [-9.6147e-09,  6.4437e-08]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 1.2206e-09, -2.4723e-07],
         [-2.0721e-06, -7.0711e-01],
         [ 8.9443e-01, -2.3528e-06]],

        [

In [37]:
for node, node2 in zip(mps.mats_env[(n_features//2 - 5):-(n_features//2 - 5)],
                       mps2.mats_env[(n_features//2 - 5):-(n_features//2 - 5)]):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[-0.4149, -0.4158],
         [ 0.5773, -0.0095],
         [-0.4015,  0.4006]],

        [[ 0.4006,  0.4015],
         [ 0.0095,  0.5773],
         [-0.4158,  0.4149]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 0.0537, -0.3016],
         [ 0.3071, -0.4889],
         [ 0.7451,  0.1327]],

        [[-0.1327,  0.7451],
         [-0.4889, -0.3071],
         [ 0.3016,  0.0537]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 0.4102, -0.4071],
         [ 0.5773,  0.0028],
         [ 0.4063,  0.4094]],

        [[-0.4094,  0.4063],
         [ 0.0028, -0.5773],
         [ 0.4071,  0.4102]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[-0.7704,  0.2296],
         [ 0.2594,  0.5158],
         [-0.0409, -0.1372]],

        [[-0.1372,  0.0409],
         [-0.5158,  0.2594],
         [ 0.2296,  0.7704]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[-0.4039,  0.4095],

In [38]:
m1 = mps.mats_env[n_features//2]
m2 = mps2.mats_env[n_features//2]

m1.tensor, m2.tensor

(Parameter containing:
 tensor([[[ 0.5803,  0.0069],
          [ 0.4010,  0.4156],
          [-0.0071,  0.5776]],
 
         [[ 0.5743,  0.0069],
          [-0.4153,  0.4008],
          [ 0.0071, -0.5770]]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[[-0.3812,  0.3945],
          [ 0.5761, -0.0379],
          [-0.4349, -0.4202]],
 
         [[-0.4202,  0.4349],
          [ 0.0379,  0.5761],
          [ 0.3945,  0.3812]]], device='cuda:0', requires_grad=True))

In [39]:
m1.norm(), m2.norm()

(tensor(1.4142, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>),
 tensor(1.4142, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>))

In [40]:
m1.norm() * m2.norm()

tensor(2., device='cuda:0', grad_fn=<MulBackward0>)

In [41]:
out = torch.einsum('lir,lir->', m1.tensor, m2.tensor)
out

tensor(-0.4832, device='cuda:0', grad_fn=<ViewBackward0>)

In [42]:
out = torch.einsum('lir,ljr->ij', m1.tensor, m1.tensor)
out

tensor([[ 6.6667e-01, -1.6084e-04, -1.4589e-06],
        [-1.6084e-04,  6.6668e-01,  3.0558e-03],
        [-1.4589e-06,  3.0558e-03,  6.6665e-01]], device='cuda:0',
       grad_fn=<ViewBackward0>)

### Check degeneracies

In [44]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)
mps = mps.parameterize(set_param=False, override=True)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

norm: tensor(nan, device='cuda:0')
log-scale norm: tensor(54.5840, device='cuda:0')


In [46]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])
mps2.canonicalize(oc=0, renormalize=True)
mps2 = mps2.parameterize(set_param=False, override=True)

print('norm:', mps2.norm())
print('log-scale norm:', mps2.norm(log_scale=True))

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

norm: tensor(nan, device='cuda:0')
log-scale norm: tensor(54.5841, device='cuda:0')


In [47]:
@torch.no_grad()
def degeneracies(mps):
    mps.reset()

    prev_auto_stack = mps._auto_stack
    mps.auto_stack = False

    oc = mps._n_features - 1
    
    log_norm = 0
    
    nodes = mps._mats_env[:]
    if mps._boundary == 'obc':
        nodes[0].tensor[1:] = torch.zeros_like(
            nodes[0].tensor[1:])
        nodes[-1].tensor[..., 1:] = torch.zeros_like(
            nodes[-1].tensor[..., 1:])
    
    diff_s = []
    
    for i in range(oc):
        # Get singular values
        aux_tensor = nodes[i]['right'].contract().tensor
        _, s, _ = torch.linalg.svd(aux_tensor.reshape(6, 6))
        s = s[:2]
        diff_s.append((s[0] - s[1]).abs())
        
        result1, result2 = nodes[i]['right'].svd_(
            side='right',
            rank=nodes[i]['right'].size())
        
        # Renormalize
        aux_norm = result2.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result2.tensor = result2.tensor / aux_norm
            log_norm += aux_norm.log()

        result1 = result1.parameterize()
        nodes[i] = result1
        nodes[i + 1] = result2

    for i in range(len(nodes) - 1, oc, -1):
        # Get singular values
        aux_tensor = nodes[i]['left'].contract().tensor
        _, s, _ = torch.linalg.svd(aux_tensor.reshape(6, 6))
        s = s[:2]
        assert len(s) == 2
        diff_s.append((s[0] - s[1]).abs())
        
        result1, result2 = nodes[i]['left'].svd_(
            side='left',
            rank=nodes[i]['left'].size())
        
        # Renormalize
        aux_norm = result1.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result1.tensor = result1.tensor / aux_norm
            log_norm += aux_norm.log()

        result2 = result2.parameterize()
        nodes[i] = result2
        nodes[i - 1] = result1

    nodes[oc] = nodes[oc].parameterize()
    
    # Rescale
    if log_norm != 0:
        rescale = (log_norm / len(nodes)).exp()
        
        for node in nodes:
            node.tensor = node.tensor * rescale
    
    mps.reset()
    
    # Update variables
    mps._mats_env = nodes
    mps.update_bond_dim()

    mps.auto_stack = prev_auto_stack
    
    return diff_s

In [48]:
degeneracies(mps)

[tensor(0.7174, device='cuda:0'),
 tensor(0.1363, device='cuda:0'),
 tensor(0.0454, device='cuda:0'),
 tensor(0.0151, device='cuda:0'),
 tensor(0.0050, device='cuda:0'),
 tensor(0.0017, device='cuda:0'),
 tensor(0.0006, device='cuda:0'),
 tensor(0.0002, device='cuda:0'),
 tensor(6.2585e-05, device='cuda:0'),
 tensor(2.0742e-05, device='cuda:0'),
 tensor(7.2718e-06, device='cuda:0'),
 tensor(2.7418e-06, device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(1.0729e-06, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(2.3842e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(1.1921e-07, device='cuda:0'),
 tensor(1.5497e-06, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(5.9605e-07, device='cu

In [49]:
degeneracies(mps2)

[tensor(0.7125, device='cuda:0'),
 tensor(0.1358, device='cuda:0'),
 tensor(0.0452, device='cuda:0'),
 tensor(0.0151, device='cuda:0'),
 tensor(0.0050, device='cuda:0'),
 tensor(0.0017, device='cuda:0'),
 tensor(0.0006, device='cuda:0'),
 tensor(0.0002, device='cuda:0'),
 tensor(6.1274e-05, device='cuda:0'),
 tensor(2.2292e-05, device='cuda:0'),
 tensor(7.3910e-06, device='cuda:0'),
 tensor(5.1260e-06, device='cuda:0'),
 tensor(3.3379e-06, device='cuda:0'),
 tensor(2.1458e-06, device='cuda:0'),
 tensor(9.5367e-07, device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(5.9605e-07, device='cuda:0'),
 tensor(8.3447e-07, device='cuda:0'),
 tensor(1.3113e-06, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(1.3113e-06, device='cuda:0'),
 tensor(1.5497e-06, device='cuda:0'),
 tensor(1.1921e-06, device='cuda:0'),
 tensor(1.4305e-06, device='cuda:0'),
 tensor(1.3113e-06, device='cuda:0'),
 tensor(1.0729e-06, device='cuda:0'),
 tensor(1.4305e-06, device='cuda:0'),
 tenso

### Get uniform tensor

In [None]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [None]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps2.parameters():
    p.requires_grad_(False)

mps2.canonicalize(oc=0, renormalize=True)

In [None]:
@torch.no_grad()
def canonicalize(mps):
    mps.reset()

    prev_auto_stack = mps._auto_stack
    mps.auto_stack = False

    oc = 0
    log_norm = 0
    
    nodes = mps._mats_env[:]
    if mps._boundary == 'obc':
        nodes[0].tensor[1:] = torch.zeros_like(
            nodes[0].tensor[1:])
        nodes[-1].tensor[..., 1:] = torch.zeros_like(
            nodes[-1].tensor[..., 1:])

    # SVDs from right to left
    for i in range(len(nodes) - 1, oc, -1):
        result1, result2 = nodes[i]['left'].svd_(
            side='left',
            rank=nodes[i]['left'].size())
        
        # Renormalize
        aux_norm = result1.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result1.tensor = result1.tensor / aux_norm
            log_norm += aux_norm.log()

        result2 = result2.parameterize()
        nodes[i] = result2
        nodes[i - 1] = result1

    nodes[oc] = nodes[oc].parameterize()
    
    # Rescale
    if log_norm != 0:
        rescale = (log_norm / len(nodes)).exp()
        
        for node in nodes:
            node.tensor = node.tensor * rescale
    
    # Fix remaining gauge freedom
    for i in range(len(nodes)):
        tensor = nodes[i].tensor
        
        if i == 0:
            left_tensor = torch.einsum('lir,lik->rk', tensor.conj(), tensor)
        else:
            left_tensor = torch.einsum('lir,lm,m,mn,nik->rk',
                                       tensor.conj(),
                                       prev_v,
                                       prev_s,
                                       prev_v_dagger,
                                       tensor)
            
        v, s, v_dagger = torch.linalg.svd(left_tensor, full_matrices=False)
        # print(v)
        # print(s)
        # print(v_dagger)
        # assert torch.dist(v, v_dagger.H) < 1e-5
        # assert (s[0] - s[1]).abs() > 1e-5
        
        if i == 0:
            new_tensor = torch.einsum('ijk,kl->ijl', tensor, v)
        else:
            new_tensor = torch.einsum('hi,ijk,kl->hjl', prev_v_dagger, tensor, v)
            
        nodes[i].tensor = new_tensor
        
        prev_v = v
        prev_s = s
        prev_v_dagger = v_dagger
    
    mps.reset()
    
    # Update variables
    mps._mats_env = nodes
    mps.update_bond_dim()

    mps.auto_stack = prev_auto_stack

In [None]:
canonicalize(mps)

In [None]:
canonicalize(mps2)

In [None]:
for node, node2 in zip(mps.mats_env[(n_features//2 - 5):-(n_features//2 - 5)],
                       mps2.mats_env[(n_features//2 - 5):-(n_features//2 - 5)]):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[ 0.2472,  0.4805],
         [-0.7608, -0.6437],
         [ 1.1574, -0.5954]],

        [[ 0.5954,  1.1574],
         [-0.6437,  0.7608],
         [-0.4805,  0.2472]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 1.3002,  0.4201],
         [-0.5293,  0.8443],
         [ 0.1061, -0.3285]],

        [[ 0.3285,  0.1061],
         [ 0.8443,  0.5293],
         [-0.4201,  1.3002]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 1.1831, -0.4132],
         [ 0.7226,  0.6863],
         [ 0.2126,  0.6086]],

        [[-0.6086,  0.2126],
         [ 0.6863, -0.7226],
         [ 0.4132,  1.1831]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 4.3713e-04, -4.3330e-01],
         [ 3.0735e-01, -9.4796e-01],
         [ 1.3411e+00,  1.3546e-03]],

        [[-1.3543e-03,  1.3411e+00],
         [-9.4796e-01, -3.0735e-01],
         [ 4.3330e-01,  4.3769e-04]]], device='cuda:0', requires_grad=True)

Pa

## Topological phase estimation (mps)

In [70]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1).to(torch.complex64)

# Model hyperparameters
n_features = 100
phys_dim = aklt_core.shape[1]
bond_dim = aklt_core.shape[0]

fn_name = 'aklt_mps'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

boundary_conditions = [0, 0]
scale = sqrt(phys_dim)

# aklt_cores = [aklt_core.to(device) * scale for _ in range(n_features)]
aklt_cores = [aklt_core.to(device) for _ in range(n_features)]
aklt_cores[0] = aklt_cores[0][boundary_conditions[0], :, :]
aklt_cores[-1] = aklt_cores[-1][:, :, boundary_conditions[1]]


# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

# Disable parameters
_ = mps.parameterize(set_param=False, override=True)

print(mps.norm())
mps.reset()

# mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

tensor(0.7071, device='cuda:0')


In [71]:
sigma_x = torch.tensor([[0, sqrt(1 / 3)],
                        [sqrt(1 / 3), 0]])
sigma_y = torch.tensor([[0, -sqrt(1 / 3)*1j],
                        [sqrt(1 / 3)*1j, 0]])
sigma_z = torch.tensor([[sqrt(1 / 3), 0],
                        [0, -sqrt(1 / 3)]])

aklt_core_pauli = torch.stack([sigma_x, sigma_y, sigma_z], dim=1)
aklt_core_pauli

tensor([[[ 0.0000+0.0000j,  0.5774+0.0000j],
         [ 0.0000+0.0000j, -0.0000-0.5774j],
         [ 0.5774+0.0000j,  0.0000+0.0000j]],

        [[ 0.5774+0.0000j,  0.0000+0.0000j],
         [ 0.0000+0.5774j,  0.0000+0.0000j],
         [ 0.0000+0.0000j, -0.5774+0.0000j]]])

In [72]:
U = torch.linalg.lstsq(
    aklt_core.permute(0, 2, 1).reshape(4, 3).to(aklt_core_pauli.dtype),
    aklt_core_pauli.permute(0, 2, 1).reshape(4, 3)).solution

U

tensor([[ 0.7071+0.0000j,  0.0000-0.7071j,  0.0000-0.0000j],
        [ 0.0000-0.0000j,  0.0000-0.0000j, -1.0000-0.0000j],
        [-0.7071-0.0000j,  0.0000-0.7071j,  0.0000-0.0000j]])

In [73]:
U @ U.H

tensor([[1.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 1.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 1.0000+0.j]])

In [74]:
L = 10
aux_mps = mps
aux_mps_copy = aux_mps.copy(share_tensors=True)
aux_mps_copy.left_node.move_to_network(aux_mps)

ux = torch.Tensor([[1, 0, 0],
                   [0, -1, 0],
                   [0, 0, -1]]).to(torch.complex64)
uz = torch.Tensor([[-1, 0, 0],
                   [0, -1, 0],
                   [0, 0, 1]]).to(torch.complex64)

ux = (U @ ux @ U.H).to(device)
uz = (U @ uz @ U.H).to(device)


start = (n_features // 2) - (L * 5 // 2)
node_blocks = [aux_mps.mats_env[(start + i*L):(start + (i+1)*L)]
               for i in range(5)]
node_blocks_copy = [aux_mps_copy.mats_env[(start + i*L):(start + (i+1)*L)]
                    for i in range(5)]

In [75]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks[i]):
        node_z = tk.Node(tensor=uz,
                         name=f'node_z_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_z['output']

In [76]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks_copy[i]):
        node_x = tk.Node(tensor=ux,
                         name=f'node_x_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_x['output']

In [77]:
i1_lst = [0, 1, 2, 3, 4]
i2_lst = [0, 3, 2, 1, 4]

contracted_blocks = []
for i1, i2 in zip(i1_lst, i2_lst):
    
    # Connect all nodes of each block
    for j in range(L):
        if i1 in [1, 2]:
            edge1 = node_blocks[i1][j].neighbours('input')['input']
        else:
            edge1 = node_blocks[i1][j]['input']
        
        if i2 in [1, 2]:
            edge2 = node_blocks_copy[i2][j].neighbours('input')['input']
        else:
            edge2 = node_blocks_copy[i2][j]['input']
            
        edge1 ^ edge2
    
    # Contract with ux, uz
    for j in range(L):
        if i1 in [1, 2]:
            node_blocks[i1][j] = node_blocks[i1][j]['input'].contract()
        
        if i2 in [1, 2]:
            node_blocks_copy[i2][j] = node_blocks_copy[i2][j]['input'].contract()
    
    # Contract each node of a block with the corresponding copy
    aux_results = []
    for j in range(L):
        aux_results.append(node_blocks[i1][j] @ node_blocks_copy[i2][j])
    
    # Contract all nodes in each block in line
    result = aux_results[0]
    for j in range(1, L):
        result @= aux_results[j]
    
    contracted_blocks.append(result)

In [78]:
contracted_blocks[0].shape

torch.Size([2, 2, 2, 2])

In [79]:
u, s, vh = torch.linalg.svd(contracted_blocks[0].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6950e-05, 1.6935e-05, 1.6935e-05], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)

In [80]:
u[:, 0]

tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [81]:
vh[0, :]

tensor([0.7071-0.j, 0.0000-0.j, 0.0000-0.j, 0.7071-0.j], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [82]:
u, s, vh = torch.linalg.svd(contracted_blocks[-1].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6950e-05, 1.6935e-05, 1.6935e-05], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)

In [83]:
u[:, 0]

tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [86]:
u[:, 0].view(2, 2) @ u[:, 0].view(2, 2)

tensor([[0.5000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.5000+0.j]], device='cuda:0', grad_fn=<MmBackward0>)

In [84]:
vh[0, :]

tensor([0.7071-0.j, 0.0000-0.j, 0.0000-0.j, 0.7071-0.j], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [87]:
_, left_node = contracted_blocks[0].split(node1_axes=['left_0', 'left_1'],
                                          node2_axes=['right_0', 'right_1'],
                                          side='left',
                                          rank=1)

right_node, _ = contracted_blocks[-1].split(node1_axes=['left_0', 'left_1'],
                                            node2_axes=['right_0', 'right_1'],
                                            side='right',
                                            rank=1)

In [88]:
left_node

Node(
 	name: split_1
	tensor:
		tensor([[[0.7071-0.j, 0.0000-0.j],
		         [0.0000-0.j, 0.7071-0.j]]], device='cuda:0', grad_fn=<ViewBackward0>)
	axes:
		[split
		 right_0
		 right_1]
	edges:
		[split_0[split] <-> split_1[split]
		 mats_env_node_(34)_0[right] <-> mats_env_node_(35)_0[left]
		 mats_env_node_(34)_1[right] <-> mats_env_node_(35)_1[left]])

In [89]:
right_node

Node(
 	name: split_2
	tensor:
		tensor([[[0.7071+0.j],
		         [0.0000+0.j]],
		
		        [[0.0000+0.j],
		         [0.7071+0.j]]], device='cuda:0', grad_fn=<ViewBackward0>)
	axes:
		[left_0
		 left_1
		 split]
	edges:
		[mats_env_node_(64)_0[right] <-> mats_env_node_(65)_0[left]
		 mats_env_node_(64)_1[right] <-> mats_env_node_(65)_1[left]
		 split_2[split] <-> split_3[split]])

In [90]:
result = left_node @ contracted_blocks[1] @ contracted_blocks[2] @ contracted_blocks[3] @ right_node

In [94]:
result.tensor.item() * bond_dim**2

(-1.0000026226043701+0j)

## Topological phase estimation (mps2)

In [120]:
# Load cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

cores = torch.load(os.path.join(results_dir, f'{fn_name}.pt'),
                   weights_only=False)

# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device).to(torch.complex64) for c in cores])

mps2.canonicalize(oc=0, renormalize=True)

# Disable parameters
mps2 = mps2.parameterize(set_param=False, override=True)

# Canonical form doesn't give us the exact normalization we want
# We need to hev all tensors except the ones at the extremes with the same norm
# (sqrt(2)), and the norm of the extreme tensors equal to 1

# Renormalize cores
log_norm = 0

for node in mps2.mats_env:
    log_norm += (node.norm() / sqrt(phys_dim)).log()
    node.tensor = node.tensor / node.norm()

for node in mps2.mats_env[1:-1]:
    log_norm -= log(sqrt(2))
    node.tensor = node.tensor * sqrt(2)

for node in mps2.mats_env[:1] + mps2.mats_env[-1:]:
    node.tensor = node.tensor * (log_norm / 2).exp()

print(mps2.norm())
mps2.reset()

# mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# for node in mps2.mats_env:
#     print(node.norm() ** 2)
# print()

tensor(0.7071, device='cuda:0')


In [121]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1).to(torch.complex64)

In [122]:
sigma_x = torch.tensor([[0, sqrt(1 / 3)],
                        [sqrt(1 / 3), 0]])
sigma_y = torch.tensor([[0, -sqrt(1 / 3)*1j],
                        [sqrt(1 / 3)*1j, 0]])
sigma_z = torch.tensor([[sqrt(1 / 3), 0],
                        [0, -sqrt(1 / 3)]])

aklt_core_pauli = torch.stack([sigma_x, sigma_y, sigma_z], dim=1)
aklt_core_pauli

tensor([[[ 0.0000+0.0000j,  0.5774+0.0000j],
         [ 0.0000+0.0000j, -0.0000-0.5774j],
         [ 0.5774+0.0000j,  0.0000+0.0000j]],

        [[ 0.5774+0.0000j,  0.0000+0.0000j],
         [ 0.0000+0.5774j,  0.0000+0.0000j],
         [ 0.0000+0.0000j, -0.5774+0.0000j]]])

In [123]:
U = torch.linalg.lstsq(
    aklt_core.permute(0, 2, 1).reshape(4, 3).to(aklt_core_pauli.dtype),
    aklt_core_pauli.permute(0, 2, 1).reshape(4, 3)).solution

U

tensor([[ 0.7071+0.0000j,  0.0000-0.7071j,  0.0000-0.0000j],
        [ 0.0000-0.0000j,  0.0000-0.0000j, -1.0000-0.0000j],
        [-0.7071-0.0000j,  0.0000-0.7071j,  0.0000-0.0000j]])

In [124]:
U @ U.H

tensor([[1.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 1.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 1.0000+0.j]])

In [125]:
L = 10
aux_mps = mps2
aux_mps_copy = aux_mps.copy(share_tensors=True)
aux_mps_copy.left_node.move_to_network(aux_mps)

ux = torch.Tensor([[1, 0, 0],
                   [0, -1, 0],
                   [0, 0, -1]]).to(torch.complex64)
uz = torch.Tensor([[-1, 0, 0],
                   [0, -1, 0],
                   [0, 0, 1]]).to(torch.complex64)

ux = (U @ ux @ U.H).to(device)
uz = (U @ uz @ U.H).to(device)


start = (n_features // 2) - (L * 5 // 2)
node_blocks = [aux_mps.mats_env[(start + i*L):(start + (i+1)*L)]
               for i in range(5)]
node_blocks_copy = [aux_mps_copy.mats_env[(start + i*L):(start + (i+1)*L)]
                    for i in range(5)]

In [126]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks[i]):
        node_z = tk.Node(tensor=uz,
                         name=f'node_z_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_z['output']

In [127]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks_copy[i]):
        node_x = tk.Node(tensor=ux,
                         name=f'node_x_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_x['output']

In [128]:
i1_lst = [0, 1, 2, 3, 4]
i2_lst = [0, 3, 2, 1, 4]

contracted_blocks = []
for i1, i2 in zip(i1_lst, i2_lst):
    
    # Connect all nodes of each block
    for j in range(L):
        if i1 in [1, 2]:
            edge1 = node_blocks[i1][j].neighbours('input')['input']
        else:
            edge1 = node_blocks[i1][j]['input']
        
        if i2 in [1, 2]:
            edge2 = node_blocks_copy[i2][j].neighbours('input')['input']
        else:
            edge2 = node_blocks_copy[i2][j]['input']
            
        edge1 ^ edge2
    
    # Contract with ux, uz
    for j in range(L):
        if i1 in [1, 2]:
            node_blocks[i1][j] = node_blocks[i1][j]['input'].contract_()
        
        if i2 in [1, 2]:
            node_blocks_copy[i2][j] = node_blocks_copy[i2][j]['input'].contract_()
    
    # Contract each node of a block with the corresponding copy
    aux_results = []
    for j in range(L):
        aux_results.append(tk.contract_between_(node_blocks[i1][j],
                                                node_blocks_copy[i2][j]))
    
    # Contract all nodes in each block in line
    result = aux_results[0]
    for j in range(1, L):
        result = tk.contract_between_(result, aux_results[j])
    
    contracted_blocks.append(result)

In [129]:
contracted_blocks[0].shape

torch.Size([2, 2, 2, 2])

In [130]:
u, s, vh = torch.linalg.svd(contracted_blocks[0].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6935e-05, 1.6935e-05, 1.6926e-05], device='cuda:0')

In [131]:
u[:, 0]

tensor([7.0711e-01+0.j, 2.0224e-09+0.j, 2.0548e-09+0.j, 7.0711e-01+0.j],
       device='cuda:0')

In [132]:
vh[0, :]

tensor([7.0711e-01-0.j, 1.2375e-07-0.j, 1.2387e-07-0.j, 7.0711e-01-0.j],
       device='cuda:0')

In [133]:
u, s, vh = torch.linalg.svd(contracted_blocks[-1].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6935e-05, 1.6935e-05, 1.6918e-05], device='cuda:0')

In [134]:
u[:, 0]

tensor([ 7.0711e-01+0.j, -2.0888e-09+0.j, -1.9283e-09+0.j,  7.0711e-01+0.j],
       device='cuda:0')

In [135]:
u[:, 0].view(2, 2) @ u[:, 0].view(2, 2)

tensor([[ 5.0000e-01+0.j, -2.9540e-09+0.j],
        [-2.7270e-09+0.j,  5.0000e-01+0.j]], device='cuda:0')

In [136]:
vh[0, :]

tensor([7.0711e-01-0.j, 4.9514e-08-0.j, 4.9602e-08-0.j, 7.0711e-01-0.j],
       device='cuda:0')

In [137]:
_, left_node = contracted_blocks[0].split_(node1_axes=['left_0', 'left_1'],
                                           node2_axes=['right_0', 'right_1'],
                                           side='left',
                                           rank=1)

right_node, _ = contracted_blocks[-1].split_(node1_axes=['left_0', 'left_1'],
                                             node2_axes=['right_0', 'right_1'],
                                             side='right',
                                             rank=1)

# If left_node/right_node is not semidefinite positive, we can choose other
# left_node/right_node that is semidefinite positive, as it will always
# exist. Therefore, if we see u and vh are positive multiples of -I, we can
# multiply both by -1 to make them positive multiples of I

In [138]:
left_node

Node(
 	name: split_ip_1
	tensor:
		tensor([[[7.0711e-01-0.j, 1.2375e-07-0.j],
		         [1.2387e-07-0.j, 7.0711e-01-0.j]]], device='cuda:0')
	axes:
		[split
		 right_0
		 right_1]
	edges:
		[split_ip_0[split] <-> split_ip_1[split]
		 split_ip_1[right_0] <-> contract_edges_ip_0[left_0]
		 split_ip_1[right_1] <-> contract_edges_ip_2[left_1]])

In [139]:
right_node

Node(
 	name: split_ip_2
	tensor:
		tensor([[[ 7.0711e-01+0.j],
		         [-2.0888e-09+0.j]],
		
		        [[-1.9283e-09+0.j],
		         [ 7.0711e-01+0.j]]], device='cuda:0')
	axes:
		[left_0
		 left_1
		 split]
	edges:
		[contract_edges_ip_2[right_0] <-> split_ip_2[left_0]
		 contract_edges_ip_0[right_1] <-> split_ip_2[left_1]
		 split_ip_2[split] <-> split_ip_3[split]])

In [140]:
result = left_node @ contracted_blocks[1] @ contracted_blocks[2] @ contracted_blocks[3] @ right_node

In [141]:
result.tensor.item() * bond_dim**2

(-0.9999990463256836+0j)

# Neural Quantum State

In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "gpu"

In [2]:
# Import netket library
import netket as nk

# Import Json, this will be needed to load log files
import json

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import time

import flax.linen as nn
import jax.numpy as jnp
import jax

In [3]:
# Define a 1d chain
L = 10  #n_features
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=False)

In [4]:
# Define the Hilbert space based on this graph
hi = nk.hilbert.Spin(s=1, N=g.n_nodes)

In [5]:
sigmaz     = [[1, 0, 0], [0, 0, 0], [0, 0, -1]]
sigmaplus  = [[0, np.sqrt(2), 0], [0, 0, np.sqrt(2)], [0, 0, 0]]
sigmaminus = [[0, 0, 0], [np.sqrt(2), 0, 0], [0, np.sqrt(2), 0]]

heisenberg = np.kron(sigmaz, sigmaz) + \
             0.5 * np.kron(sigmaplus, sigmaminus) + \
             0.5 * np.kron(sigmaminus, sigmaplus)

operator = (0.5*heisenberg + np.dot(heisenberg, heisenberg)/6. + np.identity(9)/3.)
operator

array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.5       , 0.        , 0.5       , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.16666667, 0.        , 0.33333333,
        0.        , 0.16666667, 0.        , 0.        ],
       [0.        , 0.5       , 0.        , 0.5       , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33333333, 0.        , 0.66666667,
        0.        , 0.33333333, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5       , 0.        , 0.5       , 0.        ],
       [0.        , 0.        , 0.16666667, 0.        , 0.33333333,
        0.        , 0.16666667, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5       , 0.        , 0.5       , 0.        ],


In [6]:
ha = nk.operator.LocalOperator(hilbert=hi,
                               operators=[operator for _ in range(L)],
                               acting_on=[[i, (i + 1)%L] for i in range(L)])

In [7]:
# RBM ansatz with alpha=1
ma = nk.models.RBM(alpha=1)

In [8]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=2*x.shape[-1],
                    #  param_dtype=np.complex128,
                     kernel_init=nn.initializers.normal(stddev=0.1),
                     bias_init=nn.initializers.normal(stddev=0.1))(x)
        x = nk.nn.activation.log_cosh(x)
        return jax.numpy.sum(x, axis=-1)
    
ma = Model()

In [None]:
# Build the sampler
sa = nk.sampler.MetropolisExchange(hilbert=hi, graph=g)

# Optimizer
op = nk.optimizer.Sgd(learning_rate=0.05)

# Stochastic Reconfiguration
sr = nk.optimizer.SR(diag_shift=0.1)

# The variational state
vs = nk.vqs.MCState(sa, ma, n_samples=1000)


gs = nk.VMC(
    hamiltonian=ha,
    optimizer=op,
    preconditioner=sr,
    variational_state=vs)

start = time.time()
gs.run(out='RBM', n_iter=600)
end = time.time()

print('### RBM calculation')
print('Has',vs.n_parameters,'parameters')
print('The RBM calculation took',end-start,'seconds')

  self.n_samples = n_samples


  0%|          | 0/600 [00:00<?, ?it/s]

### RBM calculation
Has 220 parameters
The RBM calculation took 43.622886419296265 seconds
