# QUIMB save boundary operators
Created 03/07/2024

Objectives:
* Use QUIMB to extract boundary operators.

# Package imports

In [1]:
import sys

In [2]:
sys.path.append("../../")

In [3]:
from itertools import chain
import re

In [4]:
import h5py
from tenpy.tools import hdf5_io
import tenpy
import tenpy.linalg.np_conserved as npc

import os
import pickle

In [5]:
import numpy as np
import jax.numpy as jnp

import matplotlib.pyplot as plt

In [6]:
import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import TNOptimizer



# Load data

In [7]:
DATA_DIR = r"../../data/transverse_cluster_200_site_dmrg"

In [8]:
loaded_data = list()

for local_file_name in os.listdir(DATA_DIR):
    f_name = r"{}/{}".format(DATA_DIR, local_file_name, ignore_unknown=False)
    with h5py.File(f_name, 'r') as f:
        data = hdf5_io.load_from_hdf5(f)
        loaded_data.append(data)

In [9]:
b_parameters = sorted(list(d['paramters']['B'] for d in loaded_data))

In [10]:
psi_dict = dict()

In [11]:
for b in b_parameters:
    psi = next(
        d['wavefunction']
        for d in loaded_data
        if d['paramters']['B'] == b
    )

    rounded_b = round(b, 1)
    psi_dict[rounded_b] = psi

# Definitions

In [12]:
np_I = np.array([[1,0],[0,1]])
np_X = np.array([[0,1],[1,0]])

In [13]:
symmetry_actions = [
    [np_I, np_I],
    [np_I, np_X],
    [np_X, np_I],
    [np_X, np_X]
]

In [14]:
symmetry_labels = [
    'II',
    'IX',
    'XI',
    'XX'
]

In [15]:
def generate_problem_rdm(quimb_psi, symmetry_site_pairs, leftmost_symmetry_site,
                         num_symmetry_sites, num_boundary_sites):
    q_top = quimb_psi.copy(deep=True)
    for i, s in symmetry_site_pairs:
        q_top.gate(
            s,
            where=i,
            contract=False,
            inplace=True
        )

    
    indices_to_map = list(chain(
        range(leftmost_symmetry_site-num_boundary_sites, leftmost_symmetry_site),
        range(leftmost_symmetry_site+num_symmetry_sites, leftmost_symmetry_site+num_symmetry_sites+num_boundary_sites)
    ))

    index_mapping = {f'k{i}': f'b{i}' for i in indices_to_map}

    q_bottom = (
        quimb_psi
        .copy()
        .reindex(index_mapping, inplace=True)
        .conj()
    )

    sites_to_contract = {
        'left': list(range(leftmost_symmetry_site-num_boundary_sites)),
        'middle': list(range(leftmost_symmetry_site, leftmost_symmetry_site+num_symmetry_sites)),
        'right': list(range(leftmost_symmetry_site+num_symmetry_sites+num_boundary_sites, quimb_psi.L))
    }

    tags_to_contract = {
        k: [f'I{i}' for i in v]
        for k, v in sites_to_contract.items()
    }

    tn = (q_top & q_bottom)

    tnc = (
        tn
        .contract(tags_to_contract['left'])
        .contract(tags_to_contract['middle'])
        .contract(tags_to_contract['right'])
    )

    return tnc

In [16]:
def initialize_mpo(left_most_symmetry_site, num_symmetry_sites,
                   num_boundary_sites, bond_dimension, phys_dim=2):
    # Kind of hard coding the physical dimensions.
    ml = qtn.MPO_rand(
        num_boundary_sites,
        bond_dimension,
        phys_dim=phys_dim,
        normalize=True,
        sites=list(range(left_most_symmetry_site-num_boundary_sites, left_most_symmetry_site)),
        dtype=np.complex128,
        tags='left_mpo'
    )

    mr = qtn.MPO_rand(
        num_boundary_sites,
        bond_dimension,
        phys_dim=phys_dim,
        normalize=True,
        sites=list(range(
            left_most_symmetry_site + num_symmetry_sites,
            left_most_symmetry_site + num_symmetry_sites + num_boundary_sites
        )),
        dtype=np.complex128,
        tags='right_mpo'
    )

    mpo = (ml & mr)

    return mpo

