In [None]:
import matplotlib.pyplot as plt
from matplotlib import cycler


plt.rcParams.update({
    "text.usetex": True, # enable latex font
    "font.family": "Helvetica", # set font style
    "text.latex.preamble": r'\usepackage{amsmath}', # add latex packages
    "font.size": "18", # set font size
    "figure.figsize": [10, 6], # set figure size
    "lines.linewidth": 2, # set line width
})
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

linestyles = ["-", ":", "-.", "--", (0, (3, 1, 1, 1, 1, 1)), "-", "--", "-.", ":"]
markers = ["o", "v", "s", "d", "X", "p", "*", "o", "v", "s", "d", "*"]
colors = ['#c65742', '#9ad0bb', '#e4bf44', '#87584E', '#aba18d', '#332737','#c65742', '#9ad0bb', '#e4bf44', '#87584E', '#aba18d', '#332737']
colors2 = ['#003122', '#225544', '#497c6a', '#71a591', '#9ad0bb', '#c5fde7']
plt.rcParams['axes.prop_cycle'] = cycler('color', colors)

In [None]:
import pennylane as qml
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
from cqd.expectation import PauliSum, PauliString
from cqd.utils import zero_tree_like
from cqd.models import LogAmplitudes, JastrowPlusSingle
from cqd.tdvp import HybridTDVP, CQDCallback
import netket.experimental as nkex

# Correcting Trotter error
This notebook contains the code to correct Trotter errors in a transvers-field Ising (TFIM) model. We start by defining the Hamiltonian using Pennylane. The final line converts the pennylane object in a `cqd.expectation.PauliSum` object, which is the required format for the TDVP class later.

In [None]:
def tfim_hamiltonian(L, J=1.0, h=1.0, periodic=True):
    """Implements the transverse field Ising model Hamiltonian with couplings J and field h in 1D."""

    ham = 0
    for i in range(L - 1):
        ham -= h * qml.PauliX(i)
        ham -= J * qml.PauliZ(i) @ qml.PauliZ(i + 1)

    ham -= h * qml.PauliX(L - 1)
    if periodic:
        ham -= J * qml.PauliZ(L - 1) @ qml.PauliZ(0)

    return PauliSum.from_pennylane(ham)

We now define the model. Note that `LogAmplitudes` is ther required format to run the TDVP later. The `model` variable however can be replaced by an arbitrary flax module!

In [None]:
# Create the Hamiltonian
n_spins = 10
J = 2.0
h = 1.0

h_tot = tfim_hamiltonian(n_spins, J=J, h=h, periodic=False)

# Define the classical ansatz
model = JastrowPlusSingle()  # Plug and play!
jax_seed = jax.random.PRNGKey(42)
theta0 = model.init(jax_seed, jnp.ones((1, n_spins)))
theta0 = zero_tree_like(theta0)
logmodel = LogAmplitudes(model)

In [None]:
# define observables
z_correlator = 1 * PauliString(qml.Z(0) @ qml.Z(1), n_spins)
z_long_correlator = 1 * PauliString(qml.Z(0) @ qml.Z(n_spins // 2), n_spins)

observables = [
    z_correlator,
    z_long_correlator,
]

In [None]:
# Define hyperparameters
dt = 0.005
trotter_step = 0.25
shots = None  # Set shots == 2000 to obtain the shot-based results of the paper
trotter_order = 2
rcond = 1e-5
acond = 1e-5
tend = 1
trotter_correct = True

# Define the TDVP object
tdvp = HybridTDVP(
    h_tot,
    logmodel,
    theta0,
    init_quantum_state=psi0,
    h_tilde=h_tot,
    trotter_step=trotter_step,
    shots=shots,
    trotter_order=trotter_order,
    dt=dt,
    rcond=rcond,
    acond=acond,
    correct_trotter=trotter_correct,
    integrator=nkex.dynamics.Euler(dt),
)

# Collect results using a callback
callback = CQDCallback(h_tot, psi0, observables)

# Run the algorithm
theta1 = tdvp.run(tend, callback=callback)

To benchmark, we run a purely classical simulation with NetKet.

In [None]:
ham_nk = h_tot.to_netket()
shots = None
if shots is None:
    vqs = nk.vqs.FullSumState(ham_nk.hilbert, model)
else:
    sampler = nk.sampler.MetropolisLocal(ham_nk.hilbert)
    vqs = nk.vqs.MCState(sampler, model, n_samples=shots, sampler_seed=42)
vqs.parameters = zero_tree_like(vqs.parameters)
integrator = nkex.dynamics.RK45(dt)
tdvp_nk = nkex.TDVP(
    ham_nk,
    vqs,
    integrator,
)

fids_nk = []
times_nk = []
expvals_nk = []
obs_netket = [obs.to_netket() for obs in observables]
states_hist_nk = []


def nk_callback(_, logdata, tdvp_nk: nkex.TDVP):
    state = tdvp_nk.state.to_array()
    exact = exact_state(tdvp_nk.t)
    fid = np.abs(np.vdot(exact, state)) ** 2
    fids_nk.append(fid)
    times_nk.append(tdvp_nk.t)
    evals = []
    for obs in obs_netket:
        evals.append(tdvp_nk.state.expect(obs).mean.real)
    expvals_nk.append(evals)
    # print(f"t={tdvp_nk.t}, Fid={fid}", end="\r")
    states_hist_nk.append([state, exact])
    return True


tdvp_nk.run(tend, callback=nk_callback)

In [None]:
import matplotlib.pyplot as plt
import json

times = callback.times
fids = callback.fidelities[:, 0]
fids_q = callback.fidelities[:, 1]
times_nk = np.array(times_nk)
fids_nk = np.array(fids_nk)


trot_times = np.isin(np.round(times, 3), np.linspace(0, tend, 5), 3)
trot_times[0] = True
trot_times[-1] = True
plt.plot(times, fids, label="CQD ansatz (statevector)")
plt.scatter(times[trot_times], fids[trot_times])
plt.plot(times, fids_q, label="Trotter circuit", color=colors[2])

plt.scatter(times[trot_times], fids_q[trot_times], color=colors[2])

plt.plot(times_nk, fids_nk, label="Jastrow", color=colors[1])

plt.legend()
plt.xlabel("Time")
plt.ylabel("Fidelity")
plt.show()

In [None]:
fig, axs = plt.subplots(2, figsize=(10, 6), sharex=True)
axs[1].set_xlabel("Time")
for i in range(len(observables)):
    axs[i].plot(
        times,
        callback.expectation_values[:, 1, i],
        label="Exact",
        color="black",
        linestyle="--",
    )
    axs[i].plot(
        times,
        callback.expectation_values[:, 0, i],
        label="Trotter circuit + Jastrow",
    )
    axs[i].plot(
        times,
        callback.expectation_values[:, 2, i],
        label="Hardware efficient circuit",
        color=colors[2],
    )
    axs[i].plot(times_nk, np.array(expvals_nk)[:, i], label="Jastrow", color=colors[1])

    axs[i].scatter(
        times[trot_times],
        callback.expectation_values[trot_times, 0, i],
        label="Trotter circuit + Jastrow",
        color=colors[0],
    )

    axs[i].scatter(
        times[trot_times],
        callback.expectation_values[trot_times, 2, i],
        label="Trotter circuit",
        color=colors[2],
    )
    axs[i].set_ylabel(f"{observables[i]}")

plt.show()

Copyright 2025 Gian Gentinetta - All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.