In [1]:
%load_ext autoreload
%autoreload 2
%load_ext autotime
%load_ext line_profiler

from joblib import Parallel, delayed

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from dwave.system import DWaveSampler, FixedEmbeddingComposite

from qbm.utils import (
    compute_stats_over_dfs,
    convert_bin_list_to_str,
    get_project_dir,
    get_rng,
    load_artifact,
    save_artifact,
)

time: 1.45 s (started: 2021-12-29 12:38:03 +01:00)


In [2]:
project_dir = get_project_dir()

time: 862 µs (started: 2021-12-29 12:38:04 +01:00)


## Sample Analysis

In [3]:
from numba import njit


@njit(boundscheck=True)
def compute_KL_divergence(
    p_exact,
    E_exact,
    E_samples,
    counts_samples,
    n_bins=32,
    normalize=False,
    invert=False,
    prob_sum_tol=1e-2
):
    """
    Computes the KL divergence of the theory w.r.t. the samples, i.e., 
    D_KL(p_exact || p_samples).
    
    :param p_exact: Exact computed probability vector, i.e., the diagonal of ρ.
    :param E_exact: Exact computed energy vector, i.e., the diagonal of H.
    :param E_samples: Energies of the samples.
    :param n_bins: Number of bins to compute over.
    :param normalize: If True will normalize the bin probabilities.
    :param invert: If true will compute D_KL(p_samples || p_exact).
    :param prob_sum_tol: The tolerance for the probabilities to sum up to approx 1.
    
    :returns: D_KL(p_exact || p_samples).
    """
    p = np.zeros(n_bins)
    q = np.zeros(n_bins)

    if np.isnan(p_exact).any() or np.abs(p_exact.sum() - 1) > prob_sum_tol:
        return np.nan

    bin_edges = np.linspace(E_exact.min(), E_exact.max(), n_bins + 1)
    sum_counts = counts_samples.sum()
    for i, (a, b) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
        if i < n_bins - 1:
            p[i] = p_exact[np.logical_and(E_exact >= a, E_exact < b)].sum()
            q[i] = (
                counts_samples[np.logical_and(E_samples >= a, E_samples < b)].sum()
                / sum_counts
            )
        else:
            p[i] = p_exact[E_exact >= a].sum()
            q[i] = counts_samples[E_samples >= a].sum() / sum_counts

    assert np.abs(p.sum() - 1) < prob_sum_tol
    assert np.abs(q.sum() - 1) < prob_sum_tol

    support_intersection = np.logical_and(p > 0, q > 0)
    p = p[support_intersection]
    q = q[support_intersection]

    if normalize:
        p /= p.sum()
        q /= q.sum()

    if invert:
        p, q = q, p

    return np.sum(p * np.log(p / q))


@njit(boundscheck=True)
def get_state_energies(states, E_exact):
    """
    Returns the (quantum + classical) energies of the provided states corresponding
    to the provided exact calculated energies.
    
    :param states: Array of states. Must be a value in 0, 1, ..., 2 ** n_qubits - 1.
    :param E_exact: Array of exact computed energies, corresponds to the diagonal of H.
    
    :returns: Array where entry i is the energy of states[i].
    """
    E_samples = np.zeros(len(states))
    for i, state in enumerate(states):
        E_samples[i] = E_exact[state]

    return E_samples


def convert_spin_vector_to_state_number(spins):
    """
    Converts the spins vector (e.g. all values ±1) to an integer corresponding to the state.
    For example, the spin vector [1, 1, 1, 1] corresponds to the state |0000⟩ which is the
    0th state. The spin vector [-1, -1, -1, -1] corresponds to the state |1111⟩ which is the
    15th state.
    
    :param spins: Vector of spin values (±1).
    
    :returns: Integer corresponding to the state. 
    """
    bit_vector = ((1 - spins) / 2).astype(np.int64)
    return (bit_vector * 2 ** np.arange(len(spins) - 1, -1, -1)).sum()


def compare_exact_to_samples(exact_data, samples):
    """
    Compares each exact computed data distribution against the provided samples instance.
    
    :param exact_data: Dictionary with keys of the form (s, T) with s being the relative
        anneal time at which H and ρ were computed, and T being the effective temperature.
        Values are of the form {"E": [...], "p": [...]}
    :param samples: Instance of Ocean SDK SampleSet.
    
    :returns: Dataframe of KL divergences, energies of the samples.
    """
    states = np.array(
        [convert_spin_vector_to_state_number(x) for x in samples.record.sample]
    )
    counts = samples.record.num_occurrences

    s_values = []
    T_values = []
    for s, T in exact_data.keys():
        s_values.append(s)
        T_values.append(T)
    s_values = np.sort(np.unique(s_values))
    T_values = np.sort(np.unique(T_values) * 1000).astype(np.int64)

    KL_divergence = pd.DataFrame(
        np.zeros((len(T_values), len(s_values))), index=T_values, columns=s_values,
    )
    E_samples = {}
    KL_divergence = {}
    for s, T in exact_data.keys():
        p_exact = exact_data[(s, T)]["p"]
        E_exact = exact_data[(s, T)]["E"]
        E_samples_sT = get_state_energies(states, E_exact)
        E_samples[(s, T)] = E_samples_sT
        KL_divergence[int(T * 1000), s] = compute_KL_divergence(
            p_exact, E_exact, E_samples_sT, counts, invert=False
        )

    KL_divergence = pd.DataFrame.from_dict(KL_divergence, orient="index")

    return KL_divergence, E_samples

