# Time-correlated ZNE

## Usage

1. Run from Jupyter notebook / Colab.
2. Run with Papermill.
  * `papermill run.ipynb out.ipynb -p param1 value1 -p param2 value2 ...`
  * *Tip*: Do `papermill run.ipynb --help-notebook` to see all parameters.

## Setup

### Installs

In [None]:
"""TODO: In order for this to work, 

TODO: Pin all installs for reproducibility.
"""

# !pip install git+https://github.com/mezze-team/mezze --quiet
# !pip install mitiq==0.10.0 --quiet
# !pip install tensorflow==2.3.0

### Imports

In [None]:
import numpy as np
import scipy.signal as si
import matplotlib.pyplot as plt

import cirq

import mezze.random.SchWARMA as sch
from mezze.tfq import (
    SimpleDephasingSchWARMAFier,
    TensorFlowSchWARMASim,
    get_filter_fcns_only,
    compute_A_BBs,
)
from mitiq import zne

import mezze.channel as ch
from functools import reduce

### Functions that should be imports

In [None]:
# @title TODO: Put these in a .py file somewhere.
from typing import Tuple


def make_stretched_lowpass_noise(
    scale_factor: float,
    base_power: float = 0.01,
    cutoff: float = 0.1 * np.pi,
) -> SimpleDephasingSchWARMAFier:
    NUM_ELEMENTS = 1000
    
    # Norm of unscaled noise
    norm = np.linalg.norm(si.firwin(NUM_ELEMENTS, cutoff))

    # Generate the filter. The cutoff frequency is scaled to simulate pulse stretching.
    b = si.firwin(NUM_ELEMENTS, cutoff * scale_factor)
    
    # Normalize to base_power using the scale_factor-independent "norm".
    b = b / norm * np.sqrt(base_power)
    
    # Scale noise amplitude
    b = b * np.sqrt(scale_factor)
    
    return SimpleDephasingSchWARMAFier(b=b)


def make_stretched_white_noise(
    scale_factor: float,
    base_power: float = 0.01,
) -> SimpleDephasingSchWARMAFier:
    b = np.array(
        [
            1,
        ]
    )
    
    # Normalize to base_power
    b = b / np.linalg.norm(b) * np.sqrt(base_power)
    
    # Scale noise amplitude
    b = b * np.sqrt(scale_factor)
        
    return SimpleDephasingSchWARMAFier(b=b)


def _get_pink_arma(
    alpha, power=None, wl=0.001 * np.pi, wh=0.999 * np.pi
) -> Tuple[np.ndarray, np.ndarray]:
    """Implementes the approach from

         S. Plaszczynski, Fluctuation and Noise Letters 7, R1 (2007)

    Args:
        alpha: Noise exponent. (Float in (0,2].)
        power: Power of the noise.
        wl: Normalized frequency cutoff for white band at the start.
        wh: Normalized frequency cutoff for 1/f^2 band at the end.

    Returns:
        bb, aa as np.array's of ARMA coefficients (in si.filter form).
    """
    Nf = np.ceil(2.5 * (np.log10(wh) - np.log10(wl)))
    delp = (np.log10(wh) - np.log10(wl)) / Nf
    logps = np.log10(wl) + 0.5 * (1 - alpha / 2.0) * delp + np.arange(Nf) * delp
    logzs = logps + alpha / 2.0 * delp
    ps = 10 ** (logps)
    zs = 10 ** (logzs)

    pstx = (1 - ps) / (1 + ps)
    zstx = (1 - zs) / (1 + zs)
    bb, aa = si.zpk2tf(zstx, pstx, k=1e-4)

    if power is not None:
        w_pa, h_pa = si.freqz(bb, aa, worN=2048 * 8, whole=True)
        acv = np.fft.ifft(np.abs(h_pa) ** 2)
        bb = bb / np.sqrt(acv[0]) * np.sqrt(power)

    return bb, aa


def make_stretched_pink_noise(
    scale_factor: float, alpha: float, base_power: float = 0.01, 
) -> SimpleDephasingSchWARMAFier:
    
    # Get scale_factor independent norm
    bb, aa = _get_pink_arma(alpha, power=None, wl=0.001*np.pi*1, wh=0.999*np.pi*1)
    w_pa, h_pa = si.freqz(bb, aa, worN=2048 * 8, whole=True)
    acv = np.fft.ifft(np.abs(h_pa) ** 2)
    norm = np.sqrt(acv[0])
    
    # Get ARMA coefficients with stretched wl and wh
    b, a = _get_pink_arma(alpha, power=None, wl=0.001*np.pi*scale_factor, wh=0.999*np.pi*scale_factor)
    
    # Normalize to base_power
    b = b / norm * np.sqrt(base_power)
    
    # Stretch frequency assuming 1/f^alpha noise"
    b = b * np.sqrt(scale_factor ** alpha)
    
    # Scale noise amplitude
    b = b * np.sqrt(scale_factor)

    return SimpleDephasingSchWARMAFier(b=b, a=a)

