In [12]:
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import os
import sys
import time

import sknn
import theano
import theano.tensor as T
import lasagne
from lasagne import layers
from lasagne.updates import nesterov_momentum
from nolearn.lasagne import NeuralNet
from nolearn.lasagne import visualize

from functools import reduce

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc
from six.moves import cPickle as pickle

from sklearn.neural_network import MLPClassifier
from collections import Counter

%matplotlib inline

Import libraries

In [13]:
files = [
    'data_batch_1',
    'data_batch_2',
    'data_batch_3',
    'data_batch_4',
    'data_batch_5',
    'test_batch'
]
data = []
labels = []
start = time.time()
for file in files:
    with open(file, 'rb') as f:
        d = pickle.load(f, encoding='bytes')
        if file == 'test_batch':
            test_data = d[b'data']
            test_labels = d[b'labels']
        else:
            data.append(d[b'data'])
            labels.append(d[b'labels'])
end = time.time()
print('Time to load data: {:.3f}s'.format(end - start))
for i in range(len(data)):
    print('Train data {}:'.format(i), data[i].shape, len(labels[i]))
print('Test data:', test_data.shape, len(test_labels))

merged_data = reduce(lambda a,b: np.vstack((a,b)), data)
merged_labels = reduce(lambda a,b: a+b, labels)
print('Merged train data:', merged_data.shape, len(merged_labels))

Time to load data: 0.157s
Train data 0: (10000, 3072) 10000
Train data 1: (10000, 3072) 10000
Train data 2: (10000, 3072) 10000
Train data 3: (10000, 3072) 10000
Train data 4: (10000, 3072) 10000
Test data: (10000, 3072) 10000
Merged train data: (50000, 3072) 50000


Load data and print out time- this code is taken directly from section

In [14]:
scaler = StandardScaler()
scaler.fit(merged_data) 
#Train our scaler based on our training data

train_data = merged_data[:]
train_labels = merged_labels[:]
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
#Apply our scaling to our training and test data, creating a copy of our merged data called 'train_data'



Scaling all the data so the Neural Network will work as expected

In [15]:
mlp = MLPClassifier(solver='lbgfs', alpha=1e-5, hidden_layer_sizes=(5,2), random_state=1)
start = time.time()
mlp.fit(train_data, train_labels)
end = time.time()
print('Time to build: {:.3f}s'.format(end - start))
#Create and train a basic Multi Layer Perceptron model

Time to build: 171.217s


In [16]:
predictions = mlp.predict(train_data) #Store the predictions for this basic model
correct = merged_labels - predictions #If the numbers are the same, they'll be 0, otherwise- any other number
accuracy = (correct == 0).sum() / len(correct)
accuracy * 100
#This is the accuracy of the base model on the training data

32.013999999999996

In [17]:
test_pred = mlp.predict(test_data)
correct = test_labels - test_pred #If the numbers are the same, they'll be 0, otherwise- any other number
test_accuracy = (correct == 0).sum() / len(correct)
test_accuracy * 100
#Accuracy of the base model on the test data

30.34

Our basic MLP model seems to have an accuracy of roughly 30-32% to start- looks like there's a lot of optimization we can do on this

http://lasagne.readthedocs.io/en/latest/user/tutorial.html

The documentation for the Lasagne package was extensively utilized for creating the following Neural Network.

In [18]:
x_train = merged_data[:]
x_train = x_train.reshape(50000, 3, 32, 32)
y_train = train_labels[:]
#Prepare data for lasagne

In [19]:
def build_mlp(input_var=None):
    l_in = lasagne.layers.InputLayer(shape=(None, 1, 28, 28),
                                     input_var=input_var)
    l_in_drop = lasagne.layers.DropoutLayer(l_in, p=0.2)
    l_hid1 = lasagne.layers.DenseLayer(
        l_in_drop, num_units=800,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=lasagne.init.GlorotUniform())
    l_hid1_drop = lasagne.layers.DropoutLayer(l_hid1, p=0.5)

    l_hid2 = lasagne.layers.DenseLayer(
        l_hid1_drop, num_units=800,
        nonlinearity=lasagne.nonlinearities.rectify)

    l_hid2_drop = lasagne.layers.DropoutLayer(l_hid2, p=0.5)
    
    l_out = lasagne.layers.DenseLayer(
        l_hid2_drop, num_units=10,
        nonlinearity=lasagne.nonlinearities.softmax)
    return l_out

