# Sample Analysis

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext autotime
%load_ext line_profiler

from joblib import Parallel, delayed

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from dwave.system import DWaveSampler, FixedEmbeddingComposite
from matplotlib.patches import Rectangle
from numba import njit

from qbm.utils import (
    compute_stats_over_dfs,
    convert_bin_list_to_str,
    get_project_dir,
    get_rng,
    load_artifact,
    save_artifact,
)

time: 1.49 s (started: 2022-01-10 14:32:02 +01:00)


In [2]:
project_dir = get_project_dir()

time: 525 µs (started: 2022-01-10 14:32:04 +01:00)


## Analysis functions

In [3]:
@njit(boundscheck=True)
def compute_KL_divergence(
    p_exact,
    E_exact,
    E_samples,
    counts_samples,
    n_bins=32,
    normalize=False,
    swap=False,
    prob_sum_tol=1e-2,
    ϵ_smooth=1e-6,
):
    """
    Computes the KL divergence of the theory w.r.t. the samples, i.e., 
    D_KL(p_exact || p_samples).
    
    :param p_exact: Exact computed probability vector, i.e., the diagonal of ρ.
    :param E_exact: Exact computed energy vector, i.e., the diagonal of H.
    :param E_samples: Energies of the samples.
    :param n_bins: Number of bins to compute over.
    :param normalize: If True will normalize the bin probabilities.
    :param swap: If True will compute D_KL(p_samples || p_exact) rather than 
        D_KL(p_exact || p_samples).
    :param prob_sum_tol: The tolerance for the probabilities to sum up to approx 1.
    
    :returns: D_KL(p_exact || p_samples).
    """
    p = np.zeros(n_bins)
    q = np.zeros(n_bins)

    # return NaN if exact data isn't a proper probability distribution
    if np.isnan(p_exact).any() or np.abs(p_exact.sum() - 1) > prob_sum_tol:
        return np.nan

    # bin the probabilities
    bin_edges = np.linspace(E_exact.min(), E_exact.max(), n_bins + 1)
    sum_counts = counts_samples.sum()
    for i, (a, b) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
        if i < n_bins - 1:
            p[i] = p_exact[np.logical_and(E_exact >= a, E_exact < b)].sum()
            q[i] = (
                counts_samples[np.logical_and(E_samples >= a, E_samples < b)].sum()
                / sum_counts
            )
        else:
            p[i] = p_exact[E_exact >= a].sum()
            q[i] = counts_samples[E_samples >= a].sum() / sum_counts
    
    # smoothing of sample data
    if ϵ_smooth:
        smooth_mask = np.logical_and(p > ϵ_smooth, q == 0)
        not_smooth_mask = np.logical_not(smooth_mask)
        q[smooth_mask] = ϵ_smooth
        q[not_smooth_mask] -= ϵ_smooth * smooth_mask.sum() / not_smooth_mask.sum()

    # assert that p and q sum up to approx 1
    assert np.abs(p.sum() - 1) < prob_sum_tol
    assert np.abs(q.sum() - 1) < prob_sum_tol

    # take intersection of supports to avoid div zero errors
    support_intersection = np.logical_and(p > 0, q > 0)
    p = p[support_intersection]
    q = q[support_intersection]

    # re-normalize the p and q if True
    if normalize:
        p /= p.sum()
        q /= q.sum()

    # swap p and q if True
    if swap:
        p, q = q, p

    return np.sum(p * np.log(p / q))


@njit(boundscheck=True)
def get_state_energies(states, E_exact):
    """
    Returns the (quantum + classical) energies of the provided states corresponding
    to the provided exact calculated energies.
    
    :param states: Array of states. Must be a value in 0, 1, ..., 2 ** n_qubits - 1.
    :param E_exact: Array of exact computed energies, corresponds to the diagonal of H.
    
    :returns: Array where entry i is the energy of states[i].
    """
    E_samples = np.zeros(len(states))
    for i, state in enumerate(states):
        E_samples[i] = E_exact[state]

    return E_samples


