In [3]:
import numpy as np
import matplotlib.pyplot as plt

from network_components import *
from network_logic import Model
from sample_data import *

import ipywidgets

In [4]:
def get_grid_predictions(x_min, x_max, nx, ny, y_min, y_max, model):
    x = np.linspace(x_min, x_max, nx)
    y = np.linspace(y_min, y_max, ny)
    xx, yy = np.meshgrid(x, y)
    points = np.vstack((
        xx.flatten(),
        yy.flatten()
    ))

    def softmax(Z):
        Z_exp = np.exp(Z)
        return Z_exp / np.sum(Z_exp, axis=0)
    preds = softmax(model(points))[0].reshape(ny, nx)
    
    return x, y, preds

In [70]:
# Get sample data
X, Y = circle()
# Transform Y into probabilities
Y_ = np.zeros((2, Y.shape[0]))
Y_[0, Y==1] = 1
Y_[1, Y==0] = 1

In [71]:
linear1 = LinearLayer(in_features=2, out_features=4)
relu1 = ReLU()
linear2 = LinearLayer(in_features=4, out_features=4)
relu2 = ReLU()
linear3 = LinearLayer(in_features=4, out_features=2)
smax_ce_loss = SoftmaxCrossEntropyLoss()

model = Model(
    linear1,
    relu1,
    linear2,
    relu2,
    linear3,
    loss_func=smax_ce_loss
)

In [72]:
grid_predictions = []
for i in range(1000):
    loss, acc = model.fit(X, Y_, lr=.01)
    if i % 1 == 0:
        print(f'epoch: {i}, loss = {loss:.3f}, accuracy = {acc:.3f}')
        grid_pred = get_grid_predictions(x_min=-2.5, x_max=2.5, nx=50, y_min=-2.5, y_max=2.5, ny=50, model=model)
        grid_predictions.append(grid_pred)

epoch: 0, loss = 0.814, accuracy = 0.500
epoch: 1, loss = 0.812, accuracy = 0.500
epoch: 2, loss = 0.811, accuracy = 0.500
epoch: 3, loss = 0.809, accuracy = 0.500
epoch: 4, loss = 0.808, accuracy = 0.500
epoch: 5, loss = 0.807, accuracy = 0.500
epoch: 6, loss = 0.805, accuracy = 0.500
epoch: 7, loss = 0.804, accuracy = 0.500
epoch: 8, loss = 0.802, accuracy = 0.500
epoch: 9, loss = 0.801, accuracy = 0.500
epoch: 10, loss = 0.800, accuracy = 0.500
epoch: 11, loss = 0.799, accuracy = 0.500
epoch: 12, loss = 0.797, accuracy = 0.500
epoch: 13, loss = 0.796, accuracy = 0.500
epoch: 14, loss = 0.795, accuracy = 0.500
epoch: 15, loss = 0.794, accuracy = 0.500
epoch: 16, loss = 0.792, accuracy = 0.500
epoch: 17, loss = 0.791, accuracy = 0.500
epoch: 18, loss = 0.790, accuracy = 0.500
epoch: 19, loss = 0.789, accuracy = 0.500
epoch: 20, loss = 0.788, accuracy = 0.500
epoch: 21, loss = 0.787, accuracy = 0.500
epoch: 22, loss = 0.786, accuracy = 0.500
epoch: 23, loss = 0.784, accuracy = 0.500
ep

epoch: 285, loss = 0.696, accuracy = 0.500
epoch: 286, loss = 0.696, accuracy = 0.500
epoch: 287, loss = 0.696, accuracy = 0.500
epoch: 288, loss = 0.696, accuracy = 0.500
epoch: 289, loss = 0.696, accuracy = 0.500
epoch: 290, loss = 0.696, accuracy = 0.500
epoch: 291, loss = 0.696, accuracy = 0.500
epoch: 292, loss = 0.696, accuracy = 0.500
epoch: 293, loss = 0.696, accuracy = 0.500
epoch: 294, loss = 0.696, accuracy = 0.500
epoch: 295, loss = 0.696, accuracy = 0.500
epoch: 296, loss = 0.696, accuracy = 0.500
epoch: 297, loss = 0.696, accuracy = 0.500
epoch: 298, loss = 0.696, accuracy = 0.500
epoch: 299, loss = 0.696, accuracy = 0.500
epoch: 300, loss = 0.695, accuracy = 0.500
epoch: 301, loss = 0.695, accuracy = 0.500
epoch: 302, loss = 0.695, accuracy = 0.500
epoch: 303, loss = 0.695, accuracy = 0.500
epoch: 304, loss = 0.695, accuracy = 0.500
epoch: 305, loss = 0.695, accuracy = 0.500
epoch: 306, loss = 0.695, accuracy = 0.500
epoch: 307, loss = 0.695, accuracy = 0.500
epoch: 308,

