In [None]:
import json
import pennylane as qml
import pennylane.numpy as np

symbols = ["H", "H", "H"]


def h3_ground_energy(bond_length):
    
    """
    Uses VQE to calculate the ground energy of the H3+ molecule with the given bond length.
    
    Args:
        - bond_length(float): The bond length of the H3+ molecule modelled as an
        equilateral triangle.
    Returns:
        - Union[float, np.tensor, np.array]: A float-like output containing the ground 
        state of the H3+ molecule with the given bond length.
    """
    
    # Adapted from https://pennylane.ai/qml/demos/tutorial_mol_geo_opt/
    # Fixed bond length, using gradient descent optimiser to find the gronud state
    
    symbols = ["H", "H", "H"]
    x = np.array([-bond_length/2, 0.0, 0.0, bond_length/2, 0.0, 0.0, 0.0, np.sqrt(3)/2*bond_length, 0.0], requires_grad=True)

    def H(x):
        return qml.qchem.molecular_hamiltonian(symbols, x, charge=1)[0]

    hf = qml.qchem.hf_state(electrons=2, orbitals=6)
    print(hf)

    num_wires = 6
    dev = qml.device("lightning.qubit", wires=num_wires)


    @qml.qnode(dev, interface="autograd")
    def circuit(params, obs, wires):
        qml.BasisState(hf, wires=wires)
        qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
        qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])

        return qml.expval(obs)
    
    def cost(params, x):
        hamiltonian = H(x)
        return circuit(params, obs=hamiltonian, wires=range(num_wires))    
    
    def finite_diff(f, x, delta=0.01):
        """Compute the central-difference finite difference of a function"""
        gradient = []

        for i in range(len(x)):
            shift = np.zeros_like(x)
            shift[i] += 0.5 * delta
            res = (f(x + shift) - f(x - shift)) * delta**-1
            gradient.append(res)

        return gradient


    def grad_x(params, x):
        grad_h = finite_diff(H, x)
        grad = [circuit(params, obs=obs, wires=range(num_wires)) for obs in grad_h]
        return np.array(grad)

    opt_theta = qml.GradientDescentOptimizer(stepsize=0.4)
    #opt_x = qml.GradientDescentOptimizer(stepsize=0.8)
    
    theta = np.array([0.0, 0.0], requires_grad=True)

    # store the values of the cost function
    energy = []

    # store the values of the bond length
    #bond_length = []

    # Factor to convert from Bohrs to Angstroms
    bohr_angs = 0.529177210903

    for n in range(20):

        # Optimize the circuit parameters
        theta.requires_grad = True
        x.requires_grad = False
        theta, _ = opt_theta.step(cost, theta, x)

        # Optimize the nuclear coordinates
        #x.requires_grad = True
        #theta.requires_grad = False
        #_, x = opt_x.step(cost, theta, x, grad_fn=grad_x)

        energy.append(cost(theta, x))
        #bond_length.append(np.linalg.norm(x[0:3] - x[3:6]) * bohr_angs)

        if n % 4 == 0:
            #print(f"Step = {n},  E = {energy[-1]:.8f} Ha,  bond length = {bond_length[-1]:.5f} A")
            print(f"Step = {n},  E = {energy[-1]:.8f} Ha,  bond length = {bond_length:.5f} A")

        # Check maximum component of the nuclear gradient
        if np.max(grad_x(theta, x)) <= 1e-05:
            break

    print("\n" f"Final value of the ground-state energy = {energy[-1]:.8f} Ha")
    print("\n" "Ground-state equilibrium geometry")
    print("%s %4s %8s %8s" % ("symbol", "x", "y", "z"))
    for i, atom in enumerate(symbols):
        print(f"  {atom}    {x[3 * i]:.4f}   {x[3 * i + 1]:.4f}   {x[3 * i + 2]:.4f}")
        
    return energy[-1]

# These functions are responsible for testing the solution.

def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    outs = h3_ground_energy(ins)
    return str(outs)


def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    assert np.allclose(solution_output,expected_output, atol = 1e-4), "Not the correct ground energy"


# These are the public test cases
test_cases = [
    ('1.5', '-1.232574'),
    ('0.8', '-0.3770325')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")