## Optimisation functions

In [17]:
def split_mpo_pair(mpo_pair):
    ml = qtn.TensorNetwork(
        list(map(mpo_pair.tensor_map.__getitem__, mpo.tag_map['left_mpo']))
    )

    mr = qtn.TensorNetwork(
        list(map(mpo_pair.tensor_map.__getitem__, mpo.tag_map['right_mpo']))
    )

    return (ml, mr)

In [18]:
def overlap_loss_function(ml, mr, rdm_tn, epsilon=0):
    c = (rdm_tn & ml & mr) ^ ...

    c_abs_squared = (
        c
        *jnp.conjugate(c)
    )
    #c_abs_squared = c_abs_squared.astype('float32')
    c_abs = (jnp.sqrt(c_abs_squared+epsilon))

    target = jnp.sqrt(1+epsilon)
    loss = (c_abs - target)**2

    return loss

In [19]:
def overlap_loss_function_mpo_pair(mpo_pair, rdm_tn):
    ml, mr = split_mpo_pair(mpo_pair)

    return overlap_loss_function(ml, mr, rdm_tn)

In [20]:
regex_s = r"^I\d+$"
regex_p = re.compile(regex_s)

In [21]:
def relabel_mpo(mpo, k_label, b_label):
    site_locs = [
        int(k[1:]) for k in mpo.tag_map
        if bool(re.search(regex_p, k))
    ]

    k_in_indices = [f'k{i}' for i in site_locs]
    j_in_indices = [f'b{i}' for i in site_locs]

    k_out_indices = [f'{k_label}{i}' for i in site_locs]
    j_out_indices = [f'{b_label}{i}' for i in site_locs]

    mapping = dict(
        chain(
            zip(k_in_indices, k_out_indices),
            zip(j_in_indices, j_out_indices)
        )
    )

    mpo.reindex(mapping, inplace=True)

In [22]:
def unitarity_tn(tn, total_physical_dim):
    ms = [tn.copy(), tn.copy(), tn.copy()]

    relabel_mpo(ms[0], 'k', 'l')
    relabel_mpo(ms[1], 'm', 'l')
    relabel_mpo(ms[2], 'm', 'b')

    ms[0] = ms[0].conj()
    ms[2] = ms[2].conj()

    n2tn = (tn & tn.conj())
    n2 = n2tn.contract(n2tn.tag_map)
    n4tn = (tn & ms[0] & ms[1] & ms[2])
    n4 = n4tn.contract(n4tn.tag_map)

    out = jnp.real(n4 - 2*n2 + total_physical_dim)

    return out

In [23]:
def overall_loss_function(mpo_pair, rdm_tn, total_physical_dimension,
    unitary_cost_coefficient=1, overlap_cost_coefficient=1, losses=None):
    ml, mr = split_mpo_pair(mpo_pair)

    o_loss = overlap_loss_function(ml, mr, rdm_tn)
    ul_loss = unitarity_tn(ml, total_physical_dimension)
    ur_loss = unitarity_tn(mr, total_physical_dimension)

    out = (
        unitary_cost_coefficient*(ul_loss+ur_loss)
        + overlap_cost_coefficient*o_loss
    )

    out = jnp.real(out)

    if losses is not None:
        losses.append((o_loss, ul_loss, ur_loss))
    return out

# Loop and solve for boundary operators

In [24]:
num_boundary_sites=6
left_most_symmetry_site=60
num_symmetry_sites=80
bond_dimension=5

total_physical_dim = 2**num_boundary_sites

In [25]:
num_iterations = int(5e3) 
num_seeds = 10

alpha=3e-4
beta_1 = 0.4
beta_2 = 0.4
overlap_learning_rate=10

score_target=0.1

