# Code: A sense of uncertainty

In [None]:
import numpy as np
from PIL import Image
from scipy.spatial import Delaunay
from scipy.interpolate import griddata
from scipy.optimize import leastsq
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

## Chapter 1: Exploring images

In [None]:
def load_data(size=64):
    # Load image, resize, make monochrome (black and white) copy and convert to Numpy array
    image_rgb = Image.open("../images/afghan_girl.jpg").resize([size, size])
    image_bw = image_rgb.convert('L')
    image_rgb = np.asarray(image_rgb)
    image_bw = np.asarray(image_bw)

    # Get size and color info
    height, width, channels = image_rgb.shape
    colors = np.array([f"rgb{rgb[0], rgb[1], rgb[2]}" for rgb in image_rgb.reshape(-1, 3)])
    values = image_bw.reshape(-1).astype(str)

    # Make grid to paint colors (or pixel values) on
    x, y = np.meshgrid(np.arange(width), np.arange(height))
    flat_x, flat_y = x.reshape(-1), y.reshape(-1)

    red = np.array([rgb[0] for rgb in image_rgb.reshape(-1, 3)]).astype(str)
    green = np.array([rgb[1] for rgb in image_rgb.reshape(-1, 3)]).astype(str)
    blue = np.array([rgb[2] for rgb in image_rgb.reshape(-1, 3)]).astype(str)
    
    return {"rgb": image_rgb,
            "mono": image_bw,
            "x": flat_x,
            "y": flat_y,
            "colors": colors,
            "values": values,
            "red": red,
            "green": green,
            "blue": blue}

In [None]:
# Standard layout
def get_layout():
    return go.Layout(template="plotly_white",
                     xaxis=dict(constrain="domain",
                                visible=False),
                     yaxis=dict(scaleanchor='x',
                                visible=False,
                                autorange="reversed"),
                     hoverlabel=dict(font_size=18),
                     height=1024,
                     margin=dict(r=0, l=0, b=0, t=0, pad=0))

### Conversion from colors to color values (interactive)

In [None]:
data = load_data()

# Select channel to visualize (if in RGB mode):
# - 0 = red
# - 1 = green
# - 2 = blue
channel = 1
if channel == 0:
    text = data["red"]
elif channel == 1:
    text = data["green"]
else:
    text = data["blue"]

width, height = data["mono"].shape
color = data["colors"]
value = data["values"]

# Make figure
fig = go.FigureWidget(go.Scattergl(x=data["x"],
                                   y=data["y"],
                                   mode="markers+text",
                                   marker=dict(size=16,
                                               color=color,
                                               symbol="square"),
                                   textfont=dict(size=8, color="white"),
                                   hoverinfo="none",
                                   showlegend=False),
                      layout=get_layout())

# Setup interactive conversion
def get_text(x, bw=False):
    indices = np.arange(x - 2, x + 3)
    indices = np.append(indices, [x - width,
                                  x + width,
                                  x - width + 1,
                                  x - width - 1,
                                  x + width + 1,
                                  x + width - 1,
                                  x - 2 * width,
                                  x + 2 * width])
    indices = indices[indices >= 0]
    indices = indices[indices < width * height]
    tmp = np.array(["   "] * len(text))
    if bw:
        tmp[indices] = value[indices]
    else:
        tmp[indices] = text[indices]
    return tmp

def pixel_to_number(trace, points, selector):
    if trace.textfont.color == "white":
        fig.data[0].text = get_text(points.point_inds[0])
    else:
        fig.data[0].text = get_text(points.point_inds[0], bw=True)

fig.data[0].on_hover(pixel_to_number)

# Add RGB/Monochrome selection menu
b1_args = [dict(marker=dict(size=16, color=color, symbol="square"),
                            textfont=dict(size=8, color="white"))]
b2_args = [dict(marker=dict(size=16, color=value, symbol="square", colorscale="Greys"),
                            textfont=dict(size=8, color="lime"))]

buttons = [dict(label="RGB", method="restyle", args=b1_args),
           dict(label="Monochrome", method="restyle", args=b2_args)]

fig.update_layout(updatemenus=[dict(buttons=buttons, x=0.17, y=1)])

### Transition from colors to color values (animated)

In [None]:
data = load_data(size=32)
width, height, channels = data["rgb"].shape
color = np.array(data["colors"])
x, y = data["x"], data["y"]

