# In this notebook file, we will test each function of the CAS preparation

In [56]:
import CAS_Cropping.planted_solutions_workflow as planted_workflow
from pathlib import Path
import CAS_Cropping.ferm_utils as feru
import CAS_Cropping.csa_utils as csau
import CAS_Cropping.var_utils as varu
import openfermion as of
import numpy as np
from CAS_Cropping.sdstate import *
from itertools import product
import random
from pathlib import Path
from itertools import product
import h5py
import sys
from CAS_Cropping.matrix_utils import construct_orthogonal
import pickle
import CAS.dmrghandler.src.dmrghandler.pyscf_wrappers
import CAS.dmrghandler.src.dmrghandler.dmrg_calc_prepare
import CAS.dmrghandler.src.dmrghandler.qchem_dmrg_calc
import CAS.dmrghandler.src.dmrghandler as dmrghandler
import tensorflow as tf
import pyscf.tools.fcidump
import scipy
import pandas as pd
import os

### Parameters
tol = 1e-5
balance_strength = 2
save = False
# Number of spatial orbitals in a block
block_size = 3
# Number of electrons per block
ne_per_block = 6
# +- difference in number of electrons per block
ne_range = 0
# Number of killer operators for each CAS block
n_killer = 3
# Running Full CI to check compute the ground state, takes exponentially amount of time to execute
FCI = False
# Checking symmetries of the planted Hamiltonian, very costly
check_symmetry = False
# Concatenate the states in each block to compute the ground state solution
check_state = False

In [78]:
fcidump_filename = "test_fcidumps/fcidump.2_co2_6-311++G__"
fcidump_output_path = "fcidumps_catalysts_planted_solutions"

In [82]:
(
one_body_tensor,
two_body_tensor,
nuc_rep_energy,
num_orbitals,
num_spin_orbitals,
num_electrons,
two_S,
two_Sz,
orb_sym,
extra_attributes,
) = dmrghandler.dmrg_calc_prepare.load_tensors_from_fcidump(data_file_path=fcidump_filename, molpro_orbsym_convention=True)
spin_orbs = 2*num_orbitals
spatial_orbs = num_orbitals
print(f"Spatial orbs: {spatial_orbs}, two_S: {two_S}, two_Sz: {two_Sz}, # of electrons: {num_electrons}")

Parsing test_fcidumps/fcidump.2_co2_6-311++G__
Spatial orbs: 6, two_S: 0, two_Sz: 0, # of electrons: 8


In [84]:
assert one_body_tensor.shape[0] == spatial_orbs, "Tensor not in spatial orbital"

In [59]:
Hobt = one_body_tensor
Htbt = two_body_tensor
Hobt -= 0.5 * np.einsum("prrq->pq", Htbt.copy())
Htbt *= 0.5
H = (Hobt, Htbt)

### Test block construction

In [60]:
def test_block_construction(size, orbs):
    blocks = planted_workflow.construct_blocks(size, orbs, False)
    num_same_blocks = orbs // size
    last_block_size = orbs - num_same_blocks * size
    block_count = 0
    for i in range(len(blocks)):
        if i != len(blocks) - 1:
            block_count += 1
            assert len(blocks[i]) == size, "Block size mismatch"
        else:
            assert len(blocks[i]) == size or len(blocks[i]) == last_block_size, "Last block size mismatch"

test_block_construction(block_size, spatial_orbs)
k = planted_workflow.construct_blocks(block_size, spatial_orbs)

### Test getting param num

In [61]:
def test_get_param_num(spatial_orbs, k):
    upnum, casnum, pnum = planted_workflow.get_param_num(spatial_orbs, k, complex=False)
    block_size = len(k[0])
    assert casnum % (block_size ** 2) == 0, "parameter num failed"
    assert upnum == int(spatial_orbs * (spatial_orbs - 1) / 2), "Upper triangular parameters num failed"
test_get_param_num(spatial_orbs, k)
upnum, casnum, pnum = planted_workflow.get_param_num(spatial_orbs, k, complex=False)

### Test getting truncated CAS tbt