def convert_spin_vector_to_state_number(spins):
    """
    Converts the spins vector (e.g. all values ±1) to an integer corresponding to the state.
    For example, the spin vector [1, 1, 1, 1] corresponds to the state |0000⟩ which is the
    0th state. The spin vector [-1, -1, -1, -1] corresponds to the state |1111⟩ which is the
    15th state.
    
    :param spins: Vector of spin values (±1).
    
    :returns: Integer corresponding to the state. 
    """
    bit_vector = ((1 - spins) / 2).astype(np.int64)

    return (bit_vector * 2 ** np.arange(len(spins) - 1, -1, -1)).sum()


def compute_KL_divergence_df(exact_data, samples, swap):
    """
    Compares each exact computed data distribution against the provided samples instance.
    
    :param exact_data: Dictionary with keys of the form (s, T) with s being the relative
        anneal time at which H and ρ were computed, and T being the effective temperature.
        Values are of the form {"E": [...], "p": [...]}
    :param samples: Instance of Ocean SDK SampleSet.
    :param swap: If True will compute D_KL(p_samples || p_exact) rather than 
        D_KL(p_exact || p_samples).
    
    :returns: Dataframe of KL divergences, with T values as index and s values as columns.
    """
    # convert spin vectors to state numbers
    states = np.array(
        [convert_spin_vector_to_state_number(x) for x in samples.record.sample]
    )

    KL_divergence = {}
    for s, T in exact_data.keys():
        p_exact = exact_data[(s, T)]["p"]
        E_exact = exact_data[(s, T)]["E"]
        E_samples = get_state_energies(states, E_exact)
        KL_divergence[int(T * 1000), s] = compute_KL_divergence(
            p_exact, E_exact, E_samples, samples.record.num_occurrences, swap=swap
        )

    return pd.Series(KL_divergence)


def process_run_gauge_dir(run, gauge_dir, exact_data, swap):
    """
    Helper function for processing the runs and computing the KL divergences
    in parallel.
    
    :param run: Name of the run.
    :param gauge_dir: Directory of the gauge data.
    :param exact_data: Exact computed data to compare against.
    :param swap: If True will compute D_KL(p_samples || p_exact) rather than 
        D_KL(p_exact || p_samples).
    
    :returns: KL divergence dataframe.
    """
    samples = load_artifact(gauge_dir / f"{run}.pkl")
    KL_divergence_df = compute_KL_divergence_df(exact_data, samples, swap)

    return KL_divergence_df

time: 27.7 ms (started: 2022-01-10 14:32:04 +01:00)


## Plotting Functions

In [13]:
def plot_heat_map(KL_divergence, title, **kwargs):
    """
    Plots the KL divergence heat map.
    
    :param KL_divergence: KL divergence dataframe.
    :param title: Title of the plot.
    :param suptitle: Suptitle of the plot.
    :param **kwargs: Additional kwargs for sns.heatmap().
    
    :returns: Matplotlib fig, ax.
    """
    KL_divergence = KL_divergence.copy()
    KL_divergence.index = pd.MultiIndex.from_tuples(KL_divergence.index)
    KL_divergence = KL_divergence.unstack(level=-1)
    
    cmap = sns.color_palette("rocket_r", as_cmap=True)
    s_values = KL_divergence.columns.to_numpy()
    T_values = KL_divergence.index.to_numpy()
    xticks = np.arange(len(s_values))[::10]
    xticklabels = s_values[::10]
    yticks = np.arange(len(T_values))[::2]
    yticklabels = T_values[::2]

    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
#     fig.suptitle(suptitle)
    sns.heatmap(KL_divergence, ax=ax, cmap=cmap, **kwargs)
    ax.set_title(title)
    ax.invert_yaxis()
    ax.set_xlabel(r"$s$")
    ax.set_ylabel(r"$T$ [mK]")
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, rotation=0)
    ax.set_yticks(yticks)
    ax.set_yticklabels(yticklabels, rotation=0)
    plt.tight_layout()

    return fig, ax


