In [158]:
"""Construct CAS Hamiltonians with cropping
"""
import saveload_utils as sl
import ferm_utils as feru
import csa_utils as csau
import var_utils as varu
import openfermion as of
import numpy as np
from sdstate import *
from itertools import product
import random
import h5py
import sys
from matrix_utils import construct_orthogonal
import pickle
import tensorflow as tf

### Parameters
tol = 1e-5
balance_strength = 2
save = False
# Number of spatial orbitals in a block
block_size = 4
# Number of electrons per block
ne_per_block = 4
# +- difference in number of electrons per block
ne_range = 2
# Number of killer operators for each CAS block
n_killer = 3
# Running Full CI to check compute the ground state, takes exponentially amount of time to execute
FCI = False
# Checking symmetries of the planted Hamiltonian, very costly
check_symmetry = False
# Concatenate the states in each block to compute the ground state solution
check_state = False
# File path and name
path = "hamiltonians_catalysts/"
file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"

def construct_blocks(b: int, spin_orbs: int, spin_orb = False):
    """Construct CAS blocks of size b for spin_orbs/spatial number of orbitals"""
    if spin_orb:
        b = b * 2
    k = []
    tmp = [0]
    for i in range(1, spin_orbs):
        if i % b == 0:
            k.append(tmp)
            tmp = [i]
        else:
            tmp.append(i)
    if len(tmp) != 0:
        k.append(tmp)
    return k
        
def get_truncated_cas_tbt(H, k, casnum):
#     Trunctate the original Hamiltonian two body tensor into the cas block structures
    Hobt, Htbt = H
    n = Htbt.shape[0]
    cas_tbt = np.zeros([n, n, n, n])
    cas_obt = np.zeros([n, n])
    cas_x = np.zeros(casnum)
    idx = 0
    for block in k:
        for p, q in product(block, repeat = 2):
            cas_obt[p,q] = Hobt[p,q]
            cas_x[idx] = Hobt[p,q]
            idx += 1
        for a, b, c, d in product(block, repeat = 4):
            cas_tbt [a,b,c,d] = Htbt [a,b,c,d]
            cas_x[idx] = Htbt[a,b,c,d]
            idx += 1
    return cas_obt, cas_tbt, cas_x

def in_orbs(term, orbs):
    """Return if the term is a local excitation operator within orbs"""
    if len(term) == 2:
        return term[0][0] in orbs and term[1][0] in orbs
    elif len(term) == 4:
        return term[0][0] in orbs and term[1][0] in orbs and term[2][0] in orbs and term[3][0] in orbs
    return False

# def transform_orbs(term, orbs):
#     """Transform the operator term to align the orbs starting from 0"""
# #     pass
#     if len(term) == 2:
#         return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0))
#     if len(term) == 4:
#         return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0), 
#                (orbs.index(term[2][0]), 1), (orbs.index(term[3][0]), 0))   
#     return None

def solve_enums(H, k, ne_per_block = 0, ne_range = 0, balance_t = 10):
    """Solve for number of electrons in each CAS block with FCI within the block,
    H = (obt, tbt) as the Hamiltonian in spatial orbitals. 
    Notice that some quadratic terms (Ne-ne)^2 are added to ensure the correct number
    of electrons in the ground state of each block
    """ 
    cas_obt = H[0]
    cas_tbt = H[1]
    e_nums = []
    states = []
    E_cas = 0
    for orbs in k:
        s = orbs[0]
        t = orbs[-1] + 1
        norbs = len(orbs)
        ne = min(ne_per_block + random.randint(-ne_range, ne_range), norbs * 2 - 1)
        print(f"Ne within current block: {ne}")
#         Construct (Ne^-ne)^2 terms in matrix, to enforce structure of states
        if ne_per_block != 0:
            balance_obt = np.zeros([norbs, norbs])
            balance_tbt = np.zeros([norbs, norbs,  norbs, norbs])
            for p, q in product(range(norbs), repeat = 2):
                balance_tbt[p,p,q,q] += 1
            for p in range(len(orbs)):
                balance_obt[p,p] -= 2 * ne
#             Construct 2e tensor to enforce the Ne in the ground state.
            strength = balance_t * (1 + random.random())
#             tmp_tbt = np.add(tmp_tbt, balance_tbt)
        flag = True
        while flag:
            strength *= 2
            cas_tbt[s:t, s:t, s:t, s:t] = np.add(cas_tbt[s:t, s:t, s:t, s:t], strength * balance_tbt)
            cas_obt[s:t, s:t] = np.add(cas_obt[s:t, s:t], strength * balance_obt)