In [26]:
for k1, mps_psi in psi_dict.items():

    psi_arrays = list()
    psi_arrays.append(mps_psi.get_B(0, 'Th')[0, ...].to_ndarray())
    for i in range(1, mps_psi.L-1):
        psi_arrays.append(mps_psi.get_B(i).to_ndarray())
    psi_arrays.append(mps_psi.get_B(mps_psi.L-1)[..., 0].to_ndarray())

    q1 = (
        qtn
        .tensor_1d
        .MatrixProductState(psi_arrays, shape='lpr')
    )

    for k2, bs in enumerate(symmetry_actions[1:], start=1):
        symmetry_site_pairs = (
            [(i, bs[0]) for i in range(left_most_symmetry_site, left_most_symmetry_site+num_symmetry_sites, 2)]
            + [(i, bs[1]) for i in range(left_most_symmetry_site+1, left_most_symmetry_site+num_symmetry_sites+1, 2)]
        )

        problem_rdm = generate_problem_rdm(
            q1,
            symmetry_site_pairs,
            left_most_symmetry_site,
            num_symmetry_sites,
            num_boundary_sites
        )

        seed_number = 0
        new_seed_needed=True
        while new_seed_needed:
            mpo = initialize_mpo(
                left_most_symmetry_site,
                num_symmetry_sites,
                num_boundary_sites,
                bond_dimension
            )
        
            optmzr = qtn.optimize.TNOptimizer(
                mpo,                                
                loss_fn=overall_loss_function,
                loss_kwargs={
                    'rdm_tn': problem_rdm,
                    'total_physical_dimension': total_physical_dim,
                    'unitary_cost_coefficient': 1,
                    'overlap_cost_coefficient': overlap_learning_rate
                },
                autodiff_backend='jax',
                optimizer='cadam',
            )

            optmzr.optimize(
                num_iterations,
                learning_rate=alpha,
                beta1=beta_1,
                beta2=beta_2,
                tol=score_target
            )

            final_score = jnp.min(np.array(optmzr.losses))

            if (seed_number == new_seed_needed - 1) or (final_score < score_target):
                new_seed_needed = False

            seed_number += 1

            file_name = rf'solutions/{k1}_{k2}_{seed_number}.pickle'
            print('Saving: ' + file_name)

            with open(file_name, 'wb') as file:
                pickle.dump([final_score, optmzr.get_tn_opt()], file)

+0.003898471827 [best: +0.001938064466] : : 5001it [00:07, 625.70it/s]                                                                                                            


Saving: solutions/0.0_1_1.pickle


+0.006448279601 [best: +0.001184285968] : : 5001it [00:07, 629.74it/s]                                                                                                            


Saving: solutions/0.0_2_1.pickle


+0.004438123666 [best: +0.001424031332] : : 5001it [00:07, 630.97it/s]                                                                                                            


Saving: solutions/0.0_3_1.pickle


+0.005682895891 [best: +0.002508692676] : : 5001it [00:08, 592.70it/s]                                                                                                            


Saving: solutions/0.1_1_1.pickle


+0.003970586229 [best: +0.001492828131] : : 5001it [00:08, 591.48it/s]                                                                                                            


Saving: solutions/0.1_2_1.pickle


+0.007192279678 [best: +0.002078964375] : : 5001it [00:08, 590.32it/s]                                                                                                            


Saving: solutions/0.1_3_1.pickle


+0.004022007342 [best: +0.003442064160] : : 5001it [00:08, 595.05it/s]                                                                                                            


Saving: solutions/0.2_1_1.pickle


+0.008704075590 [best: +0.006880163681] : : 5001it [00:08, 590.88it/s]                                                                                                            


Saving: solutions/0.2_2_1.pickle


+0.005190647673 [best: +0.004386923276] : : 5001it [00:08, 596.82it/s]                                                                                                            


Saving: solutions/0.2_3_1.pickle


+0.017653260380 [best: +0.016506722197] : : 5001it [00:08, 595.04it/s]                                                                                                            


Saving: solutions/0.3_1_1.pickle


+0.023728003725 [best: +0.021521393210] : : 5001it [00:08, 570.51it/s]                                                                                                            


Saving: solutions/0.3_2_1.pickle


+0.031231982633 [best: +0.027453035116] : : 5001it [00:08, 577.44it/s]                                                                                                            


Saving: solutions/0.3_3_1.pickle


+0.040320117027 [best: +0.039184376597] : : 5001it [00:08, 567.02it/s]                                                                                                            


Saving: solutions/0.4_1_1.pickle


