In [1]:
import os
import sys

sys.path.append(os.path.pardir)

EXAMPLE_DIR = os.path.curdir
DATA_DIR = os.path.join(EXAMPLE_DIR, 'mnist_data')

if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

In [2]:
# Download MNIST dataset

import numpy as np
from sklearn.datasets import fetch_mldata

def onehot_encoding(categories, max_categories):
    """Given a list of integer categories (out of a set of max_categories)
    return one-hot enocded values"""

    out_array = np.zeros((len(categories), max_categories))
    for key, val in enumerate(categories):
        out_array[key, int(val)] = 1.0

    return out_array


mnist_data = fetch_mldata('MNIST original', data_home=DATA_DIR)

image_data = mnist_data['data'].reshape(-1, 28, 28, order='F')[..., np.newaxis]
image_labels = onehot_encoding(list(mnist_data['target']), 10)
del mnist_data

#%matplotlib inline
#import matplotlib.pyplot as plt
#imgplot = plt.imshow(image_data[4,:,:])

In [3]:
# Initialize the model

from model_wrangler.corral.convolutional_feedforward import ConvolutionalFeedforward
from model_wrangler.dataset_managers import CategoricalDataManager

model_name = "mnist_example"

conv_ff_network = ConvolutionalFeedforward(
    name=model_name,
    verb=True,
    in_size=[28, 28, 1],
    out_size=10,
    conv_nodes=[16, 24],
    conv_params={
        'dropout_rate': 0.0,
        'kernel': [5, 5],
        'strides': 2,
        'pool_size': 2,
    },
    dense_nodes=[10],
    dense_params={
        'dropout_rate': 0.1,
        'activation': 'relu',
        'act_reg': None,
    },
    output_params={
        'dropout_rate': None,
        'activation': 'softmax',
        'act_reg': None,
    },
    num_epochs=10
)

conv_ff_network.tf_mod.DATA_CLASS = CategoricalDataManager

2017-11-04 22:42:31,107 model_wrangler.tf_models INFO     Save directory : ./mnist_example
2017-11-04 22:42:31,109 model_wrangler.tf_models INFO     Directory ./mnist_example already exists
2017-11-04 22:42:31,110 model_wrangler.tf_models INFO     Save directory : ./mnist_example/tb_log
2017-11-04 22:42:31,111 model_wrangler.tf_models INFO     Directory ./mnist_example/tb_log already exists


In [4]:
from model_wrangler.tf_ops import accuracy

In [5]:
conv_ff_network.score(
    image_data[::10, :, :, :],
    image_labels[::10, :],
    accuracy
)

0.063428573

In [6]:
# Run training
conv_ff_network.train(
    image_data[::10, :, :, :],
    image_labels[::10, :]
)

# monitor performance using tensorboard.
# Set up a server: `tensorboard --logdir examples/mnist_example/tb_log`
# and watch in a browser. Default at localhost:6006

2017-11-04 22:42:43,142 model_wrangler.dataset_managers INFO     Input data size (7000, 28)
2017-11-04 22:42:43,169 model_wrangler.dataset_managers INFO     Input has 10 groups
2017-11-04 22:42:43,176 model_wrangler.dataset_managers INFO     Num training samples 6305
2017-11-04 22:42:43,177 model_wrangler.dataset_managers INFO     Num holdout samples 695
2017-11-04 22:42:43,178 model_wrangler.model_wrangler INFO     Starting Epoch 0
2017-11-04 22:42:43,626 model_wrangler.model_wrangler INFO     Batch 0: Training score = 3863.681396
2017-11-04 22:42:43,627 model_wrangler.model_wrangler INFO     Batch 0: Holdout score = 11500.026367
2017-11-04 22:42:47,006 model_wrangler.model_wrangler INFO     Saving weights file in ./mnist_example
2017-11-04 22:42:47,820 model_wrangler.tf_models INFO     Save directory : ./mnist_example
2017-11-04 22:42:47,821 model_wrangler.tf_models INFO     Directory ./mnist_example already exists
2017-11-04 22:42:47,822 model_wrangler.tf_models INFO     Saving para

In [7]:
conv_ff_network.score(
    image_data[::10, :, :, :],
    image_labels[::10, :],
    accuracy
)

0.99185717

In [10]:
conv_ff_network.score(
    image_data[1::10, :, :, :],
    image_labels[1::10, :],
    accuracy
)

0.97042859

In [8]:
# You can load the file from disk!
param_file = os.path.join(model_name,'{}-params.json'.format(model_name))
restored_model = ConvolutionalFeedforward.load(param_file)

2017-11-04 22:44:45,537 model_wrangler.tf_models INFO     Save directory : ./mnist_example
2017-11-04 22:44:45,538 model_wrangler.tf_models INFO     Directory ./mnist_example already exists
2017-11-04 22:44:45,539 model_wrangler.tf_models INFO     Save directory : ./mnist_example/tb_log
2017-11-04 22:44:45,540 model_wrangler.tf_models INFO     Directory ./mnist_example/tb_log already exists
INFO:tensorflow:Restoring parameters from ./mnist_example/mnist_example-00000009


In [9]:
restored_model.score(
    image_data[::10, :, :, :],
    image_labels[::10, :],
    accuracy
)

0.99185717