In [2]:
"""Construct CAS Hamiltonians with cropping
"""
import saveload_utils as sl
import ferm_utils as feru
import csa_utils as csau
import var_utils as varu
import openfermion as of
import numpy as np
from sdstate import *
from itertools import product
import random
import h5py
import sys
from matrix_utils import construct_orthogonal
import pickle
import tensorflow as tf

### Parameters
tol = 1e-5
balance_strength = 2
save = False
# Number of spatial orbitals in a block
block_size = 4
# Number of electrons per block
ne_per_block = 4
# +- difference in number of electrons per block
ne_range = 2
# Number of killer operators for each CAS block
n_killer = 3
# Running Full CI to check compute the ground state, takes exponentially amount of time to execute
FCI = False
# Checking symmetries of the planted Hamiltonian, very costly
check_symmetry = False
# Concatenate the states in each block to compute the ground state solution
check_state = False
# File path and name
path = "hamiltonians_catalysts/"
file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"

def construct_blocks(b: int, spin_orbs: int, spin_orb = False):
    """Construct CAS blocks of size b for spin_orbs/spatial number of orbitals"""
    if spin_orb:
        b = b * 2
    k = []
    tmp = [0]
    for i in range(1, spin_orbs):
        if i % b == 0:
            k.append(tmp)
            tmp = [i]
        else:
            tmp.append(i)
    if len(tmp) != 0:
        k.append(tmp)
    return k
        
def get_truncated_cas_tbt(H, k, casnum):
#     Trunctate the original Hamiltonian two body tensor into the cas block structures
    Hobt, Htbt = H
    n = Htbt.shape[0]
    cas_tbt = np.zeros([n, n, n, n])
    cas_obt = np.zeros([n, n])
    cas_x = np.zeros(casnum)
    idx = 0
    for block in k:
        for p, q in product(block, repeat = 2):
            cas_obt[p,q] = Hobt[p,q]
            cas_x[idx] = Hobt[p,q]
            idx += 1
        for a, b, c, d in product(block, repeat = 4):
            cas_tbt [a,b,c,d] = Htbt [a,b,c,d]
            cas_x[idx] = Htbt[a,b,c,d]
            idx += 1
    return cas_obt, cas_tbt, cas_x

def in_orbs(term, orbs):
    """Return if the term is a local excitation operator within orbs"""
    if len(term) == 2:
        return term[0][0] in orbs and term[1][0] in orbs
    elif len(term) == 4:
        return term[0][0] in orbs and term[1][0] in orbs and term[2][0] in orbs and term[3][0] in orbs
    return False

# def transform_orbs(term, orbs):
#     """Transform the operator term to align the orbs starting from 0"""
# #     pass
#     if len(term) == 2:
#         return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0))
#     if len(term) == 4:
#         return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0), 
#                (orbs.index(term[2][0]), 1), (orbs.index(term[3][0]), 0))   
#     return None

def solve_enums(H, k, ne_per_block = 0, ne_range = 0, balance_t = 10):
    """Solve for number of electrons in each CAS block with FCI within the block,
    H = (obt, tbt) as the Hamiltonian in spatial orbitals. 
    Notice that some quadratic terms (Ne-ne)^2 are added to ensure the correct number
    of electrons in the ground state of each block
    """ 
    cas_obt = H[0]
    cas_tbt = H[1]
    e_nums = []
    states = []
    E_cas = 0
    for orbs in k:
        s = orbs[0]
        t = orbs[-1] + 1
        norbs = len(orbs)
        ne = min(ne_per_block + random.randint(-ne_range, ne_range), norbs * 2 - 1)
        print(f"Ne within current block: {ne}")
#         Construct (Ne^-ne)^2 terms in matrix, to enforce structure of states
        if ne_per_block != 0:
            balance_obt = np.zeros([norbs, norbs])
            balance_tbt = np.zeros([norbs, norbs,  norbs, norbs])
            for p, q in product(range(norbs), repeat = 2):
                balance_tbt[p,p,q,q] += 1
            for p in range(len(orbs)):
                balance_obt[p,p] -= 2 * ne