def scale_noise_trotter(circ: cirq.Circuit, scale_factor: float) -> cirq.Circuit:
    if not np.isclose(scale_factor, round(scale_factor)):
        raise ValueError(
            f"Arg `scale_factor` must be an integer but was {scale_factor}."
        )

    operations = []
    for moment in circ:
        for operation in moment:
            for _ in range(scale_factor):
                operations.append(operation ** (1 / scale_factor))
    return cirq.Circuit(operations)

### Settings

In [None]:
plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "font.size": 14,
    }
)
%matplotlib inline

### Runtime parameters

In [None]:
"""Parameters and default values. 

Note: Papermill inserts a new cell after this one if new parameter values are provided.
"""
# Benchmark circuit parameters.
circuit_type: str = "rb"
depth: int = 2
nqubits: int = 2

# Noise parameters.
noise_type: str = "pink"
base_power: float = 0.01
cutoff_as_fraction_of_pi: float = 0.01  # Lowpass cutoff for lowpass noise. Note: the actual cutoff is 10 times larger.
alpha: float = 2.0  # The α in 1 / f^α noise.
# Scaling parameters.
local_fold_key: str = "random"
max_scale_factor: int = 9

# Option to save data. TODO: Explain save format here / somewhere else.
save: bool = False
    
# Option to save figures.
savefigs: bool = False

# Other miscellaneous parameters.
num_monte_carlo: int = 3000
verbosity: int = 1

In [None]:
# Cast settings into a string
setting_str = f"{circuit_type}_{nqubits}_{depth}_{noise_type}_{base_power}_{cutoff_as_fraction_of_pi}_{alpha}_{local_fold_key}_{max_scale_factor}_{num_monte_carlo}"
setting_str

## Check/parse runtime parameters

In [None]:
# @title Parse benchmark circuit type.
valid_circuit_types = ("rb", "qaoa", "mirror")

if circuit_type == "rb":
    from mitiq.benchmarks import generate_rb_circuits

    (circuit,) = generate_rb_circuits(nqubits, depth, 1)
elif circuit_type == "mirror":
    import networkx as nx
    from mitiq.benchmarks import generate_mirror_circuit
    
    circuit, correct_bitstring = generate_mirror_circuit(
        nlayers=depth,
        two_qubit_gate_prob=1.0,
        connectivity_graph=nx.complete_graph(nqubits),
    )
elif circuit_type == "qaoa":
    qreg = cirq.LineQubit.range(nqubits)
    # Set random qaoa parameters
    alphas = np.random.rand(depth)
    betas = np.random.rand(depth)
    circuit = (
        cirq.Circuit(
            [[
                (cirq.ZZ ** alpha).on(qreg[i + shift], qreg[i + shift + 1])
                for shift in (0, 1)
                for i in range(0, nqubits - shift - 1, 2)
            ],
            cirq.Moment([(cirq.X ** beta).on_each(*qreg)])]
            for alpha, beta in zip(alphas, betas)
        )
    )
    # Append inverse circuit such that tr{rho_ideal |0..0><0..|}=1
    circuit += cirq.inverse(circuit)
else:
    raise ValueError(
        f"Value for `circuit_type` ('{circuit_type}') is invalid. Valid options are {valid_circuit_types}."
    )

circuit = circuit.transform_qubits(lambda q: cirq.GridQubit(q.x, q.x))

print("Benchmark circuit:")
print(circuit)

In [None]:
# @title Parse noise type.
valid_noise_types = ("lowpass", "white", "pink")

if noise_type == "lowpass" or noise_type == "low_pass":
    make_stretched_noise = lambda scale_factor: make_stretched_lowpass_noise(
        scale_factor, base_power=base_power, cutoff=cutoff_as_fraction_of_pi * np.pi
    )
elif noise_type == "white":
    make_stretched_noise = lambda scale_factor: make_stretched_white_noise(
        scale_factor, base_power=base_power
    )
elif noise_type == "pink":
    make_stretched_noise = lambda scale_factor: make_stretched_pink_noise(
        scale_factor, alpha=alpha, base_power=base_power
    )
