# Sample Analysis

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext autotime
%load_ext line_profiler

from joblib import Parallel, delayed

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from dwave.system import DWaveSampler, FixedEmbeddingComposite
from matplotlib.patches import Rectangle
from numba import njit
from scipy.constants import k as k_B, h as h_P

k_B /= h_P * 1e9

from qbm.utils import (
    compute_stats_over_dfs,
    convert_bin_list_to_str,
    get_project_dir,
    get_rng,
    load_artifact,
    save_artifact,
)

project_dir = get_project_dir()

time: 1.64 s (started: 2022-02-23 13:41:11 +01:00)


## Analysis functions

In [2]:
@njit(boundscheck=True)
def kl_divergence(
    p_exact,
    E_exact,
    E_sample,
    counts_sample,
    n_bins=32,
    prob_sum_tol=1e-6,
    ϵ_smooth=1e-6,
):
    """
    Computes the KL divergence of the theory w.r.t. the sample, i.e., 
    D_KL(p_exact || p_sample).
    
    :param p_exact: Exact computed probability vector, i.e., the diagonal of ρ.
    :param E_exact: Exact computed energy vector, i.e., the diagonal of H.
    :param E_sample: Energies of the sample.
    :param n_bins: Number of bins to compute over.
    :param prob_sum_tol: The tolerance for the probabilities to sum up to approx 1.
    :param ϵ_smooth: Smoothing parameter for the sample distribution.
    
    :returns: D_KL(p_exact || p_sample).
    """
    p = np.zeros(n_bins)
    q = np.zeros(n_bins)
    
    # compute the bin edges
    buffer = np.abs(E_exact).max() * 1e-15
    bin_edges = np.linspace(E_exact.min() - buffer, E_exact.max() + buffer, n_bins + 1)
    
    # check that bin edges include all possible E values
    assert bin_edges.min() <= E_exact.min()
    assert bin_edges.max() >= E_exact.max()
    
    # bin the probabilities
    sum_counts = counts_sample.sum()
    for i, (a, b) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
        if i < n_bins - 1:
            p[i] = p_exact[np.logical_and(E_exact >= a, E_exact < b)].sum()
            q[i] = (
                counts_sample[np.logical_and(E_sample >= a, E_sample < b)].sum()
                / sum_counts
            )
        else:
            p[i] = p_exact[E_exact >= a].sum()
            q[i] = counts_sample[E_sample >= a].sum() / sum_counts

    # smoothing of sample data
    smooth_mask = np.logical_and(p > 0, q == 0)
    not_smooth_mask = np.logical_not(smooth_mask)
    q[smooth_mask] = p[smooth_mask] * ϵ_smooth
    q[not_smooth_mask] -= q[smooth_mask].sum() / not_smooth_mask.sum()

    # check that p and q sum up to approx 1
    assert np.abs(p.sum() - 1) < prob_sum_tol
    assert np.abs(q.sum() - 1) < prob_sum_tol

    # take intersection of supports to avoid div zero errors
    support_intersection = np.logical_and(p > 0, q > 0)
    p = p[support_intersection]
    q = q[support_intersection]

    return (p * np.log(p / q)).sum()


@njit(boundscheck=True)
def get_state_energies(states, E_exact):
    """
    Returns the (quantum + classical) energies of the provided states corresponding
    to the provided exact calculated energies.
    
    :param states: Array of states. Must be a value in 0, 1, ..., 2 ** n_qubits - 1.
    :param E_exact: Array of exact computed energies, corresponds to the diagonal of H.
    
    :returns: Array where entry i is the energy of states[i].
    """
    E_sample = np.zeros(len(states))
    for i, state in enumerate(states):
        E_sample[i] = E_exact[state]

    return E_sample


