# Code: Laplace Approximation for Bayesian Deep Learning

In [2]:
import sys
import numpy as np
from statsmodels.distributions.empirical_distribution import ECDF
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff
from plotly.colors import DEFAULT_PLOTLY_COLORS

In [3]:
resnets = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
new_resnets = ["wide_resnet50_2", "wide_resnet101_2", "resnext50_32x4d", "resnext101_32x8d"]
densenets = ["densenet121", "densenet161", "densenet169", "densenet201"]
vggs = ["vgg11", "vgg13", "vgg16", "vgg19"]
small = ["alexnet", "shufflenet_v2_x1_0", "mobilenet_v2", "squeezenet1_1", "mnasnet1_0"]
inception = ["googlenet", "inception_v3"]

all_models = resnets + new_resnets + densenets + vggs + small + inception

In [4]:
def logit(p):
    return np.log(np.asarray(p)) - np.log(1 - np.asarray(p))

In [9]:
def get_data(model, data, estimator, plot, reload=False):
    filename = f"{model}_{data}_{estimator}_{plot}"
    path = f"/home/matthias/Data/Ubuntu/git/hummat.github.io/data/{filename}.npy"
    try:
        if reload:
            raise FileNotFoundError
        return np.load(path, allow_pickle=True).item()
    except FileNotFoundError:
        sys.path.append("/home/matthias/Data/Ubuntu/git/curvature")
        from curvature.utils import expected_calibration_error, calibration_curve, predictive_entropy

        data = np.load(f"/volume/USERSTORE/humt_ma/{model}/data/{estimator}/{model}_{data}.npz")
        labels, probabilities, bnn_probabilities = data["labels"], data["predictions"], data["bnn_predictions"]
        ood_probabilities, bnn_ood_probabilities = data["ood_predictions"], data["bnn_ood_predictions"]
        
        if plot in ["reliability", "calibration"]:
            if plot == "reliability":
                ece, aces, accs, confs = expected_calibration_error(probabilities, labels)
                bnn_ece, bnn_aces, bnn_accs, bnn_confs = expected_calibration_error(bnn_probabilities, labels)
            elif plot == "calibration":
                ece, confs, accs, _ = calibration_curve(probabilities, labels)
                bnn_ece, bnn_confs, bnn_accs, _ = calibration_curve(bnn_probabilities, labels)
                aces = confs - accs
                bnn_aces = bnn_confs - bnn_accs
            acc = 100 * np.mean(np.argmax(probabilities, axis=1) == labels)
            bnn_acc = 100 * np.mean(np.argmax(bnn_probabilities, axis=1) == labels)

            np.save(path,
                   {"acc": acc,
                    "ece": ece,
                    "aces": aces,
                    "accs": accs,
                    "confs": confs,
                    "bnn_acc": bnn_acc,
                    "bnn_ece": bnn_ece,
                    "bnn_aces": bnn_aces,
                    "bnn_accs": bnn_accs,
                    "bnn_confs": bnn_confs})
        elif plot == "entropy":
            np.save(path, {"num_classes": probabilities.shape[1],
                           "pred_ent": predictive_entropy(probabilities),
                           "bnn_pred_ent": predictive_entropy(bnn_probabilities),
                           "ood_pred_ent": predictive_entropy(ood_probabilities),
                           "bnn_ood_pred_ent": predictive_entropy(bnn_ood_probabilities)})

        return np.load(path, allow_pickle=True).item()

In [None]:
# Preprocess data
for model in all_models:
    for estimator in ["diag", "kfac", "efb"]:
        for plot in ["reliability", "calibration", "entropy"]:
            try:
                get_data(model, "imagenet", estimator, plot, reload=False)
            except FileNotFoundError:
                print(f"No data for {model}, {estimator}, {plot}.")

## Reliability diagram

In [8]:
# Choose a model from 'all_models' above and an estimator from 'sgd', 'diag', 'efb' or 'kfac'.
# 'sgd' is the deterministic NN.
model = "densenet121"
estimator = "sgd"

data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="reliability")
accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]

x = np.linspace(0.05, 1.05, 11)

fig = go.Figure(data=[go.Bar(name=f"Accuracy | {acc:.2f}%",
                             x=x,
                             y=accs,
                             hoverinfo="text",
                             hovertext=[f"{100 * acc:.2f}%" for acc in accs],
                             width=0.1,
                             marker_line_color="black"),
                      go.Bar(name=f"ECE | {100 * ece:.2f}%",
                             x=x,
                             y=aces,
                             hoverinfo="text",
                             hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                             width=0.1,
                             marker=dict(color="rgba(255, 0, 0, 0.5)",
                                         line=dict(color="crimson"))),
                      go.Scatter(x=[0, 0.5, 1],
                                 y=[0, 0.5, 1],
                                 name="Perfect calibration",
                                 hoverinfo="none",
                                 mode="lines",
                                 marker=dict(color="black"),
                                 line=dict(dash="dash",
                                           width=1))])