#             Set spin_orb to False to represent spatial orbital basis Hamiltonian
            tmp = feru.get_ferm_op(cas_tbt[s:t, s:t, s:t, s:t], spin_orb = False)
            tmp += feru.get_ferm_op(cas_obt[s:t, s:t], spin_orb = False)
            sparse_H_tmp = of.get_sparse_operator(tmp)
            tmp_E_min, t_sol = of.get_ground_state(sparse_H_tmp)
            st = sdstate(n_qubit = len(orbs) * 2) 
            for i in range(len(t_sol)):
                if np.linalg.norm(t_sol[i]) > np.finfo(np.float32).eps:
                    st += sdstate(s = i, coeff = t_sol[i])
#             print(f"state norm: {st.norm()}")
            st.normalize()
            E_st = st.exp(tmp)
            flag = False
            for sd in st.dic:
                ne_computed = bin(sd)[2:].count('1')
                if ne_computed != ne:
                    print("Not enough balance, adding more terms")
                    print(bin(sd))
                    flag = True
                    break
        print(f"E_min: {tmp_E_min} for orbs: {orbs}")
        print(f"current state Energy: {E_st}")
        E_cas += E_st
        states.append(st)
        e_nums.append(ne)                
    return e_nums, states, E_cas
    
    
def H_to_sparse(H: of.FermionOperator, n):
    """ Construct the sparse tensor representation of the Hamiltonian, represented by a constant term, a 
    """
    h1e_keys = []
    h1e_vals = []
    h2e_keys = []
    h2e_vals = []
    for key, val in H.terms.items():
        if len(key) == 2:
            h1e_keys.append([key[0][0], key[1][0]])
            h1e_vals.append(val)
        elif len(key) == 4:
            h2e_keys.append([key[0][0], key[1][0], key[2][0], key[3][0]])
            h2e_vals.append(val)
    sparse_h1 = tf.sparse.SparseTensor(indices = h1e_keys, values = h1e_vals, dense_shape = [n,n])
    sparse_h2 = tf.sparse.SparseTensor(indices = h2e_keys, values = h2e_vals, dense_shape = [n,n,n,n])
    return sparse_h1, sparse_h2

def sparse_to_H(H_sparse):
    H = of.FermionOperator.zero()
    for _, (term, value) in enumerate(zip(H_sparse.indices, H_sparse.values)):
        index = [int(i) for i in term.numpy()]
        val = value.numpy()
        if len(index) == 2:
            H += of.FermionOperator(
                term = (
                    (index[0], 1), (index[1], 0)
                ), 
                coefficient = val
            )
        elif len(index) == 4:
            H += of.FermionOperator(
                term = (
                    (index[0], 1), (index[1], 0),
                    (index[2], 1), (index[3], 0)
                ), 
                coefficient = val
            )
    return H
# Killer Construction
def construct_killer(k, e_num, n = 0, const = 1e-2, t = 1e2, n_killer = 3):
    """ Construct a killer operator for CAS Hamiltonian, based on cas block structure of k and the size of killer is 
    given in k, the number of electrons in each CAS block of the ground state
    is specified by e_nums. t is the strength of quadratic balancing terms for the killer with respect to k,
    n_killer specifies the number of operators O to choose.
    """
    if not n:
        n = max([max(orbs) for orbs in k])
    killer = of.FermionOperator.zero()
    for i in range(len(k)):
        orbs = k[i]
        outside_orbs = [j for j in range(n) if j not in orbs]
    #     Construct Ne
        Ne = sum([of.FermionOperator("{}^ {}".format(i, i)) for i in orbs])
    #     Construct O, for O as combination of Epq which preserves Sz and S2
        if len(outside_orbs) >= 4:
            tmp = 0
            while tmp < n_killer:
                p, q = random.sample(outside_orbs, 2)
                if abs(p - q) > 1:
#                     Constructing symmetry conserved killers
                    O = of.FermionOperator.zero()