# Select channel to visualize (if in RGB mode):
# - 0 = red
# - 1 = green
# - 2 = blue
channel = 1
if channel == 0:
    text = data["red"]
elif channel == 1:
    text = data["green"]
else:
    text = data["blue"]

tmp = np.array(["   "] * len(text))
frames = list()
for k in np.arange(-width+1, width)[::-1]:
    mask = np.triu(np.arange(width * height).reshape(width, height), k=k).flatten() > 0
    tmp[mask] = text[mask]
    color[mask] = "rgb(0, 0, 0)"
    frames.append(go.Frame(data=go.Scatter(x=x,
                                           y=y,
                                           mode="markers+text",
                                           marker=dict(size=16,
                                                       color=color,
                                                       symbol="square"),
                                           text=tmp,
                                           textfont=dict(size=8, color="white"),
                                           hoverinfo="none",
                                           showlegend=False)))

fig = go.Figure(go.Scatter(x=x,
                           y=y,
                           mode="markers",
                           marker=dict(size=16,
                                       color=data["colors"],
                                       symbol="square"),
                           textfont=dict(size=8, color="white"),
                           hoverinfo="none",
                           showlegend=False),
                frames=frames,
                layout=get_layout())

fig.update_layout(updatemenus=[dict(type="buttons",
                                    buttons=[dict(label="Start",
                                                  method="animate",
                                                  args=[None])])],
                  height=512)

### Directly showing color values instead of colors

In [None]:
data = load_data()
red, green, blue, mono = data["red"], data["green"], data["blue"], data["values"]
color = data["colors"]
value = data["values"]

layout = get_layout()
layout.margin = dict(r=0, l=0, b=0, t=10, pad=0)
# layout.autosize = True
fig = go.Figure(go.Scattergl(x=data["x"],
                             y=data["y"],
                             mode="text",
                             text=red,
                             textfont=dict(size=8),
                             hoverinfo="none",
                             showlegend=False),
                layout=layout)

# Add R/G/B/Monochrome/RGB selection menu
buttons = [dict(label="Red (numeric)", method="restyle", args=[dict(mode="text", text=[red])]),
           dict(label="Green (numeric)", method="restyle", args=[dict(mode="text", text=[green])]),
           dict(label="Blue (numeric)", method="restyle", args=[dict(mode="text", text=[blue])]),
           dict(label="Monochrome (numeric)", method="restyle", args=[dict(mode="text", text=[mono])]),
           dict(label="Red", method="restyle", args=[dict(mode="markers",
                                                          marker=dict(size=16,
                                                                      color=[f"rgb{int(r), 0, 0}" for r in red],
                                                                      symbol="square"))]),
          dict(label="Green", method="restyle", args=[dict(mode="markers",
                                                           marker=dict(size=16,
                                                                       color=[f"rgb{0, int(g), 0}" for g in green],
                                                                       symbol="square"))]),
          dict(label="Blue", method="restyle", args=[dict(mode="markers",
                                                          marker=dict(size=16,
                                                                      color=[f"rgb{0, 0, int(b)}" for b in blue],
                                                                      symbol="square"))]),
          dict(label="Monochrome", method="restyle", args=[dict(mode="markers",
                                                                marker=dict(size=16,
                                                                            color=value,
                                                                            symbol="square",
                                                                            colorscale="Greys"))]),
          dict(label="RGB", method="restyle", args=[dict(mode="markers",
                                                         marker=dict(size=16,
                                                                     color=color,
                                                                     symbol="square"))])]

fig.update_layout(updatemenus=[dict(buttons=buttons, x=0.6, y=1.01)], height=700)
fig

