In [1]:
# Use Braket SDK Cost Tracking to estimate the cost to run this example
from braket.tracking import Tracker
t = Tracker().start()

In [2]:
from braket.devices import Devices
import pennylane as qml
from pennylane import qchem
from pennylane import numpy as np
import time

In [3]:
n_electrons = 8
symbols, coordinates = qchem.read_structure('qchem/co2.xyz')
# suppress a HDF5 warning
import warnings
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    H, qubits = qchem.molecular_hamiltonian(symbols,
                                            coordinates,
                                            method="pyscf",
                                            active_electrons=n_electrons,
                                            name="co2")
print(f"Number of qubits: {qubits}")

Number of qubits: 16


In [4]:
# Hartree-Fock state
hf_state = qchem.hf_state(n_electrons, qubits)
# generate single- and double-excitations
singles, doubles = qchem.excitations(n_electrons, qubits)
print("Number of single excitations: ", len(singles))
print("Number of double excitations: ", len(doubles))

Number of single excitations:  32
Number of double excitations:  328


# Adaptive VQE

In general, we do not know which excitations are important (contribute a lot to the final energy) and which are not (contribute little). [Grimsley et al.](https://www.nature.com/articles/s41467-019-10988-2) developed the ADAPT-VQE algorithm, which allows us to perform this desired filtering. The steps to be followed are:
  1. Compute derivatives with respect to all `doubles` excitations
  2. Filter out all `doubles` with derivatives below some cutoff
  3. Optimize the remaining `doubles` excitations
  4. Compute derivatives with respect to all `singles` excitatations, keeping the filtered-and-optimized `doubles` fixed
  5. Filter out all `singles` with derivatives below some cutoff
  6. Optimize all remaining `singles` and `doubles` excitations
  7. Compute the final energy

# Running adaptive VQE with adjoint differentiation

We'll need to set up the device to use with PennyLane. Because the qubit count is 16, this workflow is a good candidate for SV1, the Amazon Braket on-demand state vector simulator. SV1 now supports two gradient computation methods in `shots=0` (exact) mode: adjoint differentiation, available by setting `diff_method='device'`, and parameter shift, available by setting `diff_method='parameter-shift'`. As shown in [the adjoint gradient example notebook](../../braket_features/Using_The_Adjoint_Gradient_Result_Type.ipynb), the adjoint differentiation method is an execution-frugal way to compute gradients. When using `parameter-shift`, each partial derivative in the gradient requires *two* circuit executions to compute, but with the adjoint method we can compute *all* partial derivatives (and thus the entire gradient) with one circuit execution and the "back-stepping" procedure, which is similar in runtime. The adjoint method can deliver a quadratic speedup in the number of parameters, making it a great choice when the number of parameterized gates is large, as it is for our problem.

In [5]:
# set the device and differentiation method
device_arn = Devices.Amazon.SV1
dev = qml.device("braket.aws.qubit", device_arn=device_arn, wires=qubits, shots=0)
diff_method = 'device'

In [6]:
@qml.qnode(dev, diff_method=diff_method)
def circuit_1(params, excitations):
    qml.BasisState(hf_state, wires=H.wires)
    for i, excitation in enumerate(excitations):
        if len(excitation) == 4:
            qml.DoubleExcitation(params[i], wires=excitation)
        else:
            qml.SingleExcitation(params[i], wires=excitation)
    return qml.expval(H)

<div class="alert alert-block alert-warning">
<b>Caution</b> This cell may take about 30s to run on SV1.
</div>

In [7]:
circuit_gradient = qml.grad(circuit_1, argnum=0)
doubles_select = []
params = [0.0] * len(doubles)

adjoint_doubles_start = time.time()

doubles_grads = circuit_gradient(params, excitations=doubles)

adjoint_doubles_stop = time.time()
adjoint_doubles_time = adjoint_doubles_stop - adjoint_doubles_start
print(f"Time to compute all double excitation derivatives with adjoint differentiation: {adjoint_doubles_time}")

Time to compute all double excitation derivatives with adjoint differentiation: 36.21157622337341


In [8]:
doubles_select = [doubles[i] for i in range(len(doubles)) if abs(doubles_grads[i]) > 1.0e-5]
print(f"Total number of doubles {len(doubles)}")
print(f"Total number of selected doubles {len(doubles_select)}")

Total number of doubles 328
Total number of selected doubles 84


In [9]:
stepsize=0.4

opt = qml.GradientDescentOptimizer(stepsize=stepsize)
iterations = 10

params_doubles = np.zeros(len(doubles_select), requires_grad=True)

for n in range(iterations):
    print(f"Iteration {n} of doubles optimization.")
    params_doubles = opt.step(circuit_1, params_doubles, excitations=doubles_select)

Iteration 0 of doubles optimization.
Iteration 1 of doubles optimization.
Iteration 2 of doubles optimization.
Iteration 3 of doubles optimization.
Iteration 4 of doubles optimization.
Iteration 5 of doubles optimization.
Iteration 6 of doubles optimization.
Iteration 7 of doubles optimization.
Iteration 8 of doubles optimization.
Iteration 9 of doubles optimization.


In [10]:
@qml.qnode(dev, diff_method=diff_method)
def circuit_2(params, excitations, gates_select, params_select):
    qml.BasisState(hf_state, wires=H.wires)

    for i, gate in enumerate(gates_select):
        if len(gate) == 4:
            qml.DoubleExcitation(params_select[i], wires=gate)
        elif len(gate) == 2:
            qml.SingleExcitation(params_select[i], wires=gate)

    for i, gate in enumerate(excitations):
        if len(gate) == 4:
            qml.DoubleExcitation(params[i], wires=gate)
        elif len(gate) == 2:
            qml.SingleExcitation(params[i], wires=gate)

    return qml.expval(H)

In [11]:
circuit_gradient = qml.grad(circuit_2, argnum=0)
params = [0.0] * len(singles)

adjoint_singles_start = time.time()

singles_grads = circuit_gradient(params, excitations=singles, gates_select=doubles_select, params_select=params_doubles)

adjoint_singles_stop = time.time()
adjoint_singles_time = adjoint_singles_stop - adjoint_singles_start
print(f"Time to compute all singles derivatives with adjoint differentiation: {adjoint_singles_time}")

Time to compute all singles derivatives with adjoint differentiation: 25.915367364883423


In [12]:
singles_select = [singles[i] for i in range(len(singles)) if abs(singles_grads[i]) > 1.0e-5]
print(f"Total number of singles {len(singles)}")
print(f"Total number of selected singles {len(singles_select)}")

Total number of singles 32
Total number of selected singles 4


In [13]:
params = np.zeros(len(doubles_select + singles_select), requires_grad=True)
gates_select = doubles_select + singles_select

best_energy = 0.0
for n in range(iterations):
    print(f"Iteration {n} of full optimization.")
    params, energy = opt.step_and_cost(circuit_1, params, excitations=gates_select)
    best_energy=energy.numpy()


Iteration 0 of full optimization.
Iteration 1 of full optimization.
Iteration 2 of full optimization.
Iteration 3 of full optimization.
Iteration 4 of full optimization.
Iteration 5 of full optimization.
Iteration 6 of full optimization.
Iteration 7 of full optimization.
Iteration 8 of full optimization.
Iteration 9 of full optimization.


In [14]:
print(f"Best energy: {best_energy}")

Best energy: -184.90548216687063


In [15]:
diff_method = 'parameter-shift'
doubles_count = min(35, len(doubles))
doubles_ps = doubles[:doubles_count]

In [16]:
@qml.qnode(dev, diff_method=diff_method)
def circuit_1_ps_serial(params, excitations):
    qml.BasisState(hf_state, wires=H.wires)
    for i, excitation in enumerate(excitations):
        if len(excitation) == 4:
            qml.DoubleExcitation(params[i], wires=excitation)
        else:
            qml.SingleExcitation(params[i], wires=excitation)
    return qml.expval(H)

In [17]:
circuit_gradient = qml.grad(circuit_1_ps_serial, argnum=0)
params = [0.0] * doubles_count

doubles_ps_unbatched_start = time.time()

unbatched_grads = circuit_gradient(params, excitations=doubles_ps)

doubles_ps_unbatched_stop = time.time()
doubles_ps_unbatched_time = doubles_ps_unbatched_stop - doubles_ps_unbatched_start
print(f"Time to compute {doubles_count} double excitation derivatives using unbatched parameter shift: {doubles_ps_unbatched_time}")

Time to compute 35 double excitation derivatives using unbatched parameter shift: 3186.5143387317657


In [18]:
dev_parallel = qml.device(
    "braket.aws.qubit",
    device_arn=device_arn,
    wires=qubits,
    shots=0,
    parallel=True,
)
@qml.qnode(dev_parallel, diff_method=diff_method)
def circuit_1_ps_parallel(params, excitations): # must redefine due to new device
    qml.BasisState(hf_state, wires=H.wires)
    for i, excitation in enumerate(excitations):
        if len(excitation) == 4:
            qml.DoubleExcitation(params[i], wires=excitation)
        else:
            qml.SingleExcitation(params[i], wires=excitation)
    return qml.expval(H)

In [20]:
extrapolated_unbatched_ps_time = doubles_ps_unbatched_time * (len(doubles)/doubles_count)
print(f"Extrapolated time to compute all doubles derivatives with unbatched parameter shift: {extrapolated_unbatched_ps_time}")

Extrapolated time to compute all doubles derivatives with unbatched parameter shift: 29862.191517257692


In [21]:
extrapolated_batched_ps_time = doubles_ps_batched_time * (len(doubles)/doubles_count)
print(f"Extrapolated time to compute all doubles derivatives with batched parameter shift: {extrapolated_batched_ps_time}")

Extrapolated time to compute all doubles derivatives with batched parameter shift: 16891.023650142124


In [22]:
adjoint_vs_unbatched_ps = extrapolated_unbatched_ps_time / adjoint_doubles_time
print(f"Time to compute all doubles derivatives:\nRatio of (extrapolated) unbatched parameter shift time to adjoint time: {adjoint_vs_unbatched_ps}")

Time to compute all doubles derivatives:
Ratio of (extrapolated) unbatched parameter shift time to adjoint time: 824.6587039749627


In [23]:
adjoint_vs_batched_ps = extrapolated_batched_ps_time / adjoint_doubles_time
print(f"Time to compute all doubles derivatives:\nRatio of (extrapolated) batched parameter shift time to adjoint time: {adjoint_vs_batched_ps}")

Time to compute all doubles derivatives:
Ratio of (extrapolated) batched parameter shift time to adjoint time: 466.453698285564


In [24]:
adjoint_derivs = [d.numpy() for d in doubles_grads[:doubles_count]]
unbatched_derivs = [d.numpy() for d in unbatched_grads[:doubles_count]]
print(adjoint_derivs)
print(unbatched_derivs)
assert np.allclose(adjoint_derivs, unbatched_derivs)

[-0.044793067882551406, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.04479306788255156, 0.0, 0.0, 0.0, 0.0, -0.03548100882402628, 0.0, 0.0, -0.023288073123347845, 0.0, 0.0, 0.0, 0.0, 0.0, -0.012217382659441306, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[-0.044793067882545376, -6.1928479400457334e-15, 2.6531566187835427e-15, -8.942675704972513e-16, 1.7746123057175047e-15, 5.901095647171372e-15, 1.1868639193156705e-15, -0.044793067882549185, 4.7342881085964644e-15, 8.820892689029992e-16, 1.4863490802086347e-15, 5.90638513669108e-15, -0.03548100882403037, -6.2279833078444984e-15, -2.6752696290439975e-15, -0.023288073123352977, 1.6921423417461336e-15, 2.670274324449899e-15, -2.815484849928097e-15, -2.0406466188365425e-17, -5.810940290979032e-15, -0.01221738265940523, 6.8565100157100086e-15, 5.280719959261674e-15, 6.1969905855345365e-15, -2.171061239200851e-15, 5.089076788104875e-16, -2.6180781927668725e-15, 5.3732805495666735e-15, 6.131155462712434e-16, 3.811084445308944e-15, -6

In [25]:
batched_derivs = [d.numpy() for d in batched_grads[:doubles_count]]
print(batched_derivs)
assert np.allclose(adjoint_derivs, batched_derivs)

[-0.044793067882545376, -6.1928479400457334e-15, 2.6531566187835427e-15, -8.942675704972513e-16, 1.7746123057175047e-15, 5.901095647171372e-15, 1.1868639193156705e-15, -0.044793067882549185, 4.7342881085964644e-15, 8.820892689029992e-16, 1.4863490802086347e-15, 5.90638513669108e-15, -0.03548100882403037, -6.2279833078444984e-15, -2.6752696290439975e-15, -0.023288073123352977, 1.6921423417461336e-15, 2.670274324449899e-15, -2.815484849928097e-15, -2.0406466188365425e-17, -5.810940290979032e-15, -0.01221738265940523, 6.8565100157100086e-15, 5.280719959261674e-15, 6.1969905855345365e-15, -2.171061239200851e-15, 5.089076788104875e-16, -2.6180781927668725e-15, 5.3732805495666735e-15, 6.131155462712434e-16, 3.811084445308944e-15, -6.333324218826695e-15, -3.5324021440048863e-15, -4.675311186009487e-15, 5.167609105081913e-15]
