# Transfer matrix search class

Created 31/05/2024

Objectives:
* Implement an initial class for transfer matrix search optimization, to eventually be ported off to a .py file.
* Just for qubit sites for now.

# Package imports

In [12]:
from functools import reduce

import numpy as np

In [11]:
import h5py
from tenpy.tools import hdf5_io
import tenpy
import tenpy.linalg.np_conserved as npc

import os

In [13]:
from spt_classification import (
    get_transfer_matrix_from_unitary,
    get_transfer_matrices_from_unitary_list,
    multiply_transfer_matrices,
    to_npc_array
)

from super_fibonacci import super_fibonacci

In [14]:
np_I = np.array([[1,0],[0,1]])
np_X = np.array([[0,1],[1,0]])
np_Y = np.array([[0,-1j],[1j,0]])
np_Z = np.array([[1,0],[0,-1]])

In [15]:
base_unitaries = np.array([np_I, 1j*np_X, 1j*np_Y, 1j*np_Z])

In [16]:
def s3_to_unitary(p):
    X = p[0]*np_I + 1j*(p[1]*np_X + p[2]*np_Y + p[3]*np_Z)
    return X

In [17]:
def operator_norm(m):
    singular_values = np.linalg.svd(m).S
    singular_values_counts = Counter(np.round(singular_values, 3))

    norm = max(singular_values_counts.keys())
    count = singular_values_counts[norm]

    return (norm, count)

In [18]:
def get_left_environment(psi, index):
    left_leg = psi.get_B(index).legs[0]
    SL = npc.diag(psi.get_SL(index), left_leg, labels = ['vL', 'vR'])
    left_environment = (
        npc.tensordot(SL, SL.conj(), (['vL',], ['vL*',]))
        .combine_legs([['vR', 'vR*'],])
        .to_ndarray()
    )

    return left_environment

