# QUIMB save boundary operators
Created 26/09/2024

Objectives:
* Use QUIMB to solve for the boundary operator unitary circuits on ED states.

# Package imports

In [1]:
import sys

In [2]:
sys.path.append("../")

In [3]:
from itertools import chain
import re

In [4]:
import h5py
from tenpy.tools import hdf5_io
import tenpy
import tenpy.linalg.np_conserved as npc

import os
import pickle

In [5]:
import numpy as np
import jax.numpy as jnp

import matplotlib.pyplot as plt

In [6]:
import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import TNOptimizer



# Load data

In [7]:
DATA_DIR = r"../../../../data/transverse_cluster_14_site_ed"

In [8]:
def split_full_state_legs(full_state_psi):
    dim = full_state_psi.shape

    num_legs = int(np.log2(dim[0]))
    new_labels = [f'p{i}' for i in range(num_legs)]

    full_state_psi.iset_leg_labels(['(' + '.'.join(new_labels) + ')'])
    full_state_psi = full_state_psi.split_legs([0])
    return full_state_psi

In [9]:
loaded_data = list()

for local_file_name in os.listdir(DATA_DIR):
    f_name = r"{}/{}".format(DATA_DIR, local_file_name, ignore_unknown=False)
    with h5py.File(f_name, 'r') as f:
        data = hdf5_io.load_from_hdf5(f)
        loaded_data.append(data)

In [10]:
b_parameters = sorted(list(d['paramters']['B'] for d in loaded_data))

In [11]:
psi_dict = dict()

In [12]:
for b in b_parameters:
    psi = next(
        d['wavefunction']
        for d in loaded_data
        if d['paramters']['B'] == b
    )
    #psi = split_full_state_legs(psi)
    
    rounded_b = round(b, 1)
    psi_dict[rounded_b] = psi

# Definitions

In [13]:
np_I = np.array([[1,0],[0,1]])
np_X = np.array([[0,1],[1,0]])

In [14]:
symmetry_actions = [
    [np_I, np_I],
    [np_I, np_X],
    [np_X, np_I],
    [np_X, np_X]
]

In [15]:
symmetry_labels = [
    'II',
    'IX',
    'XI',
    'XX'
]

In [16]:
def generate_problem_rdm(quimb_psi, symmetry_site_pairs, leftmost_symmetry_site,
                         num_symmetry_sites, num_boundary_sites):
    q_top = quimb_psi.copy(deep=True)
    for i, s in symmetry_site_pairs:
        q_top.gate(
            s,
            where=i,
            contract=False,
            inplace=True
        )

    
    indices_to_map = list(chain(
        range(leftmost_symmetry_site-num_boundary_sites, leftmost_symmetry_site),
        range(leftmost_symmetry_site+num_symmetry_sites, leftmost_symmetry_site+num_symmetry_sites+num_boundary_sites)
    ))

    index_mapping = {f'k{i}': f'b{i}' for i in indices_to_map}

    q_bottom = (
        quimb_psi
        .copy()
        .reindex(index_mapping, inplace=True)
        .conj()
    )

    sites_to_contract = {
        'left': list(range(leftmost_symmetry_site-num_boundary_sites)),
        'middle': list(range(leftmost_symmetry_site, leftmost_symmetry_site+num_symmetry_sites)),
        'right': list(range(leftmost_symmetry_site+num_symmetry_sites+num_boundary_sites, quimb_psi.L))
    }

    tags_to_contract = {
        k: [f'I{i}' for i in v]
        for k, v in sites_to_contract.items()
    }

    tn = (q_top & q_bottom)

    """
    tnc = (
        tn
        .contract(tags_to_contract['left'])
        .contract(tags_to_contract['middle'])
        .contract(tags_to_contract['right'])
    )
    """

    tnc = tn.contract()

    return tnc

## Optimisation functions

In [17]:
def loss(circ_pair, problem_rdm):
    c = (problem_rdm & circ_pair)^...

    c_abs_squared = (
        c
        *jnp.conjugate(c)
    )
    c_abs_squared = jnp.real(c_abs_squared)

    out = 1-c_abs_squared

    return out

## Gate & circuit functions

In [18]:
def single_qubit_layer(circ, gate_round=None):
    """Apply a parametrizable layer of single qubit ``U3`` gates.
    """
    for i in range(circ.N):
        # initialize with random parameters
        params = qu.randn(3, dist='uniform')
        circ.apply_gate(
            'U3',
            *params,
            i,
            gate_round=gate_round,
            parametrize=True
        )

In [19]:
def two_qubit_layer(circ, gate2='CZ', start=0, gate_round=None):
    """Apply a layer of constant entangling gates.
    """
    num_sites = circ.N # Assuming this is even
    sites = [
        i % num_sites
        for i in range(start, start + num_sites)
    ]
    
    site_pairs = list(zip(sites[::2], sites[1::2]))

    for i, j in site_pairs:
        circ.apply_gate(
            gate2,
            i,
            j,
            gate_round=gate_round
        )

