# DeepLR (MNIST)

### 1. Setup
* Set up the Python environment and Data for Caffe:

In [1]:
from pylab import *
%matplotlib inline

# import caffe, adding it to `sys.path` if needed. Make suer pycaffe is already built.

caffe_root = '/home/yanzhaowu/research/lr_study/caffe-gt/'  # set the root directory of Caffe
lrbench_root = '/home/yanzhaowu/research/lr_study/DeepLR/LRBench' # set the root directory of LRBench
import sys
import os
sys.path.insert(0, os.path.join(caffe_root, 'python'))
sys.path.insert(0, lrbench_root)
import caffe

# Setup device
caffe.set_device(0)
caffe.set_mode_gpu()

# run scripts from caffe root
cwd = os.getcwd()
# print 'current directory', cwd
#  uncomment the following for preparing data
#os.chdir(caffe_root)
# Download data
#!data/cifar-10/get_cifar10.sh
# Prepare data
#!examples/cifar10/create_cifar10.sh
# back to cwd
os.chdir(cwd)

### 2. The Neural Network 

Typically, two external files define the neural network and corresponding training methods:
* the net `prototxt` defines the architecture and path to the train/test data (mnist_train_test.prototxt)
* the solver `prototxt` define the learning hyper-parameters (here, we use the python interface of Caffe instead.)

### 3. Metrics

* Utility
    1. Accuracy (Top-1, Top-5) # Note for MNIST Top-5 may be 100%
    2. Average Confidence (AC)
    3. Confidence Variance (CV)
    4. Confidence Variance Across Class (CVAC)
* Robutness
    5. Loss Difference
    6. Base LR Scale (1, 0.1, 0.5, 5, 10)
* Cost
    7. #Iterations
    8. #Iterations @ Accuracy Threshold


### 4. Experiments

Compare different learning rates by setting the `solver`.


In [2]:
# Solver Configuration (for writeSolver function)

# export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

train_net_path = 'mnist_train_test.prototxt'
test_net_path = 'mnist_train_test.prototxt'
solver_config_path = 'tmp_mnist_solver.prototxt' # temporary solver file
### define solver
from caffe.proto import caffe_pb2

def generateSolver(LR):
    """
    base_lr: k0
    """
    solver = caffe_pb2.SolverParameter()

    # Set the random seed for reproducible experiments
    solver.random_seed = 0xCAFFE

    # Specify locations
    solver.train_net = train_net_path
    solver.test_net.append(test_net_path)

    # Display the current training status.
    solver.display = 100
    # Test
    solver.test_interval = 500 # test every test_interval iterations
    solver.test_iter.append(100) # #iterations for each test
    solver.max_iter = 10000
    
    
    # Set momentum (Default)
    solver.momentum = 0.9
    solver.weight_decay = 5e-4 # Different from CIFAR-10
    
    # Train on the GPU
    solver.solver_mode = caffe_pb2.SolverParameter.GPU

    # Snapshot
    solver.snapshot = 5000
    solver.snapshot_prefix = LR.toString() + '/snapshot/mnist_' # Single Solver
    
    # LR parameters
    solver.type = 'SGD'
    
    # Using LRBench LRs
    policy = 'SET'
    solver.lr_policy = 'set'
    solver.base_lr = LR.getLR(0)

    return solver    

