In [1]:
import numpy as np
import matplotlib.animation as animation
import matplotlib.pyplot as plt

from network_components import *
from network_logic import Model
from sample_data import *
from visualize import *

import ipywidgets

In [2]:
# Get sample data

X, Y = circle()

# Transform Y into probabilities
Y_ = np.zeros((2, Y.shape[0]))
Y_[0, Y==1] = 1
Y_[1, Y==0] = 1

In [6]:
# Define model

linear1 = LinearLayer(in_features=2, out_features=6)
# relu1 = ReLU()
tanh1 = TanH()
linear2 = LinearLayer(in_features=6, out_features=6)
# relu2 = ReLU()
tanh2 = TanH()
linear3 = LinearLayer(in_features=6, out_features=2)
smax_ce_loss = SoftmaxCrossEntropyLoss()

model = Model(
    linear1,
    # relu1,
    tanh1,
    linear2,
    # relu2,
    tanh2,
    linear3,
    loss_func=smax_ce_loss
)

softmax = smax_ce_loss.softmax

In [7]:
# Training

num_epochs = 100
batch_size = 16
lr = .03

get_grid_predictions_wrapper = lambda model: get_grid_predictions(
    -2.5, 2.5, 50, -2.5, 2.5, 50, 
    lambda x: softmax(model(x))
)

loss_list, acc_list, grid_prediction_list = model.fit(
    X, 
    Y_, 
    num_epochs, 
    batch_size, 
    lr,
    get_grid_predictions_wrapper
)

#Epoch: 100%|####################| 100/100 [00:01<00:00, 58.23it/s, Acc = 1.000]


In [26]:
# Plot output

def plot_grid_prediction_wrapper(x):
    fig, ax, cf, title = plot_grid_prediction(grid_prediction_list[x], X, Y, x, acc_list[x], dpi=100)
    plt.show()

x = ipywidgets.IntSlider(
    value=0,
    min=0,
    max=len(grid_prediction_list)-1,
    description='Epoch:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

ipywidgets.interact(plot_grid_prediction_wrapper, x=x)
print()

interactive(children=(IntSlider(value=0, continuous_update=False, description='Epoch:', max=99), Output()), _d…




In [27]:
# Make animation

plt.ioff()
plt.rcParams["animation.html"] = "jshtml"

epoch = 0
fig, ax, cf, title = plot_grid_prediction(grid_prediction_list[epoch], X, Y, epoch, acc_list[epoch], dpi=100)
fig.tight_layout()

def animate(epoch):
    global cf
    cf = update_grid_prediction(grid_prediction_list[epoch], epoch, acc_list[epoch], ax, cf, title)
    return cf.collections
    
anim = animation.FuncAnimation(fig, animate, frames=60, blit=True)

In [16]:
anim

In [28]:
anim.save('circle.gif', fps=15)