# Code: Laplace Approximation for Bayesian Deep Learning

In [None]:
import sys
import numpy as np
from statsmodels.distributions.empirical_distribution import ECDF
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff
from plotly.colors import qualitative

In [None]:
resnets = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
new_resnets = ["wide_resnet50_2", "wide_resnet101_2", "resnext50_32x4d", "resnext101_32x8d"]
densenets = ["densenet121", "densenet161", "densenet169", "densenet201"]
vggs = ["vgg11", "vgg13", "vgg16", "vgg19"]
small = ["alexnet", "shufflenet_v2_x1_0", "mobilenet_v2", "squeezenet1_1", "mnasnet1_0"]
inception = ["googlenet", "inception_v3"]

all_models = resnets + new_resnets + densenets + vggs + small + inception

In [None]:
def logit(p):
    return np.log(np.asarray(p)) - np.log(1 - np.asarray(p))

In [None]:
def get_data(model, data, estimator, plot, reload=False):
    filename = f"{model}_{data}_{estimator}_{plot}"
    path = f"/home/ManhLab/Data/Ubuntu/git/hummat.github.io/data/{filename}.npy"
    try:
        if reload:
            raise FileNotFoundError
        if plot == "fgsm":
            return np.load(path.replace(".npy", ".npz"), allow_pickle=True)
        else:
            return np.load(path, allow_pickle=True).item()
    except FileNotFoundError:
        sys.path.append("/home/ManhLab/Data/Ubuntu/git/curvature")
        from curvature.utils import expected_calibration_error, calibration_curve, predictive_entropy

        data = np.load(f"/volume/USERSTORE/humt_ma/{model}/data/{estimator}/{model}_{data}.npz")
        labels, probabilities, bnn_probabilities = data["labels"], data["predictions"], data["bnn_predictions"]
        ood_probabilities, bnn_ood_probabilities = data["ood_predictions"], data["bnn_ood_predictions"]
        
        if plot in ["reliability", "calibration"]:
            if plot == "reliability":
                ece, aces, accs, confs = expected_calibration_error(probabilities, labels)
                bnn_ece, bnn_aces, bnn_accs, bnn_confs = expected_calibration_error(bnn_probabilities, labels)
            elif plot == "calibration":
                ece, confs, accs, _ = calibration_curve(probabilities, labels)
                bnn_ece, bnn_confs, bnn_accs, _ = calibration_curve(bnn_probabilities, labels)
                aces = confs - accs
                bnn_aces = bnn_confs - bnn_accs
            acc = 100 * np.mean(np.argmax(probabilities, axis=1) == labels)
            bnn_acc = 100 * np.mean(np.argmax(bnn_probabilities, axis=1) == labels)

            np.save(path,
                   {"acc": acc,
                    "ece": ece,
                    "aces": aces,
                    "accs": accs,
                    "confs": confs,
                    "bnn_acc": bnn_acc,
                    "bnn_ece": bnn_ece,
                    "bnn_aces": bnn_aces,
                    "bnn_accs": bnn_accs,
                    "bnn_confs": bnn_confs})
        elif plot == "entropy":
            np.save(path, {"num_classes": probabilities.shape[1],
                           "pred_ent": predictive_entropy(probabilities),
                           "bnn_pred_ent": predictive_entropy(bnn_probabilities),
                           "ood_pred_ent": predictive_entropy(ood_probabilities),
                           "bnn_ood_pred_ent": predictive_entropy(bnn_ood_probabilities)})

        return np.load(path, allow_pickle=True).item()

In [None]:
# Preprocess data
for model in all_models:
    for estimator in ["diag", "kfac", "efb"]:
        for plot in ["reliability", "calibration", "entropy"]:
            try:
                get_data(model, "imagenet", estimator, plot, reload=False)
            except FileNotFoundError:
                print(f"No data for {model}, {estimator}, {plot}.")

## Reliability diagram

In [None]:
# Choose a model from 'all_models' above and an estimator from 'sgd', 'diag', 'efb' or 'kfac'.
# 'sgd' is the deterministic NN.
model = "densenet121"
estimator = "sgd"