In [None]:
# Save figure
pio.write_html(fig,
               file='../_includes/figures/image_numbers.html',
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

In [None]:
data = load_data(size=128)
red, green, blue = data["red"], data["green"], data["blue"]
x, y = data["x"], data["y"]
color = data["colors"]

image2d = go.Scattergl(x=x,
                       y=y,
                       mode='markers',
                       marker=dict(size=4.4,
                                   color=color,
                                   symbol="square"),
                       hovertemplate="<b>Pixel:</b> %{x}, %{y}<br>"+
                                     "<b>Color:</b> %{marker.color}<extra></extra>",
                       hoverlabel=dict(bgcolor=color),
                       showlegend=False)

red = [f"rgba{int(r), 0, 0, 0.7}" for r in red]
green = [f"rgba{0, int(g), 0, 0.2}" for g in green]
blue = [f"rgba{0, 0, int(b), 0.2}" for b in blue]
rgb_color = np.concatenate([red, green, blue])

x3 = np.concatenate([x] * 3)
y3 = np.concatenate([y] * 3)
z3 = np.concatenate([np.zeros_like(x), np.ones_like(x) * 10, np.ones_like(x) * 20])
size = 1.85

image3d = go.Scatter3d(x=x3,
                       y=y3,
                       z=z3,
                       mode='markers',
                       marker=dict(size=1,
                                   color=rgb_color,
                                   symbol="square"),
                       hovertemplate="<b>Pixel:</b> %{x}, %{y}<extra></extra>",
                       hoverlabel=dict(bgcolor=color),
                       showlegend=False)

# Figure
fig = make_subplots(rows=1,
                    cols=2,
                    horizontal_spacing=0,
                    vertical_spacing=0,
                    specs=[[dict(type="xy"), dict(type="scene")]])

fig.add_trace(image2d, row=1, col=1)
fig.add_trace(image3d, row=1, col=2)

fig.layout.scene.camera=dict(eye=dict(x=0, y=0, z=-1),
                             up=dict(x=0, y=-1, z=0),
                             projection=dict(type="orthographic"))

fig.update_layout(template="plotly_white",
                  xaxis=dict(constrain="domain",
                             visible=False),
                  yaxis=dict(scaleanchor='x',
                             visible=False,
                             autorange="reversed"),
                  scene=dict(xaxis=dict(visible=False),
                             yaxis=dict(visible=False),
                             zaxis=dict(visible=False),
                             aspectratio=dict(x=size, y=size, z=size)),
                  hoverlabel=dict(font_size=18),
                  height=400,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0))

In [None]:
# Save figure
pio.write_html(fig,
               file='../_includes/figures/image.html',
               full_html=False,
               include_plotlyjs='cdn')

## Chapter 2: Exploring loss landscapes

In [None]:
# Choose network/dataset:
# 0. ResNet18 on CIFAR10
# 1. ResNet50 on CIFAR10
# 2. DenseNet121 on CIFAR10
# 3. GoogLeNet on CIFAR10
# 4. VGG16 with BN on CIFAR10
# 5. Mobilenet V2 on CIFAR10
# 6. Inception V3 on CIFAR10
# 7. LeNet5 on MNIST
# 8. ResNet50 on ImageNet
# 9. DenseNet121 on ImageNet
# 10. GoogLeNet on ImageNet
choice = 10

values =  ["resnet18_cifar10",
           "resnet50_cifar10",
           "densenet121_cifar10",
           "googlenet_cifar10",
           "vgg16bn_cifar10",
           "mobilenet_v2_cifar10",
           "inception_v3_cifar10",
           "lenet5_mnist",
           "resnet50_imagenet",
           "densenet121_imagenet",
           "googlenet_imagenet"]
data = np.load(f"../data/{values[choice]}.npy")
x, y, loss, acc = data[:, 0], data[:, 1], data[:, 2], data[:, 3]
print(f"Found {len(loss)} samples. Max loss: {loss.max():.0f}.")

# Heuristic to get visually pleasing and usuful plots
threshold = np.median(loss) + 1 if loss.var() > 10 else loss.max()
index = np.argwhere(loss <= threshold).reshape(-1)
x = x[index]
y = y[index]
loss = loss[index]
acc = acc[index]
print(f"Removed {len(data) - len(loss)} samples at threshold {threshold:.0f}.")

# points2D = np.vstack([x, y]).T
# tri = Delaunay(points2D)
# simplices = tri.simplices

xi = np.linspace(-1, 1, 101)
yi = np.linspace(-1, 1, 101)
loss_contours = griddata((x, y), loss, (xi[None, :], yi[:, None]), method="cubic")
acc_contours = griddata((x, y), acc, (xi[None, :], yi[:, None]), method="cubic")

# Correct for impossible interpolation values
loss_contours = np.where(loss_contours < 0., 0., loss_contours)
acc_contours = np.where(acc_contours < 0., 0., acc_contours)
acc_contours = np.where(acc_contours > 100., 100., acc_contours)