#             Construct 2e tensor to enforce the Ne in the ground state.
            strength = balance_t * (1 + random.random())
#             tmp_tbt = np.add(tmp_tbt, balance_tbt)
        flag = True
        while flag:
            strength *= 2
            cas_tbt[s:t, s:t, s:t, s:t] = np.add(cas_tbt[s:t, s:t, s:t, s:t], strength * balance_tbt)
            cas_obt[s:t, s:t] = np.add(cas_obt[s:t, s:t], strength * balance_obt)
#             Set spin_orb to False to represent spatial orbital basis Hamiltonian
            tmp = feru.get_ferm_op(cas_tbt[s:t, s:t, s:t, s:t], spin_orb = False)
            tmp += feru.get_ferm_op(cas_obt[s:t, s:t], spin_orb = False)
            sparse_H_tmp = of.get_sparse_operator(tmp)
            tmp_E_min, t_sol = of.get_ground_state(sparse_H_tmp)
            st = sdstate(n_qubit = len(orbs) * 2) 
            for i in range(len(t_sol)):
                if np.linalg.norm(t_sol[i]) > np.finfo(np.float32).eps:
                    st += sdstate(s = i, coeff = t_sol[i])
#             print(f"state norm: {st.norm()}")
            st.normalize()
            E_st = st.exp(tmp)
            flag = False
            for sd in st.dic:
                ne_computed = bin(sd)[2:].count('1')
                if ne_computed != ne:
                    print("Not enough balance, adding more terms")
                    print(bin(sd))
                    flag = True
                    break
        print(f"E_min: {tmp_E_min} for orbs: {orbs}")
        print(f"current state Energy: {E_st}")
        E_cas += E_st
        states.append(st)
        e_nums.append(ne)                
    return e_nums, states, E_cas
    
    
def H_to_sparse(H: of.FermionOperator, n):
    """ Construct the sparse tensor representation of the Hamiltonian, represented by a constant term, a 
    """
    h1e_keys = []
    h1e_vals = []
    h2e_keys = []
    h2e_vals = []
    for key, val in H.terms.items():
        if len(key) == 2:
            h1e_keys.append([key[0][0], key[1][0]])
            h1e_vals.append(val)
        elif len(key) == 4:
            h2e_keys.append([key[0][0], key[1][0], key[2][0], key[3][0]])
            h2e_vals.append(val)
    sparse_h1 = tf.sparse.SparseTensor(indices = h1e_keys, values = h1e_vals, dense_shape = [n,n])
    sparse_h2 = tf.sparse.SparseTensor(indices = h2e_keys, values = h2e_vals, dense_shape = [n,n,n,n])
    return sparse_h1, sparse_h2

def sparse_to_H(H_sparse):
    H = of.FermionOperator.zero()
    for _, (term, value) in enumerate(zip(H_sparse.indices, H_sparse.values)):
        index = [int(i) for i in term.numpy()]
        val = value.numpy()
        if len(index) == 2:
            H += of.FermionOperator(
                term = (
                    (index[0], 1), (index[1], 0)
                ), 
                coefficient = val
            )
        elif len(index) == 4:
            H += of.FermionOperator(
                term = (
                    (index[0], 1), (index[1], 0),
                    (index[2], 1), (index[3], 0)
                ), 
                coefficient = val
            )
    return H
# Killer Construction
def construct_killer(k, e_num, n = 0, const = 1e-2, t = 1e2, n_killer = 3):
    """ Construct a killer operator for CAS Hamiltonian, based on cas block structure of k and the size of killer is 
    given in k, the number of electrons in each CAS block of the ground state
    is specified by e_nums. t is the strength of quadratic balancing terms for the killer with respect to k,
    n_killer specifies the number of operators O to choose.
    """
    if not n:
        n = max([max(orbs) for orbs in k])
    killer = of.FermionOperator.zero()
    for i in range(len(k)):
        orbs = k[i]
        outside_orbs = [j for j in range(n) if j not in orbs]
    #     Construct Ne
        Ne = sum([of.FermionOperator("{}^ {}".format(i, i)) for i in orbs])
    #     Construct O, for O as combination of Epq which preserves Sz and S2
        if len(outside_orbs) >= 4:
            tmp = 0
            while tmp < n_killer:
                p, q = random.sample(outside_orbs, 2)
                if abs(p - q) > 1:
