In [1]:
"""Construct CAS Hamiltonians with cropping
"""
import saveload_utils as sl
import ferm_utils as feru
import csa_utils as csau
import var_utils as varu
import openfermion as of
import numpy as np
from sdstate import *
from itertools import product
import random

### Parameters
mol = 'h2'
tol = 1e-5
save = False
method_name = 'CAS-Cropping'
k = [[0,1,2,3,4,5],[6,7,8,9,10,11]]
spin_symmetry_check = False
FCI = False
# Get two-body tensor
Hf = sl.load_fermionic_hamiltonian(mol, prefix = "./")
spin_orbs = of.count_qubits(Hf)  
print(spin_orbs)
spatial_orb = spin_orbs // 2
Sz = of.hamiltonians.sz_operator(spatial_orb)
S2 = of.hamiltonians.s_squared_operator(spatial_orb)
def get_truncated_cas_tbt(Htbt, k, casnum):
#     Trunctate the original Hamiltonian two body tensor into the cas block structures
    cas_tbt = np.zeros(Htbt.shape)
    cas_x = np.zeros(casnum)
    idx = 0
    for block in k:
        for a in block:
            for b in block:
                for c in block:
                    for d in block:
                        cas_tbt [a,b,c,d] = Htbt [a,b,c,d]
                        cas_x[idx] = Htbt[a,b,c,d]
                        idx += 1
    return cas_tbt, cas_x

Htbt = feru.get_chemist_tbt(Hf, spin_orbs, spin_orb = True)
one_body = of.normal_ordered(Hf - feru.get_ferm_op(Htbt, spin_orb=True))
# print(one_body)
onebody_matrix = feru.get_obt(one_body, n = spin_orbs, spin_orb = True)
# for term,val in one_body.terms.items():
#     if len(term) == 2:
#         assert onebody_matrix[term[0][0], term[1][0]] == val
#         print(f"Position {[term[0][0], term[1][0]]}: checked")
assert np.transpose(onebody_matrix.any()) == onebody_matrix.any()
onebody_tbt = feru.onebody_to_twobody(onebody_matrix)
r = feru.get_ferm_op(onebody_tbt, True)
Htbt = np.add(Htbt, onebody_tbt)
recombined = feru.get_ferm_op(Htbt, True)
# print(of.normal_ordered(Hf - recombined))
upnum, casnum, pnum = csau.get_param_num(spin_orbs, k, complex = False)
# print(Htbt.shape)
cas_tbt, cas_x = get_truncated_cas_tbt(Htbt, k, casnum)
H_cas = feru.get_ferm_op(cas_tbt, True)
#     Checking H_cas symmetries
assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, H_cas)), "Sz symmetry broken"
assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, H_cas)), "S2 symmetry broken"

ModuleNotFoundError: No module named 'saveload_utils'

In [5]:
# Checking ground state with FCI
# Warning: This takes exponential time to run
E_min, sol = of.get_ground_state(of.get_sparse_operator(H_cas))
print(f"FCI Energy: {E_min}")
tmp_st = sdstate()
for s in range(len(sol)):
    if sol[s] > np.finfo(np.float32).eps:
        tmp_st += sdstate(s, sol[s])
#         print(bin(s))
print(tmp_st.norm())
tmp_st.normalize()
print(tmp_st.exp(H_cas))

FCI Energy: -12.17213924421073
(0.7602294488916079+0j)
-12.145338224562707


In [15]:
def in_orbs(term, orbs):
    """Return if the term is a local excitation operator within orbs"""
    if len(term) == 2:
        return term[0][0] in orbs and term[1][0] in orbs
    elif len(term) == 4:
        return term[0][0] in orbs and term[1][0] in orbs and term[2][0] in orbs and term[3][0] in orbs
    return False

def transform_orbs(term, orbs):
    """Transform the operator term to align the orbs starting from 0"""
#     pass
    if len(term) == 2:
        return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0))
    if len(term) == 4:
        return ((orbs.index(term[0][0]), 1), (orbs.index(term[1][0]), 0), 
               (orbs.index(term[2][0]), 1), (orbs.index(term[3][0]), 0))   
    return None