layout = go.Layout(template="plotly_white",
                   xaxis=dict(constrain="domain"),
                   yaxis=dict(scaleanchor='x'),
                   hoverlabel=dict(font_size=18),
                   height=700,
                   margin=dict(r=0, l=0, b=0, t=0, pad=0),
                   font=dict(family="Courier New, monospace",
                             size=16,
                             color="#7f7f7f"))

### 2D loss and accuracy

In [None]:
# Loss contour lines
ht_loss = "<b>Location</b>: %{x:.2f}, %{y:.2f}<br><b>Loss</b>: %{z:.2f}<extra></extra>"
start = 0 if loss.var() < 1 else round(np.ceil(loss.min()) * 2) / 2
end = loss.max()
size = 0.2 if loss.var() < 1 else round(np.floor((loss.max() - loss.min()) * 0.2)) / 2

loss_data = go.Contour(x=xi,
                       y=yi,
                       z=loss_contours,
                       colorscale="Viridis",
                       contours=dict(start=start,
                                     end=end,
                                     size=size,
                                     showlabels=True,
                                     labelfont=dict(size=12)),
                       contours_coloring='lines',
                       line_width=3,
                       showscale=False,
                       hovertemplate=ht_loss,
                       hoverlabel=dict(bgcolor='darkslategray'),
                       visible=True)

# Accuracy contour lines
ht_acc = "<b>Location</b>: %{x:.2f}, %{y:.2f}<br><b>Accuracy</b>: %{z:.2f}<extra></extra>",
acc_data = go.Contour(x=xi,
                      y=yi,
                      z=acc_contours,
                      contours=dict(showlabels=True,
                                    labelfont=dict(size=12)),
                      contours_coloring='lines',
                      line_width=3,
                      showscale=False,
                      hovertemplate=ht_acc,
                      hoverlabel=dict(bgcolor='darkslategray'),
                      visible=False)

samples = go.Scattergl(x=x,
                       y=y,
                       mode="markers",
                       visible=False)

layout.xaxis.visible = False
layout.yaxis.visible = False
fig = go.Figure([loss_data, acc_data, samples],
                layout=layout)

# Add loss/accuracy/both options
ht_both = "<b>Location</b>: %{x:.2f}, %{y:.2f}<extra></extra>"
buttons = [dict(label="Loss", method="update", args=[dict(visible=[True, False, False],
                                                          line=dict(dash="solid", width=3),
                                                          hovertemplate=ht_loss)]),
           dict(label="Accuracy", method="update", args=[dict(visible=[False, True, False],
                                                              line=dict(dash="solid", width=3),
                                                              hovertemplate=ht_acc)]),
           dict(label="Both", method="update", args=[dict(visible=[True, True, False],
                                                          line=[dict(dash="solid", width=3),
                                                                dict(dash="dash", width=2)],
                                                          hovertemplate=ht_both)])]

fig.update_layout(updatemenus=[dict(buttons=buttons, x=0.2, y=1)])

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/loss/{values[choice]}_loss_acc_2d.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### 3D loss landscape

In [None]:
# Loss landscape
# loss3d = ff.create_trisurf(x=x, y=y, z=loss,
#                            colormap="Viridis",
#                            simplices=simplices,
#                            plot_edges=False,
#                            show_colorbar=False).data[0]
loss3d = go.Surface(x=xi, y=yi, z=loss_contours,
                    colorscale="Viridis",
                    showscale=False)
fig = go.Figure(loss3d)
fig.update_traces(contours_z=dict(show=True,
                                  usecolormap=True,
                                  start=start,
                                  end=end,
                                  size=size,
                                  project_z=True),
                  hovertemplate="<b>Location</b>: %{x:.2f}, %{y:.2f}<br>"+
                                "<b>Loss</b>: %{z:.2f}<extra></extra>",
                  hoverlabel=dict(bgcolor='darkslategray'))

# Viewpoint
camera = dict(
    eye=dict(x=0, y=-2.2, z=1.2),
    center=dict(x=0, y=0, z=0.05))

# Figure properties
fig.update_layout(scene=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False),
                    aspectmode='data'),
                  hoverlabel=dict(font_size=18),
                  height=700,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_camera=camera,
                  scene_dragmode="orbit")

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/loss/{values[choice]}_loss_3d.html",
               full_html=False,
               include_plotlyjs='cdn',
               config=dict(displayModeBar=False))