def plot_histogram(E, title, **kwargs):
    """
    Plots a histogram.
    
    :param E: Energy array.
    :param title: Title of the plot.
    :param **kwargs: Additional kwargs for ax.hist().
    
    :returns: Matplotlib fig, ax.
    """
    fig, ax = plt.subplots(figsize=(10, 6), dpi=144)
    ax.hist(E, bins=32, **kwargs)
    ax.set_title(title)
    ax.set_xlabel(r"$E$")
    ax.grid()
    plt.tight_layout()

    return fig, ax

time: 1.13 ms (started: 2022-01-10 14:54:44 +01:00)


## Data Loading

In [19]:
config_id = 3
embedding_id = 1
n_jobs = 6

config_dir = project_dir / f"artifacts/exact_analysis/{config_id:02}/"
embedding_dir = config_dir / f"samples/embedding_{embedding_id:02}"

config = load_artifact(config_dir / "config.json")
exact_data = load_artifact(config_dir / "exact_data.pkl")

gauge_dirs = sorted([x for x in embedding_dir.iterdir() if x.name.startswith("gauge_")])
run_names = sorted([x.stem for x in gauge_dirs[0].iterdir() if x.name != "gauge.pkl"])

run_infos = {}
t_as = []
pause_durations = []
for run_name in run_names:
    run_info = {x.split("=")[0]: float(x.split("=")[1]) for x in run_name.split(",")}
    run_info["t_a"] = round(run_info["t_pause"] / run_info["s_pause"], 1)
    run_infos[run_name] = run_info
    
    if run_info["t_a"] not in t_as:
        t_as.append(run_info["t_a"])

    if run_info["pause_duration"] not in pause_durations:
        pause_durations.append(run_info["pause_duration"])

time: 98.7 ms (started: 2022-01-10 15:15:11 +01:00)


## KL Divergence Computations

In [20]:
compute_DKL = True
if not (embedding_dir / "KL_divergences.pkl").exists() or compute_DKL:
    KL_divergences = {}
    for run_name in run_names:
        KL_divergence_dfs = Parallel(n_jobs=n_jobs)(
            delayed(process_run_gauge_dir)(run_name, gauge_dir, exact_data, swap=False)
            for gauge_dir in gauge_dirs
        )
        KL_divergences[run_name] = compute_stats_over_dfs(KL_divergence_dfs)
    save_artifact(KL_divergences, embedding_dir / "KL_divergences.pkl")
else:
    KL_divergences = load_artifact(embedding_dir / "KL_divergences.pkl")

time: 9min 39s (started: 2022-01-10 15:15:12 +01:00)


## KL Divergence Min Value Plots

In [21]:
plot_dir_DKL_mins = config_dir / f"plots/DKL_mins/embedding_{embedding_id:02}"
if not plot_dir_DKL_mins.exists():
    plot_dir_DKL_mins.mkdir(parents=True)

color_map = {0: "tab:red", 10: "tab:blue", 100: "tab:orange", 1000: "tab:green"}

α_quench = 2.0
for t_a in t_as:
    fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
    ax.set_title(fr"$t_a = {t_a:.0f}$ μs, $α_{{quench}} = {α_quench:.0f}$")
    ax.set_xlabel(r"$s_{{pause}}$")
    ax.set_ylabel(r"$\min_{s,T}\{D_{KL}(p_{exact}(s,T) \ || \ p_{samples})\}$")

    if config_id == 1:
        ax.set_xticks(np.arange(0.25, 0.8, 0.05))
        ax.set_yticks(np.arange(0, 0.07, 0.01))
        ax.set_ylim(0, 0.06)
    elif config_id == 2:
        ax.set_xticks(np.arange(0.25, 0.8, 0.05))
        ax.set_yticks(np.arange(0, 0.14, 0.02))
        ax.set_ylim(0, 0.12)
    elif config_id == 3:
        ax.set_xticks(np.arange(0.25, 0.8, 0.05))
        ax.set_yticks(np.arange(0, 0.14, 0.02))
        ax.set_ylim(0, 0.12)
    elif config_id == 4:
        ax.set_xticks(np.arange(0.55, 0.675, 0.025))
        ax.set_ylim(0, 0.25)

    run_names_plot = [
        run_name
        for run_name, run_info in run_infos.items()
        if run_info["t_a"] == t_a
    ]
    for pause_duration in pause_durations:
        x = []
        y = []
        y_err = []
        for run_name in run_names_plot:
            if run_infos[run_name]["pause_duration"] == pause_duration:
                KL_divergences_run = KL_divergences[run_name]
                y_min_index = KL_divergences_run["means"].argmin()

                x.append(run_infos[run_name]["s_pause"])
                y.append(KL_divergences_run["means"].iloc[y_min_index])
                y_err.append(KL_divergences_run["stds"].iloc[y_min_index])

        if x and y:
            label = fr"$\Delta_{{pause}} = {int(pause_duration)}$"
            ax.errorbar(
                x,
                y,
                color=color_map[pause_duration],
                yerr=y_err,
                fmt="o",
                markersize=5,
                linewidth=1.8,
                capsize=10,
                capthick=1.8,
                label=label,
            )

    ax.grid()
    ax.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(plot_dir_DKL_mins / f"t_a={t_a},quench_slope={α_quench}.png")

