# Extending the system size

Here we show how additional degrees of freedom can be incorporated in the classical ansatz.

In [None]:
import pennylane as qml
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import netket as nk
from cqd.expectation import PauliSum
from cqd.utils import zero_tree_like, all_z
from cqd.models import (
    TwoSubCQD,
    JastrowPlusSingle,
    QuantumPlusMeanField,
    build_neural_mean_field,
    NetketMeanField,
)
from cqd.forces import ExactForcesAndQGT
from cqd.tdvp import TwoSubTDVP
from netket.experimental.dynamics import *

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cycler


plt.rcParams.update({
    "text.usetex": True, # enable latex font
    "font.family": "Helvetica", # set font style
    "text.latex.preamble": r'\usepackage{amsmath}', # add latex packages
    "font.size": "18", # set font size
    "figure.figsize": [10, 6], # set figure size
    "lines.linewidth": 2, # set line width
})
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

linestyles = ["-", ":", "-.", "--", (0, (3, 1, 1, 1, 1, 1)), "-", "--", "-.", ":"]
markers = ["o", "v", "s", "d", "X", "p", "*", "o", "v", "s", "d", "*"]
colors = ['#c65742', '#9ad0bb', '#e4bf44', '#87584E', '#aba18d', '#332737','#c65742', '#9ad0bb', '#e4bf44', '#87584E', '#aba18d', '#332737']
colors2 = ['#003122', '#225544', '#497c6a', '#71a591', '#9ad0bb', '#c5fde7']
plt.rcParams['axes.prop_cycle'] = cycler('color', colors)