data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="reliability")
accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]

x = np.linspace(0.05, 1.05, 11)

fig = go.Figure(data=[go.Bar(name=f"Accuracy | Avg.: {acc:.2f}%",
                             x=x,
                             y=accs,
                             hoverinfo="text",
                             hovertext=[f"{100 * acc:.2f}%" for acc in accs],
                             width=0.1,
                             marker=dict(line=dict(color="black",
                                                   width=1))),
                      go.Bar(name=f"ACE | ECE: {100 * ece:.2f}%",
                             x=x,
                             y=aces,
                             hoverinfo="text",
                             hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                             width=0.1,
                             marker=dict(color="rgba(255, 0, 0, 0.5)",
                                         line=dict(color="crimson",
                                                   width=1))),
                      go.Scatter(x=[0, 0.5, 1],
                                 y=[0, 0.5, 1],
                                 name="Perfect calibration",
                                 hoverinfo="none",
                                 mode="lines",
                                 marker=dict(color="black"),
                                 line=dict(dash="dash",
                                           width=1))])
grid = False
visible = True
fig.update_layout(
    xaxis=dict(
        visible=visible,
        range=[0, 1],
        showgrid=grid,
        constrain="domain",
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        visible=visible,
        range=[aces.min() - 0.03, 1],
        showgrid=grid,
        scaleanchor="x",
        title="Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.08,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="stack",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode='x',
    height=700,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))
fig.show()

In [None]:
# Save figure
filename = f"{model}_imagenet_{estimator}_reliability"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Calibration curve

### 1. Estimator comparison

In [None]:
model = "densenet161"
estimators = ["sgd", "diag", "kfac"]

min_ace, max_ace = 100., 0.
fig = go.Figure()
for estimator in estimators:
    data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", "calibration")
    confs = data["bnn_confs"] if estimator != "sgd" else data["confs"]
    accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
    aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
    acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
    ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]
    
    min_ace = aces.min() if aces.min() < min_ace else min_ace
    max_ace = aces.max() if aces.max() > max_ace else max_ace
    
    fig.add_trace(go.Scatter(mode="markers+lines",
                             name=f"{estimator.upper()}   |   ECE: {100 * ece:.2f}%",
                             x=logit(confs),
                             y=aces,
                             hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                             hoverinfo="text+name"))

fig.add_trace(go.Scatter(name="",
                         visible="legendonly",
                         hoverinfo="none",
                         x=[0],
                         y=[0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(x=logit([0.1, 0.1, 0.999999, 0.999999, 0.1]),
                         y=[min_ace, max_ace, max_ace, min_ace, min_ace],
                         mode="lines",
                         fill="toself",
                         hoverinfo="none",
                         line=dict(width=0),
                         fillcolor="rgba(255, 0, 0, 0.1)",
                         visible="legendonly",
                         name="error range"))

grid = False
visible = True
tickvals = [0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.999999]
fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        visible=visible,
        range=logit([0.1, 0.999999]),
        tickmode="array",
        tickvals=logit(tickvals),
        ticktext=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1],
        showgrid=grid,
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        zerolinewidth=2,
        scaleanchor="x",
        scaleratio=50,
        visible=visible,
        showgrid=grid,
        title="Confidence - Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="top",
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18,
                    namelength=max([len(est) for est in estimators]) + 3),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_calibration"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### 2. Network comparison

In [None]:
compare = "densenets"
estimator = "sgd"

prefix = compare + '_'
colors = qualitative.Plotly
if compare == "all":
    models = all_models
    colors = qualitative.Light24
elif compare == "densenets":
    models = densenets
elif compare == "resnets":
    models = resnets
    