#                     if p % 2 != 0:
#                         p -= 1
#                     if q % 2 != 0:
#                         q -= 1
#                     ferm_op = of.FermionOperator("{}^ {}".format(p, q)) + of.FermionOperator("{}^ {}".format(q, p))
#                     O += ferm_op
#                     O += of.hermitian_conjugated(ferm_op)
#                     ferm_op = of.FermionOperator("{}^ {}".format(p + 1, q + 1)) + of.FermionOperator("{}^ {}".format(q + 1, p + 1))
#                     O += ferm_op
#                     O += of.hermitian_conjugated(ferm_op)
                    ferm_op = of.FermionOperator("{}^ {}".format(p, q)) + of.FermionOperator("{}^ {}".format(q, p))
                    O += ferm_op
                    O += of.hermitian_conjugated(ferm_op)
                    k_const = const * (1 + np.random.rand())
                    killer += k_const * O * (Ne - e_nums[i])
                    tmp += 1
        killer += t * (1 + np.random.rand()) * const * ((Ne - e_nums[i]) ** 2)
    killer_obt, killer_tbt = H_to_sparse(killer, n)
#     Killer constant term
    c = killer.terms[()]
    return c, killer_obt, killer_tbt

def construct_orbs(key: str):
#     Contruct k from the given key
    count = 0
    lis = key.split("-")
    k = []
    for i in lis:
        tmp = int(i)
        k.append(list(range(count, count + tmp)))
        count += tmp
    return k

def get_param_num(n, k, complex = False):
    '''
    Counting the parameters needed, where k is the number of orbitals occupied by CAS Fragments,
    and n-k orbitals are occupied by the CSA Fragments
    '''
    if not complex:
        upnum = int(n * (n - 1) / 2)
    else:
        upnum = n * (n - 1)
    casnum = 0
    for block in k:
        casnum += len(block) ** 4 + len(block) ** 2
    pnum = upnum + casnum
    return upnum, casnum, pnum
    
def check_for_incorrect_spin_terms(tbt_to_check):
    num_incorrect_terms = 0
    #Check no incorrect spin terms present
    num_spin_orbitals = tbt_to_check.shape[0]
    no_incorrect_terms = True
    for piter in range(num_spin_orbitals):
        for qiter in range(num_spin_orbitals):
            for riter in range(num_spin_orbitals):
                for siter in range(num_spin_orbitals):
                    if piter % 2 == 0 and qiter % 2 == 0 and riter % 2 == 0 and siter % 2 == 0:
                        continue
                    if piter % 2 == 1 and qiter % 2 == 1 and riter % 2 == 1 and siter % 2 == 1:
                        continue
                    if piter % 2 == 0 and qiter % 2 == 0 and riter % 2 == 1 and siter % 2 == 1:
                        continue
                    if piter % 2 == 1 and qiter % 2 == 1 and riter % 2 == 0 and siter % 2 == 0:
                        continue
                    if not np.isclose(tbt_to_check[piter, qiter, riter, siter], 0.0):
                        # print(f"Incorrect spin term present in two body tensor at indices {piter}, {qiter}, {riter}, {siter}: {tbt_to_check[piter, qiter, riter, siter]}")
                        no_incorrect_terms = False
                        num_incorrect_terms += 1

    return no_incorrect_terms, num_incorrect_terms

In [153]:


if __name__ == "__main__":   
#     for file_name in os.listdir(path):
    file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"
    with h5py.File(path + file_name, mode="r") as h5f:
        attributes = dict(h5f.attrs.items())
        Hobt = np.array(h5f["one_body_tensor"])
        Htbt = np.array(h5f["two_body_tensor"])
        spatial_orbs = Hobt.shape[0]
    print(f"Number of spatial orbitals: {spatial_orbs}")
    Hobt -= 0.5*np.einsum("prrq->pq",Htbt)
    Htbt *= 0.5
    H = feru.get_ferm_op(Htbt, spin_orb = True) + feru.get_ferm_op(Hobt, spin_orb = True)
#     print(H)
#     for t, val in H.terms.items():
#         print(t)
#         print(val)
    sparse_obt, sparse_tbt = H_to_sparse(H, Hobt.shape[0])
    tmp_H = sparse_to_H(sparse_obt) + sparse_to_H(sparse_tbt)
    print(np.linalg.norm(Htbt - tf.sparse.to_dense(sparse_tbt)))
    print(True)

  Hobt = np.array(h5f["one_body_tensor"])
  Htbt = np.array(h5f["two_body_tensor"])


Number of spatial orbitals: 12
0.0
True


## Construct CAS Hamiltonians for all catalyst systems

In [None]:
if __name__ == "__main__":   
    for file_name in os.listdir(path):
