# Code: Laplace Approximation for Bayesian Deep Learning

In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

In [None]:
def get_data(model, data, estimator, plot):
    filename = f"{model}_{data}_{estimator}_{plot}"
    path = f"/home/matthias/Data/Ubuntu/git/hummat.github.io/data/{filename}.npy"
    try:
        return np.load(path, allow_pickle=True).item()
    except FileNotFoundError:
        !git clone https://github.com/hummat/curvature.git
        from curvature.curvature.utils import expected_calibration_error, calibration_curve

        data = np.load(f"/home/matthias/Data/Ubuntu/results/{model}/data/{estimator}/{model}_{data}.npz")
        
        labels, probabilities, bnn_predictions = data["labels"], data["predictions"], data["bnn_predictions"]
        
        if plot == "reliability":
            ece, bin_aces, bin_accs, bin_confs = expected_calibration_error(probabilities, labels)
        elif plot == "calibration":
            ece, bin_confs, bin_accs, _ = calibration_curve(probabilities, labels)
            bin_aces = bin_confs - bin_accs
        acc = 100 * np.mean(np.argmax(probabilities, axis=1) == labels)
        
        !rm -rf curvature
        
        np.save(path,
               {"acc": acc,
                "ece": ece,
                "bin_aces": bin_aces,
                "bin_accs": bin_accs,
                "bin_confs": bin_confs})

        return np.load(path, allow_pickle=True).item()

## Reliability diagram

In [None]:
data = get_data("densenet121", "imagenet", "diag", "reliability")
x = np.linspace(0.05, 1.05, 11)

fig = go.Figure(data=[go.Bar(name=f"Accuracy | {data['acc']:.2f}%",
                             x=x,
                             y=data["bin_accs"],
                             hoverinfo="text",
                             hovertext=[f"{100 * acc:.2f}%" for acc in data["bin_accs"]],
                             width=0.1,
                             marker_line_color="black"),
                      go.Bar(name=f"ECE | {100 * data['ece']:.2f}%",
                             x=x,
                             y=data["bin_aces"],
                             hoverinfo="text",
                             hovertext=[f"{100 * ace:.2f}%" for ace in data["bin_aces"]],
                             width=0.1,
                             marker=dict(color="rgba(255, 0, 0, 0.5)",
                                         line=dict(color="crimson"))),
                      go.Scatter(x=[0, 0.5, 1],
                                 y=[0, 0.5, 1],
                                 name="Perfect calibration",
                                 hoverinfo="none",
                                 mode="lines",
                                 marker=dict(color="black"),
                                 line=dict(dash="dash",
                                           width=1))])
grid = False
visible = True
fig.update_layout(
    xaxis=dict(
        visible=visible,
        range=[0, 1],
        showgrid=grid,
        constrain="domain",
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        visible=visible,
        range=[data["bin_aces"].min() - 0.03, 1],
        showgrid=grid,
        scaleanchor='x',
        title="Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.18,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="stack",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode='x',
    height=700,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Calibration curve

In [None]:
from matplotlib import pyplot as plt

data = get_data("densenet121", "imagenet", "diag", "calibration")
fig, ax = plt.subplots(figsize=(12, 7), tight_layout=True)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(direction='out', labelsize=14, right=False, top=False)
ax.set_xlabel('Confidence', fontsize=16)
ax.set_ylabel('Accuracy', fontsize=16)

ax.set_ylabel('Confidence - Accuracy', fontsize=16)

ax.axhline(0, color='black', linestyle='--', linewidth=1)
ax.plot(data["bin_confs"], data["bin_aces"], marker='o',
        label=f"DenseNet121 | ECE: {100 * data['ece']:.2f}%",
        linewidth=2, linestyle='-', alpha=0.3, color='crimson')

ax.set_xscale('logit')
ax.set_xlim(0.1, 0.999999)
ax.minorticks_off()
plt.xticks([0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.999999],
           labels=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1])
plt.show()

In [None]:
def logit(p):
    return np.log(np.asarray(p)) - np.log(1 - np.asarray(p))

In [None]:
data = get_data("densenet121", "imagenet", "diag", "calibration")
data1 = get_data("densenet121", "imagenet", "kfac", "calibration")
data2 = get_data("densenet121", "imagenet", "efb", "calibration")
x = logit(data["bin_confs"])

fig = go.Figure([go.Scatter(
                    name="DIAG",
                    x=x,
                    y=data["bin_aces"],
                    hovertext=[f"{100 * ace:.2f}%" for ace in data["bin_aces"]],
                    hoverinfo="text+name"),
                go.Scatter(
                    name="KFAC",
                    x=x,
                    y=data1["bin_aces"],
                    hovertext=[f"{100 * ace:.2f}%" for ace in data1["bin_aces"]],
                    hoverinfo="text+name"),
                go.Scatter(
                    name="EFB",
                    x=x,
                    y=data2["bin_aces"],
                    hovertext=[f"{100 * ace:.2f}%" for ace in data2["bin_aces"]],
                    hoverinfo="text+name")])

grid = False
visible = True
tickvals = [0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.9999]
fig.update_layout(
    xaxis=dict(
        zeroline=False,
        visible=visible,
        range=logit([0.1, 0.9999]),
        tickmode="array",
        tickvals=logit(tickvals),
        ticktext=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1],
        showgrid=grid,
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        visible=visible,
        showgrid=grid,
        title="Confidence - Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=1,
        y=1,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="stack",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=700,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))