else:
    raise ValueError(
        f"Value for `noise_type` ('{noise_type}') is invalid. Valid options are {valid_noise_types}."
    )

In [None]:
# @title Parse scaling options.
local_fold_method = {
    "random": zne.scaling.fold_gates_at_random,
    "left": zne.scaling.fold_gates_from_left,
    "right": zne.scaling.fold_gates_from_right,
}.get(local_fold_key)

scale_factors = tuple(range(1, max_scale_factor + 1, 2))

if verbosity >= 1:
    print("Using scale factors:", scale_factors)

## Visualize all noise spectra
Independently from the the chosen `noise_type` we first plot all the considered noise spectra.

In [None]:
# Set plot style
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": False})
axis_title_size = 16
tick_label_size = 15

# @title Plotting code.
plt.figure(figsize=(7,5))

# Plot white noise
S = make_stretched_white_noise(scale_factor=1, base_power=base_power)
w, p = S.psd()
plt.plot(w, p, lw=3, alpha=0.85, label="white")


# Plot lowpass noise
S = make_stretched_lowpass_noise(
        scale_factor=1, base_power=base_power, cutoff=cutoff_as_fraction_of_pi * np.pi
    )
w, p = S.psd()
plt.plot(w, p, lw=3, alpha=0.85, label="low-pass")

# Plot 1/f noise
S = make_stretched_pink_noise(scale_factor=1, alpha=1, base_power=base_power)
w, p = S.psd()
plt.plot(w, p, lw=3, alpha=0.85, label="$1/ f$")

# Plot 1/f**2 noise
S = make_stretched_pink_noise(scale_factor=1, alpha=2, base_power=base_power)
w, p = S.psd()
plt.plot(w, p, lw=3, alpha=0.85, label="$1/ f^2$")


plt.ylim(10 ** -4, 10)
plt.xlim(10 ** -3, max(w))
plt.yscale("log")
plt.xscale("log")
plt.xticks(size=tick_label_size)
plt.yticks(size=tick_label_size)
plt.xlabel(r"Normalized frequency $\omega$", size=axis_title_size)
plt.ylabel(r"Noise spectrum $S(\omega)$", size=axis_title_size)

plt.legend();

if savefigs:
    plt.savefig(f"all_noise_spectra_" + setting_str + ".pdf")
plt.show()

## Visualize noise-scaled spectrum via pulse-stretching

In [None]:
# Set plot style
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": False})

# @title Plotting code.
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 5))

for scale_factor in scale_factors:
    S = make_stretched_noise(scale_factor=scale_factor)
    w, p = S.psd()
    ax.plot(w, p, label=rf"$\lambda = {scale_factor}$", lw=3, alpha=0.85)

ax.set_title("Dephasing noise spectrum")
ax.set_xlabel("Normalized frequency")
ax.set_ylabel("Noise spectrum")

if noise_type == "pink":
    ax.set_yscale("log")

ax.legend()

plt.tight_layout();

# Commented to avoid saving too many figures
# if savefigs:
#     plt.savefig(f"spectrum_" + setting_str + ".pdf")
# plt.show()

## Scale noise with all methods

In [None]:
all_scaling_methods = {
    "Global": zne.scaling.fold_global,
    "Local": local_fold_method,
    "Trotter": scale_noise_trotter,
}

all_scaled_circuits = []
for scaling_method in all_scaling_methods.values():
    all_scaled_circuits.append(
        [scaling_method(circuit, scale_factor) for scale_factor in scale_factors]
    )

## Visualize frequency response of scaled circuits

In [None]:
num_freqs = 8192 * 16