### Loss landscape vs Gaussian

In [None]:
# loss3d = ff.create_trisurf(x=x, y=y, z=np.exp(-loss),
#                           colormap="Viridis",
#                            simplices=simplices,
#                            plot_edges=False,
#                            show_colorbar=False).data[0]

# Obtain the likelihood/posterior which is the exponential of the negative loss
exp_neg_loss = np.exp(-loss_contours)

# Remove NaN values for fitting
i = 1
while np.sum(np.isnan(exp_neg_loss[i:-i, i:-i])) > 0:
    i += 1
exp_neg_loss = exp_neg_loss[i:-i, i:-i]
xii = xi[i:-i]
yii = yi[i:-i]

# Plot likelihood/posterior
loss3d = go.Surface(x=xii,
                    y=yii,
                    z=exp_neg_loss,
                    contours_z=dict(show=True,
                                    usecolormap=True,
                                    start=0.2,
                                    end=1,
                                    size=0.1,
                                    project_z=True),
                    colorscale="Viridis",
                    showscale=False)
loss3d.hovertemplate = "<b>Location</b>: %{x:.2f}, %{y:.2f}<br><b>Exp. neg. loss</b>: %{z:.2f}<extra></extra>"
loss3d.hoverlabel = dict(bgcolor="darkslategray")

# Fit multivariate normal to it
# Source: https://scipy-cookbook.readthedocs.io/items/FittingData.html
def gaussian(height, center_x, center_y, width_x, width_y):
    """Returns a gaussian function with the given parameters"""
    width_x = float(width_x)
    width_y = float(width_y)
    return lambda x,y: height*np.exp(
                -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)

def moments(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution by calculating its
    moments """
    total = data.sum()
    X, Y = np.indices(data.shape)
    x = (X*data).sum()/total
    y = (Y*data).sum()/total
    col = data[:, int(y)]
    width_x = np.sqrt(np.abs((np.arange(col.size)-x)**2*col).sum()/col.sum())
    row = data[int(x), :]
    width_y = np.sqrt(np.abs((np.arange(row.size)-y)**2*row).sum()/row.sum())
    height = data.max()
    return height, x, y, width_x, width_y

def fitgaussian(data):
    """Returns (height, x, y, width_x, width_y)
    the gaussian parameters of a 2D distribution found by a fit"""
    params = moments(data)
    errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
                                 data)
    p, success = leastsq(errorfunction, params)
    return p

fit = gaussian(*fitgaussian(exp_neg_loss))(*np.indices(exp_neg_loss.shape))
gauss = go.Surface(x=xii,
                   y=yii,
                   z=fit,
                   contours_z=dict(show=True,
                                   usecolormap=True,
                                   start=0.2,
                                   end=1,
                                   size=0.1,
                                   project_z=True),
                   colorscale="Viridis",
                   hovertemplate="<b>Location</b>: %{x:.2f}, %{y:.2f}<br>"+
                                 "<b>Density</b>: %{z:.2f}<extra></extra>",
                   hoverlabel=dict(bgcolor="darkslategray"),
                   showscale=False)

# Figure
fig = make_subplots(rows=1,
                    cols=2,
                    horizontal_spacing=0,
                    specs=[[{'is_3d': True}, {'is_3d': True}]])
fig.add_trace(loss3d, row=1, col=1)
fig.add_trace(gauss, row=1, col=2)

# Viewpoint
zoom = 1.3
camera = dict(
    eye=dict(x=zoom, y=zoom, z=zoom),
    center=dict(x=0, y=0, z=-0.4))

# Figure properties
fig.layout.scene1.camera = camera
fig.layout.scene2.camera = camera
fig.update_layout(scene1=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False)),
                  scene2=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False)),
                  height=400,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0))

fig = go.FigureWidget(fig)
def cam_change(layout, camera):
    fig.layout.scene2.camera = camera

fig.layout.scene1.on_change(cam_change, 'camera')
fig

In [None]:
# Save figure
pio.write_html(fig,
               file=f"../_includes/figures/loss/{values[choice]}_loss_vs_gauss.html",
               full_html=False,
               include_plotlyjs='cdn')