In [219]:
class TransferMatrixSearch:
    def __init__(self, psi, symmetry_operations, index=None,
                 max_num_virtual_points=10000, num_search_points=1000):
        self.psi = psi
        self.symmetry_operations = symmetry_operations

        if index is None:
            self.left_symmetry_index = (self.psi.L - len(self.symmetry_operations))//2 
        else:
            self.left_symmetry_index = index
        self.right_symmetry_index = self.left_symmetry_index + len(self.symmetry_operations) - 1

        self.symmetry_transfer_matrices = (
            get_transfer_matrices_from_unitary_list(
                self.psi,
                self.symmetry_operations,
                self.left_symmetry_index
            )
        )

        self.npc_symmetry_transfer_matrix = reduce(
            multiply_transfer_matrices,
            self.symmetry_transfer_matrices
        )

        self.np_symmetry_transfer_matrix = (
            self.npc_symmetry_transfer_matrix
            .combine_legs([['vL', 'vL*'], ['vR', 'vR*']])
            .to_ndarray()
        )

        U, S, Vh = np.linalg.svd(self.np_symmetry_transfer_matrix)
        assert S[0]/np.sum(S) > 0.9999

        self.left_projected_symmetry_state = U[:,0]
        self.right_projected_symmetry_state = Vh[0]
        self.symmetry_transfer_matrix_op_norm = S[0]

        self.right_virtual_points = list()
        self.left_virtual_points = list()
        self.right_s3_points = list()
        self.left_s3_points = list()
        self.right_ovelraps = list()
        self.left_overlaps = list()

        self.right_max_overlap = 0
        self.left_max_overlap = 0
        self.max_overlap = 0

        self.right_max_s3_points = None
        self.left_max_s3_points = None

        self.max_num_virtual_points = max_num_virtual_points
        self.num_search_points = num_search_points

        s3_points = super_fibonacci(2*num_search_points)
        self.s3_search_points = s3_points[s3_points[:, 0] >=  0]
        self.num_s3_search_points = len(self.s3_search_points)

        self.current_right_depth = 0
        self.current_left_depth = 0

    def update_max_overlap(self):
        self.max_overlap = (
            self.symmetry_transfer_matrix_op_norm
            *self.right_max_overlap
            *self.left_max_overlap
        )

    def search_step_right(self):
        previous_depth = self.current_right_depth
        self.current_right_depth += 1

        site_index = self.right_symmetry_index + self.current_right_depth

        bond_dimension = self.psi.chi[site_index]

        right_environment = np.identity(bond_dimension).reshape((bond_dimension**2,))
        
        base_transfer_matrices = np.array([
            get_transfer_matrix_from_unitary(self.psi, u, site_index)
            .combine_legs([['vL', 'vL*'], ['vR', 'vR*']])
            .to_ndarray()
            for u in base_unitaries
        ])

        if self.current_right_depth == 1:
            previous_points = self.right_projected_symmetry_state[np.newaxis, :]
        elif self.current_right_depth > 1:
            previous_points = self.right_virtual_points[previous_depth-1]

        base_vectors = np.matmul(previous_points, base_transfer_matrices)
        base_overlaps = np.dot(base_vectors, right_environment)

        overlaps = np.abs(np.tensordot(self.s3_search_points, base_overlaps, [[1,], [0,]]))
        target_percentile = 100.0*(1.0 - min(1, self.max_num_virtual_points/(overlaps.size)))
        overlap_threshold = np.percentile(overlaps, target_percentile)

        overlaps_filter = (overlaps > overlap_threshold)

        all_next_points = np.tensordot(
            self.s3_search_points,
            base_vectors,
            [[1,], [0,]]
        )

        if self.current_right_depth == 1:
            assert previous_points.shape[0] == 1
            all_next_s3_points = self.s3_search_points[:, np.newaxis, np.newaxis, :]
        elif self.current_right_depth > 1:
            prev_s3_points = self.right_s3_points[previous_depth-1]
            prev_num_s3_points = prev_s3_points.shape[0]

            all_next_s3_points = np.zeros(
                (
                    self.num_s3_search_points,
                    prev_num_s3_points,
                    self.current_right_depth,
                    4
                )
            )

            all_next_s3_points[:, :, :-1, :] = prev_s3_points[np.newaxis, ...]
            all_next_s3_points[:, :, -1, :] = self.s3_search_points[:, np.newaxis, :]

        filtered_next_points = np.reshape(
            all_next_points[overlaps_filter],
            (-1, bond_dimension**2)
        )
        filtered_next_s3_points = np.reshape(
            all_next_s3_points[overlaps_filter],
            (-1, self.current_right_depth, 4)
        )
        filtered_overlaps = overlaps[overlaps_filter].flatten()

        self.right_virtual_points.append(filtered_next_points)
        self.right_s3_points.append(filtered_next_s3_points)
        self.right_ovelraps.append(filtered_overlaps)

        max_overlap = np.max(filtered_overlaps)
        if max_overlap > self.right_max_overlap:
            self.right_max_overlap = max_overlap
            max_arg = np.argmax(filtered_overlaps)
            self.right_max_s3_points = filtered_next_s3_points[max_arg]
            self.update_max_overlap()

    def search_step_left(self):
        previous_depth = self.current_left_depth
        self.current_left_depth += 1

        site_index = self.left_symmetry_index - self.current_left_depth

        bond_dimension = self.psi.chi[site_index]

        left_environment = get_left_environment(self.psi, site_index)
        
        base_transfer_matrices = np.array([
            get_transfer_matrix_from_unitary(self.psi, u, site_index)
            .combine_legs([['vL', 'vL*'], ['vR', 'vR*']])
            .to_ndarray()
            for u in base_unitaries
        ])

        if self.current_left_depth == 1:
            previous_points = self.left_projected_symmetry_state[np.newaxis, :]
        elif self.current_left_depth > 1:
            previous_points = self.left_virtual_points[previous_depth-1]

        base_vectors = np.tensordot(
            previous_points,
            base_transfer_matrices,
            [[-1,], [2,]]
        )
        base_overlaps = np.dot(base_vectors, left_environment)

        overlaps = np.abs(np.tensordot(self.s3_search_points, base_overlaps, [[1,], [1,]]))
        target_percentile = 100.0*(1.0 - min(1, self.max_num_virtual_points/(overlaps.size)))
        overlap_threshold = np.percentile(overlaps, target_percentile)

        overlaps_filter = (overlaps > overlap_threshold)

        all_next_points = np.tensordot(
            self.s3_search_points,
            base_vectors,
            [[1,], [1,]]
        )

        if self.current_left_depth == 1:
            assert previous_points.shape[0] == 1
            all_next_s3_points = self.s3_search_points[:, np.newaxis, np.newaxis, :]
        elif self.current_left_depth > 1:
            prev_s3_points = self.left_s3_points[previous_depth-1]
            prev_num_s3_points = prev_s3_points.shape[0]

            all_next_s3_points = np.zeros(
                (
                    self.num_s3_search_points,
                    prev_num_s3_points,
                    self.current_left_depth,
                    4
                )
            )
            
            all_next_s3_points[:, :, :-1, :] = prev_s3_points[np.newaxis, ...]
            all_next_s3_points[:, :, -1, :] = self.s3_search_points[:, np.newaxis, :]

        filtered_next_points = np.reshape(
            all_next_points[overlaps_filter],
            (-1, bond_dimension**2)
        )
        filtered_next_s3_points = np.reshape(
            all_next_s3_points[overlaps_filter],
            (-1, self.current_left_depth, 4)
        )
        filtered_overlaps = overlaps[overlaps_filter].flatten()

        self.left_virtual_points.append(filtered_next_points)
        self.left_s3_points.append(filtered_next_s3_points)
        self.left_overlaps.append(filtered_overlaps)

        max_overlap = np.max(filtered_overlaps)
        if max_overlap > self.left_max_overlap:
            self.left_max_overlap = max_overlap
            max_arg = np.argmax(filtered_overlaps)
            self.left_max_s3_points = filtered_next_s3_points[max_arg]
            self.update_max_overlap()

