# We start from the Amplitude Amplification Tutorial from pennylane

# The main formular from the paper: [arXiv:2504.02385](https://arxiv.org/abs/2504.02385)

We implement it using RY as a rotational operator and use jax, as the clean python cell explodes with memory.

	
$$
\begin{bmatrix}
P(e^{i H}) & * \\
* & *
\end{bmatrix}
=
\Bigg(
\prod_{j=1}^d RY\big(\theta_{d+j}\big)\, \begin{bmatrix}
0 & 0 \\
0 & e^{-i H}
\end{bmatrix}

\Bigg)
\Bigg(
\prod_{j=1}^{d} RY\big(\theta_j\big)\,  \begin{bmatrix}
e^{i H} & 0 \\
0 & 0
\end{bmatrix} 
\Bigg)
RY\big(\theta_0)
$$




In [None]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

# In this demo d = 500 is for speed, the result is far away from the exact solution, at about 10000 it might be close.

d = 500 # degree of the polynomial, which is more or less the number of steps to approximate the imagary time evolution

def P(x):
    return jnp.cos(x)**(d-2)

def R(theta):
    return jnp.array([
        [jnp.cos(theta/2), -jnp.sin(theta/2)],
        [jnp.sin(theta/2),  jnp.cos(theta/2)]
    ], dtype=jnp.complex128)

def U_signal(x, which):
    if which == 0:
        return jnp.array([[jnp.exp(1j*x), 0.0],
                          [0.0, 1.0]], dtype=jnp.complex128)
    else:
        return jnp.array([[1.0, 0.0],
                          [0.0, jnp.exp(-1j*x)]], dtype=jnp.complex128)

from jax import lax

@jax.jit
def qsp_matrix(angles, x):
    d = (angles.shape[0] - 1) // 2

    theta0 = angles[0]
    theta1 = angles[1:1+d]
    theta2 = angles[1+d:1+2*d]

    M = jnp.array([
        [jnp.cos(theta0/2), -jnp.sin(theta0/2)],
        [jnp.sin(theta0/2),  jnp.cos(theta0/2)]
    ], dtype=jnp.complex128)

    def loop0(M, theta):
        M = jnp.array([
        [jnp.cos(theta/2), -jnp.sin(theta/2)],
        [jnp.sin(theta/2),  jnp.cos(theta/2)]
    ], dtype=jnp.complex128) @ jnp.array([[jnp.exp(1j*x), 0.0],
                          [0.0, 1.0]], dtype=jnp.complex128) @ M
        return M, None

    def loop1(M, theta):
        M = jnp.array([
        [jnp.cos(theta/2), -jnp.sin(theta/2)],
        [jnp.sin(theta/2),  jnp.cos(theta/2)]
    ], dtype=jnp.complex128) @ jnp.array([[1.0, 0.0],
                          [0.0, jnp.exp(-1j*x)]], dtype=jnp.complex128) @ M
        return M, None

    M, _ = lax.scan(loop0, M, theta1)
    M, _ = lax.scan(loop1, M, theta2)

    return M

