# Plot frequency response for different noise scaling methods

### Imports

In [None]:
import numpy as np
import scipy.signal as si
from scipy.fft import fft, ifft

import matplotlib.pyplot as plt

import cirq

import mezze.random.SchWARMA as sch
from mezze.tfq import (
    SimpleDephasingSchWARMAFier,
    TensorFlowSchWARMASim,
    get_filter_fcns_only,
    compute_A_BBs
)

from mitiq import zne
import mezze.channel as ch
from functools import reduce

In [None]:
# Define "gate-Trotterization" function

def scale_noise_trotter(circ: cirq.Circuit, scale_factor: float) -> cirq.Circuit:
    if not np.isclose(scale_factor, round(scale_factor)):
        raise ValueError(
            f"Arg `scale_factor` must be an integer but was {scale_factor}."
        )

    operations = []
    for moment in circ:
        for operation in moment:
            for _ in range(scale_factor):
                operations.append(operation ** (1 / scale_factor))
    return cirq.Circuit(operations)

### Settings

In [None]:
plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "font.size": 14,
    }
)
%matplotlib inline

### Runtime parameters

**Note:** some of the parameters are irrelevant for this notebook but we keep them for consistency with `run.ipynb`.

In [None]:
"""Parameters and default values. 

Note: Papermill inserts a new cell after this one if new parameter values are provided.
"""
# Benchmark circuit parameters.
circuit_type: str = "rb"
depth: int = 10
nqubits: int = 1

# Noise parameters.
noise_type: str = "pink"
base_power: float = 0.01
cutoff_as_fraction_of_pi: float = 0.01  # Lowpass cutoff for lowpass noise. Note: the actual cutoff is 10 times larger.
alpha: float = 2.0  # The α in 1 / f^α noise.
# Scaling parameters.
local_fold_key: str = "random"
max_scale_factor: int = 9

# Option to save data. TODO: Explain save format here / somewhere else.
save: bool = False
    
# Option to save figures.
savefigs: bool = False

# Other miscellaneous parameters.
num_monte_carlo: int = 3000
verbosity: int = 1

In [None]:
# Parameters
circuit_type = "rb"
noise_type = "white"
depth= 2
alpha = 1
savefigs = False
nqubits = 2
save = False

In [None]:
# Cast settings into a string
setting_str = f"{circuit_type}_{nqubits}_{depth}_{noise_type}_{base_power}_{cutoff_as_fraction_of_pi}_{alpha}_{local_fold_key}_{max_scale_factor}_{num_monte_carlo}"
setting_str

## Check/parse runtime parameters

In [None]:
# @title Parse benchmark circuit type.
valid_circuit_types = ("rb", "qaoa", "mirror")

if circuit_type == "rb":
    from mitiq.benchmarks import generate_rb_circuits

    (circuit,) = generate_rb_circuits(nqubits, depth, 1)
elif circuit_type == "mirror":
    import networkx as nx
    from mitiq.benchmarks import generate_mirror_circuit
    
    circuit, correct_bitstring = generate_mirror_circuit(
        nlayers=depth,
        two_qubit_gate_prob=1.0,
        connectivity_graph=nx.complete_graph(nqubits),
    )
elif circuit_type == "qaoa":
    qreg = cirq.LineQubit.range(nqubits)
    # Set random qaoa parameters
    alphas = np.random.rand(depth)
    betas = np.random.rand(depth)
    circuit = (
        cirq.Circuit(
            [[
                (cirq.ZZ ** alpha).on(qreg[i + shift], qreg[i + shift + 1])
                for shift in (0, 1)
                for i in range(0, nqubits - shift - 1, 2)
            ],
            cirq.Moment([(cirq.X ** beta).on_each(*qreg)])]
            for alpha, beta in zip(alphas, betas)
        )
    )
    # Append inverse circuit such that tr{rho_ideal |0..0><0..|}=1
    circuit += cirq.inverse(circuit)
else:
    raise ValueError(
        f"Value for `circuit_type` ('{circuit_type}') is invalid. Valid options are {valid_circuit_types}."
    )

circuit = circuit.transform_qubits(lambda q: cirq.GridQubit(q.x, q.x))