#         file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"
        with h5py.File(path + file_name, mode="r") as h5f:
            attributes = dict(h5f.attrs.items())
            Hobt = np.array(h5f["one_body_tensor"])
            Htbt = np.array(h5f["two_body_tensor"])
        #     Construct a single 2e tensor to represent the Hamiltonian with idempotent transformation
    #     print(Htbt.shape)
        spatial_orbs = Hobt.shape[0]
        print(f"Number of spatial orbitals: {spatial_orbs}")
    #     spatial_orbs = spin_orbs // 2
    #     onebody_tbt = feru.onebody_to_twobody(one_body)
    #     Htbt = np.add(two_body, onebody_tbt)
        Hobt -= 0.5*np.einsum("prrq->pq",Htbt)
        Htbt *= 0.5
        H = (Hobt, Htbt)
        k = construct_blocks(block_size, spatial_orbs)
        print(f"orbital splliting: {k}")
        upnum, casnum, pnum = get_param_num(spatial_orbs, k, complex = False)
        cas_obt, cas_tbt, cas_x = get_truncated_cas_tbt(H, k, casnum)
        H_cas = [cas_obt, cas_tbt]
    #     cas_tbt_tmp = copy.deepcopy(cas_tbt)
    #     print(H_cas)
        e_nums, states, E_cas = solve_enums(H_cas, k, ne_per_block = ne_per_block,
                                            ne_range = ne_range, balance_t = balance_strength)
    #     assert np.allclose(cas_tbt_tmp, cas_tbt), "changed"
        print(f"e_nums:{e_nums}")
        print(f"E_cas: {E_cas}")
        if check_state:
            sd_sol = sdstate()
            for st in states:
                sd_sol = sd_sol.concatenate(st)
    #     The following code segment checks the state energy for the full Hamiltonian, takes exponential space 
    #     and time with respect to the number of blocks
            print(sd_sol.n_qubit)
            H_cas_op = feru.get_ferm_op(cas_tbt, False) + feru.get_ferm_op(cas_obt, False) 
            E_sol = sd_sol.exp(H_cas_op)
            print(f"Double check ground state energy: {E_sol}")

        # Checking ground state with FCI
        # Warning: This takes exponential time and space to run
        #     Checking H_cas symmetries
        if check_symmetry:
            Sz = of.hamiltonians.sz_operator(spatial_orbs)
            S2 = of.hamiltonians.s_squared_operator(spatial_orbs)
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, H_cas)), "Sz symmetry broken"
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, H_cas)), "S2 symmetry broken"

        if FCI:
            E_min, sol = of.get_ground_state(of.get_sparse_operator(H_cas_op))
            print(f"FCI Energy: {E_min}")
            tmp_st = sdstate(n_qubit = spin_orbs * 2)
            for s in range(len(sol)):
                if sol[s] > np.finfo(np.float32).eps:
                    tmp_st += sdstate(s, sol[s])
            #         print(bin(s))
            
            print(f"truncated wavefunction norm: {tmp_st.norm()}")
            tmp_st.normalize()
            print(tmp_st.exp(H_cas))

        killer_c, killer_obt, killer_tbt = construct_killer(k, e_nums, n = spatial_orbs, n_killer = n_killer)
        if check_symmetry:
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, cas_killer)), "Killer broke Sz symmetry"
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, cas_killer)), "S2 symmetry broken"

        # Checking: if FCI of killer gives same result. Warning; takes exponential time 
        if FCI:
            cas_killer = feru.get_ferm_op(killer_obt) + killer_c 
            sparse_with_killer = of.get_sparse_operator(cas_killer + H_cas_op)
            killer_Emin, killer_sol = of.get_ground_state(sparse_with_killer)
            print(f"FCI Energy solution with killer: {killer_Emin}")
            sd_Emin = sd_sol.exp(cas_tbt) + sd_sol.exp(cas_killer)
            print(f"difference with CAS energy: {sd_Emin - killer_Emin}")

        # Checking: if killer does not change ground state
        if check_state:
            killer_error = sd_sol.exp(cas_killer)
            print(f"Solution Energy shift by killer: {killer_error}")
            killer_E_sol = sd_sol.exp(H_cas_op + cas_killer)
            print(f"Solution Energy with killer: {killer_E_sol}")

        planted_sol = {}
        planted_sol["E_min"] = E_sol
        planted_sol["e_nums"] = e_nums        
        planted_sol["killer"] = (killer_c, killer_obt, killer_tbt)
        planted_sol["k"] = k
        planted_sol["casnum"] = casnum
        planted_sol["pnum"] = pnum
        planted_sol["upnum"] = upnum
        planted_sol["spatial_orbs"] = spatial_orbs
        planted_sol["cas_x"] = cas_x
        planted_sol["sol"] = states
        if check_state:
            planted_sol["solution"] = sd_sol
            
        ps_path = "planted_solutions/"
        f_name = file_name.split(".")[0] + ".pkl"
        print(ps_path + f_name)

        l = list(map(len, k))
        l = list(map(str, l))
        key = "-".join(l)
        print(key)
        if os.path.exists(ps_path + f_name):
            with open(ps_path + f_name, 'rb') as handle:
                dic = pickle.load(handle)
        else:
            dic = {}

        with open(ps_path + f_name, 'wb') as handle:
            dic[key] = planted_sol
            pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)

  Hobt = np.array(h5f["one_body_tensor"])
  Htbt = np.array(h5f["two_body_tensor"])