def convert_spin_vector_to_state_number(spins):
    """
    Converts the spins vector (e.g. all values ±1) to an integer corresponding to the state.
    For example, the spin vector [1, 1, 1, 1] corresponds to the state |0000⟩ which is the
    0th state. The spin vector [-1, -1, -1, -1] corresponds to the state |1111⟩ which is the
    15th state.
    
    :param spins: Vector of spin values (±1).
    
    :returns: Integer corresponding to the state. 
    """
    bit_vector = ((1 - spins) / 2).astype(np.int64)

    return (bit_vector * 2 ** np.arange(len(spins) - 1, -1, -1)).sum()


def kl_divergence_df(exact_data, sample):
    """
    Compares each exact computed data distribution against the provided sample instance.
    
    :param exact_data: Dictionary with keys of the form (s, T) with s being the relative
        anneal time at which H and ρ were computed, and T being the effective temperature.
        Values are of the form {"E": [...], "p": [...]}
    :param sample: Instance of Ocean SDK SampleSet.
    
    :returns: Dataframe of KL divergences, with T values as index and s values as columns.
    """
    # convert spin vectors to state numbers
    states = np.array(
        [convert_spin_vector_to_state_number(x) for x in sample.record.sample]
    )

    dkl = {}
    for s, T in exact_data.keys():
        p_exact = exact_data[(s, T)]["p"]
        E_exact = exact_data[(s, T)]["E"]
        E_sample = get_state_energies(states, E_exact)

        dkl[int(T * 1000), s] = kl_divergence(
            p_exact, E_exact, E_sample, sample.record.num_occurrences
        )

    return pd.Series(dkl)


def process_run_gauge_dir(run, gauge_dir, exact_data):
    """
    Helper function for processing the runs and computing the KL divergences
    in parallel.
    
    :param run: Name of the run.
    :param gauge_dir: Directory of the gauge data.
    :param exact_data: Exact computed data to compare against.
    
    :returns: KL divergence dataframe.
    """
    sample = load_artifact(gauge_dir / f"{run}.pkl")
    dkl_df = kl_divergence_df(exact_data, sample)

    return dkl_df

time: 19.2 ms (started: 2022-02-23 13:41:12 +01:00)


## Data Loading

In [3]:
config_id = 2
n_jobs = 6
best_embedding_id = 5

config_dir = project_dir / f"artifacts/exact_analysis/{config_id:02}/"
embedding_dirs = sorted([x for x in (config_dir / "samples").iterdir()])

config = load_artifact(config_dir / "config.json")
exact_data = load_artifact(config_dir / "exact_data.pkl")

gauge_dirs = sorted(
    [x for x in embedding_dirs[best_embedding_id - 1].iterdir() if x.name.startswith("gauge_")]
)
run_names = sorted([x.stem for x in gauge_dirs[0].iterdir() if x.name != "gauge.pkl"])

run_infos = {}
t_as = []
s_pauses = []
anneal_durations = []
pause_durations = []
for run_name in run_names:
    run_info = {x.split("=")[0]: x.split("=")[1] for x in run_name.split("-")}
    for k, v in run_info.items():
        if k in ("t_pause", "s_pause", "pause_duration", "quench_slope"):
            run_info[k] = float(v)
    for k, v in run_info.items():
        if k in ("reverse", "reinit") and v == "True":
            run_info[k] = True
        elif k in ("reverse", "reinit") and v == "False":
            run_info[k] = False

    if "reverse" in run_info:
        run_info["t_a"] = round(run_info["t_pause"] / (1 - run_info["s_pause"]), 1)
    else:
        run_info["reverse"] = False
        run_info["reinit"] = True
        run_info["t_a"] = round(run_info["t_pause"] / run_info["s_pause"], 1)
    run_infos[run_name] = run_info

    if run_info["t_a"] not in t_as:
        t_as.append(run_info["t_a"])

    if run_info["s_pause"] not in s_pauses:
        s_pauses.append(run_info["s_pause"])

    if run_info["pause_duration"] not in pause_durations:
        pause_durations.append(run_info["pause_duration"])

