In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt

def cartesian_product(*arrays):
    ndim = len(arrays)
    return (np.stack(np.meshgrid(*arrays), axis=-1)
              .reshape(-1, ndim))

 #Element-wise ReLU
def relu(x):
    return np.maximum(np.zeros_like(x), x)

def ntwk(x1, x2,
        #weights
        w11, w12, w13, w14, w21, w22, w23, w24,
        u11, u12, u21, u22, u31, u32, u41, u42,
        v11, v21,
        #biases
        a1, a2, a3, a4,
        b1, b2,
        c1
):
    w = np.array([[w11, w12, w13, w14],
                    [w21, w22, w23, w24]])
    u = np.array([[u11, u12],
                    [u21, u22],
                    [u31, u32],
                    [u41, u42]])
    v = np.array([[v11],
                    [v21]])
    a = np.array([a1, a2, a3, a4])
    b = np.array([b1, b2]),
    c = np.array([c1])
    layer1out = relu(np.array([x1, x2]) @ w) + a
    layer2out = relu(layer1out @ u) + b
    layer3out = (layer2out @ v) + c
    return np.ndarray.item(layer3out)

vntwk = np.vectorize(ntwk)

def plotntwk(
        #weights
        w11, w12, w13, w14, w21, w22, w23, w24,
        u11, u12, u21, u22, u31, u32, u41, u42,
        v11, v21,
        #biases
        a1, a2, a3, a4,
        b1, b2,
        c1):
    x = np.linspace(-1, 1, 8)
    y = np.linspace(-1, 1, 12)
    X, Y = np.meshgrid(x, y)
    data = vntwk(X, Y,
                w11, w12, w13, w14, w21, w22, w23, w24,
                u11, u12, u21, u22, u31, u32, u41, u42,
                v11, v21,
                #biases
                a1, a2, a3, a4,
                b1, b2,
                c1)
    fig, ax = plt.subplots()
    img = ax.imshow(data, origin='lower')
    fig.colorbar(img)
    plt.show()
    plt.close()

# inputSliders = []
# for i in range(2):
#     inputSliders.append(widgets.FloatSlider(min=-1, max=1, step=.01))
weightSliders = []
for i in range(18):
    weightSliders.append(widgets.IntSlider(min=-1, max=1, step=1, value=1, layout=widgets.Layout(width='20%')))
biasSliders = []
for i in range(7):
    biasSliders.append(widgets.FloatSlider(min=-1, max=1, step=.1))

widget = widgets.interactive(plotntwk, 
        #x1=inputSliders[0], x2=inputSliders[1],
        #weights
        w11=weightSliders[0], w12=weightSliders[1], w13=weightSliders[2], w14=weightSliders[3], w21=weightSliders[4], w22=weightSliders[5], w23=weightSliders[6], w24=weightSliders[7],
        u11=weightSliders[8], u12=weightSliders[9], u21=weightSliders[10], u22=weightSliders[11], u31=weightSliders[12], u32=weightSliders[13], u41=weightSliders[14], u42=weightSliders[15],
        v11=weightSliders[16], v21=weightSliders[17],
        #biases
        a1=biasSliders[0], a2=biasSliders[1], a3=biasSliders[2], a4=biasSliders[3],
        b1=biasSliders[4], b2=biasSliders[5],
        c1=biasSliders[6])
layer1weights = widgets.HBox(widget.children[0:8], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
layer2weights = widgets.HBox(widget.children[8:16], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
layer3weights = widgets.HBox(widget.children[16:18], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
layer1bias = widgets.HBox(widget.children[18:22], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
layer2bias = widgets.HBox(widget.children[22:24], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
layer3bias = widgets.HBox(widget.children[24:25], layout = widgets.Layout(flex_flow='row wrap', border='solid 1px'))
output = widget.children[-1]
widgets.VBox([layer1weights, layer2weights, layer3weights, layer1bias, layer2bias, layer3bias, output])

VBox(children=(HBox(children=(IntSlider(value=1, description='w11', layout=Layout(width='20%'), max=1, min=-1)…