min_ace, max_ace = 100., 0.
fig = go.Figure()
index = 0
for model in models:
    try:
        data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", "calibration")
        confs = data["bnn_confs"] if estimator != "sgd" else data["confs"]
        accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
        aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
        acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
        ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]
        
        min_ace = aces.min() if aces.min() < min_ace else min_ace
        max_ace = aces.max() if aces.max() > max_ace else max_ace
        
        fig.add_trace(go.Scatter(mode="markers+lines",
                                 line_color=colors[index],
                                 name=f"{model.capitalize()}   |   ECE: {100 * ece:.2f}%",
                                 x=logit(confs),
                                 y=aces,
                                 hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                                 hoverinfo="text+name",
                                 visible="legendonly" if model in ["inception_v3", "googlenet"] else True))
        index += 1
        if index == len(colors):
            index = 0
    except FileNotFoundError:
        print(f"No data for {model} and estimator {estimator}.")
        
fig.add_trace(go.Scatter(name="",
                         visible="legendonly",
                         hoverinfo="none",
                         x=[0],
                         y=[0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(x=logit([0.1, 0.1, 0.999999, 0.999999, 0.1]),
                         y=[min_ace, max_ace, max_ace, min_ace, min_ace],
                         mode="lines",
                         fill="toself",
                         hoverinfo="none",
                         line=dict(width=0),
                         fillcolor="rgba(255, 0, 0, 0.1)",
                         visible="legendonly",
                         name="error range"))

grid = False
visible = True
tickvals = [0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.999999]
fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        visible=visible,
        range=logit([0.1, 0.999999]),
        tickmode="array",
        tickvals=logit(tickvals),
        ticktext=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1],
        showgrid=grid,
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        zerolinewidth=2,
        scaleanchor="x",
        scaleratio=50,
        visible=visible,
        showgrid=grid,
        title="Confidence - Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="top",
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18,
                    namelength=max([len(model) for model in models]) + 3),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{prefix}{estimator}_imagenet_calibration"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Uncertainty NN vs BNN

### 1. Estimator comparison using ECDF

In [None]:
model = "resnet50"
colors = qualitative.Plotly
fig = go.Figure()

for index, estimator in enumerate(["sgd", "diag", "kfac"]):
    data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
    x_lim = np.log(data["num_classes"])
    entropy_range = np.linspace(0, x_lim, data["num_classes"])
    
    ecdf = ECDF(data["bnn_pred_ent"] if estimator != "sgd" else data["pred_ent"])
    ood_ecdf = ECDF(data["bnn_ood_pred_ent"] if estimator != "sgd" else data["ood_pred_ent"])
    
    fig.add_trace(go.Scatter(x=entropy_range,
                             y=1 - ecdf(entropy_range),
                             marker_color=colors[index],
                             line=dict(dash="dash",
                                       width=2.5),
                             showlegend=False,
                             hovertemplate="%{y:.2f}<extra></extra>",
                             legendgroup=index,
                             visible="legendonly" if estimator in ["kfac", "efb"] else True))
    fig.add_trace(go.Scatter(name=estimator.upper(),
                             x=entropy_range,
                             y=1 - ood_ecdf(entropy_range),
                             marker_color=colors[index],
                             line_width=2.5,
                             hovertemplate="%{y:.2f}",
                             legendgroup=index,
                             hoverinfo="x+y+name",
                             visible="legendonly" if estimator in ["kfac", "efb"] else True))