#                     Constructing symmetry conserved killers
                    O = of.FermionOperator.zero()
#                     if p % 2 != 0:
#                         p -= 1
#                     if q % 2 != 0:
#                         q -= 1
#                     ferm_op = of.FermionOperator("{}^ {}".format(p, q)) + of.FermionOperator("{}^ {}".format(q, p))
#                     O += ferm_op
#                     O += of.hermitian_conjugated(ferm_op)
#                     ferm_op = of.FermionOperator("{}^ {}".format(p + 1, q + 1)) + of.FermionOperator("{}^ {}".format(q + 1, p + 1))
#                     O += ferm_op
#                     O += of.hermitian_conjugated(ferm_op)
                    ferm_op = of.FermionOperator("{}^ {}".format(p, q)) + of.FermionOperator("{}^ {}".format(q, p))
                    O += ferm_op
                    O += of.hermitian_conjugated(ferm_op)
                    k_const = const * (1 + np.random.rand())
                    killer += k_const * O * (Ne - e_nums[i])
                    tmp += 1
        killer += t * (1 + np.random.rand()) * const * ((Ne - e_nums[i]) ** 2)
    killer_obt, killer_tbt = H_to_sparse(killer, n)
#     Killer constant term
    c = killer.terms[()]
    return c, killer_obt, killer_tbt

def construct_orbs(key: str):
#     Contruct k from the given key
    count = 0
    lis = key.split("-")
    k = []
    for i in lis:
        tmp = int(i)
        k.append(list(range(count, count + tmp)))
        count += tmp
    return k

def get_param_num(n, k, complex = False):
    '''
    Counting the parameters needed, where k is the number of orbitals occupied by CAS Fragments,
    and n-k orbitals are occupied by the CSA Fragments
    '''
    if not complex:
        upnum = int(n * (n - 1) / 2)
    else:
        upnum = n * (n - 1)
    casnum = 0
    for block in k:
        casnum += len(block) ** 4 + len(block) ** 2
    pnum = upnum + casnum
    return upnum, casnum, pnum
    
def check_for_incorrect_spin_terms(tbt_to_check):
    num_incorrect_terms = 0
    #Check no incorrect spin terms present
    num_spin_orbitals = tbt_to_check.shape[0]
    no_incorrect_terms = True
    for piter in range(num_spin_orbitals):
        for qiter in range(num_spin_orbitals):
            for riter in range(num_spin_orbitals):
                for siter in range(num_spin_orbitals):
                    if piter % 2 == 0 and qiter % 2 == 0 and riter % 2 == 0 and siter % 2 == 0:
                        continue
                    if piter % 2 == 1 and qiter % 2 == 1 and riter % 2 == 1 and siter % 2 == 1:
                        continue
                    if piter % 2 == 0 and qiter % 2 == 0 and riter % 2 == 1 and siter % 2 == 1:
                        continue
                    if piter % 2 == 1 and qiter % 2 == 1 and riter % 2 == 0 and siter % 2 == 0:
                        continue
                    if not np.isclose(tbt_to_check[piter, qiter, riter, siter], 0.0):
                        # print(f"Incorrect spin term present in two body tensor at indices {piter}, {qiter}, {riter}, {siter}: {tbt_to_check[piter, qiter, riter, siter]}")
                        no_incorrect_terms = False
                        num_incorrect_terms += 1

    return no_incorrect_terms, num_incorrect_terms

## Construct CAS Hamiltonians for all catalyst systems

In [None]:
if __name__ == "__main__":   
    ps_path = "planted_solutions/"
    for file_name in os.listdir(path):    