t_as = sorted(t_as)
s_pauses = sorted(s_pauses)
pause_durations = sorted(pause_durations)
anneal_durations = sorted(anneal_durations)

t_as = [x for x in t_as if x != 10]
run_names_ = []
run_names_best_embedding = []
for run_name in run_names:
    run_info = run_infos[run_name]
    if (
        round(run_info["s_pause"] * 100) % 5 == 0
        and run_info["pause_duration"] == 0
        and run_info["t_a"] == 20
        and not run_info["reverse"]
    ):
        run_names_.append(run_name)
    if (
        run_info["pause_duration"] == 0
        and run_info["t_a"] == 20
        and not run_info["reverse"]
    ):
        run_names_best_embedding.append(run_name)
        
run_names_best_embedding = sorted(run_names_best_embedding, key=lambda x: run_infos[x]["s_pause"])
run_infos_best_embedding = {k: v for k, v in run_infos.items() if k in run_names_best_embedding}
run_names = sorted(run_names_, key=lambda x: run_infos[x]["s_pause"])
run_infos = {k: v for k, v in run_infos.items() if k in run_names}

time: 360 ms (started: 2022-02-23 13:41:13 +01:00)


## KL Divergence Computations

In [4]:
compute_kl_divergences = False
dkls_file_path = config_dir / "kl_divergences_embedding_comparison.pkl"
if not dkls_file_path.exists() or compute_kl_divergences:
    dkls_embeddings = {}
    for embedding_dir in embedding_dirs:
        embedding_id = int(embedding_dir.name.split("_")[-1])
        gauge_dirs = sorted(
            [x for x in embedding_dir.iterdir() if x.name.startswith("gauge_")]
        )

        dkls_embeddings[embedding_id] = {}
        for run_name in run_names:
            if run_name not in dkls_embeddings[embedding_id]:
                dkl_dfs = Parallel(n_jobs=n_jobs)(
                    delayed(process_run_gauge_dir)(run_name, gauge_dir, exact_data)
                    for gauge_dir in gauge_dirs
                )
                dkls_embeddings[embedding_id][run_name] = compute_stats_over_dfs(
                    dkl_dfs
                )

    save_artifact(dkls_embeddings, dkls_file_path)
else:
    dkls_embeddings = load_artifact(dkls_file_path)

dkls_best_file_path = config_dir / "kl_divergences_best_embedding.pkl"
if not dkls_best_file_path.exists() or compute_kl_divergences:
    dkls_best_embedding = {}
    embedding_dir = [
        x for x in embedding_dirs if str(x).endswith(f"embedding_{best_embedding_id:02}")
    ][0]
    gauge_dirs = sorted(
        [x for x in embedding_dir.iterdir() if x.name.startswith("gauge_")]
    )

    for run_name in run_names_best_embedding:
        if run_name not in dkls_best_embedding:
            dkl_dfs = Parallel(n_jobs=n_jobs)(
                delayed(process_run_gauge_dir)(run_name, gauge_dir, exact_data)
                for gauge_dir in gauge_dirs
            )
            dkls_best_embedding[run_name] = compute_stats_over_dfs(dkl_dfs)

    save_artifact(dkls_best_embedding, dkls_best_file_path)
else:
    dkls_best_embedding = load_artifact(dkls_best_file_path)

time: 695 ms (started: 2022-02-23 13:41:14 +01:00)


## KL Divergence Min Value Plots

In [5]:
plot_dir_dkl_min = project_dir / f"results/plots/qbm/8x4/embedding_comparison"
if not plot_dir_dkl_min.exists():
    plot_dir_dkl_min.mkdir(parents=True)

