# Using Q-Alchemy for Quantum Chemistry ab initio calculations

This is a wonderful [blog post](https://pennylane.ai/qml/demos/tutorial_initial_state_preparation) by Stephan Fomichev from Xanadu Inc. which we
reproduce here but with using Q-Alchemy.

First, we create $H_3$ as a molecule

In [None]:
from pyscf import gto, scf, ci
from pennylane.qchem import import_state
import numpy as np

R = 1.2
# create the H3+ molecule
mol = gto.M(atom=[["H", (0, 0, 0)], ["H", (0, 0, R)], ["H", (0, 0, 2 * R)]], charge=1)
# perform restricted Hartree-Fock and then CISD
myhf = scf.RHF(mol).run()
myci = ci.CISD(myhf).run()
wf_hf = import_state(myhf, tol=1e-1)
wf_cisd = import_state(myci, tol=1e-1)
f"CISD-based state vector: \n{np.round(wf_cisd.real, 4)}"

We can now use this state to speed up VQE

In [None]:
import pennylane as qml
from pennylane import qchem
from jax import numpy as jnp

# generate the molecular Hamiltonian for H3+
symbols = ["H", "H", "H"]
geometry = jnp.array([[0, 0, 0], [0, 0, R/0.529], [0, 0, 2*R/0.529]])
molecule = qchem.Molecule(symbols, geometry, charge=1)

H2mol, qubits = qchem.molecular_hamiltonian(molecule)
wires = list(range(qubits))
dev = qml.device("default.qubit", wires=qubits)

# create all possible excitations in H3+
singles, doubles = qchem.excitations(2, qubits)
excitations = singles + doubles

In [None]:
@qml.qnode(dev)
def circuit_VQE(theta, initial_state):
    qml.StatePrep(initial_state, wires=wires)
    for i, excitation in enumerate(excitations):
        if len(excitation) == 4:
            qml.DoubleExcitation(theta[i], wires=excitation)
        else:
            qml.SingleExcitation(theta[i], wires=excitation)
    return qml.expval(H2mol)

def cost_fn_hf(param):
    return circuit_VQE(param, initial_state=wf_hf)

def cost_fn_cisd(param):
    return circuit_VQE(param, initial_state=wf_cisd)

In [None]:
import optax
import jax
jax.config.update("jax_enable_x64", True)

cost_fn = cost_fn_hf

opt = optax.sgd(learning_rate=0.4)  # sgd stands for StochasticGradientDescent
theta = jnp.array(jnp.zeros(len(excitations)))
delta_E, iteration = 10, 0
results_hf = []
opt_state = opt.init(theta)
prev_energy = cost_fn(theta)

# run the VQE optimization loop until convergence threshold is reached
while abs(delta_E) > 1e-5:
    gradient = jax.grad(cost_fn)(theta)
    updates, opt_state = opt.update(gradient, opt_state)
    theta = optax.apply_updates(theta, updates)
    new_energy = cost_fn(theta)
    delta_E = new_energy - prev_energy
    prev_energy = new_energy
    results_hf.append(new_energy)
    if len(results_hf) % 5 == 0:
        print(f"Step = {len(results_hf)},  Energy = {new_energy:.6f} Ha")
print(f"Starting with HF state took {len(results_hf)} iterations until convergence.")