<Figure size 3000x1800 with 1 Axes>

<Figure size 3000x1800 with 1 Axes>

time: 1.38 s (started: 2022-01-10 15:24:51 +01:00)


### KL Divergence Heatmaps

In [22]:
for run_name, KL_divergence_dict in KL_divergences.items():
    KL_divergence_means = KL_divergence_dict["means"]

    plot_dir_heatmaps = (
        config_dir / f"plots/heatmaps/embedding_{embedding_id:02}"
    )
    if not plot_dir_heatmaps.exists():
        plot_dir_heatmaps.mkdir(parents=True)


    run_info = run_infos[run_name]
    t_pause = run_info["t_pause"]
    s_pause = run_info["s_pause"]
    pause_duration = run_info["pause_duration"]
    α_quench = run_info["quench_slope"]
    title = fr"$t_{{pause}} = {t_pause:.0f}$ μs, $s_{{pause}} = {s_pause}$, $\Delta_{{pause}} = {pause_duration:.0f}$ μs, $\alpha_{{quench}} = {α_quench:.0f}$"
    cbar_kws={'label': r"$D_{KL}(p_{exact}(s,T) \ || \ p_{samples})$"}

    fig, ax = plot_heat_map(KL_divergence_means, title, vmin=0, vmax=0.20, cbar_kws=cbar_kws)
#     T_min, s_min = KL_divergence_means.index[KL_divergence_means.argmin()]
#     x_min = s_min * 100 - 20
#     y_min = T_min / 2 - 1
#     ax.add_patch(
#         Rectangle((x_min, y_min), 1, 1, fill=False, edgecolor="tab:blue", lw=2)
#     )
    plt.savefig(plot_dir_heatmaps / f"{run_name}.png")
    fig.clear()
    plt.close(fig)

time: 42.5 s (started: 2022-01-10 15:24:53 +01:00)


## Anneal Schedule Plots

In [6]:
# load the anneal schedule data
anneal_schedule_data = pd.read_csv(
    project_dir
    / "data/anneal_schedules/csv/09-1265A-A_Advantage_system5_1_annealing_schedule.csv",
    index_col="s",
)
# for some reason 0.5 is missing for Advantage_system5.1 so we need to interpolate
if 0.5 not in anneal_schedule_data.index:
    anneal_schedule_data.loc[0.5] = (
        anneal_schedule_data.loc[0.499] + anneal_schedule_data.loc[0.501]
    ) / 2
anneal_schedule_data.sort_index(inplace=True)

plot_dir_anneal_schedules = project_dir / f"artifacts/plots/anneal_schedules"
if not plot_dir_anneal_schedules.exists():
    plot_dir_anneal_schedules.mkdir(parents=True)