α_quench = 2.0
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
# ax.set_title(r"$t_a = 20$ μs, $\Delta_{{pause}} = 0$ μs, $\alpha_{{quench}} = 2$, $h_i, J_{ij} \sim \mathcal{N}(0, 0.1)$")
ax.set_xlabel(r"$s_{{quench}}$")
ax.set_ylabel(r"$\min_{s,T}\{D_{KL}(p_{exact} \ || \ p_{sample})\}$")
markers = ["o", "^", "v", "<", ">", "s", "p", "*", "P", "X"]
colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]
xs = []
ys = []
y_errs = []
for embedding_id, dkls_embedding in dkls_embeddings.items():
#     if embedding_id in (2, 7, 4, 6):
#         continue
    ax.set_xticks(np.arange(0.25, 1.05, 0.05))
    ax.set_yticks(np.arange(0, 0.045, 0.005))
    ax.set_ylim(0, 0.04)

    x = []
    y = []
    y_err = []
    print(embedding_id)
    for run_name in run_names:
        if round(run_infos[run_name]["s_pause"] * 100) % 5 != 0:
            continue
        print(run_name)
        x.append(run_infos[run_name]["s_pause"])
        argmin = np.argmin(dkls_embedding[run_name]["means"])
        y.append(dkls_embedding[run_name]["means"].iloc[argmin])
        y_err.append(dkls_embedding[run_name]["stds"].iloc[argmin])
        print(dkls_embedding[run_name]["means"].index[argmin])

        dkl_means = dkls_embedding[run_name]["means"].copy()
        dkl_stds = dkls_embedding[run_name]["stds"].copy()
        dkl_means.index = pd.MultiIndex.from_tuples(dkl_means.index)
        dkl_stds.index = pd.MultiIndex.from_tuples(dkl_stds.index)
        dkl_means = dkl_means.unstack(level=-1)[1.0]
        dkl_stds = dkl_stds.unstack(level=-1)[1.0]
        argmin = np.argmin(dkl_means)
        print(y[-1] - dkl_means.iloc[argmin])
        
    x = np.array(x)
    y = np.array(y)
    y_err = np.array(y_err)
    ys.append(y)
    y_errs.append(y_err)
    
    label = fr"Embedding {embedding_id}"
    ax.fill_between(x, y - y_err, y + y_err, interpolate=True, color=colors[embedding_id-1], alpha=0.10)
    ax.plot(
        x,
        y,
        marker=markers[embedding_id - 1],
        markersize=10,
        linewidth=1.2,
        label=label,
        color=colors[embedding_id - 1],
    )

ys = np.vstack(ys)
y_errs = np.vstack(y_errs)
y = np.mean(ys, axis=0)
y_err = np.sqrt(np.sum(y_errs ** 2, axis=0) / (len(y_errs) - 1))

ax.fill_between(x, y - y_err, y + y_err, interpolate=True, color="k", alpha=0.10)
ax.plot(
    x,
    y,
    marker="d",
    markersize=10,
    linewidth=1.2,
    linestyle="--",
    label="Average",
    fillstyle="none",
    mew=2,
    color="k",
)

ax.grid(True)
ax.legend(ncol=2)
plt.tight_layout()
plt.savefig(plot_dir_dkl_min / f"kl_divergence_mins.png")

1
t_pause=5.0-s_pause=0.25-pause_duration=0.0-quench_slope=2.0
(64, 0.65)
-7.202486547459343e-05
t_pause=6.0-s_pause=0.3-pause_duration=0.0-quench_slope=2.0
(92, 0.84)
-1.0727401095412398e-05
t_pause=7.0-s_pause=0.35-pause_duration=0.0-quench_slope=2.0
(86, 0.82)
-5.589581369544114e-05
t_pause=8.0-s_pause=0.4-pause_duration=0.0-quench_slope=2.0
(104, 0.93)
-7.733179684810826e-05
t_pause=9.0-s_pause=0.45-pause_duration=0.0-quench_slope=2.0
(98, 0.92)
-1.5378251493900277e-05
t_pause=10.0-s_pause=0.5-pause_duration=0.0-quench_slope=2.0
(74, 0.8)
-5.461874699373703e-06
t_pause=11.0-s_pause=0.55-pause_duration=0.0-quench_slope=2.0
(46, 0.62)
-6.3316406117217006e-06
t_pause=12.0-s_pause=0.6-pause_duration=0.0-quench_slope=2.0
(76, 0.88)
-1.9299320016719253e-05
t_pause=13.0-s_pause=0.65-pause_duration=0.0-quench_slope=2.0
(46, 0.64)
-9.456223128108919e-05
t_pause=14.0-s_pause=0.7-pause_duration=0.0-quench_slope=2.0
(56, 0.73)
-0.00010413780126480691
t_pause=15.0-s_pause=0.75-pause_duration=0.

