In [None]:
from __future__ import print_function
from pylab import *
import os, struct
import gzip
import shutil
import caffe
import urllib.request
import numpy as np
import lmdb
from array import array as pyarray
from numpy import append, array, int8, uint8, zeros
from caffe import layers as L, params as P
from caffe.proto import caffe_pb2
%matplotlib inline

In [None]:
def lenet(lmdb, batch_size):
    n = caffe.NetSpec()
    
    n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb,
                             transform_param=dict(scale=1./255), ntop=2)
    
    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20, weight_filler=dict(type='xavier'))
    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50, weight_filler=dict(type='xavier'))
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    n.fc1 =   L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier'))
    n.relu1 = L.ReLU(n.fc1, in_place=True)
    n.score = L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier'))
    n.loss =  L.SoftmaxWithLoss(n.score, n.label)
    
    return n.to_proto()

def download_dataset(fn):
    dataset_gzfilename = os.path.join(dataset_dir, '%s.gz' % fn)
    urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/%s.gz' % fn, dataset_gzfilename)
    with gzip.open(dataset_gzfilename, 'rb') as f_in, open(dataset_gzfilename[:-3], 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)
    print('download dataset: %s' % fn)
    
def check_dir(dir_path):
    if not os.path.isdir(dir_path):
        print('make dir: %s' % dir_path)
        os.mkdir(dir_path)
        
def load_mnist(fname_img, fname_lbl, digits=np.arange(10)):
    flbl = open(fname_lbl, 'rb')
    magic_nr, size = struct.unpack(">II", flbl.read(8))
    lbl = pyarray("b", flbl.read())
    flbl.close()

    fimg = open(fname_img, 'rb')
    magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
    img = pyarray("B", fimg.read())
    fimg.close()

    ind = [ k for k in range(size) if lbl[k] in digits ]
    N = len(ind)
    images = zeros((N, rows, cols), dtype=uint8)
    labels = zeros((N, 1), dtype=int8)
    for i in range(len(ind)):
        images[i] = array(img[ ind[i]*rows*cols : (ind[i]+1)*rows*cols ]).reshape((rows, cols))
        labels[i] = lbl[ind[i]]

    return images, labels

def convert_to_lmdb(images, labels, lmdb_path):
    N =images.shape[0]
    hight=images.shape[1]
    width=images.shape[2]
    X=images.reshape(N, 1, hight, width)
    y = labels
    map_size = X.nbytes * 10
    env = lmdb.open(lmdb_path, map_size=map_size)
    with env.begin(write=True) as txn:
        for i in range(N):
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = X.shape[1]
            datum.height = X.shape[2]
            datum.width = X.shape[3]
            datum.data = X[i].tostring()
            datum.label = int(y[i])
            str_id = '{:08}'.format(i)
            txn.put(str_id.encode('ascii'), datum.SerializeToString())

In [None]:
workdir_prefix = '/batch'
check_dir(workdir_prefix)
dataset_dir = os.path.join(workdir_prefix, 'data')
check_dir(dataset_dir)
model_dir = os.path.join(workdir_prefix, 'models')
check_dir(model_dir)
snapshot_dir = os.path.join(model_dir, 'lenet')
check_dir(snapshot_dir)
mnist_dataset_filename = ['train-images-idx3-ubyte', 'train-labels-idx1-ubyte', 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte']
train_net_path = os.path.join(model_dir, 'lenet_auto_train.prototxt')
test_net_path = os.path.join(model_dir, 'lenet_auto_test.prototxt')
solver_config_path = os.path.join(model_dir, 'lenet_auto_solver.prototxt')
mnist_train_lmdb_path = os.path.join(model_dir, 'mnist_train_lmdb')
check_dir(mnist_train_lmdb_path)
mnist_test_lmdb_path = os.path.join(model_dir, 'mnist_test_lmdb')
check_dir(mnist_test_lmdb_path)

In [None]:
for fn in mnist_dataset_filename:
    if not os.path.isfile(os.path.join(dataset_dir, fn)):
        download_dataset(fn)

In [None]:
print('convert dataset to lmdb format.')
train_mnist_image, train_mnist_label = load_mnist(os.path.join(dataset_dir, mnist_dataset_filename[0]), os.path.join(dataset_dir, mnist_dataset_filename[1]), np.arange(10))
convert_to_lmdb(train_mnist_image, train_mnist_label, mnist_train_lmdb_path)
test_mnist_image, test_mnist_label = load_mnist(os.path.join(dataset_dir, mnist_dataset_filename[2]), os.path.join(dataset_dir, mnist_dataset_filename[3]), np.arange(10))
convert_to_lmdb(test_mnist_image, test_mnist_label, mnist_test_lmdb_path)

In [None]:
with open(train_net_path, 'w') as f:
    f.write(str(lenet(mnist_train_lmdb_path, 64)))    
with open(test_net_path, 'w') as f:
    f.write(str(lenet(mnist_test_lmdb_path, 100)))

In [None]:
s = caffe_pb2.SolverParameter()
s.random_seed = 0xCAFFE
s.train_net = train_net_path
s.test_net.append(test_net_path)
s.test_interval = 500
s.test_iter.append(100)
s.max_iter = 10000
s.type = "SGD"
s.base_lr = 0.01
s.momentum = 0.9
s.weight_decay = 5e-4
s.lr_policy = 'inv'
s.gamma = 0.0001
s.power = 0.75
s.display = 1000
s.snapshot = 5000
s.snapshot_prefix = snapshot_dir
s.solver_mode = caffe_pb2.SolverParameter.CPU
with open(solver_config_path, 'w') as f:
    f.write(str(s))

In [None]:
solver = None
solver = caffe.get_solver(solver_config_path)

In [None]:
niter = 250
train_loss = zeros(niter)
test_acc = zeros(niter)

In [None]:
for it in range(niter):
    solver.step(1)
    train_loss[it] = solver.net.blobs['loss'].data
    correct = 0
    for test_it in range(100):
        solver.test_nets[0].forward()
        correct += sum(solver.test_nets[0].blobs['score'].data.argmax(1)
                       == solver.test_nets[0].blobs['label'].data)
    test_acc[it] = correct / 1e4
    print('[epoch %d] loss: %f, accuracy: %f' % (it + 1, train_loss[it], test_acc[it]))

In [None]:
_, ax1 = subplots()
ax2 = ax1.twinx()
ax1.plot(arange(niter), train_loss)
ax2.plot(arange(niter), test_acc, 'r')
ax1.set_xlabel('iteration')
ax1.set_ylabel('train loss')
ax2.set_ylabel('test accuracy')
ax2.set_title('Custom Test Accuracy: {:.2f}'.format(test_acc[-1]))