#         file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"
        with h5py.File(path + file_name, mode="r") as h5f:
            attributes = dict(h5f.attrs.items())
            Hobt = np.array(h5f["one_body_tensor"])
            Htbt = np.array(h5f["two_body_tensor"])
        #     Construct a single 2e tensor to represent the Hamiltonian with idempotent transformation
    #     print(Htbt.shape)
        spatial_orbs = Hobt.shape[0]
        print(f"Number of spatial orbitals: {spatial_orbs}")
    #     spatial_orbs = spin_orbs // 2
    #     onebody_tbt = feru.onebody_to_twobody(one_body)
    #     Htbt = np.add(two_body, onebody_tbt)
        k = construct_blocks(block_size, spatial_orbs)
        f_name = file_name.split(".")[0] + ".pkl"
        print(ps_path + f_name)

        l = list(map(len, k))
        l = list(map(str, l))
        key = "-".join(l)
        print(key)
#         Check if computed already
        if os.path.exists(ps_path + f_name):
            print("already exists")
            continue
            with open(ps_path + f_name, 'rb') as handle:
                dic = pickle.load(handle)
        else:
            dic = {}
        Hobt -= 0.5*np.einsum("prrq->pq",Htbt)
        Htbt *= 0.5
        H = (Hobt, Htbt)
        print(f"orbital splliting: {k}")
        upnum, casnum, pnum = get_param_num(spatial_orbs, k, complex = False)
        cas_obt, cas_tbt, cas_x = get_truncated_cas_tbt(H, k, casnum)
        H_cas = [cas_obt, cas_tbt]
    #     cas_tbt_tmp = copy.deepcopy(cas_tbt)
    #     print(H_cas)
        e_nums, states, E_cas = solve_enums(H_cas, k, ne_per_block = ne_per_block,
                                            ne_range = ne_range, balance_t = balance_strength)
    #     assert np.allclose(cas_tbt_tmp, cas_tbt), "changed"
        _, _, cas_x = get_truncated_cas_tbt((cas_obt, cas_tbt), k, casnum)
        print(f"e_nums:{e_nums}")
        print(f"E_cas: {E_cas}")
        if check_state:
            sd_sol = sdstate()
            for st in states:
                sd_sol = sd_sol.concatenate(st)
    #     The following code segment checks the state energy for the full Hamiltonian, takes exponential space 
    #     and time with respect to the number of blocks
            print(sd_sol.n_qubit)
            H_cas_op = feru.get_ferm_op(cas_tbt, False) + feru.get_ferm_op(cas_obt, False) 
            E_sol = sd_sol.exp(H_cas_op)
            print(f"Double check ground state energy: {E_sol}")

        # Checking ground state with FCI
        # Warning: This takes exponential time and space to run
        #     Checking H_cas symmetries
        if check_symmetry:
            Sz = of.hamiltonians.sz_operator(spatial_orbs)
            S2 = of.hamiltonians.s_squared_operator(spatial_orbs)
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, H_cas)), "Sz symmetry broken"
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, H_cas)), "S2 symmetry broken"

        if FCI:
            E_min, sol = of.get_ground_state(of.get_sparse_operator(H_cas_op))
            print(f"FCI Energy: {E_min}")
            tmp_st = sdstate(n_qubit = spin_orbs * 2)
            for s in range(len(sol)):
                if sol[s] > np.finfo(np.float32).eps:
                    tmp_st += sdstate(s, sol[s])
            #         print(bin(s))
            
            print(f"truncated wavefunction norm: {tmp_st.norm()}")
            tmp_st.normalize()
            print(tmp_st.exp(H_cas))

        killer_c, killer_obt, killer_tbt = construct_killer(k, e_nums, n = spatial_orbs, n_killer = n_killer)
        if check_symmetry:
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, cas_killer)), "Killer broke Sz symmetry"
            assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, cas_killer)), "S2 symmetry broken"

        # Checking: if FCI of killer gives same result. Warning; takes exponential time 
        if FCI:
            cas_killer = feru.get_ferm_op(killer_obt) + killer_c 
            sparse_with_killer = of.get_sparse_operator(cas_killer + H_cas_op)
            killer_Emin, killer_sol = of.get_ground_state(sparse_with_killer)
            print(f"FCI Energy solution with killer: {killer_Emin}")
            sd_Emin = sd_sol.exp(cas_tbt) + sd_sol.exp(cas_killer)
            print(f"difference with CAS energy: {sd_Emin - killer_Emin}")

        # Checking: if killer does not change ground state
        if check_state:
            killer_error = sd_sol.exp(cas_killer)
            print(f"Solution Energy shift by killer: {killer_error}")
            killer_E_sol = sd_sol.exp(H_cas_op + cas_killer)
            print(f"Solution Energy with killer: {killer_E_sol}")

        planted_sol = {}
        planted_sol["E_min"] = E_cas
        planted_sol["e_nums"] = e_nums        
        planted_sol["killer"] = (killer_c, killer_obt, killer_tbt)
        planted_sol["k"] = k
        planted_sol["casnum"] = casnum
        planted_sol["pnum"] = pnum
        planted_sol["upnum"] = upnum
        planted_sol["spatial_orbs"] = spatial_orbs
        planted_sol["cas_x"] = cas_x
        planted_sol["sol"] = states
        if check_state:
            planted_sol["solution"] = sd_sol
            planted_sol["E_sol"] = E_sol
            


        with open(ps_path + f_name, 'wb') as handle:
            dic[key] = planted_sol
            pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)

  Hobt = np.array(h5f["one_body_tensor"])
  Htbt = np.array(h5f["two_body_tensor"])