<Figure size 3000x1800 with 1 Axes>

time: 2.25 s (started: 2022-02-23 13:41:14 +01:00)


In [6]:
plot_dir_dkl_min = project_dir / f"results/plots/qbm/8x4/embedding_comparison"
if not plot_dir_dkl_min.exists():
    plot_dir_dkl_min.mkdir(parents=True)

α_quench = 2.0
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
# ax.set_title(r"$t_a = 20$ μs, $\Delta_{{pause}} = 0$ μs, $\alpha_{{quench}} = 2$, $h_i, J_{ij} \sim \mathcal{N}(0, 0.1)$")
ax.set_xlabel(r"$s_{{quench}}$")
ax.set_ylabel(r"$\min_{s,T}\{D_{KL}(p_{exact} \ || \ p_{sample})\}$")
markers = ["o", "^", "v", "<", ">", "s", "p", "*", "P", "X"]
colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]
xs = []
ys = []
y_errs = []
for embedding_id, dkls_embedding in dkls_embeddings.items():
    if embedding_id in (2, 7, 4, 6):
        continue
    ax.set_xticks(np.arange(0.25, 1.05, 0.05))
    ax.set_yticks(np.arange(0, 0.045, 0.005))
    ax.set_ylim(0, 0.04)

    x = []
    y = []
    y_err = []
    for run_name in run_names:
        if round(run_infos[run_name]["s_pause"] * 100) % 5 != 0:
            continue
        x.append(run_infos[run_name]["s_pause"])
        argmin = np.argmin(dkls_embedding[run_name]["means"])
        y.append(dkls_embedding[run_name]["means"].iloc[argmin])
        y_err.append(dkls_embedding[run_name]["stds"].iloc[argmin])

    x = np.array(x)
    y = np.array(y)
    y_err = np.array(y_err)
    ys.append(y)
    y_errs.append(y_err)
    
    label = fr"Embedding {embedding_id}"
    ax.fill_between(x, y - y_err, y + y_err, interpolate=True, color=colors[embedding_id-1], alpha=0.10)
    ax.plot(
        x,
        y,
        marker=markers[embedding_id - 1],
        markersize=10,
        linewidth=1.2,
        label=label,
        color=colors[embedding_id - 1],
    )

ys = np.vstack(ys)
y_errs = np.vstack(y_errs)
y = np.mean(ys, axis=0)
y_err = np.sqrt(np.sum(y_errs ** 2, axis=0) / (len(y_errs) - 1))

ax.fill_between(x, y - y_err, y + y_err, interpolate=True, color="k", alpha=0.10)
ax.plot(
    x,
    y,
    marker="d",
    markersize=10,
    linewidth=1.2,
    linestyle="--",
    label="Average",
    fillstyle="none",
    mew=2,
    color="k",
)

ax.grid(True)
ax.legend(ncol=2)
plt.tight_layout()
plt.savefig(plot_dir_dkl_min / f"kl_divergence_mins_subset.png")

<Figure size 3000x1800 with 1 Axes>

time: 975 ms (started: 2022-02-23 13:41:17 +01:00)