In [20]:
input_var = T.tensor4('inputs')
target_var = T.ivector('targets')
network = build_mlp(input_var)
prediction = lasagne.layers.get_output(network)
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)
loss = loss.mean()
params = lasagne.layers.get_all_params(network, trainable=True)
updates = lasagne.updates.nesterov_momentum(
        loss, params, learning_rate=0.01, momentum=0.9)
test_prediction = lasagne.layers.get_output(network, deterministic=True)
test_loss = lasagne.objectives.categorical_crossentropy(test_prediction,
                                                        target_var)
test_loss = test_loss.mean()
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var),
                  dtype=theano.config.floatX)

In [24]:
train_fn = theano.function([x_train, y_train], loss)

TypeError: ("Unknown input type: <class 'numpy.ndarray'>, expected Variable instance", array([[[[ 59,  43,  50, ..., 158, 152, 148],
         [ 16,   0,  18, ..., 123, 119, 122],
         [ 25,  16,  49, ..., 118, 120, 109],
         ..., 
         [208, 201, 198, ..., 160,  56,  53],
         [180, 173, 186, ..., 184,  97,  83],
         [177, 168, 179, ..., 216, 151, 123]],

        [[ 62,  46,  48, ..., 132, 125, 124],
         [ 20,   0,   8, ...,  88,  83,  87],
         [ 24,   7,  27, ...,  84,  84,  73],
         ..., 
         [170, 153, 161, ..., 133,  31,  34],
         [139, 123, 144, ..., 148,  62,  53],
         [144, 129, 142, ..., 184, 118,  92]],

        [[ 63,  45,  43, ..., 108, 102, 103],
         [ 20,   0,   0, ...,  55,  50,  57],
         [ 21,   0,   8, ...,  50,  50,  42],
         ..., 
         [ 96,  34,  26, ...,  70,   7,  20],
         [ 96,  42,  30, ...,  94,  34,  34],
         [116,  94,  87, ..., 140,  84,  72]]],


       [[[154, 126, 105, ...,  91,  87,  79],
         [140, 145, 125, ...,  96,  77,  71],
         [140, 139, 115, ...,  79,  68,  67],
         ..., 
         [175, 156, 154, ...,  42,  61,  93],
         [165, 156, 159, ..., 103, 123, 131],
         [163, 158, 163, ..., 143, 143, 143]],

        [[177, 137, 104, ...,  95,  90,  81],
         [160, 153, 125, ...,  99,  80,  73],
         [155, 146, 115, ...,  82,  70,  69],
         ..., 
         [167, 154, 160, ...,  34,  53,  83],
         [154, 152, 161, ...,  93, 114, 121],
         [148, 148, 156, ..., 133, 134, 133]],

        [[187, 136,  95, ...,  71,  71,  70],
         [169, 154, 118, ...,  78,  62,  61],
         [164, 149, 112, ...,  64,  55,  55],
         ..., 
         [166, 160, 170, ...,  36,  57,  91],
         [128, 130, 142, ...,  96, 120, 131],
         [120, 122, 133, ..., 139, 142, 144]]],


       [[[255, 253, 253, ..., 253, 253, 253],
         [255, 255, 255, ..., 255, 255, 255],
         [255, 254, 254, ..., 254, 254, 254],
         ..., 
         [113, 111, 105, ...,  72,  72,  72],
         [111, 104,  99, ...,  68,  70,  78],
         [106,  99,  95, ...,  78,  79,  80]],

        [[255, 253, 253, ..., 253, 253, 253],
         [255, 255, 255, ..., 255, 255, 255],
         [255, 254, 254, ..., 254, 254, 254],
         ..., 
         [120, 118, 112, ...,  81,  80,  80],
         [118, 111, 106, ...,  75,  76,  84],
         [113, 106, 102, ...,  85,  85,  86]],

        [[255, 253, 253, ..., 253, 253, 253],
         [255, 255, 255, ..., 255, 255, 255],
         [255, 254, 254, ..., 254, 254, 254],
         ..., 
         [112, 111, 106, ...,  80,  79,  79],
         [110, 104,  98, ...,  73,  75,  82],
         [105,  98,  94, ...,  83,  83,  84]]],


       ..., 
       [[[ 35,  40,  42, ...,  99,  79,  89],
         [ 57,  44,  50, ..., 156, 141, 116],
         [ 98,  64,  69, ..., 188, 119,  61],
         ..., 
         [ 73,  53,  54, ...,  17,  21,  33],
         [ 61,  55,  57, ...,  24,  17,   7],
         [ 44,  46,  49, ...,  27,  21,  12]],

        [[178, 176, 176, ..., 177, 147, 148],
         [182, 184, 183, ..., 182, 177, 149],
         [197, 189, 192, ..., 195, 135,  79],
         ..., 
         [ 79,  63,  68, ...,  40,  36,  48],
         [ 68,  70,  79, ...,  48,  35,  23],
         [ 56,  66,  77, ...,  52,  43,  31]],

        [[235, 239, 241, ..., 219, 197, 189],
         [234, 250, 240, ..., 200, 206, 175],
         [237, 252, 245, ..., 206, 147,  90],
         ..., 
         [ 77,  68,  80, ...,  64,  51,  49],
         [ 75,  86, 103, ...,  72,  53,  32],
         [ 73,  88, 105, ...,  77,  66,  50]]],


       [[[189, 186, 185, ..., 175, 172, 169],
         [194, 191, 190, ..., 173, 171, 167],
         [208, 205, 204, ..., 175, 172, 169],
         ..., 
         [207, 203, 203, ..., 135, 162, 168],
         [198, 189, 180, ..., 178, 175, 175],
         [198, 189, 178, ..., 195, 196, 195]],

        [[211, 208, 207, ..., 195, 194, 194],
         [210, 207, 206, ..., 192, 191, 190],
         [219, 216, 215, ..., 191, 190, 191],
         ..., 
         [199, 195, 196, ..., 132, 158, 163],
         [190, 181, 172, ..., 171, 169, 169],
         [189, 181, 170, ..., 184, 189, 190]],

        [[240, 236, 235, ..., 224, 222, 220],
         [239, 236, 235, ..., 220, 218, 216],
         [244, 240, 239, ..., 217, 216, 215],
         ..., 
         [181, 175, 173, ..., 127, 150, 151],
         [170, 159, 147, ..., 160, 156, 154],
         [173, 162, 149, ..., 169, 171, 171]]],


       [[[229, 236, 234, ..., 217, 221, 222],
         [222, 239, 233, ..., 223, 227, 210],
         [213, 234, 231, ..., 220, 220, 202],
         ..., 
         [150, 140, 132, ..., 224, 230, 241],
         [137, 130, 125, ..., 181, 202, 212],
         [122, 118, 120, ..., 179, 164, 163]],

        [[229, 237, 236, ..., 219, 223, 223],
         [221, 239, 234, ..., 223, 228, 211],
         [206, 232, 233, ..., 220, 219, 203],
         ..., 
         [143, 135, 127, ..., 222, 228, 241],
         [132, 127, 121, ..., 180, 201, 211],
         [119, 116, 116, ..., 177, 164, 163]],

        [[239, 247, 247, ..., 233, 234, 233],
         [229, 249, 246, ..., 236, 238, 220],
         [211, 239, 244, ..., 232, 232, 215],
         ..., 
         [135, 127, 120, ..., 218, 225, 238],
         [126, 120, 115, ..., 178, 198, 207],
         [114, 110, 111, ..., 173, 162, 161]]]], dtype=uint8))