In [62]:
def test_truncated_cas_tbt(H, blocks: list[list[int]], casnum):
    cas_obt, cas_tbt, cas_x = planted_workflow.get_truncated_cas_tbt(H, blocks, casnum)
    for block_pair in product(blocks, blocks):
        if block_pair[0] != block_pair[1]:
            indices1 = block_pair[0]
            indices2 = block_pair[1]
            for pair in product(indices1, indices2):
                assert cas_obt[pair[0], pair[1]] == 0, "Truncation failed"
                assert cas_tbt[pair[0], pair[0], pair[1], pair[1]] == 0, "Truncation failed"

test_truncated_cas_tbt(H, k, casnum)
cas_obt, cas_tbt, cas_x = planted_workflow.get_truncated_cas_tbt(H, k, casnum)

planted_sol = {}
H_cas = [cas_obt, cas_tbt]
cas_obt_copy = cas_obt.copy()
cas_tbt_copy = cas_tbt.copy()

### Test solve nums

In [63]:
def test_solve_nums(H_cas, k, num_electrons, ne_per_block = ne_per_block,
                                    ne_range = ne_range, balance_t = balance_strength):
    e_nums, states, E_cas = planted_workflow.solve_enums(H_cas, k, num_electrons, ne_per_block = ne_per_block,
                                    ne_range = ne_range, balance_t = balance_strength)
    
    num_of_blocks = num_electrons // ne_per_block
    remaining = num_electrons - num_of_blocks * ne_per_block
    for i in range(len(e_nums)):
        if i != len(e_nums)-1:
            assert e_nums[i] == ne_per_block, "# of electrons in each block is wrong"
        else:
            assert e_nums[i] == ne_per_block or e_nums[i] == remaining, "# of electrons in each block is wrong"
    
    return e_nums, states, E_cas


def test_tensor_changes(obt, obt_copy, tbt, tbt_copy):
    if obt.all() == obt_copy.all():
        assert "Tensor hasn't been modified"
    if  tbt.all() == tbt_copy.all():
        assert "Tensor hasn't been modified"
    
    obt_diff = obt - obt_copy
    tbt_diff = tbt - tbt_copy
    assert scipy.linalg.norm(obt_diff) > 1, "Tensor change failed" 
    assert scipy.linalg.norm(tbt_diff) > 1, "Tensor change failed" 
    
    # We also want to test only the diagonals were modified.
    indices = [i for i in range(obt.shape[0])]
    
    for pair in product(indices, indices):
        if pair[0] != pair[1]:
            assert obt_diff[pair[0], pair[1]] == 0, "Non diagonal terms modified"
        if pair[0] == pair[1]:
            assert obt_diff[pair[0], pair[1]] != 0, "Diagonal terms not modified"
    
    count_non_zero_in_tbt = 0
    for pair in product(indices, indices, indices, indices):
        if pair[0] == pair[1] and pair[2] == pair[3]:
            if tbt_diff[pair[0], pair[1], pair[2], pair[3]] != 0:
                count_non_zero_in_tbt += 1
        else:
            assert tbt_diff[pair[0], pair[1], pair[2], pair[3]] == 0, "Non diagonal terms modified"
    
    assert count_non_zero_in_tbt > 0, "Diagonal terms in blocks have been modified"
    
    
e_nums, states, E_cas = test_solve_nums(H_cas, k, num_electrons, ne_per_block = ne_per_block,
                                    ne_range = ne_range, balance_t = balance_strength)

test_tensor_changes(cas_obt, cas_obt_copy, cas_tbt, cas_tbt_copy)

Ne within current block: 6
E_min: -294.46613886008566 for orbs: [0, 1, 2]
current state Energy: -294.4661388600851
Ne within current block: 2
E_min: -35.48504756736056 for orbs: [3, 4, 5]
current state Energy: -35.485047567360596


In [None]:
print(f"e_nums:{e_nums}")
print(f"E_cas: {E_cas}")
sd_sol = sdstate()
obt_2 = copy.deepcopy(cas_obt)
tbt_2 = copy.deepcopy(cas_tbt)

### Test killer term construction

In [76]:
def test_killer_construction(k, e_nums, n = spatial_orbs, n_killer = n_killer):
    killer_c, killer_obt, killer_tbt = planted_workflow.construct_killer(k, e_nums, n = spatial_orbs, n_killer = n_killer)
    for pair in killer_obt.indices.numpy():
        assert pair[0] == pair[1], "Killer obt not diagonal"
    
    for pair in killer_tbt.indices.numpy():
        assert pair[0] == pair[1] and pair[2] == pair[3], "Killer tbt not diagonal"
    