In [7]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
# ax.set_title(r"$t_a = 20$ μs, $\Delta_{{pause}} = 0$ μs, $\alpha_{{quench}} = 2$, $h_i, J_{ij} \sim \mathcal{N}(0, 0.1)$")
ax.set_xlabel(r"$s_{{quench}}$")
ax.set_ylabel(r"$\min_{s,T}\{D_{KL}(p_{exact} \ || \ p_{sample})\}$")
ax.set_xticks(np.arange(0.25, 1.05, 0.05))
ax.set_yticks(np.arange(0, 0.03, 0.005))
ax.set_ylim(0, 0.025)
markers = ["o", "^", "v", "<", ">", "s", "p", "*", "P", "X"]
colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]
x = []
y = []
y_err = []
for run_name, dkls_embedding in dkls_best_embedding.items():
    run_info = run_infos_best_embedding[run_name]

    x.append(run_info["s_pause"])
    argmin = np.argmin(dkls_embedding["means"])
    y.append(dkls_embedding["means"].iloc[argmin])
    y_err.append(dkls_embedding["stds"].iloc[argmin])

x = np.array(x)
y = np.array(y)
y_err = np.array(y_err)
    
label = fr"Embedding {best_embedding_id}"
ax.fill_between(x, y - y_err, y + y_err, interpolate=True, color=colors[embedding_id-1], alpha=0.10)
ax.plot(
    x,
    y,
    marker=markers[best_embedding_id - 1],
    markersize=10,
    linewidth=1.2,
    label=label,
    color=colors[best_embedding_id - 1],
)

ax.grid(True)
ax.legend(ncol=2)
plt.tight_layout()
plt.savefig(plot_dir_dkl_min / f"kl_divergence_mins_best_embedding.png")

<Figure size 3000x1800 with 1 Axes>

time: 714 ms (started: 2022-02-23 13:41:20 +01:00)


## $T_{\text{optimal}}(s^*)$ Plot

In [8]:
plot_dir_dkl_min = project_dir / f"results/plots/qbm/8x4/embedding_comparison"
if not plot_dir_dkl_min.exists():
    plot_dir_dkl_min.mkdir(parents=True)

α_quench = 2.0
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
ax.set_title(r"Optimal Effective Temperature to Approximate Boltzmann Distribution at $s = 1$")
ax.set_xlabel(r"$s_{{quench}}$")
ax.set_ylabel(r"$T$ [mK]")
markers = ["o", "^", "v", "<", ">", "s", "p", "*", "P", "X"]
colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]
ys = []
for embedding_id, dkls_embedding in dkls_embeddings.items():
    ax.set_xticks(np.arange(0.25, 1.05, 0.05))

    x = []
    y = []
    for run_name in run_names:
        dkls_ = dkls_embedding[run_name]["means"].copy()
        x.append(run_infos[run_name]["s_pause"])
        argmin = np.argmin(dkls_)
        dkls_.index = pd.MultiIndex.from_tuples(dkls_.index)
        dkls_ = dkls_.unstack(level=-1)
        dkls_1 = dkls_[1]
        argmin = np.argmin(dkls_1)
        y.append(dkls_.index[argmin])

    x = np.array(x)
    y = np.array(y)
    ys.append(y)
    
    label = fr"Embedding {embedding_id}"
    ax.plot(
        x,
        y,
        marker=markers[embedding_id - 1],
        markersize=10,
        linewidth=1.2,
        label=label,
        color=colors[embedding_id - 1],
    )

ys = np.vstack(ys)
y_errs = np.vstack(y_errs)
y = np.mean(ys, axis=0)
y_err = np.sqrt(np.sum(y_errs ** 2, axis=0) / (len(y_errs) - 1))

ax.plot(
    x,
    y,
    marker="d",
    markersize=10,
    linewidth=1.2,
    linestyle="--",
    label="Average",
    fillstyle="none",
    mew=2,
    color="k",
)

ax.grid(True)
ax.legend(ncol=2)
plt.tight_layout()
plt.savefig(plot_dir_dkl_min / f"optimal_distribution_temp.png")

<Figure size 3000x1800 with 1 Axes>

