In [165]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import tensorkrowch as tk
import tensorkrowch.embeddings as embeddings
from tensorkrowch.decompositions import tt_rss

from math import log, pi, sqrt
import time

import warnings
from typing import Optional, Callable, List

import os
import sys

In [251]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1)
aklt_core

tensor([[[ 0.0000,  0.8165],
         [-0.5774,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.5774],
         [-0.8165,  0.0000]]])

In [252]:
aklt_core.norm() ** 2

tensor(2.0000)

In [253]:
# Model hyperparameters
n_features = 100
phys_dim = aklt_core.shape[1]
bond_dim = aklt_core.shape[0]

fn_name = 'aklt_mps'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [254]:
boundary_conditions = [0, 0]
scale = sqrt(phys_dim)
# 
aklt_cores = [aklt_core.to(device) * scale for _ in range(n_features)]
# aklt_cores = [aklt_core.to(device) for _ in range(n_features)]
aklt_cores[0] = aklt_cores[0][boundary_conditions[0], :, :]
aklt_cores[-1] = aklt_cores[-1][:, :, boundary_conditions[1]]

In [255]:
# def embedding(x):
#     x = tk.embeddings.poly(x, degree=phys_dim - 1)
#     return x

# def embedding(x):
#     x = tk.embeddings.unit(x, dim=phys_dim)
#     return x

def embedding(x):
    x = tk.embeddings.discretize(x, base=phys_dim, level=1).squeeze(-1).int()
    x = tk.embeddings.basis(x, dim=phys_dim).float() # batch x n_features x dim
    return x

# def embedding(x):
#     x = tk.embeddings.basis(x, dim=phys_dim).float()
#     return x

In [256]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

# Disable parameters
mps.parameterize(set_param=False, override=True)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

norm: tensor(nan, device='cuda:0')
log-scale norm: tensor(54.5840, device='cuda:0')


In [257]:
def numel_mps(model):
    n = 0
    for node in model.mats_env:
        n += node.tensor.numel()
    return n

print('Nº params. mps:', numel_mps(mps))

Nº params. mps: 1200


In [258]:
def fn_aklt(input):
    output = mps(embedding(input)).view(-1, 1)
    return output

In [36]:
std = 1e-10

def fn_aklt(input):
    output = mps(embedding(input)).view(-1, 1)
    output = output + torch.randn_like(output) * std
    return output

## Random dataset

In [10]:
# dataset samples
samples_size = 100
# samples = torch.randint(low=0, high=phys_dim, size=(samples_size, n_features))
samples = torch.rand(size=(samples_size, n_features))

In [11]:
fn_aklt(samples.to(device))

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      

In [12]:
cores = tt_rss(function=fn_aklt,
               embedding=embedding,
               sketch_samples=samples,
               rank=5*bond_dim,
               cum_percentage=0.999,
               batch_size=500,
               device=device,
               verbose=True)

# Save cores
# torch.save(cores, f'cores/{fn_name}.pt')



Site: 0
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 0:
-------
tensor([[1.],
        [0.],
        [0.]])
* Final D_k: 1
* S_k out dim: 100


Site: 1
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 1:
-------
tensor([[[0.0385],
         [-0.0000],
         [-0.0000]]])
* Final D_k: 1
* S_k out dim: 100


Site: 2
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 2:
-------
tensor([[[1.],
         [0.],
         [0.]]])
* Final D_k: 1
* S_k out dim: 100


Site: 3
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 3:
-------
tensor([[[1.],
         [0.],
         [0.]]])
* Final D_k: 1
* S_k out dim: 100


Site: 4
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 4:
-------
tensor([[[0.],
         [0.],
         [0.]]])
* Final D_k: 1
* S_k out dim: 100


Site: 5
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 5:
-------
tensor([[[0.],
         [0.],
         [0.]]])
* Final D_k: 1
* S_k out dim: 100