Number of spatial orbitals: 32
planted_solutions/0_ru_macho_{'Ru'_ 'cc-pVTZ-PP', 'default'_ '6-311++G__'}_32_47784dd4-5750-4294-b5e5-80f487b9bf54.pkl
4-4-4-4-4-4-4-4
already exists
Number of spatial orbitals: 92
planted_solutions/10_fecp2+_s0.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4
already exists
Number of spatial orbitals: 92
planted_solutions/11_fecp2_s0_def2-tzvp_92_baae7692-3b91-4399-8ede-b6a94b6e20f8.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4
already exists
Number of spatial orbitals: 66
planted_solutions/12_mo_n2_{'Mo'_ 'def2-TZVP', 'default'_ 'def2-SVP'}_66_70dd7897-40ee-4e29-ba9c-b44d73cf7ccb.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-2
already exists
Number of spatial orbitals: 140
planted_solutions/13_1_lut_ts_{'Mo'_ 'def2-SVP', 'default'_ '6-311+G(d,p)'}_140_f6df9e45-5b0c-4a56-aa70-2761f74d6a6f.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4
already exists
Number of spatial orbitals: 140
planted_solutions/14_1_lut_prod_{'Mo'_ 'def2-SVP', '

Not enough balance, adding more terms
0b1111111
E_min: -869.4644518823918 for orbs: [108, 109, 110, 111]
current state Energy: -869.4644518823889
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -346.4897992787345 for orbs: [112, 113, 114, 115]
current state Energy: -346.48979927873205
Ne within current block: 3
Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -317.2542973474789 for orbs: [116, 117, 118, 119]
current state Energy: -317.25429734747576
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1369.6697258225786 for orbs: [120, 121, 122, 123]
current state Energy: -1369.6697258225568
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1046.3874020899743 for orbs: [124, 125, 126, 127]
current state E

Not enough balance, adding more terms
0b1111
E_min: -256.2523240229754 for orbs: [108, 109, 110, 111]
current state Energy: -256.25232402297564
e_nums:[3, 2, 4, 3, 6, 6, 2, 6, 3, 2, 6, 4, 3, 5, 4, 4, 2, 2, 2, 4, 2, 3, 6, 3, 6, 2, 4, 3]
E_cas: -13473.02385304534
Number of spatial orbitals: 150
planted_solutions/19_I_{'Mo'_ 'def2-SVP', 'I'_ 'def2-SVP', 'Cl'_ 'def2-SVP', 'default'_ '6-311+G(d,p)'}_150_cf8b2ac2-f910-47b5-b9c7-0ae7034066da.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-2
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47], [48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71], [72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83], [84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95], [96, 97, 98, 99], [100, 101, 102, 103],

Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1873.351117815374 for orbs: [132, 133, 134, 135]
current state Energy: -1873.351117815372
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1261.8341639243765 for orbs: [136, 137, 138, 139]
current state Energy: -1261.8341639243465
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b111
E_min: -208.52469939664158 for orbs: [140, 141, 142, 143]
current state Energy: -208.52469939664158
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1124.370095598895 for orbs: [144, 145, 146, 147]
current state Energy: -1124.3700955988757
Ne within current block: 3
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b1111
E_min: -480.50984086579865 