time: 1.19 s (started: 2022-02-23 13:41:22 +01:00)


## KL Divergence Heatmap

In [9]:
# load anneal schedule
df_anneal = pd.read_csv(
    project_dir
    / "data/anneal_schedules/csv/09-1265A-A_Advantage_system5_1_annealing_schedule.csv",
    index_col="s",
)
if 0.5 not in df_anneal.index:
    df_anneal.loc[0.5] = (df_anneal.loc[0.499] + df_anneal.loc[0.501]) / 2
df_anneal.sort_index(inplace=True)

time: 4.22 ms (started: 2022-02-23 13:41:25 +01:00)


In [54]:
embedding_id = 5
s_pause = 0.55
run_name = [x for x in run_names if f"s_pause={s_pause}" in x][0]

plot_dir_heatmap = project_dir / f"results/plots/qbm/8x4"
if not plot_dir_heatmap.exists():
    plot_dir_heatmap.mkdir(parents=True)

dkl = dkls_embeddings[embedding_id][run_name]["means"].copy()
dkl.index = pd.MultiIndex.from_tuples(dkl.index)
dkl = dkl.unstack(level=-1)

run_info = run_infos[run_name]
t_pause = run_info["t_pause"]
s_pause = run_info["s_pause"]
pause_duration = run_info["pause_duration"]
α_quench = run_info["quench_slope"]
title = fr"Embedding {embedding_id}, $t_{{pause}} = {t_pause:.1f}$ μs, $s_{{pause}} = {s_pause}$, $\Delta_{{pause}} = {pause_duration:.0f}$ μs, $\alpha_{{quench}} = {α_quench:.0f}$, $h_i, J_{{ij}} \sim \mathcal{{N}}(0, 0.1)$"

cbar_kws = {"label": r"$D_{KL}(p_{exact} \ || \ p_{sample})$"}
cmap = sns.color_palette("rocket_r", as_cmap=True)
s_values = dkl.columns.to_numpy()
T_values = dkl.index.to_numpy()
xticks = np.arange(len(s_values))[::10]
xticklabels = s_values[::10]
yticks = np.arange(len(T_values))[::10]
yticklabels = T_values[::10]

fig, ax = plt.subplots(figsize=(8, 6), dpi=300)

sns.heatmap(dkl, ax=ax, cmap=cmap, vmin=0, vmax=0.2, cbar_kws=cbar_kws)
# ax.set_title(title)
ax.invert_yaxis()
ax.set_xlabel(r"$s$")
ax.set_ylabel(r"$T$ [mK]")
ax.set_xticks(xticks + 0.5)
ax.set_xticklabels(xticklabels, rotation=0)
ax.set_yticks(yticks + 0.5)
ax.set_yticklabels(yticklabels, rotation=0)

# plot D-Wave temp
T_DW = 16.4 / 2
ax.axhline(T_DW, c="w", label=r"$T_{DW} = 16.4 \pm 0.1$ mK")

# plot constant βB(s)
s_mins = np.round(np.arange(50, 101, 1) / 100, 2)
T_mins = []
B_mins = []
βB_mins = []
for s in s_mins:
    i = np.argmin(dkl.loc[:, s])
    T_mins.append(dkl.index[i])
    B_mins.append(df_anneal.loc[s, "B(s) (GHz)"])
    βB_mins.append(B_mins[-1] / k_B / T_mins[-1])
T_mins = 1 / k_B / np.mean(βB_mins) * np.array(B_mins)
ax.plot(
    s_mins * 100,
    T_mins / 2,
    linestyle="--",
    color="k",
    label=fr"$B(s) / T = {np.mean(βB_mins) * k_B:.3f}}}$",
)

ax.legend()
plt.tight_layout()
plt.savefig(plot_dir_heatmap / f"dkl_min_heatmap.png")

<Figure size 2400x1800 with 2 Axes>

time: 716 ms (started: 2022-02-23 14:19:28 +01:00)