grid = False
visible = True
fig.update_layout(
    xaxis=dict(
        visible=visible,
        range=[0, 1],
        showgrid=grid,
        constrain="domain",
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        visible=visible,
        range=[aces.min() - 0.03, 1],
        showgrid=grid,
        scaleanchor="x",
        title="Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.2,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="stack",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode='x',
    height=700,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_{estimator}_reliability"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Calibration curve

### 1. Estimator comparison

In [10]:
model = "resnet50"
estimators = ["sgd", "diag", "kfac", "efb"]

fig = go.Figure()
for estimator in estimators:
    data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", "calibration")
    confs = data["bnn_confs"] if estimator != "sgd" else data["confs"]
    accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
    aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
    acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
    ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]
    fig.add_trace(go.Scatter(
                    mode="markers+lines",
                    name=estimator.upper(),
                    x=logit(confs),
                    y=aces,
                    hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                    hoverinfo="text+name"))

grid = False
visible = True
tickvals = [0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.999999]
fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        visible=visible,
        range=logit([0.1, 0.999999]),
        tickmode="array",
        tickvals=logit(tickvals),
        ticktext=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1],
        showgrid=grid,
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        scaleanchor="x",
        scaleratio=50,
        visible=visible,
        showgrid=grid,
        title="Confidence - Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=1,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_calibration"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### 2. Network comparison

In [11]:
prefix = "resnets_"
models = resnets
estimator = "sgd"

fig = go.Figure()
for model in models:
    try:
        data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", "calibration")
        confs = data["bnn_confs"] if estimator != "sgd" else data["confs"]
        accs = data["bnn_accs"] if estimator != "sgd" else data["accs"]
        aces = data["bnn_aces"] if estimator != "sgd" else data["aces"]
        acc = data["bnn_acc"] if estimator != "sgd" else data["acc"]
        ece = data["bnn_ece"] if estimator != "sgd" else data["ece"]
        fig.add_trace(go.Scatter(
                        mode="markers+lines",
                        name=model.capitalize(),
                        x=logit(confs),
                        y=aces,
                        hovertext=[f"{100 * ace:.2f}%" for ace in aces],
                        hoverinfo="text+name",
                        visible="legendonly" if model in ["inception_v3", "googlenet"] else True))
    except FileNotFoundError:
        print(f"No data for {model} and estimator {estimator}.")

grid = False
visible = True
tickvals = [0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 0.999999]
fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        visible=visible,
        range=logit([0.1, 0.999999]),
        tickmode="array",
        tickvals=logit(tickvals),
        ticktext=[0.2, 0.759, 0.927, 0.978, 0.993, 0.998, 1],
        showgrid=grid,
        title="Confidence",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        scaleanchor="x",
        scaleratio=50,
        visible=visible,
        showgrid=grid,
        title="Confidence - Accuracy",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=1,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

No data for resnet34 and estimator sgd.


In [None]:
# Save figure
filename = f"{prefix}{estimator}_imagenet_calibration"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

## Uncertainty NN vs BNN

### 1. Estimator comparison using ECDF

In [12]:
model = "densenet121"
colors = DEFAULT_PLOTLY_COLORS
fig = go.Figure()

for index, estimator in enumerate(["sgd", "diag", "kfac", "efb"]):
    data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
    x_lim = np.log(data["num_classes"])
    entropy_range = np.linspace(0, x_lim, data["num_classes"])
    
    ecdf = ECDF(data["bnn_pred_ent"] if estimator != "sgd" else data["pred_ent"])
    ood_ecdf = ECDF(data["bnn_ood_pred_ent"] if estimator != "sgd" else data["ood_pred_ent"])
    
    fig.add_trace(go.Scatter(x=entropy_range,
                             y=1 - ecdf(entropy_range),
                             marker_color=colors[index],
                             line=dict(dash="dash",
                                       width=3),
                             showlegend=False,
                             hovertemplate="%{y}<extra></extra>",
                             legendgroup=index,
                             visible="legendonly" if estimator in ["kfac", "efb"] else True))
    fig.add_trace(go.Scatter(name=estimator.upper(),
                             x=entropy_range,
                             y=1 - ood_ecdf(entropy_range),
                             marker_color=colors[index],
                             line_width=3,
                             legendgroup=index,
                             visible="legendonly" if estimator in ["kfac", "efb"] else True))