Site: 6
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 6:
-------
tensor([[[0.],
         [0.]

In [24]:
# Load cores
# cores = torch.load(f'cores/{fn_name}.pt')

In [13]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps2.parameters():
    p.requires_grad_(False)

In [14]:
mps.bond_dim

[2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2]

In [15]:
mps2.bond_dim

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [16]:
mps.reset()
mps.norm(), mps.norm(log_scale=True)

(tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>),
 tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>))

In [17]:
mps2.reset()
mps2.norm(), mps2.norm(log_scale=True)

(tensor(0., device='cuda:0', grad_fn=<SqrtBackward0>),
 tensor(-inf, device='cuda:0', grad_fn=<DivBackward0>))

In [18]:
# dataset samples
samples_size = 100
# samples = torch.randint(low=0, high=phys_dim, size=(samples_size, n_features))
samples = torch.rand(size=(samples_size, n_features))

In [19]:
mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [20]:
results = torch.stack([mps_results, mps2_results], dim=1)
(results[:, 0] - results[:, 1]).pow(2).sum().sqrt()

tensor(0., device='cuda:0')

In [21]:
results

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0

## Generated dataset

In [136]:
def aux_embedding(x):
    x = tk.embeddings.basis(x, dim=phys_dim).float()
    return x

In [137]:
samples_size = 200
samples = torch.tensor([]).int()

for i in range(n_features):
    mps.unset_data_nodes()
    mps.reset()
    mps.in_features = torch.arange(i + 1).tolist()
    
    new_feature = torch.arange(phys_dim).view(-1, 1)
    new_feature = new_feature.repeat(samples_size, 1)
    
    if i > 0:
        aux_samples = samples.repeat(1, phys_dim)
        aux_samples = aux_samples.reshape(samples_size * phys_dim, i)
    else:
        aux_samples = samples
    
    aux_samples = torch.cat([aux_samples, new_feature], dim=1)
    
    density = mps(aux_embedding(aux_samples.to(device)),
                  marginalize_output=True,
                  renormalize=True,
                  )
    
    if i == (n_features - 1):
        density = torch.outer(density, density)
    
    distr = density.diagonal().reshape(samples_size, phys_dim)
    distr = distr / distr.sum(dim=-1, keepdim=True)
    distr = distr.cumsum(dim=-1)
    
    probs = torch.rand(samples_size, 1).to(device)
    new_samples = phys_dim - (probs < distr).sum(dim=-1)
    
    if i > 0:
        samples = torch.cat([samples,
                             new_samples.cpu().int().reshape(-1, 1)], dim=1)
    else:
        samples = new_samples.cpu().int().reshape(-1, 1)

samples = samples / phys_dim

mps.reset()
mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [138]:
samples[0]

tensor([0.3333, 0.3333, 0.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.3333, 0.3333,
        0.3333, 0.6667, 0.0000, 0.3333, 0.6667, 0.3333, 0.0000, 0.6667, 0.0000,
        0.6667, 0.0000, 0.6667, 0.0000, 0.3333, 0.3333, 0.3333, 0.3333, 0.6667,
        0.0000, 0.6667, 0.0000, 0.6667, 0.0000, 0.6667, 0.3333, 0.0000, 0.3333,
        0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.6667, 0.0000, 0.3333, 0.3333,
        0.6667, 0.3333, 0.0000, 0.3333, 0.3333, 0.6667, 0.3333, 0.3333, 0.0000,
        0.3333, 0.3333, 0.6667, 0.3333, 0.0000, 0.6667, 0.3333, 0.3333, 0.0000,
        0.3333, 0.3333, 0.6667, 0.3333, 0.3333, 0.0000, 0.3333, 0.6667, 0.3333,
        0.3333, 0.0000, 0.6667, 0.3333, 0.3333, 0.3333, 0.3333, 0.0000, 0.3333,
        0.3333, 0.3333, 0.3333, 0.3333, 0.6667, 0.3333, 0.3333, 0.0000, 0.3333,
        0.6667, 0.0000, 0.3333, 0.6667, 0.0000, 0.3333, 0.3333, 0.3333, 0.3333,
        0.6667])

In [139]:
fn_aklt(samples[:20].to(device))

tensor([[ 4.1943e+06],
        [ 3.3554e+07],
        [ 8.3886e+06],
        [ 1.3422e+08],
        [ 3.3554e+07],
        [ 6.7109e+07],
        [-2.6843e+08],
        [ 2.6843e+08],
        [ 3.3554e+07],
        [ 2.6843e+08],
        [ 6.7109e+07],
        [ 1.3422e+08],
        [ 5.3687e+08],
        [ 1.3422e+08],
        [-3.3554e+07],
        [-6.7109e+07],
        [-3.3554e+07],
        [-3.3554e+07],
        [-2.0971e+06],
        [-1.3422e+08]], device='cuda:0')

In [22]:
cores = tt_rss(function=fn_aklt,
               embedding=embedding,
               sketch_samples=samples[:(samples_size // 2)],
               domain=torch.arange(phys_dim).float() / phys_dim,
               rank=5*bond_dim,
               cum_percentage=0.9,
               batch_size=500,
               device=device,
               verbose=True)

# Save cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

# torch.save(cores, os.path.join(results_dir, f'{fn_name}.pt'))



|| Site: 1 / 100 ||
* Max D_k: min(3, 10)
* T_k out dim: 100

Core 1:
-------
tensor([[-1.0000e+00,  4.2975e-08],
        [ 4.2975e-08,  1.0000e+00],
        [ 0.0000e+00,  0.0000e+00]])
* Final D_k: 2
* S_k out dim: 2


|| Site: 2 / 100 ||
* Max D_k: min(6, 10)
* T_k out dim: 100

Core 2:
-------
tensor([[[ 6.6229e-15, -3.0388e-08],
         [ 3.1423e-07, -7.0711e-01],
         [ 8.9443e-01,  7.3135e-08]],

        [[ 1.5411e-07, -7.0711e-01],
         [ 4.4721e-01,  5.2638e-08],
         [-3.8438e-08, -3.1430e-15]]])
* Final D_k: 2
* S_k out dim: 4


|| Site: 3 / 100 ||
* Max D_k: min(6, 10)
* T_k out dim: 100

Core 3:
-------
tensor([[[-8.4515e-01,  8.5630e-08],
         [-3.1860e-08, -6.2017e-01],
         [ 2.0139e-15,  1.2847e-08]],

        [[-2.4881e-07,  2.2705e-14],
         [-5.3452e-01, -1.2873e-07],
         [-7.7972e-08, -7.8446e-01]]])
* Final D_k: 2
* S_k out dim: 8


|| Site: 4 / 100 ||
* Max D_k: min(6, 10)
* T_k out dim: 100

Core 4:
-------
tensor([[[-3.9877e-15, 

In [154]:
# Load cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

cores = torch.load(os.path.join(results_dir, f'{fn_name}.pt'),
                   weights_only=False)

In [155]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

# Disable parameters
mps2.parameterize(set_param=False, override=True)

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [130]:
torch.tensor(mps.bond_dim)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2])

In [131]:
torch.tensor(mps2.bond_dim)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2])

In [132]:
mps.reset()
mps.norm(), mps.norm(log_scale=True)

(tensor(nan, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [133]:
mps2.reset()
mps2.norm(), mps2.norm(log_scale=True)

(tensor(inf, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [140]:
mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [141]:
results = torch.stack([mps_results, mps2_results], dim=1)
((results[:, 0] - results[:, 1]) / results[:, 0]).pow(2).sum().sqrt()

tensor(2.4105e-05, device='cuda:0')

In [142]:
results

tensor([[ 4.1943e+06,  4.1943e+06],
        [ 3.3554e+07,  3.3554e+07],
        [ 8.3886e+06,  8.3886e+06],
        [ 1.3422e+08,  1.3422e+08],
        [ 3.3554e+07,  3.3554e+07],
        [ 6.7109e+07,  6.7108e+07],
        [-2.6843e+08, -2.6843e+08],
        [ 2.6843e+08,  2.6843e+08],
        [ 3.3554e+07,  3.3554e+07],
        [ 2.6843e+08,  2.6843e+08],
        [ 6.7109e+07,  6.7109e+07],
        [ 1.3422e+08,  1.3422e+08],
        [ 5.3687e+08,  5.3687e+08],
        [ 1.3422e+08,  1.3422e+08],
        [-3.3554e+07, -3.3554e+07],
        [-6.7109e+07, -6.7109e+07],
        [-3.3554e+07, -3.3554e+07],
        [-3.3554e+07, -3.3554e+07],
        [-2.0971e+06, -2.0971e+06],
        [-1.3422e+08, -1.3422e+08],
        [ 1.3422e+08,  1.3422e+08],
        [-1.6777e+07, -1.6777e+07],
        [ 1.3422e+08,  1.3422e+08],
        [ 2.6843e+08,  2.6843e+08],
        [ 3.3554e+07,  3.3554e+07],
        [ 2.6843e+08,  2.6843e+08],
        [ 2.0971e+06,  2.0971e+06],
        [-3.3554e+07, -3.355

In [143]:
((results[:, 0] - results[:, 1]).abs() / results[:, 0]).max()

tensor(4.1127e-06, device='cuda:0')

In [144]:
# dataset samples
samples_size = 20
samples = torch.rand(size=(samples_size, n_features))

mps.reset()
mps2.reset()

with torch.no_grad():
    mps_results = mps(embedding(samples.to(device)))
    mps2_results = mps2(embedding(samples.to(device)))

In [145]:
results = torch.stack([mps_results, mps2_results], dim=1)
(results[:, 0] - results[:, 1]).pow(2).sum().sqrt()

tensor(0., device='cuda:0')

In [146]:
results

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')

In [147]:
(results[:, 0] - results[:, 1]).max()

tensor(0., device='cuda:0')

In [156]:
mps.reset()
mps.unset_data_nodes()
mps2.reset()
mps2.unset_data_nodes()

mps_norm = mps.norm(log_scale=True)
mps2_norm = mps.norm(log_scale=True)

mps.reset()
mps2.reset()

for node1, node2 in zip(mps.mats_env, mps2.mats_env):
    node1['input'] ^ node2['input']

In [157]:
log_scale = 0

# Contract mps with mps_rss
stack = tk.stack(mps.mats_env)
stack_rss = tk.stack(mps2.mats_env)
stack ^ stack_rss

mats_results = tk.unbind(stack @ stack_rss)

mats_results[0] = mps.left_node @ (mps2.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps2.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

54.58402633666992


In [158]:
(2*approx_mps_norm - mps_norm - mps2_norm).exp()

tensor(1.0000, device='cuda:0')

In [39]:
mps.canonicalize(oc=0)
mps.canonicalize(oc=n_features - 1)

In [40]:
mps2.canonicalize(oc=0)
mps2.canonicalize(oc=n_features - 1)

In [41]:
for node, node2 in zip(mps.mats_env, mps2.mats_env):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[-1.0000e+00,  1.6182e-07],
         [-9.3363e-08, -1.0000e+00],
         [ 0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 1.0000e+00,  1.3671e-06],
         [-1.4672e-06,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 1.2621e-09,  1.2176e-07],
         [-9.1366e-07, -7.0711e-01],
         [ 8.9443e-01, -7.3831e-07]],

        [[ 3.6841e-07,  7.0711e-01],
         [-4.4721e-01,  5.5046e-07],
         [-9.6147e-09,  6.4437e-08]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 2.5425e-09, -1.2315e-06],
         [ 1.5605e-06, -7.0711e-01],
         [ 8.9443e-01,  1.0476e-06]],

        [

In [42]:
for node, node2 in zip(mps.mats_env[(n_features//2 - 5):-(n_features//2 - 5)],
                       mps2.mats_env[(n_features//2 - 5):-(n_features//2 - 5)]):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[-0.4149, -0.4158],
         [ 0.5773, -0.0095],
         [-0.4015,  0.4006]],

        [[ 0.4006,  0.4015],
         [ 0.0095,  0.5773],
         [-0.4158,  0.4149]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 0.6458,  0.0369],
         [-0.3779,  0.4365],
         [ 0.0284, -0.4975]],

        [[ 0.4975,  0.0284],
         [ 0.4365,  0.3779],
         [-0.0369,  0.6458]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 0.4102, -0.4071],
         [ 0.5773,  0.0028],
         [ 0.4063,  0.4094]],

        [[-0.4094,  0.4063],
         [ 0.0028, -0.5773],
         [ 0.4071,  0.4102]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 0.0064, -0.0462],
         [ 0.1121, -0.5664],
         [ 0.8074,  0.1124]],

        [[-0.1124,  0.8074],
         [-0.5664, -0.1121],
         [ 0.0462,  0.0064]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[-0.4039,  0.4095],

In [43]:
m1 = mps.mats_env[n_features//2]
m2 = mps2.mats_env[n_features//2]

m1.tensor, m2.tensor

(Parameter containing:
 tensor([[[ 0.5803,  0.0069],
          [ 0.4010,  0.4156],
          [-0.0071,  0.5776]],
 
         [[ 0.5743,  0.0069],
          [-0.4153,  0.4008],
          [ 0.0071, -0.5770]]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[[ 0.5782, -0.2442],
          [-0.5129, -0.2651],
          [ 0.2032,  0.4811]],
 
         [[-0.4811,  0.2032],
          [-0.2651,  0.5129],
          [ 0.2442,  0.5782]]], device='cuda:0', requires_grad=True))

In [44]:
m1.norm(), m2.norm()

(tensor(1.4142, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>),
 tensor(1.4142, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>))

In [45]:
m1.norm() * m2.norm()

tensor(2.0000, device='cuda:0', grad_fn=<MulBackward0>)

In [46]:
out = torch.einsum('lir,lir->', m1.tensor, m2.tensor)
out

tensor(0.0033, device='cuda:0', grad_fn=<ViewBackward0>)

In [47]:
out = torch.einsum('lir,ljr->ij', m1.tensor, m1.tensor)
out

tensor([[ 6.6667e-01, -1.6084e-04, -1.4589e-06],
        [-1.6084e-04,  6.6668e-01,  3.0558e-03],
        [-1.4589e-06,  3.0558e-03,  6.6665e-01]], device='cuda:0',
       grad_fn=<ViewBackward0>)

In [70]:
mps.reset()
mps.norm(), mps.norm(log_scale=True)

(tensor(nan, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [71]:
mps2.reset()
mps2.norm(), mps2.norm(log_scale=True)

(tensor(inf, device='cuda:0'), tensor(54.5840, device='cuda:0'))

In [54]:
mps.reset(), mps2.reset()

(None, None)

In [72]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

# Disable parameters
_ = mps.parameterize(set_param=False, override=True)

# mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [73]:
# Load cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

cores = torch.load(os.path.join(results_dir, f'{fn_name}.pt'),
                   weights_only=False)

# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

# Disable parameters
_ = mps2.parameterize(set_param=False, override=True)

# mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

In [74]:
for node, node2 in zip(mps.mats_env, mps2.mats_env):
    node['input'] ^ node2['input']

In [75]:
log_scale = 0

# Contract mps with mps2
stack = tk.stack(mps.mats_env)
stack2 = tk.stack(mps2.mats_env)
stack ^ stack2

mats_results = tk.unbind(stack @ stack2)

mats_results[0] = mps.left_node @ (mps2.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps2.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

54.58402633666992


In [76]:
# Contract mps with mps2
stack = tk.stack(mps.mats_env)
stack2 = tk.stack(mps2.mats_env)
stack ^ stack2

mats_results = tk.unbind(stack @ stack2)

mats_results[0] = mps.left_node @ (mps2.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps2.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat

print(result.tensor)

tensor(inf, device='cuda:0')


In [77]:
0.7071 * 0.7071

0.49999040999999994

## Topological phase estimation (mps)

In [331]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1).to(torch.complex64)

# Model hyperparameters
n_features = 100 #500
phys_dim = aklt_core.shape[1]
bond_dim = aklt_core.shape[0]

fn_name = 'aklt_mps'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# device = 'cpu'
boundary_conditions = [0, 0]
scale = 1 #sqrt(phys_dim)
# 
aklt_cores = [aklt_core.to(device) * scale for _ in range(n_features)]
# aklt_cores = [aklt_core.to(device) for _ in range(n_features)]
aklt_cores[0] = aklt_cores[0][boundary_conditions[0], :, :]
aklt_cores[-1] = aklt_cores[-1][:, :, boundary_conditions[1]]


# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

# Disable parameters
_ = mps.parameterize(set_param=False, override=True)

print(mps.norm())
mps.reset()

# mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

tensor(0.7071, device='cuda:0')


In [332]:
sigma_x = torch.tensor([[0, sqrt(1 / 3)],
                        [sqrt(1 / 3), 0]])
sigma_y = torch.tensor([[0, -sqrt(1 / 3)*1j],
                        [sqrt(1 / 3)*1j, 0]])
sigma_z = torch.tensor([[sqrt(1 / 3), 0],
                        [0, -sqrt(1 / 3)]])

aklt_core_pauli = torch.stack([sigma_x, sigma_y, sigma_z], dim=1)
aklt_core_pauli

tensor([[[ 0.0000+0.0000j,  0.5774+0.0000j],
         [ 0.0000+0.0000j, -0.0000-0.5774j],
         [ 0.5774+0.0000j,  0.0000+0.0000j]],

        [[ 0.5774+0.0000j,  0.0000+0.0000j],
         [ 0.0000+0.5774j,  0.0000+0.0000j],
         [ 0.0000+0.0000j, -0.5774+0.0000j]]])

In [333]:
U = torch.linalg.lstsq(
    aklt_core.permute(0, 2, 1).reshape(4, 3).to(aklt_core_pauli.dtype),
    aklt_core_pauli.permute(0, 2, 1).reshape(4, 3)).solution

U

tensor([[ 0.7071+0.0000j,  0.0000-0.7071j,  0.0000-0.0000j],
        [ 0.0000-0.0000j,  0.0000-0.0000j, -1.0000-0.0000j],
        [-0.7071-0.0000j,  0.0000-0.7071j,  0.0000-0.0000j]])

In [334]:
U @ U.H

tensor([[1.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 1.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 1.0000+0.j]])

In [335]:
L = 10
aux_mps = mps
aux_mps_copy = aux_mps.copy(share_tensors=True)
aux_mps_copy.left_node.move_to_network(aux_mps)

ux = torch.Tensor([[1, 0, 0],
                   [0, -1, 0],
                   [0, 0, -1]]).to(torch.complex64)
uz = torch.Tensor([[-1, 0, 0],
                   [0, -1, 0],
                   [0, 0, 1]]).to(torch.complex64)

ux = (U @ ux @ U.H).to(device)
uz = (U @ uz @ U.H).to(device)


start = (n_features // 2) - (L * 5 // 2)
node_blocks = [aux_mps.mats_env[(start + i*L):(start + (i+1)*L)]
               for i in range(5)]
node_blocks_copy = [aux_mps_copy.mats_env[(start + i*L):(start + (i+1)*L)]
                    for i in range(5)]

In [336]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks[i]):
        node_z = tk.Node(tensor=uz,
                         name=f'node_z_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_z['output']
        
        # node_blocks[i][j] = node @ node_z

In [337]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks_copy[i]):
        node_x = tk.Node(tensor=ux,
                         name=f'node_x_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_x['output']
        
        # node_blocks_copy[i][j] = node @ node_x

In [338]:
i1_lst = [0, 1, 2, 3, 4]
i2_lst = [0, 3, 2, 1, 4]

contracted_blocks = []
for i1, i2 in zip(i1_lst, i2_lst):
    
    # Connect all nodes of each block
    for j in range(L):
        if i1 in [1, 2]:
            edge1 = node_blocks[i1][j].neighbours('input')['input']
        else:
            edge1 = node_blocks[i1][j]['input']
        
        if i2 in [1, 2]:
            edge2 = node_blocks_copy[i2][j].neighbours('input')['input']
        else:
            edge2 = node_blocks_copy[i2][j]['input']
            
        edge1 ^ edge2
    
    # Contract with ux, uz
    for j in range(L):
        if i1 in [1, 2]:
            node_blocks[i1][j] = node_blocks[i1][j]['input'].contract()
        
        if i2 in [1, 2]:
            node_blocks_copy[i2][j] = node_blocks_copy[i2][j]['input'].contract()
    
    # Contract each node of a block with the corresponding copy
    aux_results = []
    for j in range(L):
        aux_results.append(node_blocks[i1][j] @ node_blocks_copy[i2][j])
    
    # Contract all nodes in each block in line
    result = aux_results[0]
    for j in range(1, L):
        result @= aux_results[j]
    
    contracted_blocks.append(result)

In [339]:
len(contracted_blocks)

5

In [340]:
contracted_blocks[0].shape

torch.Size([2, 2, 2, 2])

In [341]:
u, s, vh = torch.linalg.svd(contracted_blocks[0].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6950e-05, 1.6935e-05, 1.6935e-05], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)

In [342]:
u[:, 0]

tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [343]:
vh[0, :]

tensor([0.7071-0.j, 0.0000-0.j, 0.0000-0.j, 0.7071-0.j], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [344]:
u, s, vh = torch.linalg.svd(contracted_blocks[-1].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 1.6950e-05, 1.6935e-05, 1.6935e-05], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)

In [345]:
u[:, 0]

tensor([0.7071+0.j, 0.0000+0.j, 0.0000+0.j, 0.7071+0.j], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [346]:
u[:, 0].view(2, 2) @ u[:, 0].view(2, 2)

tensor([[0.5000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.5000+0.j]], device='cuda:0', grad_fn=<MmBackward0>)

In [347]:
vh[0, :]

tensor([0.7071-0.j, 0.0000-0.j, 0.0000-0.j, 0.7071-0.j], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [348]:
_, left_node = contracted_blocks[0].split(node1_axes=['left_0', 'left_1'],
                                          node2_axes=['right_0', 'right_1'],
                                          side='left',
                                          rank=1)

right_node, _ = contracted_blocks[-1].split(node1_axes=['left_0', 'left_1'],
                                            node2_axes=['right_0', 'right_1'],
                                            side='right',
                                            rank=1)

In [349]:
left_node

Node(
 	name: split_1
	tensor:
		tensor([[[0.7071-0.j, 0.0000-0.j],
		         [0.0000-0.j, 0.7071-0.j]]], device='cuda:0', grad_fn=<ViewBackward0>)
	axes:
		[split
		 right_0
		 right_1]
	edges:
		[split_0[split] <-> split_1[split]
		 mats_env_node_(34)_0[right] <-> mats_env_node_(35)_0[left]
		 mats_env_node_(34)_1[right] <-> mats_env_node_(35)_1[left]])

In [350]:
right_node

Node(
 	name: split_2
	tensor:
		tensor([[[0.7071+0.j],
		         [0.0000+0.j]],
		
		        [[0.0000+0.j],
		         [0.7071+0.j]]], device='cuda:0', grad_fn=<ViewBackward0>)
	axes:
		[left_0
		 left_1
		 split]
	edges:
		[mats_env_node_(64)_0[right] <-> mats_env_node_(65)_0[left]
		 mats_env_node_(64)_1[right] <-> mats_env_node_(65)_1[left]
		 split_2[split] <-> split_3[split]])

In [351]:
result = left_node @ contracted_blocks[1] @ contracted_blocks[2] @ contracted_blocks[3] @ right_node

In [352]:
result.tensor * bond_dim**2

tensor([[-1.0000+0.j]], device='cuda:0', grad_fn=<MulBackward0>)

## Topological phase estimation (mps2)

In [435]:
# Load cores
cwd = os.getcwd()
results_dir = os.path.join(cwd, '..', '..', 'results', '4_interpretability')

cores = torch.load(os.path.join(results_dir, f'{fn_name}.pt'),
                   weights_only=False)

# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device).to(torch.complex64) for c in cores])

mps2.canonicalize(oc=0, renormalize=True)

# Disable parameters
mps2 = mps2.parameterize(set_param=False, override=True)

# Canonical form doesn't give us the exact normalization we want
# We need to hev all tensors except the ones at the extremes with the same norm
# (sqrt(2)), and the norm of the extreme tensors equal to 1

# Renormalize cores
log_norm = 0

for node in mps2.mats_env:
    log_norm += (node.norm() / sqrt(phys_dim)).log()
    node.tensor = node.tensor / node.norm()

for node in mps2.mats_env[1:-1]:
    log_norm -= log(sqrt(2))
    node.tensor = node.tensor * sqrt(2)

for node in mps2.mats_env[:1] + mps2.mats_env[-1:]:
    node.tensor = node.tensor * (log_norm / 2).exp()

print(mps2.norm())
mps2.reset()

# mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

tensor(0.7071, device='cuda:0')


In [436]:
for node in mps2.mats_env:
    print(node.norm() ** 2)
print()

tensor(1.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2.0000, device='cuda:0')
tensor(2

In [437]:
A_plus = torch.tensor([[0, sqrt(2 / 3)],
                       [0, 0]])
A_zero = torch.tensor([[-sqrt(1 / 3), 0],
                       [0, sqrt(1 / 3)]])
A_minus = torch.tensor([[0, 0],
                        [-sqrt(2 / 3), 0]])

aklt_core = torch.stack([A_plus, A_zero, A_minus], dim=1).to(torch.complex64)

In [438]:
sigma_x = torch.tensor([[0, sqrt(1 / 3)],
                        [sqrt(1 / 3), 0]])
sigma_y = torch.tensor([[0, -sqrt(1 / 3)*1j],
                        [sqrt(1 / 3)*1j, 0]])
sigma_z = torch.tensor([[sqrt(1 / 3), 0],
                        [0, -sqrt(1 / 3)]])

aklt_core_pauli = torch.stack([sigma_x, sigma_y, sigma_z], dim=1)
aklt_core_pauli

tensor([[[ 0.0000+0.0000j,  0.5774+0.0000j],
         [ 0.0000+0.0000j, -0.0000-0.5774j],
         [ 0.5774+0.0000j,  0.0000+0.0000j]],

        [[ 0.5774+0.0000j,  0.0000+0.0000j],
         [ 0.0000+0.5774j,  0.0000+0.0000j],
         [ 0.0000+0.0000j, -0.5774+0.0000j]]])

In [439]:
U = torch.linalg.lstsq(
    aklt_core.permute(0, 2, 1).reshape(4, 3).to(aklt_core_pauli.dtype),
    aklt_core_pauli.permute(0, 2, 1).reshape(4, 3)).solution

U

tensor([[ 0.7071+0.0000j,  0.0000-0.7071j,  0.0000-0.0000j],
        [ 0.0000-0.0000j,  0.0000-0.0000j, -1.0000-0.0000j],
        [-0.7071-0.0000j,  0.0000-0.7071j,  0.0000-0.0000j]])

In [440]:
U @ U.H

tensor([[1.0000+0.j, 0.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 1.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 0.0000+0.j, 1.0000+0.j]])

In [441]:
L = 7
aux_mps = mps2
aux_mps_copy = aux_mps.copy(share_tensors=True)
aux_mps_copy.left_node.move_to_network(aux_mps)

ux = torch.Tensor([[1, 0, 0],
                   [0, -1, 0],
                   [0, 0, -1]]).to(torch.complex64)
uz = torch.Tensor([[-1, 0, 0],
                   [0, -1, 0],
                   [0, 0, 1]]).to(torch.complex64)

ux = (U @ ux @ U.H).to(device)
uz = (U @ uz @ U.H).to(device)


start = (n_features // 2) - (L * 5 // 2)
node_blocks = [aux_mps.mats_env[(start + i*L):(start + (i+1)*L)]
               for i in range(5)]
node_blocks_copy = [aux_mps_copy.mats_env[(start + i*L):(start + (i+1)*L)]
                    for i in range(5)]

In [442]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks[i]):
        node_z = tk.Node(tensor=uz,
                         name=f'node_z_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_z['output']
        
        # node_blocks[i][j] = node @ node_z

In [443]:
for i in [1, 2]:
    for j, node in enumerate(node_blocks_copy[i]):
        node_x = tk.Node(tensor=ux,
                         name=f'node_x_({i}_{j})',
                         axes_names=('input', 'output'),
                         network=aux_mps)
        node['input'] ^ node_x['output']
        
        # node_blocks_copy[i][j] = node @ node_x

In [444]:
i1_lst = [0, 1, 2, 3, 4]
i2_lst = [0, 3, 2, 1, 4]

contracted_blocks = []
for i1, i2 in zip(i1_lst, i2_lst):
    
    # Connect all nodes of each block
    for j in range(L):
        if i1 in [1, 2]:
            edge1 = node_blocks[i1][j].neighbours('input')['input']
        else:
            edge1 = node_blocks[i1][j]['input']
        
        if i2 in [1, 2]:
            edge2 = node_blocks_copy[i2][j].neighbours('input')['input']
        else:
            edge2 = node_blocks_copy[i2][j]['input']
            
        edge1 ^ edge2
    
    # Contract with ux, uz
    for j in range(L):
        if i1 in [1, 2]:
            node_blocks[i1][j] = node_blocks[i1][j]['input'].contract_()
        
        if i2 in [1, 2]:
            node_blocks_copy[i2][j] = node_blocks_copy[i2][j]['input'].contract_()
    
    # Contract each node of a block with the corresponding copy
    aux_results = []
    for j in range(L):
        aux_results.append(tk.contract_between_(node_blocks[i1][j],
                                                node_blocks_copy[i2][j]))
    
    # Contract all nodes in each block in line
    result = aux_results[0]
    for j in range(1, L):
        result = tk.contract_between_(result, aux_results[j])
    
    contracted_blocks.append(result)

### Optimize to have equal tensors at the boundaries

In [50]:
import torch.nn as nn
import torch.optim as optim
from math import pi

In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

block_idx = -1

aux_tn = tk.TensorNetwork()
left_block = tk.Node(tensor=contracted_blocks[0].tensor.to(torch.complex64),
                     name='left_block',
                     axes_names=('left', 'left', 'right', 'right'),
                     network=aux_tn,
                     device=device)
right_block = tk.Node(tensor=contracted_blocks[block_idx].tensor.to(torch.complex64),
                      name='right_block',
                      axes_names=('left', 'left', 'right', 'right'),
                      network=aux_tn,
                      device=device)

node_v1 = tk.ParamNode(shape=(2, 2),
                       init_method='randn',
                       name='node_v1',
                       axes_names=('new_left', 'left'),
                       network=aux_tn,
                       device=device)
node_v1.tensor = node_v1.tensor.to(torch.complex64)
node_v1_copy = node_v1.copy(share_tensor=True)

node_v2H = tk.ParamNode(shape=(2, 2),
                        init_method='randn',
                        name='node_v2H',
                        axes_names=('new_right', 'right'),
                        network=aux_tn,
                        device=device)
node_v2H.tensor = node_v2H.tensor.to(torch.complex64)
node_v2H_copy = node_v2H.copy(share_tensor=True)

_ = node_v1['left'] ^ right_block['left_0']
_ = node_v1_copy['left'] ^ right_block['left_1']
_ = node_v2H['right'] ^ right_block['right_0']
_ = node_v2H_copy['right'] ^ right_block['right_1']


# thetas = torch.randn(2, device=device) #* (pi/2)
# thetas = nn.Parameter(thetas)

# psis = torch.randn(2, device=device) #* (2*pi)
# psis = nn.Parameter(psis)


# def set_params(thetas, psis):
#     v1 = torch.stack([thetas[0].cos(),
#                       (psis[0] * 1j).exp() * thetas[0].sin(),
#                       -(psis[0] * -1j).exp() * thetas[0].sin(),
#                       thetas[0].cos()],
#                      dim=0).reshape(2, 2)
#     v2H = torch.stack([thetas[1].cos(),
#                        (psis[1] * 1j).exp() * thetas[1].sin(),
#                        -(psis[1] * -1j).exp() * thetas[1].sin(),
#                        thetas[1].cos()],
#                       dim=0).reshape(2, 2)
    
#     node_v1.tensor = v1
#     node_v2H.tensor = v2H


def criterion(left_block, right_block):
    diff = left_block.tensor - right_block.tensor
    return diff.norm()

lr = 1e-3
weight_decay = 1e-3
optimizer = optim.Adam(params=aux_tn.parameters(),
                       lr=lr,
                       weight_decay=weight_decay)

criterion(left_block, right_block)

tensor(0.0112, device='cuda:0')

In [52]:
n_epochs = 2000

for i in range(n_epochs):
    # set_params(thetas, psis)
    
    aux_right_block = node_v1 @ right_block @ node_v2H
    aux_right_block = node_v1_copy @ aux_right_block @ node_v2H_copy
    
    loss = criterion(left_block, aux_right_block)
    
    optimizer.zero_grad()
    loss.backward()
        
    # Gradient descent
    optimizer.step()
    
    if (i + 1) % 100 == 0:
        print(f'Epoch {(i+1)}/{n_epochs} => Loss: {loss}')
    
    # if (thetas < 0).any() or (thetas > (pi/2)).any():
    #     print('Thetas out')
    #     aux_thetas = torch.where(thetas < 0, 0, thetas)
    #     aux_thetas = torch.where(thetas  > (pi/2), pi/2, aux_thetas)
    #     thetas = nn.Parameter(aux_thetas)
        
    #     optimizer = optim.Adam(params=[thetas, psis],
    #                            lr=lr,
    #                            weight_decay=weight_decay)
    
    # if (psis < 0).any() or (psis > (2*pi)).any():
    #     print('Psis out')
    #     aux_psis = torch.where(psis < 0, 0, psis)
    #     aux_psis = torch.where(psis  > (2*pi), 2*pi, aux_psis)
    #     psis = nn.Parameter(aux_psis)
        
    #     optimizer = optim.Adam(params=[thetas, psis],
    #                            lr=lr,
    #                            weight_decay=weight_decay)

Epoch 100/2000 => Loss: 4.967123508453369
Epoch 200/2000 => Loss: 3.7381439208984375
Epoch 300/2000 => Loss: 2.866499900817871
Epoch 400/2000 => Loss: 2.234808921813965
Epoch 500/2000 => Loss: 1.7687114477157593
Epoch 600/2000 => Loss: 1.4163049459457397
Epoch 700/2000 => Loss: 1.1341608762741089
Epoch 800/2000 => Loss: 0.8817145824432373
Epoch 900/2000 => Loss: 0.6392370462417603
Epoch 1000/2000 => Loss: 0.43162471055984497
Epoch 1100/2000 => Loss: 0.2694180905818939
Epoch 1200/2000 => Loss: 0.13367833197116852
Epoch 1300/2000 => Loss: 0.019314995035529137
Epoch 1400/2000 => Loss: 0.007953583262860775
Epoch 1500/2000 => Loss: 0.007953308522701263
Epoch 1600/2000 => Loss: 0.007953271269798279
Epoch 1700/2000 => Loss: 0.00795323308557272
Epoch 1800/2000 => Loss: 0.007953193038702011
Epoch 1900/2000 => Loss: 0.007953152060508728
Epoch 2000/2000 => Loss: 0.007953107357025146


In [53]:
# set_params(thetas, psis)

print(node_v1.tensor)
print(node_v2H.tensor)

contracted_blocks[block_idx].tensor = torch.einsum('ab,cd,bdeg,ef,gh->acfh',
                                                   node_v1.tensor,
                                                   node_v1_copy.tensor,
                                                   contracted_blocks[-1].tensor,
                                                   node_v2H.tensor,
                                                   node_v2H_copy.tensor)

Parameter containing:
tensor([[ 1.2668+0.j,  0.1125+0.j],
        [ 0.1125+0.j, -1.2668+0.j]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[-0.0539+0.j, -0.7845+0.j],
        [ 0.7845+0.j, -0.0539+0.j]], device='cuda:0', requires_grad=True)


### Continue

In [445]:
len(contracted_blocks)

5

In [446]:
contracted_blocks[0].shape

torch.Size([2, 2, 2, 2])

In [447]:
u, s, vh = torch.linalg.svd(contracted_blocks[0].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 4.5725e-04, 4.5725e-04, 4.5720e-04], device='cuda:0')

In [448]:
u[:, 0]

tensor([ 7.0711e-01+0.j, -4.9037e-08+0.j, -4.9037e-08+0.j,  7.0711e-01+0.j],
       device='cuda:0')

In [449]:
vh[0, :]

tensor([7.0711e-01-0.j, 8.8953e-08-0.j, 8.8953e-08-0.j, 7.0711e-01-0.j],
       device='cuda:0')

In [450]:
u, s, vh = torch.linalg.svd(contracted_blocks[-1].tensor.reshape(4, 4),
                           full_matrices=False)
s

tensor([1.0000e+00, 4.5725e-04, 4.5725e-04, 4.5725e-04], device='cuda:0')

In [451]:
u[:, 0]

tensor([-7.0711e-01+0.j,  1.0637e-09+0.j, -7.5815e-10+0.j, -7.0711e-01+0.j],
       device='cuda:0')

In [452]:
u[:, 0].view(2, 2) @ u[:, 0].view(2, 2)

tensor([[ 5.0000e-01+0.j, -1.5043e-09+0.j],
        [ 1.0722e-09+0.j,  5.0000e-01+0.j]], device='cuda:0')

In [453]:
vh[0, :]

tensor([-7.0711e-01+0.j,  1.6643e-07-0.j,  1.6615e-07-0.j, -7.0711e-01+0.j],
       device='cuda:0')

In [454]:
_, left_node = contracted_blocks[0].split_(node1_axes=['left_0', 'left_1'],
                                           node2_axes=['right_0', 'right_1'],
                                           side='left',
                                           rank=1)

right_node, _ = contracted_blocks[-1].split_(node1_axes=['left_0', 'left_1'],
                                             node2_axes=['right_0', 'right_1'],
                                             side='right',
                                             rank=1)

In [455]:
left_node

Node(
 	name: split_ip_1
	tensor:
		tensor([[[7.0711e-01-0.j, 8.8953e-08-0.j],
		         [8.8953e-08-0.j, 7.0711e-01-0.j]]], device='cuda:0')
	axes:
		[split
		 right_0
		 right_1]
	edges:
		[split_ip_0[split] <-> split_ip_1[split]
		 split_ip_1[right_0] <-> contract_edges_ip_0[left_0]
		 split_ip_1[right_1] <-> contract_edges_ip_2[left_1]])

In [456]:
right_node

Node(
 	name: split_ip_2
	tensor:
		tensor([[[-7.0711e-01+0.j],
		         [ 1.0637e-09+0.j]],
		
		        [[-7.5815e-10+0.j],
		         [-7.0711e-01+0.j]]], device='cuda:0')
	axes:
		[left_0
		 left_1
		 split]
	edges:
		[contract_edges_ip_2[right_0] <-> split_ip_2[left_0]
		 contract_edges_ip_0[right_1] <-> split_ip_2[left_1]
		 split_ip_2[split] <-> split_ip_3[split]])

In [457]:
result = left_node @ contracted_blocks[1] @ contracted_blocks[2] @ contracted_blocks[3] @ right_node

In [458]:
result.tensor * bond_dim**2

tensor([[1.0000+0.j]], device='cuda:0')

## Check degeneracies

In [78]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [79]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps2.parameters():
    p.requires_grad_(False)

mps2.canonicalize(oc=0, renormalize=True)

In [80]:
@torch.no_grad()
def degeneracies(mps):
    mps.reset()

    prev_auto_stack = mps._auto_stack
    mps.auto_stack = False

    oc = mps._n_features - 1
    
    log_norm = 0
    
    nodes = mps._mats_env[:]
    if mps._boundary == 'obc':
        nodes[0].tensor[1:] = torch.zeros_like(
            nodes[0].tensor[1:])
        nodes[-1].tensor[..., 1:] = torch.zeros_like(
            nodes[-1].tensor[..., 1:])
    
    diff_s = []
    
    for i in range(oc):
        # Get singular values
        aux_tensor = nodes[i]['right'].contract().tensor
        _, s, _ = torch.linalg.svd(aux_tensor.reshape(6, 6))
        s = s[:2]
        diff_s.append((s[0] - s[1]).abs())
        
        result1, result2 = nodes[i]['right'].svd_(
            side='right',
            rank=nodes[i]['right'].size())
        
        # Renormalize
        aux_norm = result2.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result2.tensor = result2.tensor / aux_norm
            log_norm += aux_norm.log()

        result1 = result1.parameterize()
        nodes[i] = result1
        nodes[i + 1] = result2

    for i in range(len(nodes) - 1, oc, -1):
        # Get singular values
        aux_tensor = nodes[i]['left'].contract().tensor
        _, s, _ = torch.linalg.svd(aux_tensor.reshape(6, 6))
        s = s[:2]
        assert len(s) == 2
        diff_s.append((s[0] - s[1]).abs())
        
        result1, result2 = nodes[i]['left'].svd_(
            side='left',
            rank=nodes[i]['left'].size())
        
        # Renormalize
        aux_norm = result1.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result1.tensor = result1.tensor / aux_norm
            log_norm += aux_norm.log()

        result2 = result2.parameterize()
        nodes[i] = result2
        nodes[i - 1] = result1

    nodes[oc] = nodes[oc].parameterize()
    
    # Rescale
    if log_norm != 0:
        rescale = (log_norm / len(nodes)).exp()
        
        for node in nodes:
            node.tensor = node.tensor * rescale
    
    mps.reset()
    
    # Update variables
    mps._mats_env = nodes
    mps.update_bond_dim()

    mps.auto_stack = prev_auto_stack
    
    return diff_s

In [81]:
degeneracies(mps)

[tensor(0.7174, device='cuda:0'),
 tensor(0.1363, device='cuda:0'),
 tensor(0.0454, device='cuda:0'),
 tensor(0.0151, device='cuda:0'),
 tensor(0.0050, device='cuda:0'),
 tensor(0.0017, device='cuda:0'),
 tensor(0.0006, device='cuda:0'),
 tensor(0.0002, device='cuda:0'),
 tensor(6.2585e-05, device='cuda:0'),
 tensor(2.0742e-05, device='cuda:0'),
 tensor(7.2718e-06, device='cuda:0'),
 tensor(2.7418e-06, device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(1.0729e-06, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(2.3842e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(3.5763e-07, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(1.1921e-07, device='cuda:0'),
 tensor(1.5497e-06, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(5.9605e-07, device='cu

In [79]:
degeneracies(mps2)

[tensor(0.7125, device='cuda:0'),
 tensor(0.1358, device='cuda:0'),
 tensor(0.0452, device='cuda:0'),
 tensor(0.0151, device='cuda:0'),
 tensor(0.0050, device='cuda:0'),
 tensor(0.0017, device='cuda:0'),
 tensor(0.0006, device='cuda:0'),
 tensor(0.0002, device='cuda:0'),
 tensor(6.0797e-05, device='cuda:0'),
 tensor(2.2411e-05, device='cuda:0'),
 tensor(8.7023e-06, device='cuda:0'),
 tensor(2.5034e-06, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(5.9605e-07, device='cuda:0'),
 tensor(8.3447e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(1.3113e-06, device='cuda:0'),
 tensor(8.3447e-07, device='cuda:0'),
 tensor(4.7684e-07, device='cuda:0'),
 tensor(1.0729e-06, device='cuda:0'),
 tensor(1.7881e-06, device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(1.0729e-06, device='cuda:0'),
 tensor(7.1526e-07, device='cuda:0'),
 tensor(4.7684e-06, device='cuda:0'),
 tensor(3.0994e-06, device='cuda:0'),
 tensor(1.6689e-06, device='cuda:0'),
 tensor(1.4305

## Get uniform tensor

In [121]:
# Initialize network
mps = tk.models.MPS(tensors=aklt_cores)

print('norm:', mps.norm())
print('log-scale norm:', mps.norm(log_scale=True))

mps.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps.parameters():
    p.requires_grad_(False)

norm: tensor(nan, device='cuda:0', grad_fn=<SqrtBackward0>)
log-scale norm: tensor(54.5840, device='cuda:0', grad_fn=<DivBackward0>)


In [124]:
# Initialize tensorized model
mps2 = tk.models.MPS(tensors=[c.to(device) for c in cores])

mps2.trace(torch.zeros(1, n_features, phys_dim, device=device))

# Disable parameters
for p in mps2.parameters():
    p.requires_grad_(False)

mps2.canonicalize(oc=0, renormalize=True)

In [122]:
@torch.no_grad()
def canonicalize(mps):
    mps.reset()

    prev_auto_stack = mps._auto_stack
    mps.auto_stack = False

    oc = 0
    log_norm = 0
    
    nodes = mps._mats_env[:]
    if mps._boundary == 'obc':
        nodes[0].tensor[1:] = torch.zeros_like(
            nodes[0].tensor[1:])
        nodes[-1].tensor[..., 1:] = torch.zeros_like(
            nodes[-1].tensor[..., 1:])

    # SVDs from right to left
    for i in range(len(nodes) - 1, oc, -1):
        result1, result2 = nodes[i]['left'].svd_(
            side='left',
            rank=nodes[i]['left'].size())
        
        # Renormalize
        aux_norm = result1.norm()
        if not aux_norm.isinf() and (aux_norm > 0):
            result1.tensor = result1.tensor / aux_norm
            log_norm += aux_norm.log()

        result2 = result2.parameterize()
        nodes[i] = result2
        nodes[i - 1] = result1

    nodes[oc] = nodes[oc].parameterize()
    
    # Rescale
    if log_norm != 0:
        rescale = (log_norm / len(nodes)).exp()
        
        for node in nodes:
            node.tensor = node.tensor * rescale
    
    # Fix remaining gauge freedom
    for i in range(len(nodes)):
        tensor = nodes[i].tensor
        
        if i == 0:
            left_tensor = torch.einsum('lir,lik->rk', tensor.conj(), tensor)
        else:
            left_tensor = torch.einsum('lir,lm,m,mn,nik->rk',
                                       tensor.conj(),
                                       prev_v,
                                       prev_s,
                                       prev_v_dagger,
                                       tensor)
            
        v, s, v_dagger = torch.linalg.svd(left_tensor, full_matrices=False)
        # print(v)
        # print(s)
        # print(v_dagger)
        # assert torch.dist(v, v_dagger.H) < 1e-5
        # assert (s[0] - s[1]).abs() > 1e-5
        
        if i == 0:
            new_tensor = torch.einsum('ijk,kl->ijl', tensor, v)
        else:
            new_tensor = torch.einsum('hi,ijk,kl->hjl', prev_v_dagger, tensor, v)
            
        nodes[i].tensor = new_tensor
        
        prev_v = v
        prev_s = s
        prev_v_dagger = v_dagger
    
    mps.reset()
    
    # Update variables
    mps._mats_env = nodes
    mps.update_bond_dim()

    mps.auto_stack = prev_auto_stack

In [123]:
canonicalize(mps)

In [125]:
canonicalize(mps2)

In [126]:
for node, node2 in zip(mps.mats_env[(n_features//2 - 5):-(n_features//2 - 5)],
                       mps2.mats_env[(n_features//2 - 5):-(n_features//2 - 5)]):
    print(node.tensor, node2.tensor)
    print()

Parameter containing:
tensor([[[ 0.2472,  0.4805],
         [-0.7608, -0.6437],
         [ 1.1574, -0.5954]],

        [[ 0.5954,  1.1574],
         [-0.6437,  0.7608],
         [-0.4805,  0.2472]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 1.3002,  0.4201],
         [-0.5293,  0.8443],
         [ 0.1061, -0.3285]],

        [[ 0.3285,  0.1061],
         [ 0.8443,  0.5293],
         [-0.4201,  1.3002]]], device='cuda:0', requires_grad=True)

Parameter containing:
tensor([[[ 1.1831, -0.4132],
         [ 0.7226,  0.6863],
         [ 0.2126,  0.6086]],

        [[-0.6086,  0.2126],
         [ 0.6863, -0.7226],
         [ 0.4132,  1.1831]]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([[[ 4.3713e-04, -4.3330e-01],
         [ 3.0735e-01, -9.4796e-01],
         [ 1.3411e+00,  1.3546e-03]],

        [[-1.3543e-03,  1.3411e+00],
         [-9.4796e-01, -3.0735e-01],
         [ 4.3330e-01,  4.3769e-04]]], device='cuda:0', requires_grad=True)

Pa

# Neural Quantum State

In [1]:
import os
os.environ["JAX_PLATFORM_NAME"] = "gpu"

In [2]:
# Import netket library
import netket as nk

# Import Json, this will be needed to load log files
import json

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import time

import flax.linen as nn
import jax.numpy as jnp
import jax

In [3]:
# Define a 1d chain
L = 10  #n_features
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=False)

In [4]:
# Define the Hilbert space based on this graph
hi = nk.hilbert.Spin(s=1, N=g.n_nodes)

In [5]:
sigmaz     = [[1, 0, 0], [0, 0, 0], [0, 0, -1]]
sigmaplus  = [[0, np.sqrt(2), 0], [0, 0, np.sqrt(2)], [0, 0, 0]]
sigmaminus = [[0, 0, 0], [np.sqrt(2), 0, 0], [0, np.sqrt(2), 0]]

heisenberg = np.kron(sigmaz, sigmaz) + \
             0.5 * np.kron(sigmaplus, sigmaminus) + \
             0.5 * np.kron(sigmaminus, sigmaplus)

operator = (0.5*heisenberg + np.dot(heisenberg, heisenberg)/6. + np.identity(9)/3.)
operator

array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.5       , 0.        , 0.5       , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.16666667, 0.        , 0.33333333,
        0.        , 0.16666667, 0.        , 0.        ],
       [0.        , 0.5       , 0.        , 0.5       , 0.        ,
        0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.33333333, 0.        , 0.66666667,
        0.        , 0.33333333, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5       , 0.        , 0.5       , 0.        ],
       [0.        , 0.        , 0.16666667, 0.        , 0.33333333,
        0.        , 0.16666667, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5       , 0.        , 0.5       , 0.        ],


In [6]:
ha = nk.operator.LocalOperator(hilbert=hi,
                               operators=[operator for _ in range(L)],
                               acting_on=[[i, (i + 1)%L] for i in range(L)])

In [7]:
# RBM ansatz with alpha=1
ma = nk.models.RBM(alpha=1)

In [8]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=2*x.shape[-1],
                    #  param_dtype=np.complex128,
                     kernel_init=nn.initializers.normal(stddev=0.1),
                     bias_init=nn.initializers.normal(stddev=0.1))(x)
        x = nk.nn.activation.log_cosh(x)
        return jax.numpy.sum(x, axis=-1)
    
ma = Model()

In [None]:
# Build the sampler
sa = nk.sampler.MetropolisExchange(hilbert=hi, graph=g)

# Optimizer
op = nk.optimizer.Sgd(learning_rate=0.05)

# Stochastic Reconfiguration
sr = nk.optimizer.SR(diag_shift=0.1)

# The variational state
vs = nk.vqs.MCState(sa, ma, n_samples=1000)


gs = nk.VMC(
    hamiltonian=ha,
    optimizer=op,
    preconditioner=sr,
    variational_state=vs)

start = time.time()
gs.run(out='RBM', n_iter=600)
end = time.time()

print('### RBM calculation')
print('Has',vs.n_parameters,'parameters')
print('The RBM calculation took',end-start,'seconds')

  self.n_samples = n_samples


  0%|          | 0/600 [00:00<?, ?it/s]

### RBM calculation
Has 220 parameters
The RBM calculation took 43.622886419296265 seconds
