In [1]:
import pickle

n_reps = 4

filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient.pickle"
with open(filepath, "rb") as f:
    data_gradient = pickle.load(f)


filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient-t2_dagger.pickle"
with open(filepath, "rb") as f:
    data_gradient_dagger = pickle.load(f)

filepath = f"n2_sto-6g_10e8o_{n_reps}_gradient_multi_stage.pickle"
with open(filepath, "rb") as f:
    data_gradient_multi_stage = pickle.load(f)

filepath = f"n2_sto-6g_10e8o_{n_reps}.pickle"
with open(filepath, "rb") as f:
    data_numerical_gradient = pickle.load(f)

filepath = f"n2_sto-6g_10e8o_{n_reps}-t2_dagger.pickle"
with open(filepath, "rb") as f:
    data_numerical_gradient_dagger = pickle.load(f)


In [2]:
print('energy')
print(f"data_gradient: {data_gradient['energy']}")
print(f"data_gradient_dagger: {data_gradient_dagger['energy']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['energy']}")
print(f"data_numerical_gradient: {data_numerical_gradient['energy']}")
print(f"data_numerical_gradient_dagger: {data_numerical_gradient_dagger['energy']}")

energy
data_gradient: -108.575632043211
data_gradient_dagger: -108.4639160716306
data_gradient_gradient_multi_stage: -106.39931406103938
data_numerical_gradient: -108.53675524617424
data_numerical_gradient_dagger: -108.4647527312714


In [3]:
print('error')
print(f"data_gradient: {data_gradient['error']}")
print(f"data_gradient_dagger: {data_gradient_dagger['error']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['error']}")
print(f"data_numerical_gradient: {data_numerical_gradient['error']}")
print(f"data_numerical_gradient_dagger: {data_numerical_gradient_dagger['error']}")

error
data_gradient: 0.02035530776394978
data_gradient_dagger: 0.13207127934434482
data_gradient_gradient_multi_stage: 2.1966732899355748
data_numerical_gradient: 0.059232104800713614
data_numerical_gradient_dagger: 0.1312346197035481


In [4]:
print('final_loss')
print(f"data_gradient: {data_gradient['final_loss']}")
print(f"data_gradient_dagger: {data_gradient_dagger['final_loss']}")
print(f"data_gradient_gradient_multi_stage: {data_gradient_multi_stage['final_loss']}")
print(f"data_numerical_gradient: {data_numerical_gradient['final_loss']}")
print(f"data_numerical_gradient_dagger: {data_numerical_gradient_dagger['final_loss']}")

final_loss
data_gradient: 0.0038108772820247333
data_gradient_dagger: 0.05763456970453262
data_gradient_gradient_multi_stage: 5.9841466281795874e-05
data_numerical_gradient: 0.0030629563885514813
data_numerical_gradient_dagger: 0.05407686477092247


In [5]:
operator_gradient = data_gradient['operator']
operator_numerical_gradient = data_numerical_gradient['operator']
operator_gradient_multi_stage= data_gradient_multi_stage['operator']
print(operator_gradient.diag_coulomb_mats.shape == operator_numerical_gradient.diag_coulomb_mats.shape)
print(operator_gradient.orbital_rotations.shape == operator_numerical_gradient.orbital_rotations.shape)
print(operator_gradient.orbital_rotations.shape)
print(operator_gradient.diag_coulomb_mats.shape)

True
True
(4, 8, 8)
(4, 2, 8, 8)


In [6]:
from opt_einsum import contract
nocc = 5

def fun(diag_coulomb_mats, orbital_rotations):
    reconstructed = (
            1j
            * contract(
                "mpq,map,mip,mbq,mjq->ijab",
                diag_coulomb_mats,
                orbital_rotations,
                orbital_rotations.conj(),
                orbital_rotations,
                orbital_rotations.conj(),
                # optimize="greedy"
            )[:nocc, :nocc, nocc:, nocc:]
        )
    return reconstructed

In [7]:
molecule_name = "n2"
basis = "sto-6g"
nelectron, norb = 10, 8