# Hack to have a separated line style legend
fig.add_trace(go.Scatter(name="",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(name="out-of-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black"))
fig.add_trace(go.Scatter(name="in-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black",
                         line=dict(dash="dash")))

fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        range=[-0.1, np.ceil(x_lim)],
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        scaleanchor="x",
        scaleratio=3.5,
        zeroline=False,
        range=[-0.01, 1.01],
        showgrid=False,
        title="1-ecdf",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.95,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_imagenet_ecdf"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### 2. Network comparison using ECDF

In [13]:
prefix = "resnets_"
models = resnets
estimator = "sgd"
colors = DEFAULT_PLOTLY_COLORS
fig = go.Figure()

index = 0
for model in models:
    try:
        data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
        x_lim = np.log(data["num_classes"])
        entropy_range = np.linspace(0, x_lim, data["num_classes"])

        ecdf = ECDF(data["bnn_pred_ent"] if estimator != "sgd" else data["pred_ent"])
        ood_ecdf = ECDF(data["bnn_ood_pred_ent"] if estimator != "sgd" else data["ood_pred_ent"])

        fig.add_trace(go.Scatter(x=entropy_range,
                                 y=1 - ecdf(entropy_range),
                                 marker_color=colors[index],
                                 line=dict(dash="dash",
                                           width=3),
                                 showlegend=False,
                                 hovertemplate="%{y}<extra></extra>",
                                 legendgroup=index))
        fig.add_trace(go.Scatter(name=model.capitalize(),
                                 x=entropy_range,
                                 y=1 - ood_ecdf(entropy_range),
                                 marker_color=colors[index],
                                 line_width=3,
                                 legendgroup=index))
        index += 1
    except FileNotFoundError:
        print(f"No data for {model} and estimator {estimator}.")

# Hack to have a separated line style legend
fig.add_trace(go.Scatter(name="",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="rgba(0, 0, 0, 0)"))
fig.add_trace(go.Scatter(name="out-of-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black"))
fig.add_trace(go.Scatter(name="in-domain",
                         mode="lines",
                         x=[0, 0],
                         y=[0, 0],
                         marker_color="black",
                         line=dict(dash="dash")))

fig.update_layout(
    xaxis=dict(
        constrain="domain",
        zeroline=False,
        range=[-0.1, np.ceil(x_lim)],
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        scaleanchor="x",
        scaleratio=3.5,
        zeroline=False,
        range=[-0.01, 1.01],
        showgrid=False,
        title="1-ecdf",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.95,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

No data for resnet34 and estimator sgd.


In [None]:
# Save figure
filename = f"{prefix}{estimator}_imagenet_ecdf"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### In- vs out-of-distribution histogram

In [14]:
model = "densenet121"
estimator = "sgd"
normalize = False
overlay = True

data = get_data(model, "imagenet", estimator if estimator != "sgd" else "kfac", plot="entropy")
in_domain = data["pred_ent"] if estimator == "sgd" else data["bnn_pred_ent"]
out_of_domain = data["ood_pred_ent"] if estimator == "sgd" else data["bnn_ood_pred_ent"]

fig = go.Figure(data=[go.Histogram(x=in_domain,
                                   nbinsx=100,
                                   name="Known domain",
                                   histnorm="probability" if normalize else "",
                                   marker=dict(line=dict(color="black", width=1 if overlay else 0))),
                      go.Histogram(x=out_of_domain,
                                   nbinsx=100,
                                   name="Unknown domain",
                                   histnorm="probability" if normalize else "",
                                   marker=dict(line=dict(color="black", width=1 if overlay else 0)))])
if overlay:
    fig.update_traces(opacity=0.75)

fig.update_layout(
    xaxis=dict(
        #constrain="domain",
        #range=[-0.1, np.ceil(np.log(data["num_classes"]))],
        zeroline=False,
        showgrid=False,
        title="Predictive entropy",
        titlefont_size=16,
        tickfont_size=14),
    yaxis=dict(
        #scaleanchor="x",
        #scaleratio=0.00065,
        zeroline=False,
        showgrid=False,
        title="Density" if normalize else "Frequency",
        titlefont_size=16,
        tickfont_size=14),
    legend=dict(
        traceorder="normal",
        font=dict(size=16),
        x=0.95,
        y=0.98,
        bgcolor='rgba(0, 0, 0, 0)',
        bordercolor='rgba(0, 0, 0, 0)'),
    barmode="overlay" if overlay else "group",
    template="plotly_white",
    hoverlabel=dict(font_size=18),
    hovermode="x",
    height=500,
    margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
filename = f"{model}_{estimator}_imagenet_hist"
pio.write_html(fig,
               file=f"../_includes/figures/{filename}.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))