Number of spatial orbitals: 32
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1256.690192383282 for orbs: [0, 1, 2, 3]
current state Energy: -1256.6901923832813
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1162.2666360395735 for orbs: [4, 5, 6, 7]
current state Energy: -1162.2666360395276
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -602.1533749471962 for orbs: [8, 9, 10, 11]
current state Energy: -602.1533749471909
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -71.45659676019197 for orbs: [12, 13, 14, 15]
current state Energy: -71.45659676019001
Ne within current block: 6
Not enough balance, adding

Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b1010111
E_min: -560.7215190357991 for orbs: [16, 17, 18, 19]
current state Energy: -560.7215190357999
Ne within current block: 3
Not enough balance, adding more terms
0b1111
E_min: -232.68281735875294 for orbs: [20, 21, 22, 23]
current state Energy: -232.6828173587511
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -957.6544127084794 for orbs: [24, 25, 26, 27]
current state Energy: -957.6544127084459
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b1010111
E_min: -707.7115341374171 for orbs: [28, 29, 30, 31]
current state Energy: -707.7115341374038
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -446.91474166346575 for orbs: [32, 33, 34, 35]
current state Energy: -446.91474166346626
Ne within current block: 4
Not enough ba

E_min: -830.1409878042125 for orbs: [0, 1, 2, 3]
current state Energy: -830.1409878041956
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1780.1557805174357 for orbs: [4, 5, 6, 7]
current state Energy: -1780.1557805173722
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -120.20307470046636 for orbs: [8, 9, 10, 11]
current state Energy: -120.20307470046528
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b111
E_min: -159.7391264921832 for orbs: [12, 13, 14, 15]
current state Energy: -159.7391264921831
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -926.1808095242874 for orbs: [16, 17, 18, 19]
current state Energy: -926.1808095242744
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -396.09900179258483 for o

Not enough balance, adding more terms
0b111
E_min: -212.07315296623972 for orbs: [8, 9, 10, 11]
current state Energy: -212.07315296621817
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1256.3981538594694 for orbs: [12, 13, 14, 15]
current state Energy: -1256.3981538594599
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b11111
E_min: -581.526304391129 for orbs: [16, 17, 18, 19]
current state Energy: -581.5263043911291
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -416.856127615378 for orbs: [20, 21, 22, 23]
current state Energy: -416.8561276153768
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -848.3781123225832 for orbs: [24, 25, 26, 27]
current state Energy: -848.3781123225722
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, a

Not enough balance, adding more terms
0b11111
E_min: -427.8903108474222 for orbs: [16, 17, 18, 19]
current state Energy: -427.8903108474228
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b111
E_min: -159.04441984444804 for orbs: [20, 21, 22, 23]
current state Energy: -159.04441984444748
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -130.2866413274721 for orbs: [24, 25, 26, 27]
current state Energy: -130.28664132747232
Ne within current block: 2
Not enough balance, adding more terms
0b110100
E_min: -128.48106955352176 for orbs: [28, 29, 30, 31]
current state Energy: -128.48106955351983
Ne within current block: 3
Not enough balance, adding more terms
0b111101
Not enough balance, adding more terms
0b111100
E_min: -356.500739794384 for orbs: [32, 33, 34, 35]
current state Energy: -356.50073979435825
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -433.841749395262