Not enough balance, adding more terms
0b11111
E_min: -411.7901824271721 for orbs: [48, 49, 50, 51]
current state Energy: -411.7901824271699
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -650.9540571007923 for orbs: [52, 53, 54, 55]
current state Energy: -650.9540571007922
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -577.2908293468623 for orbs: [56, 57, 58, 59]
current state Energy: -577.2908293468256
Ne within current block: 3
Not enough balance, adding more terms
0b1010101
E_min: -241.74332723173566 for orbs: [60, 61, 62, 63]
current state Energy: -241.74332723172785
e_nums:[6, 4, 5, 3, 6, 2, 5, 6, 5, 4, 3, 5, 4, 6, 5, 3]
E_cas: -9572.751558210612
Number of spatial orbitals: 102
planted_solutions/21_rc_{'Mo'_ 'def2-SVP', 'I'_ 'def2-SVP', 'Cl'_ 'def2-SVP', 'default'_ '6-311+G(d,p)'}_102_5cdf24d4-86fb-48ea-9345-2e70f7e68de5.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-2
orbital splliting: [[0, 1, 2, 3], [

E_min: -856.8344734947007 for orbs: [12, 13, 14, 15]
current state Energy: -856.8344734946692
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -437.2725390035302 for orbs: [16, 17, 18, 19]
current state Energy: -437.2725390035303
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1167.9373737260705 for orbs: [20, 21, 22, 23]
current state Energy: -1167.937373726039
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -539.4123939649933 for orbs: [24, 25, 26, 27]
current state Energy: -539.4123939649492
Ne within current block: 3
Not enough balance, adding more terms
0b1110110
Not enough balance, adding more terms
0b1111000
E_min: -350.41985733837123 for orbs: [28, 29, 30, 31]
current state Energy: -350.41985733836464
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -654.685375009485 for orbs: [32, 33, 34, 35]
curre

Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1487.7696163924913 for orbs: [48, 49, 50, 51]
current state Energy: -1487.7696163924552
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -128.1820689198601 for orbs: [52, 53, 54, 55]
current state Energy: -128.1820689198594
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b1010111
E_min: -668.3804771579762 for orbs: [56, 57, 58, 59]
current state Energy: -668.3804771579222
Ne within current block: 2
Not enough balance, adding more terms
0b10101
E_min: -131.4025575080977 for orbs: [60, 61, 62, 63]
current state Energy: -131.40255750809735
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b1010111
E_min: -725.4365632211662 for orbs: [64, 65, 66, 67]
current state Energy: -725.4365632211565
Ne within current block: 6
Not enough ba

Number of spatial orbitals: 102
planted_solutions/25_ts_1over2_{'Mo'_ 'def2-SVP', 'I'_ 'def2-SVP', 'Cl'_ 'def2-SVP', 'default'_ '6-311+G(d,p)'}_102_29e7d3a5-cb18-4678-9a1a-9f3bdde84243.pkl
4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-4-2
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47], [48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63], [64, 65, 66, 67], [68, 69, 70, 71], [72, 73, 74, 75], [76, 77, 78, 79], [80, 81, 82, 83], [84, 85, 86, 87], [88, 89, 90, 91], [92, 93, 94, 95], [96, 97, 98, 99], [100, 101]]
Ne within current block: 2
Not enough balance, adding more terms
0b1010111
Not enough balance, adding more terms
0b10101
E_min: -201.77230980729325 for orbs: [0, 1, 2, 3]
current state Energy: -201.7723098072827
Ne within current block: 6
Not enough balance, adding more terms
0b11111111

Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1312.0685740743297 for orbs: [40, 41, 42, 43]
current state Energy: -1312.0685740742256
Ne within current block: 2
Not enough balance, adding more terms
0b10101
E_min: -121.20706167486446 for orbs: [44, 45, 46, 47]
current state Energy: -121.20706167485625
Ne within current block: 3
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b1111
E_min: -390.4835775591659 for orbs: [48, 49, 50, 51]
current state Energy: -390.48357755914867
Ne within current block: 3
Not enough balance, adding more terms
0b1111
E_min: -219.6346018600939 for orbs: [52, 53, 54, 55]
current state Energy: -219.63460186009172
Ne within current block: 4
Not enough balance, adding more terms
0b11111
E_min: -306.15941942791176 for orbs: [56, 57, 58, 59]
current state Energy: -306.1594194279114
Ne within current block: 3
Not enough balance, adding more terms
0b1010101
E_min: -217.0648