@jax.jit
def loss(angles):
    xs = jnp.linspace(0.0, 1.5, d * 7 // 11)

    def loss_x(x):
        M = qsp_matrix(angles, x)
        amp = M[0,0]
        return jnp.abs(amp - P(x))**2

    return jnp.sum(jax.vmap(loss_x)(xs))

loss_grad = jax.jit(jax.grad(loss))


import numpy as np
from scipy.optimize import minimize

loss_np = lambda a: np.array(loss(jnp.array(a)), dtype=float)

init_angles = np.random.uniform(0, np.pi, 2*d+1)

try:
    opt_angles = np.load(f'opt_angles_d{d}.npy')
    print(f"Loaded angles from opt_angles_d{d}.npy")
    print(f"Shape: {opt_angles.shape}")
except FileNotFoundError:
    print("No saved angles found, starting optimization.")
    res = minimize(loss, init_angles, method="BFGS", options={"maxiter":500}, jac=loss_grad)
    opt_angles = res.x
    print(res)
print("Optimized angles:", opt_angles)



Loaded angles from opt_angles_d500.npy
Shape: (1001,)
Optimized angles: [0.77669598 1.2442304  0.36506457 ... 1.29953095 0.24041993 1.60163658]


In [2]:
import pennylane as qml
import numpy as np
from scipy.sparse.linalg import eigsh

L_x = 2
L_y = 2
boundary_conditions = [False, False]

t = 1
UU = 8

# Parameters for Trotterization and filter alignment
optimal_x_for_filter = 0.0
n_steps = 1
order = 1

H_penny = qml.spin.fermi_hubbard('rectangle',[L_x, L_y], hopping=t, coulomb=UU, boundary_condition=boundary_conditions)

H_original = H_penny.sparse_matrix()

eig_vals_orig, eig_vecs_orig = eigsh(H_original, k=6, which='SA')
print(f"PennyLane Ground state energy: {np.min(eig_vals_orig)}")

H_matrix = H_original
eig_max_orig = eig_max = eigsh(H_matrix, k=1, which='LA', return_eigenvectors=False)[0]
eig_min = np.min(eig_vals_orig)
print(f"Max eigenvalue: {eig_max}")
print(f"Min eigenvalue: {eig_min}")

print("Shifting Hamiltonian...")
ops = H_penny.operands
coefs = [op.scalar for op in ops]
obs = [op.base for op in ops]
H_penny = qml.Hamiltonian(coefs, obs)

H_penny = (H_penny - qml.Identity(wires=H_penny.wires) * (eig_min - optimal_x_for_filter * (eig_max - eig_min))) / (eig_max - eig_min)
H_penny = qml.Hamiltonian(H_penny.coeffs, H_penny.ops)

H_matrix = H_penny.sparse_matrix()
eig_max_check = eigsh(H_matrix, k=1, which='LA', return_eigenvectors=False)[0]
eig_min_check = eigsh(H_matrix, k=1, which='SA', return_eigenvectors=False)[0]
print(f"Max eigenvalue: {eig_max_check}")
print(f"Min eigenvalue: {eig_min_check}")

ops = H_penny.operands
coefs = [op.scalar for op in ops]
print(f"Number of terms: {len(ops)}")
print(f"Coefficients: {coefs[:5]}...")  # Show first 5
print(f"Operators: {ops[:5]}...")  # Show first 5

obs = [op.base for op in ops]
grouped = qml.pauli.group_observables(obs, coefs, grouping_type='commuting')
print(f"Number of groups: {len(grouped)}")
print(f"Shape of first group: {len(grouped[0])}")
print(f"Shape of second group: {len(grouped[1])}")
for g in grouped[0]:
    print(len(g))
for g in grouped[1]:
    print(len(g))
    
obs_flat = [item for sublist in grouped[0] for item in sublist]
coeffs_flat = [item for sublist in grouped[1] for item in sublist]

grouped_H = qml.Hamiltonian(coeffs_flat, obs_flat)
eig_vals_grouped, _ = eigsh(grouped_H.sparse_matrix(), k=6, which='SA')
print(f"Check Ground state energy: {np.min(eig_vals_grouped)}")
print(f"shape of H {grouped_H.sparse_matrix().shape[0]}")

# Remap wires of grouped_H to make wire 0 free for control
def our_R(theta, wire):
    qml.RY(theta, wires=wire) # @ qml.RZ(np.pi, wires=wire)

wire_map = {i: i+1 for i in range(grouped_H.num_wires)}
grouped_H_new = qml.map_wires(grouped_H, wire_map)

def iterate(angles):
    d = len(angles)//2
    theta0 = angles[0]
    theta1 = angles[1:d+1]
    theta2 = angles[d+1:2*d+1]

    our_R(theta0, wire=0)
    
    # Erste Schleife: wirkt auf |0> Komponente
    for t in theta1:
        qml.ctrl(qml.TrotterProduct(grouped_H_new, time=1.0, n=n_steps, order=order, check_hermitian=True), control=0, control_values=0)
        our_R(t, wire=0)
        
    # Zweite Schleife: wirkt auf |1> Komponente
    for t in theta2:
        qml.ctrl(qml.TrotterProduct(grouped_H_new, time=-1.0, n=n_steps, order=order, check_hermitian=True), control=0, control_values=1)
        our_R(t, wire=0)

devtest = qml.device("lightning.qubit", wires=grouped_H.num_wires + 1)

@qml.qnode(devtest)
def my_circ_ctrl(angles=None):
    # Prepare some state
    for w in range(1, grouped_H.num_wires + 1):
        qml.Hadamard(w)
    iterate(angles)
    return qml.state()     

rr = my_circ_ctrl(opt_angles)

psi_0 = np.ones(2**grouped_H.num_wires) / np.sqrt(2**grouped_H.num_wires)
r_now = psi_0  # rr.reshape(-1,2)[:,0]
energy = np.vdot(r_now, H_penny.sparse_matrix() @ r_now).real / np.vdot(r_now, r_now).real
print(f"Expectation value of initial state: {energy}")
r_now = rr.reshape(2,-1)[0]
energy = np.vdot(r_now, H_penny.sparse_matrix() @ r_now).real / np.vdot(r_now, r_now).real
print(f"Expectation value of Hamiltonian: {energy}")
energy = np.vdot(r_now, H_original @ r_now).real / np.vdot(r_now, r_now).real
print(f"Expectation value of original Hamiltonian: {energy}")
print(f"Norm r_now: {np.vdot(r_now, r_now).real}, Norm rr: {np.vdot(rr, rr).real}")

PennyLane Ground state energy: -3.2077509432193443
Max eigenvalue: 32.00000000000003
Min eigenvalue: -3.2077509432193443
Shifting Hamiltonian...
Max eigenvalue: 1.0000000000000004
Min eigenvalue: -2.6142342450735645e-16
Number of terms: 30
Coefficients: [np.float64(-0.01420141834127279), np.float64(-0.01420141834127279), np.float64(0.22722269346036464), np.float64(-0.01420141834127279), np.float64(-0.01420141834127279)]...
Operators: (-0.01420141834127279 * (Y(0) @ Z(1) @ Y(2)), -0.01420141834127279 * (X(0) @ Z(1) @ X(2)), 0.22722269346036464 * I([0, 1, 2, 3, 4, 5, 6, 7]), -0.01420141834127279 * (Y(1) @ Z(2) @ Y(3)), -0.01420141834127279 * (X(1) @ Z(2) @ X(3)))...
Number of groups: 2
Shape of first group: 3
Shape of second group: 3
8
14
8
8
14
8
Check Ground state energy: -2.6141460551409927e-16
shape of H 256
Expectation value of initial state: 0.3183319196189052
Expectation value of Hamiltonian: 0.0379650542404368
Expectation value of original Hamiltonian: -1.8710867689762316
Norm r_

## Here we do Amplitude Amplification to get the filtered state with high probability

In [3]:
dev = qml.device("lightning.qubit", wires=grouped_H_new.num_wires + 2)
@qml.prod
def U2(wires):
    for wire in wires:
        qml.Hadamard(wires=wire)
    iterate(opt_angles)
    
@qml.prod
def oracle():
    qml.FlipSign(0, wires=0)

@qml.qnode(dev)
def circuit_state(iters):
    U2(wires=range(1, grouped_H_new.num_wires + 1))
    qml.AmplitudeAmplification(U = U2(wires=range(1, grouped_H_new.num_wires + 1)),
                               O = oracle(),
                               iters = iters,
                               fixed_point=True,
                               work_wire=grouped_H_new.num_wires + 1)

    return qml.state()

test_output = circuit_state(iters=10)

tt = test_output.reshape(2,-1,2)[0,:,:].sum(axis=1)
print(f"Energy test_output: {np.vdot(tt, H_original @ tt).real / np.vdot(tt, tt).real}, with probability {np.vdot(tt, tt).real}")

Energy test_output: -1.8710867689761725, with probability 0.9814039917179023


## Here we measure the phi, which is the energy of the shifted hamiltonian

$$
\phi = arc(<\psi|e^{i H}|\psi>),
$$
which is done with TrotterProduct and hadamard test.

In [4]:
dev_with_measurement = qml.device("default.qubit", wires=H_penny.num_wires + 3)  # the lightning device does not support post selection in this way yet
@qml.qnode(dev_with_measurement)
def circuit_energy():
    qml.StatePrep(test_output, wires=range(H_penny.num_wires + 2))
    return qml.expval(grouped_H_new)

@qml.qnode(dev_with_measurement)
def hadamard_test(measure="X"):
    ancilla = H_penny.num_wires + 2
    system_wires = range(H_penny.num_wires + 2)
    qml.StatePrep(test_output, wires=system_wires)
    
    qml.Hadamard(wires=ancilla)
    qml.measure(wires=0, postselect=0) # The good from amplitude amplification
    
    qml.ctrl(
        qml.TrotterProduct(grouped_H_new, time=1.0, n=2, order=order),
        control=ancilla
    )
    
    if measure == "X":
        return qml.expval(qml.PauliX(ancilla))
    else:
        return qml.expval(qml.PauliY(ancilla))
    
    
energy = circuit_energy()
print(f"Energy from PennyLane QNode: {energy} (no good postselection)")
print(f"Energy from sparse matrix: {np.vdot(tt, grouped_H.sparse_matrix() @ tt).real / np.vdot(tt, tt).real}")


Re = hadamard_test("X")
Im = hadamard_test("Y")
z = Re + 1j*Im
phi = np.angle(z)  #  arg(<ψ|U|ψ>)
print(f"Energy from Hadamard test: {phi}")

Energy from PennyLane QNode: 0.043538558179383596 (no good postselection)
Energy from sparse matrix: 0.03796505424043849
Energy from Hadamard test: 0.037959451443756134


In [5]:
eig_max = eig_max_orig
eig_min = np.min(eig_vals_orig)

E_phi_original = (phi * (eig_max - eig_min)) + (eig_min - optimal_x_for_filter * (eig_max - eig_min))
print(f"Energy from Hadamard test, original Hamiltonian: {E_phi_original}, which compares to the exact solution {eig_min}")

Energy from Hadamard test, original Hamiltonian: -1.8712840308463492, which compares to the exact solution -3.2077509432193443


In [6]:
# One might want to save the angles for later use, they are loaded at the beginning of this notebook.

# np.save(f'opt_angles_d{d}.npy', opt_angles)
# print(f"Saved angles to opt_angles_d{d}.npy")

The circuit_state and hadamard_test could be combined into one final circuit which should run on a shot enabled device