# MNIST classifier from nnBuilder

This is a simple implementation of a MNIST classifier using the nnBuilder framework. The network is essentially identical to the one from the TensorFlow tutorial at https://www.tensorflow.org/tutorials/mnist/pros/.

In [None]:
import numpy as np
import tensorflow as tf
import sys, os, seaborn, time
import matplotlib.pyplot as plt
sys.path.append('../nnBuilder')
from nnLayer import *
from nnInput import *
from nnTrainer import *
from nnHandler import *

In [None]:
test_split_n=10 #Use less memory
data=Layer(type="MNIST",batch=128)
data_test=Layer(type="MNIST",test=True)
data_test=Layer(x=data_test,type="Batch_Slice",batch=10000//test_split_n)
noise=dict(type="Noise",rand_type="uniform",scale=0.2,drop_on_test=True)
batch_norm=dict(type="Batch_Norm")
layers=[]
layers.append(dict(type="Convolution",pad="SAME",window=5,stride=1,size=64,relu=True,in_features=[noise]))
layers.append(dict(type="Pool",pad="SAME",window=2,stride=1,pool_type="max",out_features=[batch_norm]))
layers.append(dict(type="Convolution",pad="SAME",window=5,stride=1,input_stride=2,size=128,relu=True,in_features=[noise]))
layers.append(dict(type="Pool",pad="SAME",window=4,stride=1,pool_type="max",out_features=[batch_norm]))
layers.append(dict(type="Convolution",pad="VALID",window=7,stride=1,input_stride=4,size=256,relu=True,
                   in_features=[noise],out_features=[batch_norm]))
layers.append(dict(type="Convolution",pad="SAME",window=1,stride=2,size=128,relu=True,
                   in_features=[noise],out_features=[batch_norm]))
layers.append(dict(type="Convolution",pad="SAME",window=1,stride=1,size=10,relu=False,
                   in_features=[noise],out_features=[batch_norm]))
#layers.append(dict(type="Relu",size=256,in_features=[noise],out_features=[batch_norm]))
#layers.append(dict(type="Relu",size=128,in_features=[noise],out_features=[batch_norm]))
#layers.append(dict(type="Linear",size=10,in_features=[dict(type="Dropout")]))
network_def=dict(type="Network",layers=layers)
network=Layer(x=data,**network_def)
network_eval=network.copy(x=data,share_vars=True,test=True)
network_test=network.copy(x=data_test,share_vars=True,test=True)
trainer=ClassifierTrainer(network=network,optimizer="adam",array=True)
trainer_eval=ClassifierTrainer(network=network_eval,test=True,array=True)
tester=ClassifierTrainer(network=network_test,test=True,array=True)
sess=SessManager(data,network,trainer,data_test,network_test,tester,network_eval,trainer_eval)
sess.start()

In [None]:
batches_per_step=5000#60000//128 #About 1 epoch
batches_per_eval=100
n_steps=100
def make_plt():
    %matplotlib notebook
    global fig,ax,train_plot,test_plot,eval_plot
    fig,ax = plt.subplots(1,1)
    train_plot=ax.plot([],[], label="training")[0]
    eval_plot=ax.plot([],[], label="eval")[0]
    test_plot=ax.plot([],[], label="test")[0]
    ax.set_ylim(0,1)
    plt.legend()
    fig.canvas.draw()
    time.sleep(.01)
def update_plt():
    x=[batches_per_step*i for i in range(len(trains))]
    train_plot.set_xdata(x)
    train_plot.set_ydata(trains)
    test_plot.set_xdata(x)
    test_plot.set_ydata(tests)
    eval_plot.set_xdata(x)
    eval_plot.set_ydata(evals)
    ax.set_xlim(0,x[-1])
    ax.set_ylim(0,tests[1]*1.1)
    fig.canvas.draw()
    #time.sleep(.01)
if "trains" not in globals():
    trains=[trainer.eval_error(n=batches_per_eval)[1]]
    evals=[trainer_eval.eval_error(n=batches_per_eval)[1]]
    tests=[tester.eval_error()[1]]
make_plt()
for i in range(n_steps):
    if i>10:
        r=5e-5
    elif i>5:
        r=5e-4
    else:
        r=5e-3
    trainer.train(batches_per_step,keep_rate=0.5,l2reg=1e-6,learn_rate=r)
    trains.append(trainer.eval_error(n=batches_per_eval)[1]) #Training error, reaches 0%
    evals.append(trainer_eval.eval_error(n=batches_per_eval,info="eval")[1]) 
    tests.append(tester.eval_error(n=test_split_n)[1])#Testing error, reaches 0.35-0.45%
    update_plt()

In [None]:
network.save() #The full network definition
save=network.save() #The full network definition
assert(Layer(x=data,**save).save()==save) #Consistency
print(save==network_def) #Saving adds stuff, should still be equivalent
save

In [None]:
range(1,2)