def solve_enums(H_cas, k):
    """Solve for number of electrons in each CAS block with FCI within the block""" 
    e_nums = []
    states = []
    E_cas = 0
    for orbs in k:
        tmp = of.FermionOperator()
        for t in H_cas.terms:
            if in_orbs(t, orbs):
                tmp += of.FermionOperator(transform_orbs(t, orbs), H_cas.terms[t])
        sparse_H_tmp = of.get_sparse_operator(tmp)
        tmp_E_min, t_sol = of.get_ground_state(sparse_H_tmp)
        print(f"E_min: {tmp_E_min} for orbs: {orbs}")
        ne = -1
        st = sdstate(n_qubit = len(orbs))
        for i in range(len(t_sol)):
            if np.linalg.norm(t_sol[i]) > np.finfo(np.float32).eps:
                st += sdstate(s = i, coeff = t_sol[i])
        print(f"state norm: {st.norm()}")
        st.normalize()
        E_st = st.exp(tmp)
        E_cas += E_st
        print(f"current state Energy: {E_st}")
        states.append(st)
        for s in st.dic:
            ne = bin(s)[2:].count('1')
            e_nums.append(ne)
            break
    return e_nums, states, E_cas
            
e_nums, states, E_cas = solve_enums(H_cas, k)
print(e_nums)
print(f"E_cas: {E_cas}")
# for s in states:
#     print(s)

E_min: -8.929277154856539 for orbs: [0, 1, 2, 3, 4, 5]
state norm: (1+0j)
current state Energy: -8.92927715485656
E_min: -3.2428620893541313 for orbs: [6, 7, 8, 9, 10, 11]
state norm: (1+0j)
current state Energy: -3.2428620893541225
[4, 4]
E_cas: -12.172139244210683


In [7]:
# Checking the full solution energy in sdstate:
sd_sol = sdstate()
for st in states:
    sd_sol = sd_sol.concatenate(st)
# print(sd_sol.n_qubit)
E_sol = sd_sol.exp(H_cas)
print(f"solution energy: {E_sol}")

solution energy: -12.172139244210667


In [8]:
# Checking if UHU* still commutes with Sz and S^2, result: Does not commute in general
# rotation_x = np.random.rand(upnum)
rotation_x = np.ones(upnum)
# rotation_x = np.zeros(upnum)
x = np.concatenate((cas_x,rotation_x))
# Hidden CAS_tbt
CAS_hidden = csau.sum_cartans(x, spin_orbs, k, 1, False)
# Transform into FemionOperator
H_hidden = feru.get_ferm_op(CAS_hidden, spin_orb=True)
print(of.normal_ordered(of.commutator(Sz, H_hidden)) == of.FermionOperator.zero())
print(of.normal_ordered(of.commutator(S2, H_hidden)) == of.FermionOperator.zero())

False
False


In [12]:
def construct_killer(k, e_num, n = 0, const = 1e-2, t = 2, n_killer = 3):
    """ Construct a killer operator for CAS Hamiltonian, based on cas block structure of k and the size of killer is 
    given in k, the number of electrons in each CAS block of the ground state
    is specified by e_nums. t is the strength of quadratic balancing terms for the killer with respect to k,
    n_killer specifies the number of operators O to choose.
    """
    if not n:
        n = max([max(orbs) for orbs in k])
    killer = of.FermionOperator.zero()
    for i in range(len(k)):
        orbs = k[i]
        outside_orbs = [j for j in range(n) if j not in orbs]
    #     Construct Ne
        Ne = sum([of.FermionOperator("{}^ {}".format(i, i)) for i in orbs])
    #     Construct O, for O as combination of Epq which preserves Sz and S2
        if len(outside_orbs) >= 4:
            tmp = 0
            while tmp < n_killer:
                p, q = random.sample(outside_orbs, 2)
                if abs(p - q) > 1:
#                     Constructing symmetry conserved killers
                    O = of.FermionOperator.zero()
                    if p % 2 != 0:
                        p -= 1
                    if q % 2 != 0:
                        q -= 1
                    O += of.FermionOperator("{}^ {}".format(p, q)) + of.FermionOperator("{}^ {}".format(q, p))
                    O += of.FermionOperator("{}^ {}".format(p + 1, q + 1)) + of.FermionOperator("{}^ {}".format(q + 1, p + 1))
                    killer += (1 + np.random.rand()) * const * O * (Ne - e_nums[i])
                    tmp += 1
        killer += t * (1 + np.random.rand()) * const * ((Ne - e_nums[i]) ** 2)
    return killer
cas_killer = construct_killer(k, e_nums, n = spin_orbs)
assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(Sz, cas_killer)), "Killer broke Sz symmetry"
assert of.FermionOperator.zero() == of.normal_ordered(of.commutator(S2, cas_killer)), "S2 symmetry broken"

In [14]:
# Checking: if killer does not change ground state
sparse_with_killer = of.get_sparse_operator(cas_killer + H_cas)
killer_Emin, killer_sol = of.get_ground_state(sparse_with_killer)
killer_E_sol = sd_sol.exp(H_cas + cas_killer)
print(f"Solution Energy with killer: {killer_E_sol}")
# print(np.linalg.norm(sol @ sparse_with_killer @ sol))
# print(killer_Emin - E_min)

Solution Energy with killer: -12.172139244210673