+0.026783876121 [best: +0.017438109964] : : 5001it [00:08, 560.05it/s]                                                                                                            


Saving: solutions/0.4_2_1.pickle


+0.053051672876 [best: +0.051368758082] : : 5001it [00:08, 572.22it/s]                                                                                                            


Saving: solutions/0.4_3_1.pickle


+0.022013479844 [best: +0.021134942770] : : 5001it [00:09, 548.75it/s]                                                                                                            


Saving: solutions/0.5_1_1.pickle


+0.023339021951 [best: +0.022112097591] : : 5001it [00:08, 559.23it/s]                                                                                                            


Saving: solutions/0.5_2_1.pickle


+0.040078446269 [best: +0.038137994707] : : 5001it [00:09, 516.64it/s]                                                                                                            


Saving: solutions/0.5_3_1.pickle


+0.081430822611 [best: +0.080513983965] : : 5001it [00:09, 547.07it/s]                                                                                                            


Saving: solutions/0.6_1_1.pickle


+0.053346335888 [best: +0.051973123103] : : 5001it [00:09, 536.88it/s]                                                                                                            


Saving: solutions/0.6_2_1.pickle


+0.085760399699 [best: +0.080425724387] : : 5001it [00:09, 509.96it/s]                                                                                                            


Saving: solutions/0.6_3_1.pickle


+0.139464452863 [best: +0.138883680105] : : 5001it [00:09, 504.24it/s]                                                                                                            


Saving: solutions/0.7_1_1.pickle


+0.044113557786 [best: +0.044113557786] : : 5001it [00:09, 540.02it/s]                                                                                                            


Saving: solutions/0.7_2_1.pickle


+0.071526318789 [best: +0.070974737406] : : 5001it [00:09, 549.57it/s]                                                                                                            


Saving: solutions/0.7_3_1.pickle


+0.173008039594 [best: +0.172082111239] : : 5001it [00:09, 546.57it/s]                                                                                                            


Saving: solutions/0.8_1_1.pickle


+0.093530736864 [best: +0.088498041034] : : 5001it [00:09, 524.71it/s]                                                                                                            


Saving: solutions/0.8_2_1.pickle


+0.216257795691 [best: +0.215820282698] : : 5001it [00:11, 439.05it/s]                                                                                                            


Saving: solutions/0.8_3_1.pickle


+0.144475668669 [best: +0.143489629030] : : 5001it [00:10, 472.13it/s]                                                                                                            


Saving: solutions/0.9_1_1.pickle


+0.269100278616 [best: +0.269100278616] : : 5001it [00:10, 461.70it/s]                                                                                                            


Saving: solutions/0.9_2_1.pickle


+0.607961535454 [best: +0.606830894947] : : 5001it [00:09, 528.11it/s]                                                                                                            


Saving: solutions/0.9_3_1.pickle


+9.829246520996 [best: +9.823372840881] : : 5001it [00:06, 719.28it/s]                                                                                                            


Saving: solutions/1.0_1_1.pickle


+0.404903799295 [best: +0.404817163944] : : 5001it [00:07, 690.88it/s]                                                                                                            


Saving: solutions/1.0_2_1.pickle


+9.981700897217 [best: +9.978163719177] : : 5001it [00:06, 801.89it/s]                                                                                                            


Saving: solutions/1.0_3_1.pickle


+0.243786871433 [best: +0.239970430732] : : 5001it [00:06, 761.26it/s]                                                                                                            


Saving: solutions/1.1_1_1.pickle


+0.232743993402 [best: +0.230619132519] : : 5001it [00:06, 731.73it/s]                                                                                                            


Saving: solutions/1.1_2_1.pickle


+0.640667378902 [best: +0.635825872421] : : 5001it [00:08, 555.88it/s]                                                                                                            


Saving: solutions/1.1_3_1.pickle


+0.094384580851 [best: +0.093509227037] : : 5001it [00:09, 531.39it/s]                                                                                                            


Saving: solutions/1.2_1_1.pickle


+0.091952055693 [best: +0.088064789772] : : 5001it [00:11, 428.83it/s]                                                                                                            


Saving: solutions/1.2_2_1.pickle


