In [2]:
import pickle

n_reps = 4

filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient.pickle"
with open(filepath, "rb") as f:
    data_gradient = pickle.load(f)


filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient-t2_dagger.pickle"
with open(filepath, "rb") as f:
    data_gradient_dagger = pickle.load(f)

filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient_multi_stage.pickle"
with open(filepath, "rb") as f:
    data_gradient_multi_stage = pickle.load(f)



In [3]:
print('energy')
print(f"data_gradient: {data_gradient['energy']}")
print(f"data_gradient_dagger: {data_gradient_dagger['energy']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['energy']}")

energy
data_gradient: -108.575632043211
data_gradient_dagger: -108.4639160716306
data_gradient_gradient_multi_stage: -106.39931406103938


In [4]:
print('error')
print(f"data_gradient: {data_gradient['error']}")
print(f"data_gradient_dagger: {data_gradient_dagger['error']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['error']}")

error
data_gradient: 0.02035530776394978
data_gradient_dagger: 0.13207127934434482
data_gradient_gradient_multi_stage: 2.1966732899355748


In [5]:
print('final_loss')
print(f"data_gradient: {data_gradient['final_loss']}")
print(f"data_gradient_dagger: {data_gradient_dagger['final_loss']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['final_loss']}")

final_loss
data_gradient: 0.0038108772820247333
data_gradient_dagger: 0.05763456970453262
data_gradient_gradient_multi_stage: 5.9841466281795874e-05


In [6]:
operator_gradient = data_gradient['operator']
operator_gradient_dagger= data_gradient_dagger['operator']
operator_gradient_multi_stage= data_gradient_multi_stage['operator']

In [7]:
from opt_einsum import contract
nocc = 5

def fun(diag_coulomb_mats, orbital_rotations):
    reconstructed = (
            1j
            * contract(
                "mpq,map,mip,mbq,mjq->ijab",
                diag_coulomb_mats,
                orbital_rotations,
                orbital_rotations.conj(),
                orbital_rotations,
                orbital_rotations.conj(),
                # optimize="greedy"
            )[:nocc, :nocc, nocc:, nocc:]
        )
    return reconstructed

In [8]:
molecule_name = "n2"
basis = "sto-6g"
nelectron, norb = 10, 8


molecule_basename = f"{molecule_name}_{basis}_{nelectron}e{norb}o"

bond_distance = 1.0

from molecules_catalog.util import load_molecular_data
from pathlib import Path
import os
from ffsim.variational.util import interaction_pairs_spin_balanced

# Get molecular data and molecular Hamiltonian
molecules_catalog_dir = "../../molecules-catalog"

mol_data = load_molecular_data(
    f"{molecule_basename}_d-{bond_distance:.5f}",
    molecules_catalog_dir=molecules_catalog_dir,
)
norb = mol_data.norb
nelec = mol_data.nelec
t2 = mol_data.ccsd_t2

In [9]:
import numpy as np
diag_coulomb_mats_gradient, _ = np.unstack(operator_gradient.diag_coulomb_mats, axis=1)
reconstrcuted_operator_gradient = fun(diag_coulomb_mats_gradient, operator_gradient.orbital_rotations)

print(f"t2 norm: {np.linalg.norm(t2)}")
print()
print(np.linalg.norm(reconstrcuted_operator_gradient))
print(np.max(np.abs((reconstrcuted_operator_gradient))))
print()
diag_coulomb_mats_gradient_multi_stage, _ = np.unstack(operator_gradient_multi_stage.diag_coulomb_mats, axis=1)
reconstrcuted_operator_gradient_multi_stage = fun(diag_coulomb_mats_gradient_multi_stage, operator_gradient_multi_stage.orbital_rotations)

print(np.linalg.norm(reconstrcuted_operator_gradient_multi_stage))
print(np.max(np.abs((reconstrcuted_operator_gradient_multi_stage))))
diff = reconstrcuted_operator_gradient_multi_stage - t2
loss = 0.5 * np.sum(np.abs(diff) ** 2)
print(f"loss: {loss}")


t2 norm: 0.2325443258525995

0.21553559262736588
0.11822834762201273

0.23212473501833156
0.11462342079201747
loss: 5.9841298593169374e-05


In [11]:
# Compute final state
import ffsim
reference_state = ffsim.hartree_fock_state(norb, nelec)
final_state_gradient = ffsim.apply_unitary(reference_state, operator_gradient, norb=norb, nelec=nelec)
final_state_gradient_dagger = ffsim.apply_unitary(reference_state, operator_gradient_dagger, norb=norb, nelec=nelec)
final_state_multi_stage = ffsim.apply_unitary(reference_state, operator_gradient_multi_stage, norb=norb, nelec=nelec)

In [14]:
# Run SQD
from qiskit.primitives import BitArray
from qiskit_addon_sqd.fermion import diagonalize_fermionic_hamiltonian, solve_sci_batch

shots = 100_000
samples_per_batch = 100
n_batches = 3
energy_tol = 1e-5
occupancies_tol = 1e-3
max_iterations = 100
symmetrize_spin = True
carryover_threshold = 1e-3
entropy = 0

mol_ham = mol_data.hamiltonian

final_states = [final_state_gradient, final_state_gradient_dagger, final_state_multi_stage]
names = ['gradient', 'gradient_dagger', 'multi_stage']
for final_state, name in zip(final_states, names):
    rng = np.random.default_rng(entropy)
    samples = ffsim.sample_state_vector(
        final_state,
        norb=norb,
        nelec=nelec,
        shots=shots,
        seed=rng,
        bitstring_type=ffsim.BitstringType.INT,
    )
    bit_array = BitArray.from_samples(samples, num_bits=2 * norb)
    result = diagonalize_fermionic_hamiltonian(
        mol_ham.one_body_tensor,
        mol_ham.two_body_tensor,
        bit_array,
        samples_per_batch=samples_per_batch,
        norb=norb,
        nelec=nelec,
        num_batches=n_batches,
        energy_tol=energy_tol,
        occupancies_tol=occupancies_tol,
        max_iterations=max_iterations,
        sci_solver=solve_sci_batch,
        symmetrize_spin=symmetrize_spin,
        carryover_threshold=carryover_threshold,
        seed=rng,
    )
    energy = result.energy + mol_data.core_energy
    sci_state = result.sci_state
    spin_squared = sci_state.spin_square()
    error = energy - mol_data.fci_energy
    print(f"sqd energy loss from {name}: {error}")

sqd energy loss from gradient: 0.01234564509971392
sqd energy loss from gradient_dagger: 0.018334499964836937
sqd energy loss from multi_stage: 3.355100182034221e-05


In [None]:
import numpy as np
import jax.numpy as jnp
import jax

indices = jnp.arange(2)
indices = [[i,j]]

def fun(a):
    b = jnp.zeros((4, 4, 4, 4))
    # b[:, :2, :2] = a
    b = b.at[indices].set(a)
    # b = jnp.concatenate([b,b,b,b])
    print(b)
    return jnp.sum(b)

a = jnp.array([[0.0, 1.0], [2.0, 3.0]])

value_and_grad_func = jax.value_and_grad(fun)

value_and_grad_func(a)


Traced<float64[4,4,4,4]>with<JVPTrace> with
  primal = Array([[[[0., 1., 0., 0.],
         [2., 3., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [2., 3., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]],


       [[[0., 1., 0., 0.],
         [2., 3., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [2., 3., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]],


       [[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
       

(Array(24., dtype=float64),
 Array([[4., 4.],
        [4., 4.]], dtype=float64))