# Load data

In [19]:
DATA_DIR = r"data/transverse_cluster_200_site_dmrg"

In [20]:
loaded_data = list()

for local_file_name in os.listdir(DATA_DIR):
    f_name = r"{}/{}".format(DATA_DIR, local_file_name, ignore_unknown=False)
    with h5py.File(f_name, 'r') as f:
        data = hdf5_io.load_from_hdf5(f)
        loaded_data.append(data)

In [21]:
b_parameters = sorted(list(d['paramters']['B'] for d in loaded_data))

In [22]:
psi_dict = dict()

In [23]:
for b in b_parameters:
    psi = next(
        d['wavefunction']
        for d in loaded_data
        if d['paramters']['B'] == b
    )

    rounded_b = round(b, 1)
    psi_dict[rounded_b] = psi

In [24]:
list(psi_dict)

[0.0,
 0.1,
 0.2,
 0.3,
 0.4,
 0.5,
 0.6,
 0.7,
 0.8,
 0.9,
 1.0,
 1.1,
 1.2,
 1.3,
 1.4,
 1.5,
 1.6,
 1.7,
 1.8,
 1.9,
 2.0]

In [25]:
test_psi = psi_dict[0.5]

# Testing

In [185]:
test_search = TransferMatrixSearch(test_psi, [np_X, np_I]*30)

In [186]:
test_search.current_right_depth

0

In [187]:
test_search.search_step_right()

In [188]:
test_search.current_right_depth

1

In [189]:
test_search.right_virtual_points[0].shape

(1002, 64)

In [190]:
test_search.search_step_left()

In [191]:
test_search.max_overlap

5.3802909910866957e-11

In [192]:
test_search.num_s3_search_points

1003

In [193]:
test_search.current_right_depth

1

In [194]:
test_search.search_step_right()

In [195]:
test_search.search_step_left()

In [196]:
test_search.max_overlap

0.922647511813931

In [197]:
test_search.current_right_depth

2

## Via expecation
Should make a function for this.
(Ideally need to make a "solution" class...)

In [198]:
r_unitaries = test_search.right_max_s3_points

In [199]:
r_unitaries

array([[ 0.02476986, -0.99731462, -0.06574199,  0.02068793],
       [ 0.0558371 ,  0.05130515,  0.01280051, -0.99703869]])

In [200]:
l_unitaries = test_search.left_max_s3_points

In [201]:
l_unitaries

array([[ 0.0558371 ,  0.05130515,  0.01280051, -0.99703869]])

In [202]:
test_search.left_max_overlap

0.3400574864005279

In [203]:
test_search.right_max_overlap

1.357477473882176

In [204]:
test_search.max_overlap

0.922647511813931

In [206]:
r_unitary_1 = s3_to_unitary(r_unitaries[0])
r_unitary_2 = s3_to_unitary(r_unitaries[1])

In [172]:
l_unitary = s3_to_unitary(l_unitaries[0])

In [207]:
np_operators = [l_unitary, *([np_X, np_I]*30), r_unitary_1, r_unitary_2]

In [208]:
operators = [to_npc_array(X) for X in np_operators]

In [209]:
expectation = test_psi.expectation_value_multi_sites(operators, (test_psi.L//2)-30-1)

In [210]:
np.abs(expectation)

0.9226475118139226

Agreement!

In [220]:
test_search = TransferMatrixSearch(test_psi, [np_X, np_I]*30)

In [223]:
test_search.search_step_right()
test_search.right_max_overlap

1.357477473882176

In [227]:
test_search.search_step_left()
test_search.left_max_overlap

0.3400574864005279

In [228]:
test_search.current_right_depth

3

In [229]:
test_search.right_virtual_points[0].shape

(1002, 64)

In [230]:
test_search.search_step_left()

In [231]:
test_search.max_overlap

0.922647511813931

_Still_ getting stuck in local optima... Should try out alternate sampling strategies.
Should also try running other optimization strategies on this.