test_killer_construction(k, e_nums, n = spatial_orbs, n_killer = n_killer)
killer_c, killer_obt, killer_tbt = planted_workflow.construct_killer(k, e_nums, n = spatial_orbs, n_killer = n_killer)


### Test temporary .pkl file saving

In [80]:
planted_sol["E_min"] = E_cas
planted_sol["e_nums"] = e_nums
planted_sol["killer"] = (killer_c, killer_obt, killer_tbt)
planted_sol["k"] = k
planted_sol["casnum"] = casnum
planted_sol["pnum"] = pnum
planted_sol["upnum"] = upnum
planted_sol["spatial_orbs"] = spatial_orbs
planted_sol["cas_x"] = cas_x
planted_sol["sol"] = states
planted_sol["cas_obt"] = cas_obt
planted_sol["cas_tbt"] = cas_tbt

if check_state:
    planted_sol["solution"] = sd_sol

ps_path = "test_cas_planted_solutions/"
f_name = fcidump_filename.split(".")[1] + ".pkl"
print(ps_path +f_name)

l = list(map(len, k))
l = list(map(str, l))
key = "-".join(l)
print(key)
if os.path.exists(ps_path + f_name):
    with open(ps_path + f_name, 'rb') as handle:
        dic = pickle.load(handle)
else:
    dic = {}

with open(ps_path + f_name, 'wb') as handle:
    dic[key] = planted_sol
    pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)

test_cas_planted_solutions/2_co2_6-311++G__.pkl
3-3


### Test loading Hamiltonians

In [85]:
def test_spatial_chem_to_phy(obt, tbt):
    phy_obt, phy_tbt = planted_workflow.chem_spatial_orb_to_phys_spatial_orb(obt, tbt)
    assert np.allclose(phy_tbt / 2, tbt), "chem to phy indices wrong"

In [90]:
def test_load_Hamiltonians(ps_path, f_name):
    U, H_cas, H_hidden, H_with_killer, H_killer_hidden, sol, E_min = planted_workflow.load_Hamiltonian_with_solution(ps_path, f_name)
    tbt_1_H_ij, tbt_1_G_ijkl = planted_workflow.chem_spatial_orb_to_phys_spatial_orb(H_cas[1], H_cas[2])
    test_spatial_chem_to_phy(H_cas[1], H_cas[2])
    tbt_1_hidden_H_ij, tbt_1_hidden_G_ijkl = planted_workflow.chem_spatial_orb_to_phys_spatial_orb(H_hidden[1], H_hidden[2])
    
    assert np.allclose(np.linalg.norm(tbt_1_H_ij), np.linalg.norm(tbt_1_hidden_H_ij)), "unitary rotation wrong"
    assert np.allclose(np.linalg.norm(tbt_1_G_ijkl), np.linalg.norm(tbt_1_hidden_G_ijkl)), "unitary rotation wrong"
    tbt_3_H_ij, tbt_3_G_ijkl = planted_workflow.chem_spatial_orb_to_phys_spatial_orb(H_with_killer[1], H_with_killer[2])
    tbt_3_H_ij = np.float64(np.real(tbt_3_H_ij.numpy()))
    tbt_3_G_ijkl = np.float64(np.real(tbt_3_G_ijkl.numpy()))
    tbt_3_hidden_H_ij, tbt_3_hidden_G_ijkl = planted_workflow.chem_spatial_orb_to_phys_spatial_orb(H_killer_hidden[1], H_killer_hidden[2])
    tbt_3_hidden_H_ij = np.float64(np.real(tbt_3_hidden_H_ij))
    tbt_3_hidden_G_ijkl = np.float64(np.real(tbt_3_hidden_G_ijkl))
    
    assert np.allclose(np.linalg.norm(tbt_3_H_ij), np.linalg.norm(tbt_3_hidden_H_ij)), "unitary rotation wrong"
    assert np.allclose(np.linalg.norm(tbt_3_G_ijkl), np.linalg.norm(tbt_3_hidden_G_ijkl)), "unitary rotation wrong"
    
test_load_Hamiltonians(ps_path, f_name)
U, H_cas, H_hidden, H_with_killer, H_killer_hidden, sol, E_min = planted_workflow.load_Hamiltonian_with_solution(ps_path, f_name)