time: 155 ms (started: 2021-12-29 12:38:04 +01:00)


In [4]:
n_jobs = 6

config_id = 4
embedding_id = 1

config_dir = project_dir / f"artifacts/exact_analysis/{config_id:02}/"
embedding_dir = config_dir / f"samples/embedding_{embedding_id:02}"

config = load_artifact(config_dir / "config.json")
exact_data = load_artifact(config_dir / "exact_data.pkl")


def process_run_gauge_dir(run, gauge_dir, exact_data):
    samples = load_artifact(gauge_dir / f"{run}.pkl")
    KL_divergence, E_samples = compare_exact_to_samples(exact_data, samples)
    return KL_divergence


gauge_dirs = sorted([x for x in embedding_dir.iterdir() if x.name.startswith("gauge_")])
runs = sorted([x.stem for x in gauge_dirs[0].iterdir() if x.name != "gauge.pkl"])
run_data = {}
for run in runs:
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_run_gauge_dir)(run, gauge_dir, exact_data)
        for gauge_dir in gauge_dirs
    )
    run_data[run] = compute_stats_over_dfs(results)

time: 1min 53s (started: 2021-12-29 12:38:04 +01:00)


In [28]:
def plot_heat_map(KL_divergence, title):
    cmap = sns.color_palette("rocket_r", as_cmap=True)
    s_values = KL_divergence.columns.to_numpy()
    T_values = KL_divergence.index.to_numpy()
    xticks = np.arange(len(s_values))[::5]
    xticklabels = s_values[::5]
    yticks = np.arange(len(T_values))
    yticklabels = T_values

    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
    sns.heatmap(KL_divergence, ax=ax, vmin=0, vmax=0.025, cmap=cmap)
    ax.set_title(title)
    ax.invert_yaxis()
    ax.set_xlabel(r"$s$")
    ax.set_ylabel(r"$T$ [mK]")
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels, rotation=0)
    ax.set_yticks(yticks)
    ax.set_yticklabels(yticklabels, rotation=0)
    plt.tight_layout()
    
    return fig, ax
    
plot_dir_heatmaps = config_dir / f"plots/heatmaps/embedding_{embedding_id:02}"
if not plot_dir_heatmaps.exists():
    plot_dir_heatmaps.mkdir(parents=True)
    
for name, KL_divergence in run_data.items():
    KL_divergence = KL_divergence["means"]
    KL_divergence.index = pd.MultiIndex.from_tuples(KL_divergence.index)
    KL_divergence = KL_divergence.unstack(level=-1)[0]
    
    title = {x.split("=")[0]: float(x.split("=")[1]) for x in name.split(",")}
    title['t_pause'] = title['anneal_duration'] * title['s_pause']
    title = fr"$t_{{pause}} = {title['t_pause']} \ μs, \ s_{{pause}} = {title['s_pause']}, \ \Delta_{{pause}} = {title['pause_duration']} \ μs, \ \alpha_{{quench}} = {title['max_slope']}$"
    
    fig, ax = plot_heat_map(KL_divergence, title)
    plt.savefig(plot_dir_heatmaps / f"{name}.png")
    plt.close(fig)

time: 18.8 s (started: 2021-12-29 13:00:25 +01:00)


In [42]:
def plot_histogram(E, title, weights=None):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=144)
    ax.hist(E, bins=32, density=True, weights=weights)
    ax.set_title(title)
    ax.set_xlabel(r"$E$")
    ax.grid()
    
    return fig, ax

time: 513 µs (started: 2021-12-29 14:06:59 +01:00)


In [43]:
plot_dir_histograms_exact = config_dir / f"plots/histograms/exact"
if not plot_dir_histograms_exact.exists():
    plot_dir_histograms_exact.mkdir(parents=True)
    
for (s, T), data in exact_data.items():
    if np.isnan(data["p"]).any() or abs(data["p"].sum() - 1) > 1e-2:
        continue
    E = np.random.choice(data["E"], size=10 ** 6, p=data["p"])
    title = fr"$s = {s}, \ T = {T}$"
    fig, ax = plot_histogram(E, title)
    plt.savefig(plot_dir_histograms_exact / f"s={s:.2f},T={T:.3f}.png")
    plt.close(fig)

time: 5min (started: 2021-12-29 14:07:01 +01:00)


In [44]:
plot_dir_histograms_samples = config_dir / f"plots/histograms/embedding_{embedding_id:02}"
if not plot_dir_histograms_samples.exists():
    plot_dir_histograms_samples.mkdir(parents=True)
    
for run in runs:
    for gauge_dir in gauge_dirs:
        samples = load_artifact(gauge_dir / f"{run}.pkl")
        
        title = {x.split("=")[0]: float(x.split("=")[1]) for x in run.split(",")}
        title['t_pause'] = title['anneal_duration'] * title['s_pause']
        title = fr"$t_{{pause}} = {title['t_pause']} \ μs, \ s_{{pause}} = {title['s_pause']}, \ \Delta_{{pause}} = {title['pause_duration']} \ μs, \ \alpha_{{quench}} = {title['max_slope']}$"

        fig, ax = plot_histogram(samples.record.energy, title, weights=samples.record.num_occurrences)
        plt.savefig(plot_dir_histograms_samples / f"{run}_{gauge_dir.name}.png")
        plt.close(fig)

time: 1min 3s (started: 2021-12-29 14:12:01 +01:00)