In [3]:
def trainWithLR(LR):
    lrPolicy = LR.toString()
    if not os.path.isdir(lrPolicy):
        os.mkdir(lrPolicy)
    if not os.path.isdir(lrPolicy + '/snapshot'):
        os.mkdir(lrPolicy+'/snapshot')
    # Write the solver to a temporary file
    with open(os.path.join(lrPolicy, lrPolicy+'_'+solver_config_path), 'w') as f:
        f.write(str(generateSolver(LR)))
    ### load the solver
    solver = None  # avoid two solvers
    solver = caffe.get_solver(os.path.join(lrPolicy, lrPolicy+'_'+solver_config_path))

    ### train (Parameters)
    niter = 10000+1
    test_interval = 500
    test_iteration = 100
    display = 100
    nclass = 10
    nTestSample = 10000.0
    # losses will also be stored in the log
    train_loss = zeros(int(np.ceil(niter * 1.0 /display)))
    lr = zeros(niter)

    test_acc = zeros(int(np.ceil(niter * 1.0 / test_interval)))
    test_acc_top5 = zeros(int(np.ceil(niter * 1.0 / test_interval)))
    test_loss = zeros(int(np.ceil(niter * 1.0 / test_interval)))
    test_average_confidence = zeros(int(np.ceil(niter * 1.0 / test_interval)))
    test_confidence_variance = zeros(int(np.ceil(niter * 1.0 / test_interval)))
    test_confidence_variance_per_class = zeros(int(np.ceil(niter * 1.0 /test_interval)))

    _train_loss = 0
    _test_loss = 0

    _probability = np.empty((0, nclass))

    import time

    start_time = time.time()
    # the major training part
    for it in range(niter):
        lr[it] = solver.lr
        solver.set_learning_rate(LR.getLR(it)) # Set learning rate with 'SET' policy
        solver.step(1)  # SGD by Caffe
        # store the train loss
        _train_loss += solver.net.blobs['loss'].data
        if it % display == 0:
            # average train loss
            train_loss[int(it / display)] = _train_loss / min(it+1, display)
            _train_loss = 0

        # run a full test every so often
        if it % test_interval == 0:
            # print 'Iteration', it, 'testing...'
            for test_it in range(test_iteration):
                solver.test_nets[0].forward()
                _test_loss += solver.test_nets[0].blobs['loss'].data
                if len(_probability) > 0:
                    _probability = np.append(_probability, solver.test_nets[0].blobs['probability'].data, axis=0)
                    gt_labels = np.append(gt_labels, solver.test_nets[0].blobs['label'].data, axis=0)
                else:
                    _probability = solver.test_nets[0].blobs['probability'].data
                    gt_labels = solver.test_nets[0].blobs['label'].data
            test_loss[it // test_interval] = _test_loss / test_iteration
            tmp_confidence = list()
            tmp_confidence_top5 = list()
            tmp_confidence_per_class = [[] for _ in range(nclass)]
            for _i in range(len(_probability)):
                if np.argmax(_probability[_i]) == int(gt_labels[_i]):
                    tmp_confidence.append(np.max(_probability[_i]))
                    tmp_confidence_per_class[int(gt_labels[_i])].append(np.max(_probability[_i]))
                if int(gt_labels[_i]) in (-_probability[_i]).argsort()[:5]:
                    tmp_confidence_top5.append(_probability[int(gt_labels[_i])])
            test_average_confidence[it // test_interval] = np.mean(tmp_confidence)
            test_confidence_variance[it // test_interval] = np.std(tmp_confidence)
            test_acc[it // test_interval] = len(tmp_confidence) / nTestSample
            test_acc_top5[it // test_interval] = len(tmp_confidence_top5) / nTestSample
            test_confidence_variance_per_class[it // test_interval] = np.nanstd([np.nanmean(_per_class) for _per_class in tmp_confidence_per_class if _per_class != []])

            _test_loss = 0
            _accuracy = 0
            _accuracy_top5 = 0
            _probability = np.empty((0, nclass))

    end_time = time.time()
    print '### Training Time (s): {}, Final Accuracy: {}'.format(end_time-start_time, test_acc[-1])
    return  train_loss, lr, test_acc, test_acc_top5, test_loss, test_average_confidence, test_confidence_variance, test_confidence_variance_per_class

In [None]:
from LRBench.lr.LR import LR

# Obtain Learning Rates from Database
# Here, just for demo
LRs = [LR({'lrPolicy': 'FIX', 'k0': 0.01}), 
       LR({'lrPolicy': 'EXP', 'k0': 0.01, 'gamma': 0.99994}),
       LR({'lrPolicy': 'NSTEP', 'k0': 0.001, 'gamma': 0.1, 'l': [8571, 9286, 10000]}),
      ]
for p in LRs:
    print p.toString()
    lr_policy = p.toString()
    train_loss, lr, test_acc, test_acc_top5, test_loss, test_average_confidence, test_confidence_variance, test_confidence_variance_per_class = trainWithLR(p)
    np.savez(os.path.join(lr_policy, lr_policy + '.npz'),
                      train_loss=train_loss,
                      learning_rate=lr,
                      test_acc = test_acc,
                      test_acc_top5=test_acc_top5,
                      test_loss = test_loss,
                      test_average_confidence = test_average_confidence,
                      test_confidence_variance =test_confidence_variance,
                      test_confidence_variance_per_class = test_confidence_variance_per_class)