for run_name, run_info in run_infos.items():
    t_pause = run_info["t_pause"]
    s_pause = run_info["s_pause"]
    pause_duration = run_info["pause_duration"]
    α_quench = run_info["quench_slope"]
    quench_duration = (1 - s_pause) / α_quench
    anneal_schedule = [
        (0, 0),
        (t_pause, s_pause),
        (t_pause + pause_duration, s_pause),
        (t_pause + pause_duration + quench_duration, 1),
    ]
    title = fr"$t_{{pause}} = {t_pause:.0f}$ μs, $s_{{pause}} = {s_pause}$, $\Delta_{{pause}} = {pause_duration:.0f}$ μs, $\alpha_{{quench}} = {α_quench:.0f}$"

    s_left = np.arange(0, s_pause + 1e-3, 1e-3)
    s_right = np.arange(s_pause + 1e-3, 1 + 1e-3, 1e-3)
    s = np.round(np.concatenate((s_left, s_right)), 3)
    t_left = np.linspace(0, t_pause, len(s_left))
    t_right = np.linspace(
        t_pause + pause_duration,
        t_pause + pause_duration + quench_duration,
        len(s_right),
    )
    t = np.round(np.concatenate((t_left, t_right)), 3)

    fig, axs = plt.subplots(2, 1, figsize=(10, 12), dpi=300)
    fig.suptitle("Anneal Schedule")

    axs[0].set_title(title)
    axs[0].plot(
        t,
        anneal_schedule_data["A(s) (GHz)"],
        color="tab:blue",
        linewidth=2,
        label="A(s)",
    )
    axs[0].plot(
        t,
        anneal_schedule_data["B(s) (GHz)"],
        color="tab:red",
        linewidth=2,
        label="B(s)",
    )
    axs[0].set_xlabel(r"$t$ [μs]")
    axs[0].set_ylabel(r"$E$ [GHz]")
    axs[0].grid()
    axs[0].legend()

    axs[1].plot(t, s, color="tab:blue", linewidth=2)
    axs[1].set_xlabel(r"$t$ [μs]")
    axs[1].set_ylabel(r"$s$")
    axs[1].grid()

    plt.tight_layout()
    plt.savefig(plot_dir_anneal_schedules / f"{run_name}.png")
    fig.clear()
    plt.close(fig)

time: 55.1 s (started: 2022-01-07 13:37:09 +01:00)


## Histograms

In [9]:
plot_dir_histograms_exact = config_dir / f"plots/histograms/exact"
if not plot_dir_histograms_exact.exists():
    plot_dir_histograms_exact.mkdir(parents=True)
    
for (s, T), data in exact_data.items():
    if np.isnan(data["p"]).any() or abs(data["p"].sum() - 1) > 1e-2:
        continue
    title = fr"$s = {s}, \ T = {T}$"
    fig, ax = plot_histogram(data["E"], title, weights=data["p"])
    plt.savefig(plot_dir_histograms_exact / f"s={s:.2f},T={T:.3f}.png")
    fig.clear()
    plt.close(fig)

time: 8min 39s (started: 2022-01-04 23:01:28 +01:00)


In [10]:
plot_dir_histograms_samples = (
    config_dir / f"plots/histograms/embedding_{embedding_id:02}"
)
if not plot_dir_histograms_samples.exists():
    plot_dir_histograms_samples.mkdir(parents=True)

for run_name in run_names:
    energy_densities = {}
    for gauge_dir in gauge_dirs:
        samples = load_artifact(gauge_dir / f"{run_name}.pkl")
        for (energy, count) in zip(samples.record.energy, samples.record.num_occurrences):
            density = count / samples.record.num_occurrences.sum() / len(gauge_dirs)
            if energy in energy_densities:
                energy_densities[energy] += density
            else:
                energy_densities[energy] = density

    run_info = run_infos[run_name]
    title = fr"$t_{{a}} = {int(run_info['anneal_duration'])} \ μs, \ s_{{pause}} = {run_info['s_pause']}, \ \Delta_{{pause}} = {int(run_info['pause_duration'])} \ μs, \ \alpha_{{quench}} = {int(run_info['max_slope'])}$"

    fig, ax = plot_histogram(
        list(energy_densities.keys()), title, weights=list(energy_densities.values())
    )
    plt.savefig(plot_dir_histograms_samples / f"{run_name}.png")
    fig.clear()
    plt.close(fig)

time: 7.77 s (started: 2022-01-04 23:10:08 +01:00)