molecule_basename = f"{molecule_name}_{basis}_{nelectron}e{norb}o"

bond_distance = 1.0

from molecules_catalog.util import load_molecular_data
from pathlib import Path
import os
from ffsim.variational.util import interaction_pairs_spin_balanced

# Get molecular data and molecular Hamiltonian
molecules_catalog_dir = "../../molecules-catalog"

mol_data = load_molecular_data(
    f"{molecule_basename}_d-{bond_distance:.5f}",
    molecules_catalog_dir=molecules_catalog_dir,
)
norb = mol_data.norb
nelec = mol_data.nelec
t2 = mol_data.ccsd_t2

In [8]:
import numpy as np
diag_coulomb_mats_gradient, _ = np.unstack(operator_gradient.diag_coulomb_mats, axis=1)
reconstrcuted_operator_gradient = fun(diag_coulomb_mats_gradient, operator_gradient.orbital_rotations)

print(f"t2 norm: {np.linalg.norm(t2)}")
print()
print(np.linalg.norm(reconstrcuted_operator_gradient))
print(np.max(np.abs((reconstrcuted_operator_gradient))))
print()
diag_coulomb_mats_gradient_multi_stage, _ = np.unstack(operator_gradient_multi_stage.diag_coulomb_mats, axis=1)
reconstrcuted_operator_gradient_multi_stage = fun(diag_coulomb_mats_gradient_multi_stage, operator_gradient_multi_stage.orbital_rotations)

print(np.linalg.norm(reconstrcuted_operator_gradient_multi_stage))
print(np.max(np.abs((reconstrcuted_operator_gradient_multi_stage))))
diff = reconstrcuted_operator_gradient_multi_stage - t2
loss = 0.5 * np.sum(np.abs(diff) ** 2)
print(f"loss: {loss}")

diag_coulomb_mats_numerical_gradient, _ = np.unstack(operator_numerical_gradient.diag_coulomb_mats, axis=1)
reconstrcuted_operator_numerical_gradient = fun(diag_coulomb_mats_numerical_gradient, operator_numerical_gradient.orbital_rotations)

print(np.linalg.norm(reconstrcuted_operator_numerical_gradient))
print(np.max(np.abs((reconstrcuted_operator_numerical_gradient))))

t2 norm: 0.2325443258525995

0.21553559262736588
0.11822834762201273

0.23212473501833156
0.11462342079201747
loss: 5.9841298593169374e-05
0.22374107158110804
0.11935705541379081


In [25]:
import numpy as np
import jax.numpy as jnp
import jax

indices_0 = [[i,i,j,j] for i in range(2) for j in range(2)]
indices_1 = [[i,j] for i in range(2) for j in range(2)]

print(indices_0)

zipped_lists = list(zip(*indices_0))

list1 = [sublist[0] for sublist in indices_0]  # Extract elements at index 0
list2 = [sublist[1] for sublist in indices_0]  # Extract elements at index 1
list3 = [sublist[2] for sublist in indices_0]  # Extract elements at index 1
list4 = [sublist[3] for sublist in indices_0]  # Extract elements at index 1

indices = np.arange(2)

a = np.array([1.0, 2.0, 3.0, 4.0])
b = np.zeros((2, 2, 2, 2))

b[zipped_lists[0], zipped_lists[1], zipped_lists[2], zipped_lists[3]] = a

print(b)

print(b[0][0][1][1])
# def fun(a):    
#     # b[:, :2, :2] = a
#     b = b.at[indices].set(a)
#     # b = jnp.concatenate([b,b,b,b])
#     print(b)
#     return jnp.sum(b)



# value_and_grad_func = jax.value_and_grad(fun)

# value_and_grad_func(a)

[[0, 0, 0, 0], [0, 0, 1, 1], [1, 1, 0, 0], [1, 1, 1, 1]]
[[[[1. 0.]
   [0. 2.]]

  [[0. 0.]
   [0. 0.]]]


 [[[0. 0.]
   [0. 0.]]

  [[3. 0.]
   [0. 4.]]]]
2.0