Not enough balance, adding more terms
0b10101
E_min: -96.07839416831472 for orbs: [12, 13, 14, 15]
current state Energy: -96.07839416831449
Ne within current block: 6
Not enough balance, adding more terms
0b1111111
E_min: -793.3893585163987 for orbs: [16, 17, 18, 19]
current state Energy: -793.3893585163553
Ne within current block: 3
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b1111
E_min: -443.9928057108431 for orbs: [20, 21, 22, 23]
current state Energy: -443.9928057108432
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -122.75752483107831 for orbs: [24, 25, 26, 27]
current state Energy: -122.757524831078
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -115.68812277770093 for orbs: [28, 29, 30, 31]
current state Energy: -115.68812277769827
Ne within current block: 2
Not enough balance, adding more terms
0b111
E_min: -117.30154875066367 for orbs: [32, 33, 34, 35]
current state Energy: -

Not enough balance, adding more terms
0b11111
Not enough balance, adding more terms
0b1111
E_min: -344.13023820110675 for orbs: [96, 97, 98, 99]
current state Energy: -344.13023820110755
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1257.1729339320566 for orbs: [100, 101, 102, 103]
current state Energy: -1257.1729339319945
e_nums:[3, 3, 6, 3, 3, 6, 3, 3, 4, 2, 6, 3, 5, 5, 4, 2, 5, 5, 6, 2, 5, 2, 3, 6, 3, 5]
E_cas: -16262.420718391335
Number of spatial orbitals: 12
planted_solutions/2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.pkl
4-4-4
orbital splliting: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
Ne within current block: 2
Not enough balance, adding more terms
0b10101
E_min: -94.87290487468708 for orbs: [0, 1, 2, 3]
current state Energy: -94.87290487468333
Ne within current block: 5
E_min: -211.89290556775893 for orbs: [4, 5, 6, 7]
current state Energy: -211.892905567757
Ne within current

Not enough balance, adding more terms
0b11111
E_min: -506.28406495151586 for orbs: [72, 73, 74, 75]
current state Energy: -506.28406495151603
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1395.643945702379 for orbs: [76, 77, 78, 79]
current state Energy: -1395.6439457021693
Ne within current block: 6
Not enough balance, adding more terms
0b11111111
Not enough balance, adding more terms
0b1111111
E_min: -1274.2307579384403 for orbs: [80, 81, 82, 83]
current state Energy: -1274.2307579384096
Ne within current block: 3
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b1111
E_min: -343.66607986299573 for orbs: [84, 85]
current state Energy: -343.6660798629955
e_nums:[3, 6, 3, 4, 2, 2, 5, 4, 6, 6, 2, 5, 2, 2, 3, 2, 2, 3, 4, 6, 6, 3]
E_cas: -10883.189786335191
Number of spatial orbitals: 56
planted_solutions/3_ts_ru_macho_co2_{'Ru'_ 'cc-pVTZ-PP', 'default'_ '6-311++G__'}_5

Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1118.3737441714186 for orbs: [68, 69, 70, 71]
current state Energy: -1118.3737441713758
Ne within current block: 5
Not enough balance, adding more terms
0b1111111
Not enough balance, adding more terms
0b1011111
E_min: -1173.8636140605633 for orbs: [72, 73, 74, 75]
current state Energy: -1173.8636140604917
Ne within current block: 4
Not enough balance, adding more terms
0b1011111
Not enough balance, adding more terms
0b11111
E_min: -716.1850934704581 for orbs: [76, 77, 78, 79]
current state Energy: -716.1850934704541
Ne within current block: 2
Not enough balance, adding more terms
0b1111
Not enough balance, adding more terms
0b111
E_min: -179.0377067719756 for orbs: [80, 81, 82, 83]
current state Energy: -179.03770677197568
Ne within current block: 5
Not enough balance, adding more terms
0b1011111
E_min: -642.8216951293725 for orbs: [84, 85, 86, 87]
current state Energy: -642.82169512