In [20]:
def ansatz_circuit(n, depth, first_site=0, gate2='CZ',
                   gate_tag=None, **kwargs):
    """Construct a circuit of single qubit and entangling layers.
    """
    circ = qtn.Circuit(n, **kwargs)

    for r in range(depth):
        # single qubit gate layer
        single_qubit_layer(circ, gate_round=r)

        # alternate between forward and backward CZ layers
        two_qubit_layer(
            circ, gate2=gate2, gate_round=r, start=r % 2)

    # add a final single qubit layer
    single_qubit_layer(circ, gate_round=r + 1)

    circ = circ.get_uni(transposed=True)

    if gate_tag is not None:
        for t in circ.tensors:
            t.add_tag(gate_tag)

    if first_site != 0:
        index_labels = ['k', 'b']

        index_map = {
            f'{l}{i}': f'{l}{i+first_site}'
            for i in range(n)
            for l in index_labels
        }

        circ.reindex(index_map, inplace=True)

    return circ

In [21]:
def ansatz_circuit_pair(leftmost_symmetry_site,
    num_symmetry_sites, num_boundary_sites, depth,
    gate2='CZ', **kwargs):
    """Construct a circuit of single qubit and entangling layers.
    """
    left_start_site = leftmost_symmetry_site-num_boundary_sites
    right_start_site = leftmost_symmetry_site+num_symmetry_sites

    left_circuit = ansatz_circuit(
        num_boundary_sites,
        depth,
        first_site=left_start_site,
        gate2=gate2,
        gate_tag='left',
        **kwargs
    )

    right_circuit = ansatz_circuit(
        num_boundary_sites,
        depth,
        first_site=right_start_site,
        gate2=gate2,
        gate_tag='right',
        **kwargs
    )

    circ_pair = left_circuit & right_circuit

    return circ_pair

In [22]:
def split_circ_pair(circ_pair):
    left_circuit = qtn.TensorNetwork(
        list(map(circ_pair.tensor_map.__getitem__, circ_pair.tag_map['left']))
    )

    right_circuit = qtn.TensorNetwork(
        list(map(circ_pair.tensor_map.__getitem__, circ_pair.tag_map['right']))
    )

    return (left_circuit, right_circuit)

# Loop and solve for boundary operators

In [23]:
num_boundary_sites=2
left_most_symmetry_site=4
num_symmetry_sites=6

total_physical_dim = 2**num_boundary_sites

In [24]:
n=2
depth=2
gate2='CZ'

In [25]:
num_iterations=300

In [26]:
for k1, mps_psi in psi_dict.items():
    print(k1)
    q1 = qtn.Dense1D(
        mps_psi.to_ndarray(),
        phys_dim=2
    )

    for k2, bs in enumerate(symmetry_actions[1:], start=1):
        print(f'-> {k2}')
        symmetry_site_pairs = (
            [(i, bs[0]) for i in range(left_most_symmetry_site, left_most_symmetry_site+num_symmetry_sites, 2)]
            + [(i, bs[1]) for i in range(left_most_symmetry_site+1, left_most_symmetry_site+num_symmetry_sites+1, 2)]
        )

        problem_rdm = generate_problem_rdm(
            q1,
            symmetry_site_pairs,
            left_most_symmetry_site,
            num_symmetry_sites,
            num_boundary_sites
        )

        circ_pair = ansatz_circuit_pair(
            left_most_symmetry_site,
            num_symmetry_sites,
            num_boundary_sites,
            depth
        )
    
        optmzr = qtn.optimize.TNOptimizer(
            circ_pair,                                
            loss_fn=loss,
            loss_constants={'problem_rdm': problem_rdm},
            #autodiff_backend='jax',
            tags=['U3'],
            optimizer='COBYLA',
            progbar=False
        )

        optmzr.optimize(
            num_iterations,
            jac=False,
            hessp=False
        )

        final_score = jnp.min(np.array(optmzr.losses))

        circ_pair = optmzr.get_tn_opt()
        cl, cr = split_circ_pair(circ_pair)
        circ_params = (
            cl.get_params(),
            cr.get_params()
        )

        file_name = rf'solutions/{k1}_{k2}.pickle'
        print('Saving: ' + file_name)

        with open(file_name, 'wb') as file:
            pickle.dump([final_score, circ_params], file)

0.0
-> 1
Saving: solutions/0.0_1.pickle
-> 2
Saving: solutions/0.0_2.pickle
-> 3
Saving: solutions/0.0_3.pickle
0.1
-> 1
Saving: solutions/0.1_1.pickle
-> 2
Saving: solutions/0.1_2.pickle
-> 3
Saving: solutions/0.1_3.pickle
0.2
-> 1
Saving: solutions/0.2_1.pickle
-> 2
Saving: solutions/0.2_2.pickle
-> 3
Saving: solutions/0.2_3.pickle
0.3
-> 1
Saving: solutions/0.3_1.pickle
-> 2
Saving: solutions/0.3_2.pickle
-> 3
Saving: solutions/0.3_3.pickle
0.4
-> 1
Saving: solutions/0.4_1.pickle
-> 2
Saving: solutions/0.4_2.pickle
-> 3
Saving: solutions/0.4_3.pickle
0.5
-> 1
Saving: solutions/0.5_1.pickle
-> 2
Saving: solutions/0.5_2.pickle
-> 3
Saving: solutions/0.5_3.pickle
0.6
-> 1
Saving: solutions/0.6_1.pickle
-> 2
Saving: solutions/0.6_2.pickle
-> 3
Saving: solutions/0.6_3.pickle
0.7
-> 1
Saving: solutions/0.7_1.pickle
-> 2
Saving: solutions/0.7_2.pickle
-> 3
Saving: solutions/0.7_3.pickle
0.8
-> 1
Saving: solutions/0.8_1.pickle
-> 2
Saving: solutions/0.8_2.pickle
-> 3
Saving: solutions/0.8_3