In [1]:
import matplotlib.pyplot as plt
import numpy as np

from jupyter_dash import JupyterDash

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.express as px

import plotly.graph_objects as go


external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

In [5]:
def Dense(out_dim):
    """Dense layer."""
    def init_fun(input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        weights = {}
        weights['W'] = 2 * np.random.rand(input_shape[-1], out_dim) - 1
        weights['b'] = np.random.rand(out_dim)
        
        return output_shape, weights
    def apply_fun(params, inputs, **kwargs):        
        return np.dot(inputs, params['W']) + params['b']
    return init_fun, apply_fun    

In [3]:
def sigmoid(z):
    return(1. / (1 + np.exp(-z)))

Sigmoid = elementwise(sigmoid)

In [2]:
def serial(*layers):
    """Combinator for composing layers in serial.
    Args:
    *layers: a sequence of layers, each an (init_fun, apply_fun) pair.

    Returns:
    A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
    composition of the given sequence of layers.
    """
    nlayers = len(layers)
    init_funs, apply_funs = zip(*layers)
    def init_fun(input_shape):
        params = []
        for init_fun in init_funs:
            input_shape, param = init_fun(input_shape)
            params.append(param)
        return input_shape, params
    def apply_fun(params, inputs):
        for fun, param in zip(apply_funs, params):
            inputs = fun(param, inputs)
        return inputs
    return init_fun, apply_fun

def elementwise(fun):
    """Layer that applies a scalar function elementwise on its inputs."""
    init_fun = lambda input_shape: (input_shape, ())
    apply_fun = lambda params, inputs: fun(inputs)
    return init_fun, apply_fun


In [4]:
def target_function(x):
    return 0.2+0.4*x**2+0.3*x*np.sin(15*x)+0.05*np.cos(50*x)

In [6]:
x = np.linspace(0, 1, 100).reshape(-1,1)

In [None]:
JupyterDash.infer_jupyter_proxy_config()

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

In [9]:
xaxis_layout = dict(
        showline=True,
        showgrid=False,
        showticklabels=True,
        linecolor='rgb(204, 204, 204)',
        linewidth=2,
        ticks='outside',
        tickfont=dict(
            family='Arial',
            size=12,
            color='rgb(82, 82, 82)',
        ),
    )

yaxis_layout = dict(
        showline=True,
        showgrid=False,
        showticklabels=True,
        linecolor='rgb(204, 204, 204)',
        linewidth=2,
        ticks='outside',
        tickfont=dict(
            family='Arial',
            size=12,
            color='rgb(82, 82, 82)',
        ),
    )

In [10]:
# x -> 1 -> 1
nn_0_init, nn_0_apply = serial(Dense(1), Sigmoid, Dense(1))
_, wnn_0 = nn_0_init(x.shape)
wnn_0[0]['W'] = np.array([[10]])
wnn_0[0]['b'] = np.array([-5])
wnn_0[2]['W'] = np.array([[1]])
wnn_0[2]['b'] = np.array([[0]])

# x -> 2 -> 1
nn_1_init, nn_1_apply = serial(Dense(2), Sigmoid, Dense(1))
_, wnn_1 = nn_1_init(x.shape)
shw0_nn1 = wnn_1[0]['W'].shape
shw2_nn1 = wnn_1[2]['W'].shape
wnn_1[0]['W'] = 100 * np.array([1, 1]).reshape(shw0_nn1)
wnn_1[0]['b'] = -np.array([0.4, 0.6]) * wnn_1[0]['W']
wnn_1[2]['W'] = np.array([1.2, -1.2]).reshape(shw2_nn1)
wnn_1[2]['b'] = np.array([0])

# x -> 4 -> 1 and sin() function
nn_2_init, nn_2_apply = serial(Dense(4), Sigmoid, Dense(1))
_, wnn_2 = nn_2_init(x.shape)
shw0_2 = wnn_2[0]['W'].shape
shw2_2 = wnn_2[2]['W'].shape
wnn_2[0]['W'] = 100 * np.array([1, 1, 1, 1]).reshape(shw0_2)
wnn_2[0]['b'] = -np.array([0.15, 0.4, 0.6, 0.8]) * wnn_2[0]['W']
wnn_2[2]['W'] = np.array([-1.2, 1.2, .7, -.7]).reshape(shw2_2)
wnn_2[2]['b'] = np.array([0])
# plt.plot(x, nn_2_apply(wnn_2, x))
# plt.plot(x, np.sin(-x*(2*np.pi)))

# x -> 10 -> 1 and comlex function
nn_3_init, nn_3_apply = serial(Dense(10), Sigmoid, Dense(1))
_, wnn_3 = nn_3_init(x.shape)
shw0_3 = wnn_3[0]['W'].shape
shw2_3 = wnn_3[2]['W'].shape
wnn_3[0]['W'] = 100 * np.ones(10).reshape(shw0_3)
wnn_3[0]['b'] = -np.array([0., 0.2, 0.2, 0.4, 0.4, 0.6, 0.6, 0.8, 0.8, 1.0]) * wnn_3[0]['W']
wnn_3[2]['W'] = np.array([-.2, 0.2, -.25, 0.25, 0.06, -0.06, -.2, .2, .35, -.35]).reshape(shw2_3)
wnn_3[2]['b'] = np.array([.4])
# plt.plot(x, nn_3_apply(wnn_3, x))
# plt.plot(x, target_function(x))

In [11]:
def f(x):
    return 0.2+0.4*x**2+0.3*x*np.sin(15*x)+0.05*np.cos(50*x)

In [12]:
app.layout = html.Div([    
    
    html.Div([dcc.Markdown(
    """
    # Universal Approximation Theorem
    inspired by Michael Nielsen YouTube [video](https://www.youtube.com/watch?v=Ijqkc7OLenI&ab_channel=MichaelNielsen)
    """)], style={'textAlign': 'center'}),
    
    html.Div([
       dcc.Markdown(
       """
       In the notebook we use a [Jax](https://jax.readthedocs.io/en/latest/) style for the definition of fully connected neural networks (FCNN). 
       Autodiff is not implemented. There is only the feedforward part. You cannot really train this "networks". 
       
       All the examples use FCNN with only 1 hidden layer and the activation function is **not applied** on the output layer. 
       In the examples some parameter is kept fixed or bounded to other paramenters.
       """
       ) 
    ]),
    
    
    html.Div([
       dcc.Markdown(
       """
       ### Single node hidden layer
       """
       ) 
    ]),
    
    # x -> 1 -> 1
    
    html.Div([
        html.Div([
            html.Div(children=[
                html.Img(src=app.get_asset_url('net_png.png'), style={'width': '80%', 'display': 'inline-block'})     
                ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}
            ),
            html.Div(children=[
                dcc.Graph(id='nn_0'),
            ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}),
        ]),
        html.Div([
            html.Div([
                html.Label("w1 "),
                dcc.Input(id='nn_0_w1', type='number', value=10, step=0.01, style={'width': 60})
            ], style={'width': '5%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b1 "),
                dcc.Input(id='nn_0_b1', type='number', value=-5, step=0.01, style={'width': 60}),
            ], style={'width': '5%', 'display': 'inline-block'}),
            html.Div([
                html.Label("w2 "),
                dcc.Input(id='nn_0_w2', type='number', value=1, step=0.01, style={'width': 60}),  
            ], style={'width': '5%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b2 "),
                dcc.Input(id='nn_0_b2', type='number', value=0, step=0.01, style={'width': 60})    
            ], style={'width': '5%', 'display': 'inline-block'}),
        ]),
    ], style={'width': '95%', 'display': 'inline-block'}),
        
    # x -> 2 -> 1    
    html.Div([
       dcc.Markdown(
       """
       ### Two nodes in the hidden layer
       Playing with the parameters you see that we can have bumps of arbitrary height and width. Even sort of delta functions.
       
       """
       ) 
    ]),

    html.Div([
        html.Div([
            html.Div(children=[
                html.Img(src=app.get_asset_url('net_png2.png'), style={'width': '80%', 'display': 'inline-block'})     
                ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}
            ),
            html.Div(children=[
            dcc.Graph(id='nn_1'),
            ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}),
        ]),
        html.Div([
            html.Div([
                html.Label("b11 "),
                dcc.Input(id='nn_1_b11', type='number', value=0.4, step=0.01, style={'width': 80}),  
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b12 "),
                dcc.Input(id='nn_1_b12', type='number', value=0.6, step=0.01, style={'width': 80}),    
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("w21 "),
                dcc.Input(id='nn_1_w21', type='number', value=1.2, step=0.01, style={'width': 80}),  
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("w22 "),
                dcc.Input(id='nn_1_w22', type='number', value=-1.2, step=0.01, style={'width': 80}),
            ], style={'width': '10%', 'display': 'inline-block'}),
        ]),
    ]),
    
    # x -> 4 -> 1    
    html.Div([
       dcc.Markdown(
       """
       ### Four nodes in the hidden layer
       Try to fit a sin() function. One bump goes up the other down
       
       """
       ) 
    ]),
    
    html.Div([
        html.Div([
            html.Div(children=[
                html.Img(src=app.get_asset_url('net_png3.png'), style={'width': '80%', 'display': 'inline-block'})     
                ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}
            ),
            html.Div(children=[
            dcc.Graph(id='nn_2'),
            ], style={'width': '45%', 'display': 'inline-block', 'vertical-align': 'middle'}),
        ]),
        html.Div([
            html.Div([
                html.Label("b11 "),
                dcc.Input(id='nn_2_b11', type='number', value=0.15, step=0.01, style={'width': 80}),  
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b12 "),
                dcc.Input(id='nn_2_b12', type='number', value=0.35, step=0.01, style={'width': 80}),    
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b13 "),
                dcc.Input(id='nn_2_b13', type='number', value=0.6, step=0.01, style={'width': 80}),  
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("b14 "),
                dcc.Input(id='nn_2_b14', type='number', value=0.8, step=0.01, style={'width': 80}),    
            ], style={'width': '10%', 'display': 'inline-block'}),
        ], style={'margin-bottom': 10}),
        html.Div([
            html.Div([
                html.Label("w21, w22"),
                dcc.Input(id='nn_2_w21', type='number', value=1.2, step=0.01, style={'width': 80}),  
            ], style={'width': '10%', 'display': 'inline-block'}),
            html.Div([
                html.Label("w23, w24"),
                dcc.Input(id='nn_2_w22', type='number', value=1.0, step=0.01, style={'width': 80}),
            ], style={'width': '10%', 'display': 'inline-block'}),
        ]),
    ]),
    
    html.Div([
        dcc.Markdown(
        """
        ### Ten nodes in the hidden layer
        We want that the predictor will follow the behavior of a complicated function defined by:
        
        f(x) = 0.2+0.4x^2+0.3x sin(15x)+0.05cos(50x)

        """
        ) 
    ]),

    # x -> 10 -> 1
    html.Div([
        html.Div([
            dcc.Graph(id='nn_3'),
        ], style={'width': '90%', 'display': 'inline-block', 'padding': '0 20'}),
        html.Div([
            html.Div([
                html.Div([
                    html.Label("b11 "),
                    dcc.Input(id='nn_3_b11', type='number', value=0., step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b12, b13 "),
                    dcc.Input(id='nn_3_b12', type='number', value=0.2, step=0.01, style={'width': 80}),    
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b14, b15 "),
                    dcc.Input(id='nn_3_b13', type='number', value=0.4, step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b16, b17 "),       
                    dcc.Input(id='nn_3_b14', type='number', value=0.6, step=0.01, style={'width': 80}),    
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b18, b19"),
                    dcc.Input(id='nn_3_b15', type='number', value=0.8, step=0.01, style={'width': 80}),    
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b110 "),
                    dcc.Input(id='nn_3_b16', type='number', value=1.0, step=0.01, style={'width': 80}),    
                ], style={'width': '10%', 'display': 'inline-block'}),
            ], style={'margin-bottom': 10}),
            html.Div([
                html.Div([
                    html.Label("w21, w22 "),
                    dcc.Input(id='nn_3_w21', type='number', value=.2, step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("w23, w24 "),
                    dcc.Input(id='nn_3_w22', type='number', value=.25, step=0.01, style={'width': 80}),
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("w25, w26 "),
                    dcc.Input(id='nn_3_w23', type='number', value=.06, step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("w27, w28 "),
                    dcc.Input(id='nn_3_w24', type='number', value=.2, step=0.01, style={'width': 80}),
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("w29, w210 "),
                    dcc.Input(id='nn_3_w25', type='number', value=.35, step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
                html.Div([
                    html.Label("b21 "),
                    dcc.Input(id='nn_3_b21', type='number', value=0.4, step=0.01, style={'width': 80}),  
                ], style={'width': '10%', 'display': 'inline-block'}),
            ]),
        ], style={'width': '60%', 'display': 'inline-block'}),
        html.Div([
            html.Div(id="loss", style={'fontSize': 32, 'font-weight': 'bold', "border":"2px black solid"}),        
        ], style={'width': 250, 'display': 'inline-block', 'height': 50, })
    ])
#     html.Div([
#         dcc.Graph(id='target_function', figure=fig_t)
#     ], style={'width': '90%', 'display': 'inline-block', 'padding': '0 20'}),
    
])


@app.callback(
    Output('nn_0', 'figure'),
    Input('nn_0_w1', 'value'),
    Input('nn_0_b1', 'value'),
    Input('nn_0_w2', 'value'),
    Input('nn_0_b2', 'value'))
def update_graph(nn_0_w1, nn_0_b1, nn_0_w2, nn_0_b2):
    wnn_0[0]['W'] = np.array([[nn_0_w1]])
    wnn_0[0]['b'] = np.array([nn_0_b1])
    wnn_0[2]['W'] = np.array([[nn_0_w2]])
    wnn_0[2]['b'] = np.array([[nn_0_b2]])

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x.flatten(),y=nn_0_apply(wnn_0, x).flatten(), mode='lines',
                    name='lines'))

    fig.update_layout(xaxis=xaxis_layout, yaxis=yaxis_layout, autosize=True, showlegend=False,
                      plot_bgcolor='white', margin={'t': 0,'l':0,'b':0,'r':10})
    return fig

@app.callback(
    Output('nn_1', 'figure'),
    Input('nn_1_w21', 'value'),
    Input('nn_1_w22', 'value'),
    Input('nn_1_b11', 'value'),
    Input('nn_1_b12', 'value'))
def update_graph(nn_1_w21, nn_1_w22, nn_1_b11, nn_1_b12):
    wnn_1[0]['W'] = 100 * np.array([1, 1]).reshape(shw0_nn1)
    wnn_1[0]['b'] = -np.array([nn_1_b11, nn_1_b12]) * wnn_1[0]['W']
    wnn_1[2]['W'] = np.array([nn_1_w21, nn_1_w22]).reshape(shw2_nn1)
    wnn_1[2]['b'] = np.array([0])
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x.flatten(),y=nn_1_apply(wnn_1, x).flatten(), mode='lines',
                    name='lines'))

    fig.update_layout(xaxis=xaxis_layout, yaxis=yaxis_layout, autosize=True, showlegend=False,
                      plot_bgcolor='white')
    return fig

@app.callback(
    Output('nn_2', 'figure'),
    Input('nn_2_w21', 'value'),
    Input('nn_2_w22', 'value'),
    Input('nn_2_b11', 'value'),
    Input('nn_2_b12', 'value'),
    Input('nn_2_b13', 'value'),
    Input('nn_2_b14', 'value'))
def update_graph(nn_2_w21, nn_2_w22, nn_2_b11, nn_2_b12, nn_2_b13, nn_2_b14):
    wnn_2[0]['W'] = 100 * np.array([1, 1, 1, 1]).reshape(shw0_2)
    wnn_2[0]['b'] = -np.array([nn_2_b11, nn_2_b12, nn_2_b13, nn_2_b14]) * wnn_2[0]['W']
    wnn_2[2]['W'] = np.array([-nn_2_w21, nn_2_w21, nn_2_w22, -nn_2_w22]).reshape(shw2_2)
    wnn_2[2]['b'] = np.array([0])
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x.flatten(),y=nn_2_apply(wnn_2, x).flatten(), mode='lines',
                    name='lines'))
    fig.add_trace(go.Scatter(x=x.flatten(),y=np.sin(-x*(2*np.pi)).flatten(), mode='lines',
                    name='lines'))

    fig.update_layout(xaxis=xaxis_layout, yaxis=yaxis_layout, autosize=True, showlegend=False,
                      plot_bgcolor='white')
    return fig


@app.callback(
    [Output('nn_3', 'figure'),
    Output('loss', 'children')],
    [Input('nn_3_b11', 'value'),
    Input('nn_3_b12', 'value'),
    Input('nn_3_b13', 'value'),
    Input('nn_3_b14', 'value'),
    Input('nn_3_b15', 'value'),
    Input('nn_3_b16', 'value'),
    Input('nn_3_w21', 'value'),
    Input('nn_3_w22', 'value'),
    Input('nn_3_w23', 'value'),
    Input('nn_3_w24', 'value'),
    Input('nn_3_w25', 'value'),
    Input('nn_3_b21', 'value')])
def update_graph(nn_3_b11, nn_3_b12, nn_3_b13, nn_3_b14, nn_3_b15, nn_3_b16, nn_3_w21, nn_3_w22, nn_3_w23, nn_3_w24, nn_3_w25, nn_3_b21):
    wnn_3[0]['W'] = 100 * np.ones(10).reshape(shw0_3)
    wnn_3[0]['b'] = -np.array([nn_3_b11, nn_3_b12, nn_3_b12, nn_3_b13, nn_3_b13, nn_3_b14, nn_3_b14, nn_3_b15, nn_3_b15, nn_3_b16]) * wnn_3[0]['W']
    wnn_3[2]['W'] = np.array([-nn_3_w21, nn_3_w21, -nn_3_w22, nn_3_w22, nn_3_w23, -nn_3_w23, -nn_3_w24, nn_3_w24, nn_3_w25, -nn_3_w25]).reshape(shw2_3)
    wnn_3[2]['b'] = np.array([nn_3_b21])
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x.flatten(),y=nn_3_apply(wnn_3, x).flatten(), mode='lines',
                    name='NN', marker_colorscale=px.colors.qualitative.G10))
    fig.add_trace(go.Scatter(x=x.flatten(), y=f(x).flatten(), mode='lines',
                    name='f(x)', marker_colorscale=px.colors.qualitative.G10))
    fig.add_trace(go.Scatter(x=x.flatten(), y=f(x).flatten(), mode='markers',
                    name='y', marker_colorscale=px.colors.qualitative.G10, visible='legendonly'))

    fig.update_layout(xaxis=xaxis_layout, yaxis=yaxis_layout, autosize=True, showlegend=True,
                      plot_bgcolor='white')
    loss = np.round(np.sqrt(np.sum((f(x).flatten()-nn_3_apply(wnn_3, x).flatten())**2)/len(x.flatten())), 3)
    loss_s = u'Loss = {}'.format(loss)
    return fig, loss_s



In [13]:
app.run_server(mode="jupyterlab")