print("Benchmark circuit:")
print(circuit)

In [None]:
# @title Parse scaling options.
local_fold_method = {
    "random": zne.scaling.fold_gates_at_random,
    "left": zne.scaling.fold_gates_from_left,
    "right": zne.scaling.fold_gates_from_right,
}.get(local_fold_key)

scale_factors = tuple(range(1, max_scale_factor + 1, 2))

if verbosity >= 1:
    print("Using scale factors:", scale_factors)

## Scale circuits with all gate-level methods

In [None]:
all_scaling_methods = {
    "Global": zne.scaling.fold_global,
    "Local": local_fold_method,
    "Trotter": scale_noise_trotter,
}

all_scaled_circuits = []
for scaling_method in all_scaling_methods.values():
    all_scaled_circuits.append(
        [scaling_method(circuit, scale_factor) for scale_factor in scale_factors]
    )

## Visualize frequency response of scaled circuits

In [None]:
num_freqs = 8192 * 16
    


def get_biggest_ff(circuit, num_freqs):
    
    paulis = ch.PauliBasis(nqubits).basis_list
    O = reduce(np.kron, [ch._sigmaZ,]*nqubits)
    Oinv = O.conj().T
    A_BB_idxs, A_BBs = compute_A_BBs(paulis, O, Oinv)
    
    FFs = get_filter_fcns_only(circuit, SimpleDephasingSchWARMAFier(b=.1),worN=num_freqs)
    
    ff_tot = []
    ops = []
    keys = list(FFs.keys())
    for i, key in enumerate(keys):
        idx = np.mod(i,len(keys)//nqubits)
        A_BB = A_BBs[idx]

        diff = [np.linalg.norm(op-A_BB) for op in ops]
        ndiff = [np.linalg.norm(op+A_BB) for op in ops]

        if len(ops)==0:
            ff_tot.append(FFs[key].copy())
            ops.append(A_BB)
        elif np.min(diff)==0:
            ii = np.argmin(diff)
            ff_tot[ii]+=FFs[key]
        elif np.min(ndiff)==0:
            ii = np.argmin(ndiff)
            ff_tot[ii]-=FFs[key]
        else:
            ff_tot.append(FFs[key].copy())
            ops.append(A_BB)

    #idx = np.argmax([np.linalg.norm(ff) for ff in ff_tot])
    idx = np.argsort([np.linalg.norm(ff) for ff in ff_tot])[-1]
    return np.real(ff_tot[idx])

digital_FFs = []
for scaled_circuits in all_scaled_circuits:
    digital_FFs.append(
        [get_biggest_ff(circuit, num_freqs)[: num_freqs // 2] for circuit in scaled_circuits]
    )

In [None]:
    
# Get FF for pulse stretching
base_FF = get_biggest_ff(scaled_circuits[0], num_freqs)[: num_freqs // 2]
base_FF_omega = ifft(base_FF)

# Fourier transform and padd with zeros to the right to scale the spectrum
high_freq_zeros = [(scale_factor - 1) * len(base_FF_omega) * [0] for scale_factor in scale_factors]
pulse_FFs_omega = [list(base_FF_omega) + zeros for zeros in high_freq_zeros] 

# Inverse Fourier transform and ensure same length
pulse_FFs = [abs(fft(FF_omega)) for FF_omega in pulse_FFs_omega]

# Fake placeholder - pulse stretching will be applied directly when plotting
pulse_FFs = [base_FF for scale_factor in scale_factors]

# Stack pulse stretching before all the other filter functions
FFs = [pulse_FFs] + digital_FFs
print(len(FFs))

In [None]:
max_scale_factor_to_plot = 5
y_max = 1.5

# Set plot style
plt.rcParams.update({"font.family": "serif", "font.size": 17, "text.usetex": True})
axis_size = 20

# @title Plotting code.
fig, axes_grid = plt.subplots(nrows=2, ncols=2, figsize=(12, 12), sharex=False, sharey=True)

axes = axes_grid.flatten()

# Set frequencies.
w = np.linspace(0, 2 * np.pi, num_freqs + 1)[:-1]
w = np.array(w[: num_freqs // 2])

xticks= [0, np.pi / 4, np.pi / 2, 3 * np.pi / 4, np.pi]
xticks_labels= ["0", r"$\pi/4$", r"$\pi/2$", r"$3 \pi /4$", r"$\pi$"]


# Plot all noise-scaled filter functions
for (i, FF) in enumerate(FFs):
    for filter_function, scale_factor in zip(FF, scale_factors): 
        
        # Special case of pulse stretching
        if i == 0 and 1 < scale_factor <= max_scale_factor_to_plot:
            w_stretched = w / float(scale_factor)
            filter_function_stretched = filter_function.copy()
            filter_function_stretched[-1] = 0.0
            axes[i].plot(
                w_stretched,
                filter_function_stretched
                / max(filter_function_stretched),
                label=rf"$\lambda = {scale_factor}$",
                alpha=0.75,
            )
            axes[i].set_xticks(xticks, minor=False)
            axes[i].set_xticklabels(xticks_labels)
            continue

        # All digital noise scaling methods
        if scale_factor <= max_scale_factor_to_plot:
            axes[i].set_ylim(0, y_max)
            axes[i].plot(
                w,
                filter_function
                / max(filter_function),
                label=rf"$\lambda = {scale_factor}$",
                alpha=0.75,
            )
            axes[i].set_xticks(xticks, minor=False)
            axes[i].set_xticklabels(xticks_labels)

fig_labels = ["(a)", "(b)", "(c)", "(d)"]
fig_names = ["- Pulse", "- Global", "- Local", "- Trotter"]
x = 0
y = 1.35
for ax, fig_label, fig_name in zip(axes[:2], fig_labels, fig_names):
    ax.text(x, y, fig_label, size=22)
    ax.text(x + 0.4, y + 0.008, fig_name, size=17)
    ax.legend()

for ax, fig_label, fig_name in zip(axes[2:], fig_labels[2:], fig_names[2:]):
    ax.text(x, y, fig_label, size=22)
    ax.text(x + 0.4, y + 0.008, fig_name, size=17)
    ax.set_xlabel("Normalized frequency $\omega$", size=axis_size)
    ax.legend()


axes[0].set_ylabel("Normalized filter function", size=axis_size)
axes[2].set_ylabel("Normalized filter function", size=axis_size)
#plt.tight_layout();


if savefigs:
    plt.savefig(f"frequency_response_functions_" + setting_str + ".pdf")

In [None]:
circuit = all_scaled_circuits[0][0]

paulis = ch.PauliBasis(nqubits).basis_list
O = reduce(np.kron, [ch._sigmaZ,]*nqubits)
Oinv = O.conj().T
A_BB_idxs, A_BBs = compute_A_BBs(paulis, O, Oinv)

FFs = get_filter_fcns_only(circuit, SimpleDephasingSchWARMAFier(b=.1),worN=num_freqs)

ff_tot = []
ops = []
keys = list(FFs.keys())
for i, key in enumerate(keys):
    idx = np.mod(i,len(keys)//nqubits)
    A_BB = A_BBs[idx]

    diff = [np.linalg.norm(op-A_BB) for op in ops]
    ndiff = [np.linalg.norm(op+A_BB) for op in ops]

    if len(ops)==0:
        ff_tot.append(FFs[key].copy())
        ops.append(A_BB)
    elif np.min(diff)==0:
        ii = np.argmin(diff)
        ff_tot[ii]+=FFs[key]
    elif np.min(ndiff)==0:
        ii = np.argmin(ndiff)
        ff_tot[ii]-=FFs[key]
    else:
        ff_tot.append(FFs[key].copy())
        ops.append(A_BB)

In [None]:
plt.figure()
plt.plot([np.linalg.norm(np.real(ff)) for ff in ff_tot],'-o',label='Real Part')
plt.plot([np.linalg.norm(np.imag(ff)) for ff in ff_tot],'-o',label='Imaginary Part')
plt.xlabel(r"$\mathcal{A}_{\beta\beta'}$ index")
#plt.xticks(np.arange(14))
plt.ylabel(r"$||\sum_{\mathcal{A}_{\beta\beta'}=\sigma}\mathcal{F}_{\alpha\beta,\alpha'\beta'}||$")
plt.legend()
#plt.ylim([0,100])
plt.show()