epoch: 576, loss = 0.690, accuracy = 0.592
epoch: 577, loss = 0.690, accuracy = 0.592
epoch: 578, loss = 0.690, accuracy = 0.590
epoch: 579, loss = 0.690, accuracy = 0.588
epoch: 580, loss = 0.690, accuracy = 0.588
epoch: 581, loss = 0.690, accuracy = 0.588
epoch: 582, loss = 0.690, accuracy = 0.588
epoch: 583, loss = 0.690, accuracy = 0.590
epoch: 584, loss = 0.690, accuracy = 0.590
epoch: 585, loss = 0.690, accuracy = 0.588
epoch: 586, loss = 0.690, accuracy = 0.586
epoch: 587, loss = 0.690, accuracy = 0.586
epoch: 588, loss = 0.689, accuracy = 0.586
epoch: 589, loss = 0.689, accuracy = 0.586
epoch: 590, loss = 0.689, accuracy = 0.586
epoch: 591, loss = 0.689, accuracy = 0.586
epoch: 592, loss = 0.689, accuracy = 0.584
epoch: 593, loss = 0.689, accuracy = 0.582
epoch: 594, loss = 0.689, accuracy = 0.582
epoch: 595, loss = 0.689, accuracy = 0.580
epoch: 596, loss = 0.689, accuracy = 0.580
epoch: 597, loss = 0.689, accuracy = 0.580
epoch: 598, loss = 0.689, accuracy = 0.580
epoch: 599,

epoch: 860, loss = 0.685, accuracy = 0.536
epoch: 861, loss = 0.685, accuracy = 0.536
epoch: 862, loss = 0.685, accuracy = 0.536
epoch: 863, loss = 0.685, accuracy = 0.536
epoch: 864, loss = 0.685, accuracy = 0.536
epoch: 865, loss = 0.685, accuracy = 0.536
epoch: 866, loss = 0.685, accuracy = 0.536
epoch: 867, loss = 0.685, accuracy = 0.536
epoch: 868, loss = 0.685, accuracy = 0.536
epoch: 869, loss = 0.685, accuracy = 0.536
epoch: 870, loss = 0.685, accuracy = 0.536
epoch: 871, loss = 0.685, accuracy = 0.536
epoch: 872, loss = 0.685, accuracy = 0.536
epoch: 873, loss = 0.685, accuracy = 0.536
epoch: 874, loss = 0.685, accuracy = 0.536
epoch: 875, loss = 0.685, accuracy = 0.536
epoch: 876, loss = 0.685, accuracy = 0.536
epoch: 877, loss = 0.685, accuracy = 0.536
epoch: 878, loss = 0.685, accuracy = 0.536
epoch: 879, loss = 0.684, accuracy = 0.536
epoch: 880, loss = 0.684, accuracy = 0.536
epoch: 881, loss = 0.684, accuracy = 0.536
epoch: 882, loss = 0.684, accuracy = 0.536
epoch: 883,

In [73]:
def plot_grid_prediction(grid_prediction, X, Y):
    x, y, preds = grid_prediction
    
    cmap = 'RdYlBu'
    num_levels = 200

    plt.contourf(x, y, preds, levels=np.linspace(0, 1, num_levels), cmap=cmap)
    plt.colorbar(ticks=np.linspace(0, 1, 11)) 
    plt.contour(x, y, preds, levels=[.5], linewidths=.5)
    plt.scatter(X[0], X[1], c=Y, edgecolors='grey', cmap=cmap)
    plt.show()

def plot_grid_prediction_wrapper(x):
    plot_grid_prediction(grid_predictions[x], X, Y)

x = ipywidgets.IntSlider(
    value=0,
    min=0,
    max=len(grid_predictions)-1,
    description='Epoch:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

ipywidgets.interact(plot_grid_prediction_wrapper, x=x)
print()

interactive(children=(IntSlider(value=0, continuous_update=False, description='Epoch:', max=999), Output()), _…