Not enough balance, adding more terms
0b1111101
E_min: -561.2071720255416 for orbs: [24, 25, 26, 27]
current state Energy: -561.2071720255408
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -620.2335235377216 for orbs: [28, 29, 30, 31]
current state Energy: -620.2335235377109
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -411.7758947051406 for orbs: [32, 33, 34, 35]
current state Energy: -411.77589470512487
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1152.646306180057 for orbs: [36, 37, 38, 39]
current state Energy: -1152.6463061800318
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -655.3840803199572 for orbs: [40, 41, 42, 43]
current state Energy: -655.3840803199352
Ne within current block: 3
Not enough balance, adding more terms
0b1111
Not enough balanc

Not enough balance, adding more terms
0b10101
E_min: -127.30980417751324 for orbs: [60, 61, 62, 63]
current state Energy: -127.30980417750193
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -940.0950996238568 for orbs: [64, 65, 66, 67]
current state Energy: -940.0950996238158
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1295.478605773666 for orbs: [68, 69, 70, 71]
current state Energy: -1295.478605773649
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -516.5684212255568 for orbs: [72, 73, 74, 75]
current state Energy: -516.5684212255214
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -403.29850825118785 for orbs: [76, 77, 78, 79]
current state Energy: -403.29850825118336
Ne within current block: 6
Not enough balance

Not enough balance, adding more terms
0b111
E_min: -136.386240343506 for orbs: [56, 57, 58, 59]
current state Energy: -136.38624034350508
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1345.417666273731 for orbs: [60, 61, 62, 63]
current state Energy: -1345.4176662736736
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -907.5342678573412 for orbs: [64, 65, 66, 67]
current state Energy: -907.5342678573317
Ne within current block: 3
Not enough balance, adding more terms
0b1010111
Not enough balance, adding more terms
0b1010101
E_min: -492.2348759336164 for orbs: [68, 69, 70, 71]
current state Energy: -492.2348759335657
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -421.2352492321953 for orbs: [72, 73, 74, 75]
current state Energy: -421.2352492321933
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough bala

Not enough balance, adding more terms
0b111
E_min: -192.39686826092307 for orbs: [84, 85, 86, 87]
current state Energy: -192.3968682609234
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b11111
E_min: -819.2362238364462 for orbs: [88, 89, 90, 91]
current state Energy: -819.2362238364442
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1458.6809386429472 for orbs: [92, 93, 94, 95]
current state Energy: -1458.6809386428952
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b11110
E_min: -394.2488819409482 for orbs: [96, 97, 98, 99]
current state Energy: -394.2488819409441
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -407.27286996338654 for orbs: [100, 101, 102, 103]
current state Energy: -407.27286996338

Number of spatial orbitals: 64
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47], [48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]]
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1381.3739602524488 for orbs: [0, 1, 2, 3]
current state Energy: -1381.3739602522285
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1285.3082111423025 for orbs: [4, 5, 6, 7]
current state Energy: -1285.3082111422239
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -455.13328697141674 for orbs: [8, 9, 10, 11]
current state Energy: -455.1332869714153
Ne within current block: 6
Not enough balance, adding more terms
0b1111