+0.250769913197 [best: +0.247783511877] : : 5001it [00:11, 439.74it/s]                                                                                                            


Saving: solutions/1.2_3_1.pickle


+0.075787790120 [best: +0.074904799461] : : 5001it [00:10, 467.82it/s]                                                                                                            


Saving: solutions/1.3_1_1.pickle


+0.087012603879 [best: +0.079761594534] : : 5001it [00:11, 429.55it/s]                                                                                                            


Saving: solutions/1.3_2_1.pickle


+0.147135227919 [best: +0.146109163761] : : 5001it [00:10, 491.38it/s]                                                                                                            


Saving: solutions/1.3_3_1.pickle


+0.057839706540 [best: +0.057208616287] : : 5001it [00:08, 585.18it/s]                                                                                                            


Saving: solutions/1.4_1_1.pickle


+0.090938776731 [best: +0.088450044394] : : 5001it [00:08, 600.72it/s]                                                                                                            


Saving: solutions/1.4_2_1.pickle


+0.075473308563 [best: +0.071600437164] : : 5001it [00:08, 601.88it/s]                                                                                                            


Saving: solutions/1.4_3_1.pickle


+0.041450724006 [best: +0.040829345584] : : 5001it [00:08, 576.68it/s]                                                                                                            


Saving: solutions/1.5_1_1.pickle


+0.053668156266 [best: +0.051601655781] : : 5001it [00:08, 594.13it/s]                                                                                                            


Saving: solutions/1.5_2_1.pickle


+0.093356281519 [best: +0.091599509120] : : 5001it [00:08, 595.56it/s]                                                                                                            


Saving: solutions/1.5_3_1.pickle


+0.043461836874 [best: +0.042221747339] : : 5001it [00:09, 522.84it/s]                                                                                                            


Saving: solutions/1.6_1_1.pickle


+0.057756900787 [best: +0.055823918432] : : 5001it [00:09, 538.46it/s]                                                                                                            


Saving: solutions/1.6_2_1.pickle


+0.053781963885 [best: +0.052410818636] : : 5001it [00:10, 489.35it/s]                                                                                                            


Saving: solutions/1.6_3_1.pickle


+0.050040226430 [best: +0.042611695826] : : 5001it [00:09, 511.99it/s]                                                                                                            


Saving: solutions/1.7_1_1.pickle


+0.058542393148 [best: +0.058542393148] : : 5001it [00:09, 545.85it/s]                                                                                                            


Saving: solutions/1.7_2_1.pickle


+0.056787073612 [best: +0.051851190627] : : 5001it [00:08, 565.28it/s]                                                                                                            


Saving: solutions/1.7_3_1.pickle


+0.054767098278 [best: +0.050603866577] : : 5001it [00:08, 559.51it/s]                                                                                                            


Saving: solutions/1.8_1_1.pickle


+0.033067487180 [best: +0.032013725489] : : 5001it [00:08, 558.93it/s]                                                                                                            


Saving: solutions/1.8_2_1.pickle


+0.043604157865 [best: +0.042830966413] : : 5001it [00:08, 571.34it/s]                                                                                                            


Saving: solutions/1.8_3_1.pickle


+0.053153976798 [best: +0.048554681242] : : 5001it [00:08, 556.69it/s]                                                                                                            


Saving: solutions/1.9_1_1.pickle


+0.056237153709 [best: +0.054473247379] : : 5001it [00:08, 559.12it/s]                                                                                                            


Saving: solutions/1.9_2_1.pickle


+0.035695116967 [best: +0.030741978437] : : 5001it [00:08, 560.69it/s]                                                                                                            


Saving: solutions/1.9_3_1.pickle


+0.043121814728 [best: +0.038444042206] : : 5001it [00:08, 583.08it/s]                                                                                                            


Saving: solutions/2.0_1_1.pickle


+0.051157221198 [best: +0.050833027810] : : 5001it [00:08, 555.96it/s]                                                                                                            


Saving: solutions/2.0_2_1.pickle


+0.033654704690 [best: +0.032682918012] : : 5001it [00:08, 557.63it/s]                                                                                                            


Saving: solutions/2.0_3_1.pickle
