# Import Statement

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
from IPython.display import HTML
from torch import nn
from torch import optim

# Global Setting

In [None]:
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.bottom'] = False

# Data

In [None]:
data_length = 10000
x_limits = [ -10.0, 10.0 ]
x = np.linspace( x_limits[ 0 ], x_limits[ 1 ], data_length )
tensor_x = torch.tensor( x, dtype = torch.float32 ) \
                .reshape( ( -1, 1 ) )

In [None]:
rng = np.random.default_rng( 42 )
torch.manual_seed( 4121 )

In [None]:
y_2 = x ** 2 - 10 + rng.normal( size = data_length )
tensor_y_2 = torch.tensor( y_2, dtype = torch.float32 ) \
                  .reshape( ( -1, 1 ) )
plt.scatter( x, y_2, s = 1, color = "#cccccc" )
plt.show()

# Helper Function

In [None]:
def create_model( hidden_nodes ):
  return nn.Sequential( nn.Linear( 1, hidden_nodes ),
                        nn.ReLU(),
                        nn.Linear( hidden_nodes, 1 ) )

In [None]:
def print_model( model ):
  print( "\nStructure")
  print( "------------------------------------------------------\n")
  print( model )
  print( "\nParameters")
  print( "------------------------------------------------------\n")
  for layer in model:
    print( layer )
    for p in layer.parameters():
      print( p )

In [None]:
def initialize_weight( hidden_nodes, x_limits ):
    cutoff_bin = ( x_limits[ 1 ] - x_limits[ 0 ] ) / (hidden_nodes + 1)
    def initialize_weight_( model ):
        with torch.no_grad():
            half_index = int( np.floor( hidden_nodes / 2 ) )
            for i, param in enumerate( model.parameters() ):
                if i == 0:
                    for j, _ in enumerate( param.data ):
                        param.data[ j ][ 0 ] = (1 if j < half_index else -1 )
                if i == 1:
                    for j, _ in enumerate( param.data ):
                        param.data[ j ] = np.abs(x_limits[ 0 ] + ( j + 1 ) * cutoff_bin)
                if i >= 2:
                    param.data.fill_( 0 )
    return initialize_weight_