E_min: -324.5221379443171 for orbs: [80, 81, 82, 83]
current state Energy: -324.52213794431856
Ne within current block: 3
Not enough balance, adding more terms
0b1111
E_min: -228.43098281765583 for orbs: [84, 85, 86, 87]
current state Energy: -228.430982817649
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -571.8258602925608 for orbs: [88, 89, 90, 91]
current state Energy: -571.8258602925564
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b111
E_min: -159.11969172817552 for orbs: [92, 93, 94, 95]
current state Energy: -159.1196917281669
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -946.6922435625659 for orbs: [96, 97, 98, 99]
current state Energy: -946.6922435625568
Ne within current block: 3
Not enough balance, adding more terms
0b1111
E_min: -240.35701947244087 for orbs: [100, 101]
current state Energy: -240.35701947244067
e_nums:[3, 2, 5, 4, 2, 3, 3,

Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1385.574540945722 for orbs: [0, 1, 2, 3]
current state Energy: -1385.5745409456442
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1495.7584142372405 for orbs: [4, 5, 6, 7]
current state Energy: -1495.75841423715
Ne within current block: 3
Not enough balance, adding more terms
0b111100
Not enough balance, adding more terms
0b111100
E_min: -597.2824174601642 for orbs: [8, 9, 10, 11]
current state Energy: -597.2824174601577
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -864.1647662747832 for orbs: [12, 13, 14, 15]
current state Energy: -864.1647662747862
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -452.26121572767386 for orbs: [16, 17, 18, 19]
current state Energy: -452.26121572767295
Ne within current block: 6
Not enough balance, ad

E_min: -404.3741983332745 for orbs: [0, 1, 2, 3]
current state Energy: -404.37419833327175
Ne within current block: 2
Not enough balance, adding more terms
0b10101
E_min: -124.5058712524225 for orbs: [4, 5, 6, 7]
current state Energy: -124.5058712524126
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b11110
E_min: -315.34378062810936 for orbs: [8, 9, 10, 11]
current state Energy: -315.34378062810754
Ne within current block: 2
Not enough balance, adding more terms
0b10101
E_min: -79.92497014146974 for orbs: [12, 13, 14, 15]
current state Energy: -79.92497014146947
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -808.4084423599137 for orbs: [16, 17, 18, 19]
current state Energy: -808.4084423598756
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -382.72345839543067 for orbs: [20, 21, 22, 23]
current state Energy: 

Number of spatial orbitals: 64
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47], [48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]]
Ne within current block: 2
Not enough balance, adding more terms
0b1010101
Not enough balance, adding more terms
0b10101
E_min: -212.27878360491388 for orbs: [0, 1, 2, 3]
current state Energy: -212.27878360488506
Ne within current block: 3
Not enough balance, adding more terms
0b1010111
Not enough balance, adding more terms
0b1010101
E_min: -510.05352015397267 for orbs: [4, 5, 6, 7]
current state Energy: -510.0535201539575
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b11111
E_min: -545.0998573568522 for orbs: [8, 9, 10, 11]
current state Energy: -545.0998573568514
Ne within current block: 3
Not 

In [47]:
def get_cas_matrix(cas_x, n, k):
    obt = np.zeros([n, n])
    tbt = np.zeros([n, n, n, n])
    idx = 0
    for orbs in k:
        for p, q in product(orbs, repeat = 2):
            obt[p,q] = cas_x[idx]
            idx += 1
        for p, q, r, s in product(orbs, repeat = 4):
            tbt[p,q,r,s] = cas_x[idx]
            idx += 1
    return obt, tbt

def orbtransf(tensor, U, complex = False):
    """Return applying UHU* for the tensor representing the 1e or 2e tensor"""
    if len(tensor.shape) == 4:
        p = np.einsum_path('ak,bl,cm,dn,klmn->abcd', U, U, U, U, tensor)[0]
        return np.einsum('ak,bl,cm,dn,klmn->abcd', U, U, U, U, tensor, optimize = p)
    elif len(tensor.shape) == 2:
        p = np.einsum_path('ap,bq, pq->ab', U, U, tensor)[0]
        return np.einsum('ap,bq, pq->ab', U, U, tensor, optimize = p)

In [78]:
def construct_Hamiltonian_with_solution(path, file_name):
    with open(ps_path + f_name, 'rb') as handle:
        dic = pickle.load(handle)
        key = list(dic.keys())[0]
        dic = dic[key]
    cas_x = dic["cas_x"]
    killer = dic["killer"]
    killer_c = killer[0]
    k_obt = tf.sparse.to_dense(killer[1])
    k_tbt = tf.sparse.to_dense(killer[2])
    k = dic["k"]
    upnum = dic["upnum"]
    spatial_orbs = dic["spatial_orbs"]
#     CAS 2e tensor
    obt, tbt = get_cas_matrix(cas_x, spatial_orbs, k)
    H_cas = (0, obt, tbt)
    H_with_killer = (killer_c, obt+k_obt, tbt+k_tbt)
#     Set up random unitary to hide 2e tensor
    random_uparams = np.random.rand(upnum)
    U = construct_orthogonal(spatial_orbs, random_uparams)
#     Hide 2e etensor with random unitary transformation
    H_hidden = (0, orbtransf(obt, U), orbtransf(tbt, U))
    H_killer_hidden = (killer_c, orbtransf(H_with_killer[1], U), orbtransf(H_with_killer[2], U))
    return H_cas, H_hidden, H_with_killer, H_killer_hidden

file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"
H_cas, H_hidden, H_with_killer, H_killer_hidden = construct_Hamiltonian_with_solution(ps_path, f_name)