def get_biggest_ff(circuit, num_freqs):
    
    paulis = ch.PauliBasis(nqubits).basis_list
    O = reduce(np.kron, [ch._sigmaZ,]*nqubits)
    Oinv = O.conj().T
    A_BB_idxs, A_BBs = compute_A_BBs(paulis, O, Oinv)
    
    FFs = get_filter_fcns_only(circuit, SimpleDephasingSchWARMAFier(b=.1),worN=num_freqs)
    
    ff_tot = []
    ops = []
    keys = list(FFs.keys())
    
    for i, key in enumerate(keys):
        idx = np.mod(i,len(keys)//nqubits)
        A_BB = A_BBs[idx]

        diff = [np.linalg.norm(op-A_BB) for op in ops]
        ndiff = [np.linalg.norm(op+A_BB) for op in ops]

        if len(ops)==0:
            ff_tot.append(FFs[key].copy())
            ops.append(A_BB)
        elif np.min(diff)==0:
            ii = np.argmin(diff)
            ff_tot[ii]+=FFs[key]
        elif np.min(ndiff)==0:
            ii = np.argmin(ndiff)
            ff_tot[ii]-=FFs[key]
        else:
            ff_tot.append(FFs[key].copy())
            ops.append(A_BB)

    idx = np.argmax(np.linalg.norm([ff for ff in ff_tot]))
    return np.real(ff_tot[idx])

def compute_gen_dephasing_FF(circuit, num_freqs):
    FFs = get_filter_fcns_only(circuit, SimpleDephasingSchWARMAFier(b=.1),worN=num_freqs)
    return np.sum(np.real([FFs[k] for k in FFs.keys()]),0)


FFs = []
for scaled_circuits in all_scaled_circuits:
    FFs.append(
        #[compute_gen_dephasing_FF(circuit, num_freqs) for circuit in scaled_circuits]
        [get_biggest_ff(circuit, num_freqs) for circuit in scaled_circuits]
    )

In [None]:
# Set plot style
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": False})
# @title Plotting code.
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(17, 5), sharex=True, sharey=True)

plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": False})


# Set frequencies.
w = np.linspace(0, 2 * np.pi, num_freqs + 1)[:-1]