# Hack to have a separated line style legend
fig.add_trace(go.Scatter(name="",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(name="out-of-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black"))
fig.add_trace(go.Scatter(name="in-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black",
                         line=dict(dash="dash")))

fig.update_layout(
    xaxis=dict(
        #constrain="domain",
        zeroline=False,
        range=[-0.1, np.ceil(x_lim)],
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        #scaleanchor="x",
        #scaleratio=4,
        zeroline=False,
        range=[0, 1.01],
        showgrid=False,
        title="1-ecdf",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="top",
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    #width=700,
    height=300,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_ecdf"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### 2. Network comparison using ECDF

In [None]:
compare = "resnets"
estimator = "sgd"

prefix = compare + '_'
colors = qualitative.Plotly
if compare == "all":
    models = all_models
    colors = qualitative.Light24
elif compare == "densenets":
    models = densenets
elif compare == "resnets":
    models = resnets

fig = go.Figure()
index = 0
for model in models:
    try:
        data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
        x_lim = np.log(data["num_classes"])
        entropy_range = np.linspace(0, x_lim, data["num_classes"])

        ecdf = ECDF(data["bnn_pred_ent"] if estimator != "sgd" else data["pred_ent"])
        ood_ecdf = ECDF(data["bnn_ood_pred_ent"] if estimator != "sgd" else data["ood_pred_ent"])

        fig.add_trace(go.Scatter(x=entropy_range,
                                 y=1 - ecdf(entropy_range),
                                 marker_color=colors[index],
                                 line=dict(dash="dash",
                                           width=2.5),
                                 showlegend=False,
                                 hovertemplate="%{y:.2f}<extra></extra>",
                                 legendgroup=index))
        fig.add_trace(go.Scatter(name=model.capitalize(),
                                 x=entropy_range,
                                 y=1 - ood_ecdf(entropy_range),
                                 marker_color=colors[index],
                                 line_width=2.5,
                                 hovertemplate="%{y:.2f}",
                                 hoverinfo="x+y+name",
                                 legendgroup=index))
        index += 1
    except FileNotFoundError:
        print(f"No data for {model} and estimator {estimator}.")

# Hack to have a separated line style legend
fig.add_trace(go.Scatter(name="",
                         visible="legendonly",
                         hoverinfo="none",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(name="out-of-domain",
                         mode="lines",
                         hoverinfo="none",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black"))
fig.add_trace(go.Scatter(name="in-domain",
                         mode="lines",
                         hoverinfo="none",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black",
                         line=dict(dash="dash")))

fig.update_layout(
    xaxis=dict(
        #constrain="domain",
        zeroline=False,
        range=[-0.1, np.ceil(x_lim)],
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        #scaleanchor="x",
        #scaleratio=3.5,
        zeroline=False,
        range=[-0.01, 1.01],
        showgrid=False,
        title="1-ecdf",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="top",
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=300,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{prefix}{estimator}_imagenet_ecdf"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### In- vs out-of-distribution histogram

In [None]:
model = "resnet50"
estimator = "kfac"
normalize = False
overlay = True

data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
in_domain = data["pred_ent"] if estimator == "sgd" else data["bnn_pred_ent"]
out_of_domain = data["ood_pred_ent"] if estimator == "sgd" else data["bnn_ood_pred_ent"]

fig = go.Figure(data=[go.Histogram(x=in_domain,
                                   nbinsx=100,
                                   name="Known domain",
                                   histnorm="probability" if normalize else "",
                                   marker=dict(line=dict(color="black", width=1 if overlay else 0)),
                                   hoverinfo="x+y+name",
                                   bingroup=1),
                      go.Histogram(x=out_of_domain,
                                   nbinsx=100,
                                   name="Unknown domain",
                                   histnorm="probability" if normalize else "",
                                   marker=dict(line=dict(color="black", width=1 if overlay else 0)),
                                   hoverinfo="x+y+name",
                                   bingroup=1)])
if overlay:
    fig.update_traces(opacity=0.75)

fig.update_layout(
    xaxis=dict(
        #constrain="domain",
        #range=[-0.1, np.ceil(np.log(data["num_classes"]))],
        zeroline=False,
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        #scaleanchor="x",
        #scaleratio=0.00065,
        zeroline=False,
        showgrid=False,
        title="Density" if normalize else "Frequency",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="top",
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="overlay" if overlay else "group",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_{estimator}_imagenet_hist"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## The Hessian

In [None]:
fig = go.Figure()
fig.add_shape(
            type="rect",
            xref="x",
            yref="y",
            x0=0,
            y0=0,
            x1=1,
            y1=1,
            line=dict(width=0),
            fillcolor="floralwhite")

index = 0
for w in np.arange(0, 0.3, 0.1):
    fig.add_trace(go.Scatter(x=[w, 0.05 + w],
                             y=[0.95 - w, 1 - w],
                             name=f"|W.{index}|",
                             hoverinfo="name",
                             showlegend=False))
    fig.add_shape(
            type="rect",
            xref="x",
            yref="y",
            x0=w,
            y0=0.9 - w,
            x1=0.1 + w,
            y1=1 - w,
            line=dict(width=0),
            fillcolor="salmon")
    index += 1
    
fig.add_trace(go.Scatter(x=[0.3, 0.4],
                         y=[0.6, 0.7],
                         name=f"|W.{index}|",
                         hoverinfo="name",
                         showlegend=False))
fig.add_shape(
        type="rect",
        xref="x",
        yref="y",
        x0=0.3,
        y0=0.5,
        x1=0.5,
        y1=0.7,
        line=dict(width=0),
        fillcolor="salmon")

for i in range(3):
    fig.add_shape(type="circle",
                  xref='x',
                  yref='y',
                  x0=0.54 + (i * 0.05),
                  y0=0.46 - (i * 0.05),
                  x1=0.56 + (i * 0.05),
                  y1=0.44 - (i * 0.05),
                  line_width=0,
                  fillcolor="salmon")

fig.add_trace(go.Scatter(x=[0.7, 0.8],
                         y=[0.2, 0.3],
                         name="|W.L-1|",
                         hoverinfo="name",
                         showlegend=False))
fig.add_shape(
        type="rect",
        xref="x",
        yref="y",
        x0=0.7,
        y0=0.3,
        x1=0.9,
        y1=0.1,
        line=dict(width=0),
        fillcolor="salmon")

fig.add_trace(go.Scatter(x=[0.9, 0.95],
                         y=[0.05, 0.1],
                         name="|W.L|",
                         hoverinfo="name",
                         showlegend=False))
fig.add_shape(
        type="rect",
        xref="x",
        yref="y",
        x0=0.9,
        y0=0.1,
        x1=1,
        y1=0,
        line=dict(width=0),
        fillcolor="salmon")

fig.update_layout(
    xaxis=dict(title="|W|",
               range=[0, 1],
               constrain="domain",
               tickmode="array",
               tickvals=[],
               tick0=0,
               dtick=1,
               titlefont_size=16),
    yaxis=dict(title="|W|",
               range=[0, 1],
               scaleanchor='x',
               scaleratio=1,
               constrain="domain",
               tickmode ="array",
               tickvals=[],
               tick0=0,
               dtick=1,
               titlefont_size=16),
    hoverlabel=dict(font_size=18),
    margin=dict(t=0, b=0, l=0, r=0),
    height=350,
    template="plotly_white")

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/hessian.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Variance and Covariance

In [None]:
def multivariate_gaussian(pos, mu, Sigma):
    """Return the multivariate Gaussian distribution on array pos.

    pos is an array constructed by packing the meshed arrays of variables
    x_1, x_2, x_3, ..., x_k into its _last_ dimension.

    """

    n = mu.shape[0]
    Sigma_det = np.linalg.det(Sigma)
    Sigma_inv = np.linalg.inv(Sigma)
    N = np.sqrt((2*np.pi)**n * Sigma_det)
    # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
    # way across all the input variables.
    fac = np.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu)

    return np.exp(-fac / 2) / N

In [None]:
# Our 2-dimensional distribution will be over variables X and Y
N = 60
x = np.linspace(-5, 5, N)
y = np.linspace(-5, 5, N)
X, Y = np.meshgrid(x, y)


# Mean vector and covariance matrix
mu = np.array([0., 0.])
Sigma = np.array([[2., 0.],
                  [0., 2.]])

# Pack X and Y into a single 3-dimensional array
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

# The distribution on the variables X, Y packed into pos.
fig = go.Figure()
    
for vx in np.arange(2, 7, 1):
    Sigma[0, 0] = vx
    Z = multivariate_gaussian(pos, mu, Sigma)
    fig.add_trace(go.Surface(x=X,
                             y=Y,
                             z=Z,
                             opacity=0.7,
                             visible=False,
                             hoverinfo="none"))

Sigma = np.array([[2., 0.],
                  [0., 2.]])
for vy in np.arange(2, 7, 1):
    Sigma[1, 1] = vy
    Z = multivariate_gaussian(pos, mu, Sigma)
    fig.add_trace(go.Surface(x=X,
                             y=Y,
                             z=Z,
                             opacity=0.7,
                             visible=False,
                             hoverinfo="none"))
    
Sigma = np.array([[2., 0.],
                  [0., 2.]])
for cov in np.arange(-1, 1.2, 0.25):
    Sigma[1, 0] = Sigma[0, 1] = cov
    Z = multivariate_gaussian(pos, mu, Sigma)
    fig.add_trace(go.Surface(x=X,
                             y=Y,
                             z=Z,
                             opacity=0.7,
                             visible=False,
                             hoverinfo="none"))

fig.update_traces(contours_z=dict(show=True,
                                  usecolormap=True,
                                  project_z=True),
                  showscale=False)

fig.data[0].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        label="",
        method="update",
        args=[{"visible": [False] * len(fig.data)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [
    dict(active=0,
         currentvalue={"prefix": "<b>Variance X</b>"},
         pad={"l": 100, "r": 100},
         steps=steps[:5],
         tickwidth=0,
         ticklen=0,),
    dict(active=0,
         currentvalue={"prefix": "<b>Variance Y</b>"},
         pad={"l": 100, "r": 100, "t": 70},
         steps=steps[5:10],
         tickwidth=0,
         ticklen=0),
   dict(active=4,
         currentvalue={"prefix": "<b>Covariance</b>"},
         pad={"l": 100, "r": 100, "t": 140, "b": 10},
         steps=steps[10:],
         tickwidth=0,
         ticklen=0)]

fig.update_layout(scene=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False)),
                  height=700,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  sliders=sliders,
                  scene_camera=dict(eye=dict(x=0.6, y=0, z=1),
                                    center=dict(x=0, y=0, z=-0.1)),
                  template="plotly_white")

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/gaussian_covariance.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Kronecker product

In [None]:
def add_four_squares(fig, x0, y0, sidelength, colors):
    fig.add_shape(type="rect",
                  xref="x",
                  yref="y",
                  x0=x0,
                  y0=y0,
                  x1=x0 + sidelength,
                  y1=y0 + sidelength,
                  line=dict(width=1),
                  fillcolor=colors[0])
    fig.add_shape(type="rect",
                  xref="x",
                  yref="y",
                  x0=x0,
                  y0=y0 + sidelength,
                  x1=x0 + sidelength,
                  y1=y0 + 2*sidelength,
                  line=dict(width=1),
                  fillcolor=colors[1])
    fig.add_shape(type="rect",
                  xref="x",
                  yref="y",
                  x0=x0 + sidelength,
                  y0=y0 + sidelength,
                  x1=x0 + 2*sidelength,
                  y1=y0 + 2*sidelength,
                  line=dict(width=1),
                  fillcolor=colors[2])
    fig.add_shape(type="rect",
                  xref="x",
                  yref="y",
                  x0=x0 + sidelength,
                  y0=y0,
                  x1=x0 + 2*sidelength,
                  y1=y0 + sidelength,
                  line=dict(width=1),
                  fillcolor=colors[3])

fig = go.Figure()
add_four_squares(fig, 0, 0, 0.25, ["tomato", "tomato", "tomato", "tomato"])
add_four_squares(fig, 0.5, 0, 0.25, ["gold", "gold", "gold", "gold"])
add_four_squares(fig, 0.5, 0.5, 0.25, ["cornflowerblue", "cornflowerblue", "cornflowerblue", "cornflowerblue"])
add_four_squares(fig, 0, 0.5, 0.25, ["mediumspringgreen", "mediumspringgreen", "mediumspringgreen", "mediumspringgreen"])

add_four_squares(fig, 1.2, 0.25, 0.25, ["tomato", "mediumspringgreen", "cornflowerblue", "gold"])
add_four_squares(fig, 1.9, 0.25, 0.25, ["gray", "gray", "gray", "gray"])

annotations=[dict(x=1.1,
                  y=0.5,
                  xref="x",
                  yref="y",
                  text=r"$=$",
                  showarrow=False),
             dict(x=1.8,
                  y=0.5,
                  xref="x",
                  yref="y",
                  text=r"$\otimes$",
                  showarrow=False)]

fig.update_layout(
    xaxis=dict(range=[0, 2.5],
               visible=False,
               constrain="domain",
               tickmode="array",
               tickvals=[],
               tick0=0,
               dtick=1,
               titlefont_size=16),
    yaxis=dict(range=[0, 1],
               visible=False,
               scaleanchor='x',
               scaleratio=1,
               constrain="domain",
               tickmode ="array",
               tickvals=[],
               tick0=0,
               dtick=1,
               titlefont_size=16),
    hoverlabel=dict(font_size=18),
    margin=dict(t=0, b=0, l=0, r=0),
    height=150,
    annotations=annotations,
    template="plotly_white")

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/kronecker_product.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Hyperparameters

In [None]:
# Our 2-dimensional distribution will be over variables X and Y
N = 60
x = np.linspace(-5, 5, N)
y = np.linspace(-5, 5, N)
X, Y = np.meshgrid(x, y)


# Mean vector and covariance matrix
mu = np.array([0., 0.])
inv_cov = np.linalg.inv([[3., 1.],
                         [1., 3.]])

# Pack X and Y into a single 3-dimensional array
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

# The distribution on the variables X, Y packed into pos.
fig = go.Figure()
    
for n in np.arange(1., 2., 0.2):
    inv_cov *= n
    Z = multivariate_gaussian(pos, mu, np.linalg.inv(inv_cov))
    fig.add_trace(go.Surface(x=X,
                             y=Y,
                             z=Z,
                             opacity=0.7,
                             visible=False,
                             hoverinfo="none"))

inv_cov = np.linalg.inv([[3., 1.],
                         [1., 3.]])
for tau in np.arange(0., 1., 0.2):
    inv_cov += tau * np.eye(2)
    Z = multivariate_gaussian(pos, mu, np.linalg.inv(inv_cov))
    fig.add_trace(go.Surface(x=X,
                             y=Y,
                             z=Z,
                             opacity=0.7,
                             visible=False,
                             hoverinfo="none"))

fig.update_traces(contours_z=dict(show=True,
                                  usecolormap=True,
                                  start=0.01,
                                  end=0.3,
                                  size=0.02,
                                  project_z=True),
                  showscale=False)

fig.data[0].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        label="",
        method="update",
        args=[{"visible": [False] * len(fig.data)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [
    dict(active=0,
         currentvalue={"prefix": "<b>N</b>"},
         pad={"l": 100, "r": 100},
         steps=steps[:5],
         tickwidth=0,
         ticklen=0,),
    dict(active=0,
         currentvalue={"prefix": "<b>tau</b>"},
         pad={"l": 100, "r": 100, "t": 70},
         steps=steps[5:],
         tickwidth=0,
         ticklen=0)]

fig.update_layout(scene=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False)),
                  height=700,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  sliders=sliders,
                  scene_camera=dict(eye=dict(x=0.6, y=0, z=1),
                                    center=dict(x=0, y=0, z=-0.1)),
                  template="plotly_white")

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/gauss_hyper.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Adversarial attack

In [None]:
model = "resnet18"
metric = "ent"

fig = go.Figure()
for estimator in ['sgd', 'kfac', 'diag']:
    try:
        data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="fgsm")
        stats, bnn_stats = data['stats'].item(), data['bnn_stats'].item()
        epsilons = stats['eps']
        fig.add_trace(go.Scatter(name=estimator.upper(),
                                 x=epsilons,
                                 y=stats[metric] if estimator == "sgd" else bnn_stats[metric],
                                 hovertemplate="%{y:.2f}"))
    except FileNotFoundError:                                            
        print(f"Data for model {model} and estimator {estimator.upper()} not available.")
        
fig.update_layout(
    xaxis=dict(
        zeroline=False,
        showgrid=False,
        title="Step size",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        zeroline=False,
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        xanchor="right",
        yanchor="bottom",
        y=0.1,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=400,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_fgsm"
pio.write_html(fig,
               file=f"../_includes/figures/curvature/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))