In [None]:
def active_plus_bath_ham(n_active, n_bath, J_active=1.0, J_bath=0.25, J_int=0.5):
    """Create Hamiltonian with n_active qubits embedded in a bath of n_bath qubits.
    Ordering of the qubits are: n_active, n_bath //2 (left chain), n_bath // 2 (right chain)
    """
    h_tot = 0
    for i in range(n_active - 1):
        h_tot += J_active * qml.X(i) @ qml.X(i + 1)
    for i in range(n_active):
        h_tot += qml.Z(i)

    h_tilde = PauliSum.from_pennylane(h_tot)
    h_tot += J_int * qml.X(0) @ qml.X(n_active)
    h_tot += J_int * qml.X(n_active - 1) @ qml.X(n_active + n_bath - 1)
    for i in range(n_active, n_active + n_bath // 2 - 1):
        h_tot += J_bath * qml.X(i) @ qml.X(i + 1)
    for i in range(n_active + n_bath // 2, n_active + n_bath - 1):
        h_tot += J_bath * qml.X(i) @ qml.X(i + 1)
    for i in range(n_active, n_active + n_bath):
        h_tot += qml.Z(i)
    h_tot = PauliSum.from_pennylane(h_tot)
    return h_tilde, h_tot

In [None]:
# create the Hamiltonian
Jq = 1
Jb = 0.1
Ji = 0.25
n_classical = 4
n_quantum = 4
h_tilde, h_tot = active_plus_bath_ham(n_quantum, n_classical, Jq, Jb, Ji)

# use any flax module
psi_c1 = JastrowPlusSingle()
# play around with the number of hidden units (alpha) and layers
psi_c2 = build_neural_mean_field(alphas_global=[], alphas=[2], activation=nn.tanh)
# this combines the two ansatze to create the full classical ansatz
psi_c = QuantumPlusMeanField(psi_c2, psi_c1)

params = psi_c.init(
    jax.random.PRNGKey(123), np.ones((1, n_quantum)), np.ones((1, n_classical))
)
# The CQD ansatz, including the classical exact sampling for mean-field methods
model = TwoSubCQD(psi_c, n_classical, n_quantum)

In [None]:
# exact simulation for benchmarking
psi0 = np.ones(2 ** (n_quantum + n_classical), dtype=complex)
psi0 /= np.linalg.norm(psi0)
D, V = np.linalg.eigh(h_tot.to_dense())


@jax.jit
def exact_state(t):
    return V @ (jnp.diag(jnp.exp(-1.0j * D * t)) @ (jnp.conj(V.T) @ psi0))

In [None]:
# Functions to compute entanglement entropy and partial fidelity
def sqrt(mat):
    # Computing diagonalization
    evalues, evectors = jnp.linalg.eigh(mat)
    evalues = jnp.where(jnp.abs(evalues) < 1e-10, 0, evalues)
    return evectors @ np.diag(np.sqrt(evalues)) @ jnp.conj(evectors).T


def schmidt_decompose(state, n_sites_a):
    state = state.reshape(2**n_sites_a, -1)
    u, s, vh = np.linalg.svd(state)
    s = s[s > 1e-10]
    rank = len(s)
    u = u[:, :rank]
    s = s[:rank]
    vh = vh[:rank, :]
    return u, s, vh


def entanglement_entropy(state, n_sites_a):
    _, s, _ = schmidt_decompose(state, n_sites_a)
    s = s[s > 1e-10]
    return -np.sum(s**2 * np.log(s**2))


def partial_state(state, n_sites_a):
    u, s, vh = schmidt_decompose(state, n_sites_a)
    return np.einsum("i, ji, ki -> jk", s**2, u, np.conj(u))


def partial_fidelity(state1, state2, n_sites_a):
    state1 = partial_state(state1, n_sites_a)
    state2 = partial_state(state2, n_sites_a)
    sqrt1 = sqrt(state1)
    return np.trace(sqrt(sqrt1 @ state2 @ sqrt1)) ** 2


def partial_fid_old(state1, state2, n_sites_a):
    state1 = partial_state(state1, n_sites_a)
    state2 = partial_state(state2, n_sites_a)
    return np.trace(state1 @ state2)

In [None]:
shots = None  # set to 500 for shot-noise like in the paper
samples = 10  # ignored if shots == None
acond = 1e-4
rcond = 1e-4
dt = 0.001
trotter_step = 0.1
trotter_order = 2

tdvp = TwoSubTDVP(
    model,
    params,
    h_tot,
    h_tilde,
    shots=shots,
    samples=samples,
    trotter_step=trotter_step,
    trotter_order=trotter_order,
    integrator=Euler(dt),
    acond=acond,
    rcond=rcond,
)

fids = []
fids_q = []
times = []
entropy = []


def callback(t, theta, tdvp):
    state = model.full_state(theta, tdvp.phi_q)
    state /= np.linalg.norm(state)
    state_q = model.full_state(tdvp.theta0, tdvp.phi_q)
    exact = exact_state(t)
    fid = np.abs(np.vdot(exact, state)) ** 2

    print(f"t={t}, Fid={fid}", end="\r")
    fids.append(fid)
    fids_q.append(
        [
            partial_fidelity(state, exact, n_quantum),
            partial_fidelity(state_q, exact, n_quantum),
        ]
    )
    entropy.append(
        [
            entanglement_entropy(state, n_quantum),
            entanglement_entropy(exact, n_quantum),
            entanglement_entropy(state_q, n_quantum),
        ]
    )
    times.append(t)


final_params = tdvp.run(1, callback=callback)

In [None]:
import functools
import netket.experimental as nkex

ham_nk = h_tot.to_netket()
model_nk = NetketMeanField(psi_c, n_quantum)


vqs = nk.vqs.FullSumState(ham_nk.hilbert, model_nk)


# vqs.parameters = zero_tree_like(vqs.parameters)
integrator = nkex.dynamics.RK45(dt)
tdvp_nk = nkex.TDVP(
    ham_nk,
    vqs,
    integrator,
    linear_solver=functools.partial(nk.optimizer.solver.pinv_smooth, rtol_smooth=1e-5),
)

fids_nk = []
fids_nk_q = []
times_nk = []
expvals_nk = []
states_hist_nk = []
entropy_nk = []


def nk_callback(_, logdata, tdvp_nk: nkex.TDVP):
    state = tdvp_nk.state.to_array()
    exact = exact_state(tdvp_nk.t)
    fid = np.abs(np.vdot(exact, state)) ** 2
    fids_nk.append(fid)
    fids_nk_q.append(partial_fidelity(state, exact, n_quantum))
    times_nk.append(tdvp_nk.t)
    entropy_nk.append(entanglement_entropy(state, n_quantum))
    states_hist_nk.append([state, exact])
    return True


tdvp_nk.run(1, callback=nk_callback)

In [None]:
fids = np.array(fids)
fids_q = np.array(fids_q)
entropy = np.array(entropy)
plt.plot(times, fids, label="Neural Mean-Field", linewidth=2)
plt.plot(times_nk, fids_nk, label="NetKet", linewidth=2)
plt.ylabel("Fidelity")
plt.xlabel("Time")
plt.legend()

In [None]:
plt.plot(times, fids_q[:, 0], label="Neural Mean-Field")
plt.plot(times_nk, fids_nk_q, label="NetKet")
plt.plot(times, fids_q[:, 1], label="Pure Quantum", linestyle="--")
plt.ylabel("Partial Fidelity")
plt.xlabel("Time")
plt.legend()

In [None]:
plt.plot(times, entropy[:, 0], label="Neural Mean-Field")
plt.plot(times_nk, entropy_nk, label="NetKet")
plt.plot(times, entropy[:, 1], label="Exact", linestyle="--", color="black")
plt.ylabel("Entanglement Entropy")
plt.xlabel("Time")
plt.legend()

Copyright 2025 Gian Gentinetta - All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.