for (i, FF) in enumerate(FFs):
    for filter_function, scale_factor in zip(FF, scale_factors):
        axes[i].plot(
            w[: num_freqs // 2],
            filter_function[: num_freqs // 2]
            / np.max(filter_function[: num_freqs // 2]),
            label=rf"$\lambda = {scale_factor}$",
            alpha=0.85,
        )
        axes[i].set_title(tuple(all_scaling_methods.keys())[i])


for ax in axes:
    ax.set_xlabel("Frequency")
    ax.legend()
axes[0].set_ylabel("Filter function")
plt.tight_layout();


if savefigs:
    plt.savefig(f"response_" + setting_str + ".pdf")

## Compute $E(\lambda)$ for all scaling methods

Note: For RB, we take $E(\lambda) = \text{Tr} \left[ \rho(\lambda) \Pi_0 \right]$ where $\Pi_0$ is the ground state projector.

In [None]:
def compute_survival_probability(
    circuit: cirq.Circuit, noise_model: SimpleDephasingSchWARMAFier
) -> float:
    rho = TensorFlowSchWARMASim(circuit, noise_model).dm_sim(num_monte_carlo)
    if circuit_type == "mirror":
        state = int("".join(map(str, correct_bitstring)), 2)
    else:
        state = 0
    return np.real(rho[state, state])

In [None]:
# Noise model at base noise level.
S = make_stretched_noise(scale_factor=1.0)

# Compute 'ground truth' E(𝜆) curve.
all_expectation_values = [
    [
        compute_survival_probability(
            circuit,
            SimpleDephasingSchWARMAFier(
                b = S.b * np.sqrt(scale_factor),
                a = S.a,
            ),  # True scaled noise.
        )
        for scale_factor in scale_factors
    ]
]

# Compute E(𝜆) curve for (approximation to) pulse stretching.
pulse_stretching_expectation_values = [
    compute_survival_probability(circuit, make_stretched_noise(scale_factor))
    for scale_factor in scale_factors
]
all_expectation_values.append(pulse_stretching_expectation_values)

# Compute E(𝜆) curve for all other scaled circuits.
for scaled_circuits in all_scaled_circuits:
    all_expectation_values.append(
        [
            compute_survival_probability(scaled_circuit, S)  # Approximate scaled noise.
            for scaled_circuit in scaled_circuits
        ]
    )

In [None]:
# Set plot style
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": True})
axis_title_size = 18
tick_label_size = 15

# @title Plotting code.
plt.figure(figsize=(7,5))

colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(all_expectation_values) - 1]
# Change order to match colors in paper
colors.append(colors.pop(0))
colors = ["grey"] + colors

labels = ["True", "Pulse"] + list(all_scaling_methods.keys())
for evals, label, color in zip(all_expectation_values, labels, colors):
    plt.plot(
        scale_factors, evals, "-s", lw=2, alpha=0.85, ms=10, mec="black", label=label, color=color,
    )

plt.xticks(size=tick_label_size)
plt.yticks(size=tick_label_size)
plt.xlabel(r"Noise scale factor $\lambda$", size=axis_title_size)
plt.ylabel(r"Expectation value $E(\lambda)$", size=axis_title_size)
plt.legend();

if savefigs:
    plt.savefig(f"noise_scaling_comparison_" + setting_str + ".pdf")
plt.show()

## Quantify errors for all scaling methods

In [None]:
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": False})
axis_title_size = 18
tick_label_size = 15

# @title Plotting code.
plt.figure(figsize=(7, 4))

colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(all_expectation_values) - 1]
# Change order to match colors in paper
colors.append(colors.pop(0))
colors = colors

true = np.array(all_expectation_values[0])
for evals, label, color in zip(all_expectation_values[1:], labels[1:], colors):
    error = np.abs((np.array(evals) - true) / true)
    plt.plot(
        scale_factors, error, "-s", lw=2, alpha=0.85, ms=10, mec="black", label=label, color=color,
    )
scale_factors_strings = [str(x) for x in scale_factors]
plt.xticks(scale_factors, scale_factors_strings, size=tick_label_size)
plt.yticks(size=tick_label_size)
plt.xlabel(r"$\lambda$", size=axis_title_size)
plt.ylabel(r"$\left| \frac{ E(\lambda) - E^*(\lambda) }{ E^*(\lambda) } \right|$", size=axis_title_size)
plt.legend();

if savefigs:
    plt.savefig(f"errors_" + setting_str + ".pdf")
plt.show()

# Extrapolation to the zero-noise limit

In [None]:
zne_limits = []
zne_curves = []

# Compute zero-noise extrapolations
for expectation_values in all_expectation_values:
    results = zne.inference.ExpFactory.extrapolate(
        scale_factors=scale_factors,
        exp_values=expectation_values,
        asymptote= 1 / (2 ** nqubits),
        full_output=True,
    )
    zne_limits.append(results[0])
    zne_curves.append(results[-1])
    

# @title Plotting code.
plt.figure(figsize=(8, 6))
plt.rcParams.update({"font.family": "serif", "font.size": 14, "text.usetex": True})
axis_title_size = 20
tick_label_size = 18
legend_font_size = 11
   

labels = ["True", "Pulse"] + list(all_scaling_methods.keys())
zne_labels = [label + " (ZNE)" for label in labels]
points_labels = [label + " (points)" for label in labels]

colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(all_expectation_values) - 1]
# Change order to match colors in paper
colors.append(colors.pop(0))
colors = ["grey"] + colors

for evals, point_label, zne_limit, zne_curve, color, zne_label in zip(
    all_expectation_values, points_labels, zne_limits, zne_curves, colors, zne_labels
):
    # Plot extrapolation curves
    x = np.arange(0, max(scale_factors), 0.1)
    plt.plot(
        x, zne_curve(x), "--", color=color, lw=2.3, alpha=0.7, ms=10, mec="black",
    )
    # Plot noise scaled expectation values (points)
    plt.plot(
        scale_factors, evals, "s", lw=2, alpha=0.8, ms=10, mec="black", label=point_label, color=color,
    )
    
    # Plot zero noise limits
    plt.plot(
        0, zne_limit, "*", lw=2, alpha=0.8, ms=10, mec="black", label=zne_label, color=color,
    )


scale_factors_strings = [str(x) for x in scale_factors]
plt.xticks(scale_factors, scale_factors_strings, size=tick_label_size)
plt.yticks(size=tick_label_size)
plt.xlabel(r"Noise scale factor $\lambda$", size=axis_title_size)
plt.ylabel(r"Expectation value $E(\lambda)$", size=axis_title_size)
plt.legend(fontsize=legend_font_size)

if savefigs:
    plt.savefig(f"zne_comparison_" + setting_str + ".pdf")
plt.show()

## [Optional] Save output

In [None]:
# @title Saving code.
if save:
    import os
    import time

    noise_params = {
        "lowpass": f"cutoff_{cutoff_as_fraction_of_pi}",
        "pink": "alpha_%0.2f" % alpha,
    }.get(noise_type)
    dir_name = f"tzne_circuit_type_{circuit_type}_nqubits_{nqubits}_depth_{depth}_noise_type_{noise_type}_base_power_{base_power}_{noise_params}_max_scale_factor_{max_scale_factor}_num_monte_carlo_{num_monte_carlo}"
    time_key = "_".join(time.asctime().split())

    try:
        os.mkdir(dir_name)
    except FileExistsError:
        pass

    # Save scale factors.
    np.savetxt(
        os.path.join(dir_name, "scale_factors_" + time_key + ".txt"), scale_factors
    )

    # Save expectation values.
    np.savetxt(
        os.path.join(dir_name, "all_expectation_values_" + time_key + ".txt"),
        all_expectation_values,
    )

    print("Data saved succesfully in", dir_name)