In [1]:
def get_cas_matrix(cas_x, n, k):
    obt = np.zeros([n, n])
    tbt = np.zeros([n, n, n, n])
    idx = 0
    for orbs in k:
        for p, q in product(orbs, repeat = 2):
            obt[p,q] = cas_x[idx]
            idx += 1
        for p, q, r, s in product(orbs, repeat = 4):
            tbt[p,q,r,s] = cas_x[idx]
            idx += 1
    return obt, tbt

def orbtransf(tensor, U, complex = False):
    """Return applying UHU* for the tensor representing the 1e or 2e tensor"""
    if len(tensor.shape) == 4:
        p = np.einsum_path('ak,bl,cm,dn,klmn->abcd', U, U, U, U, tensor)[0]
        return np.einsum('ak,bl,cm,dn,klmn->abcd', U, U, U, U, tensor, optimize = p)
    elif len(tensor.shape) == 2:
        p = np.einsum_path('ap,bq, pq->ab', U, U, tensor)[0]
        return np.einsum('ap,bq, pq->ab', U, U, tensor, optimize = p)

In [12]:
def construct_Hamiltonian_with_solution(path, file_name):
    with open(ps_path + f_name, 'rb') as handle:
        dic = pickle.load(handle)
        key = list(dic.keys())[0]
        dic = dic[key]
    
    cas_x = dic["cas_x"]
    killer = dic["killer"]
    killer_c = killer[0]
    k_obt = tf.sparse.reorder(killer[1])
    k_tbt = tf.sparse.reorder(killer[2])
    k = dic["k"]
    upnum = dic["upnum"]
    spatial_orbs = dic["spatial_orbs"]
#     CAS 2e tensor
    obt, tbt = get_cas_matrix(cas_x, spatial_orbs, k)
    H_cas = (0, obt, tbt)
    H_with_killer = (killer_c, obt+tf.sparse.to_dense(k_obt), tbt+tf.sparse.to_dense(k_tbt))
#     Set up random unitary to hide 2e tensor
    random_uparams = np.random.rand(upnum)
    U = construct_orthogonal(spatial_orbs, random_uparams)
#     Hide 2e etensor with random unitary transformation
    H_hidden = (0, orbtransf(obt, U), orbtransf(tbt, U))
    H_killer_hidden = (killer_c, orbtransf(H_with_killer[1], U), orbtransf(H_with_killer[2], U))
    return H_cas, H_hidden, H_with_killer, H_killer_hidden

ps_path = "planted_solutions/"
file_name = "2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.hdf5"
f_name = file_name.split(".")[0] + ".pkl"
print(f_name)
H_cas, H_hidden, H_with_killer, H_killer_hidden = construct_Hamiltonian_with_solution(ps_path, f_name)

2_co2_6-311++G___12_9d464efb-b312-45f8-b0ba-8c42663059dc.pkl
dict_keys(['E_min', 'e_nums', 'killer', 'k', 'casnum', 'pnum', 'upnum', 'spatial_orbs', 'cas_x', 'sol'])