In [None]:
def fit( model,
         learning_rate,
         tensor_x,
         tensor_y,
         save_file_path = "fit.mp4",
         loss_function = nn.MSELoss(),
         animate = True,
         n_epochs = 100,
         linestyles = [ "dotted", "dashed", "dashdot" ],
         colors = [ "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
                    "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928" ] ):

  optimizer = optim.Adam( params = model.parameters(), lr = learning_rate )
  
  colors_length = len( colors )

  linestyles_length = len( linestyles )

  hidden_nodes = model[ 0 ].weight.size()[ 0 ]

  losses = []

  single_relu_limits = []

  all_relu_x_limits = []
  components_y_limits = []
  y_limits = []
  tensor_y_min = tensor_y.min().item()
  tensor_y_max = tensor_y.max().item()

  activation = {}
  def get_activation( index ):
    def hook(model, input, output):
        activation[ index ] = output.detach()
    return hook


  model[ 1 ].register_forward_hook(get_activation( 1 ))

  for epoch in range( n_epochs ):

    predicted = model( tensor_x )
    loss = loss_function( predicted, tensor_y )

    with torch.no_grad():
      before_relu_parameters = np.array( [ param.flatten().tolist()
                                            for param in model.parameters() ][:2] )
      relu_x_limits = ( -before_relu_parameters[ 1 ] /
                      before_relu_parameters[ 0 ] ).tolist() + x_limits
      relu_x_limits.sort()
      single_relu_limits.append(relu_x_limits[1:-1])
      all_relu_x_limits.append( relu_x_limits )
      predicted_x_limits = model( torch.tensor( relu_x_limits ).reshape( ( -1, 1 ) ) )
      components_y_limits.append( activation[ 1 ].transpose( 0, 1 ).tolist() )
      losses.append( loss.item() )
      y_limits.append( predicted_x_limits.flatten().tolist() )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  with torch.no_grad():
    loss = loss_function( predicted, tensor_y )
    before_relu_parameters = np.array( [ param.flatten().tolist()
                                          for param in model.parameters() ][:2] )
    relu_x_limits = ( -before_relu_parameters[ 1 ] /
                        before_relu_parameters[ 0 ] ).tolist() + x_limits
    relu_x_limits.sort()
    single_relu_limits.append(relu_x_limits[1:-1])
    all_relu_x_limits.append( relu_x_limits )
    predicted_x_limits = model( torch.tensor( relu_x_limits ).reshape( ( -1, 1 ) ) )
    components_y_limits.append( activation[ 1 ].transpose( 0, 1 ).tolist() )
    losses.append( loss.item() )
    y_limits.append( predicted_x_limits.flatten().tolist() )

  if not animate:
      return None

  losses_indices = [ i for i in range( len( losses ) ) ]

  fig, ( ax_1, ax_2 ) = plt.subplots( 1, 2 )
  fig.set_size_inches( 8, 4 )

  ax_1.set_title( "Data VS Approximation", pad = 20 )
  ax_1.set( xlabel = "x", ylabel = "y" )
  y_nl = tensor_y.numpy()
  ax_1.axis([ x_limits[ 0 ], x_limits[ 1 ],
              y_nl.min(), y_nl.max() ])

  ax_1.scatter( x, y_nl, s = 1, color = "#cccccc" )
  components = []
  for j in range( hidden_nodes ):
      cl, = ax_1.plot( [], [],
                       color = colors[ j % colors_length ],
                       linestyle = linestyles[ j % linestyles_length ] )
      components.append( cl )
  v_1 = ax_1.vlines( [], tensor_y_min, tensor_y_max, colors = "#cccccc" )
  l_1, = ax_1.plot( [], [], color = "black" )
  l_2, = ax_2.plot( [], [], color = "red" )
  annotations = []
  for i in range( hidden_nodes + 1 ):
    ann = ax_1.annotate(
        "",
        xy = (0,0),
        ha = "center",
        va = "top",
        color = "#aaaaaa" )
    annotations.append(ann)

  ax_2.set_title( "Error progression", pad = 20 )
  ax_2.set( xlabel = "epoch", ylabel = "error" )
  min_loss = np.min( losses )
  max_loss = np.max( losses )
  ax_2.axis([ losses_indices[ 0 ], losses_indices[ -1 ],
              min_loss, max_loss ])
  error_ann = ax_2.annotate(
      "",
      xy = ( losses_indices[ -1 ], max_loss ),
      position = ( losses_indices[ -1 ], max_loss ),
      ha = "right",
      va = "top",
      color = "red" )

  plt.subplots_adjust(left=0.09, right=0.95, top=0.85, bottom=0.15)
  plt.close()

  def animate( i ):
    for j in range( hidden_nodes ):
      components[ j ].set_data( all_relu_x_limits[ i ], components_y_limits[ i ][ j ] )
    l_1.set_data( all_relu_x_limits[ i ], y_limits[ i ] )
    v_1.set_segments( [np.array([[xx, tensor_y_min],
                         [xx, tensor_y_max]]) for xx in single_relu_limits[ i ]]  )
    for j, ann in enumerate( annotations ):
      new_pos = ( np.mean([ all_relu_x_limits[ i ][ j ], all_relu_x_limits[ i ][ j + 1 ] ]), tensor_y_max )
      ann.set_position( new_pos )
      ann.xy = new_pos
      ann.set_text( "$R_{}$".format( j + 1 ) )
    
    error_ann.set_text( "Error: {:.2f}".format( losses[ i ] ) )
    l_2.set_data( losses_indices[ :i ], losses[ :i ] )

  ani = animation.FuncAnimation( fig, animate, frames = len( losses ) )

  ani.save( save_file_path )

  return ani.to_jshtml()

# Training

## Model 1

In [None]:
hidden_nodes = 2
model_1 = create_model( hidden_nodes )
initialize_weight( hidden_nodes, [ -9, 9 ] )( model_1 )
print_model( model_1 )

In [None]:
jshtml_1 = fit( model_1, 0.5, tensor_x, tensor_y_2, "good-fit.mp4", colors = [ "blue", "green" ] )
with torch.no_grad():
   print( nn.MSELoss()( model_1( tensor_x ), tensor_y_2 ).item() )
HTML( jshtml_1 )

## Model 2

In [None]:
hidden_nodes = 2
model_2 = create_model( hidden_nodes )
initialize_weight( hidden_nodes, [ -9, 9 ] )( model_2 )
with torch.no_grad():
  model_2[ 0 ].weight[ 1 ][ 0 ] *= -1
  model_2[ 0 ].bias[ 1 ] *= -1
print_model( model_2 )

In [None]:
jshtml_2 = fit( model_2, 0.5, tensor_x, tensor_y_2, "bad-fit.mp4", colors = [ "blue", "green" ] )
with torch.no_grad():
   print( nn.MSELoss()( model_2( tensor_x ), tensor_y_2 ).item() )
HTML( jshtml_2 )