In [1]:
require 'nn'
require 'optim'

torch.manualSeed(0)
torch.setnumthreads(4)

In [2]:
function setup() 
    classes = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }
    geometry = { 32, 32 }

    net = nn.Sequential()

    net:add(nn.SpatialConvolution(1, 6, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.SpatialConvolution(6, 16, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.View(16*5*5))
    net:add(nn.Linear(16*5*5, 120))
    net:add(nn.ReLU())
    net:add(nn.Linear(120, 84))
    net:add(nn.ReLU())
    net:add(nn.Linear(84, #classes))
    net:add(nn.LogSoftMax())
    
    parameters, gradParameters = net:getParameters()
    criterion = nn.ClassNLLCriterion()
    confusion = optim.ConfusionMatrix(classes)
end

In [3]:
function normalize(input, n_channels)
    local mean = {}
    local stdev = {}
    
    for channel = 1, n_channels do
        mean[channel] = input.data[{ {}, {channel}, {}, {} }]:mean()
        stdev[channel] = input.data[{ {}, {channel}, {}, {} }]:std()
        
        print('Channel ' .. channel .. ' mean: ' .. mean[channel] .. ' stdev: ' .. stdev[channel])
        
        input.data[{ {}, {channel}, {}, {} }]:add(-mean[channel])
        input.data[{ {}, {channel}, {}, {} }]:div(stdev[channel])
    end
end

In [4]:
function load_data(double_training, test_method)    
    train = torch.load('mnist-p1b-train.t7')
    test = torch.load('mnist-p1b-test.t7')
    
    n_train = train.data:size()[1]
    n_test = test.data:size()[1]
    
    if double_training then
        train.data:resize(n_train * 2, 1, geometry[1], geometry[2])
        train.label:resize(n_train * 2)
        for i = 1, n_train do
            train.data[n_train + i] = 255 - train.data[i]
            train.label[n_train + i] = train.label[i]
        end
        n_train = n_train * 2
    end
    
    if test_method == 'invert' then
        for i = 1, n_test do
            test.data[i] = 255 - test.data[i]
        end
    elseif test_method == 'double' then
        test.data:resize(n_test * 2, 1, geometry[1], geometry[2])
        test.label:resize(n_test * 2)
        for i = 1, n_test do
            test.data[n_test + i] = 255 - test.data[i]
            test.label[n_test + i] = test.label[i]
        end
        n_test = n_test * 2
    end
    
    train.data = train.data:double()
    test.data = test.data:double()
end

In [5]:
function exec_training(obj, n_channels, current_epoch) 
    confusion:zero()
    
    for t = 1, n_train, batch_size do
        local limit = math.min(t + batch_size - 1, n_train)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local feval = function(x)
            collectgarbage()
            
            if x ~= parameters then
                parameters:copy(x)
            end
            gradParameters:zero()
            
            local outputs = net:forward(inputs)
            local f = criterion:forward(outputs, targets)
            local df_do = criterion:backward(outputs, targets)
            net:backward(inputs, df_do)

            for i = 1, limited_batch_size do
                confusion:add(outputs[i], targets[i])
            end
            
            return f, gradParameters
        end
        
        sgd_state = sgd_state or {
            learningRate = 0.03,
            learningRateDecay = 1e-7,
            momentum = 0.5
        }
        optim.sgd(feval, parameters, sgd_state)
    end
    
    confusion:updateValids()
    return confusion.averageValid
end

In [6]:
function exec_testing(obj, n_channels, current_epoch)
    confusion:zero()
    
    for t = 1, n_test, batch_size do
        local limit = math.min(t + batch_size - 1, n_test)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local preds = net:forward(inputs)
        
        for i = 1, limited_batch_size do
            confusion:add(preds[i], targets[i])
        end
    end 
    
    confusion:updateValids()
    return confusion.averageValid
end

In [8]:
setup()
load_data(false, 'invert')
normalize(train, 1)
normalize(test, 1)

n_epoch = 32
batch_size = 256
local last_train_conf
local last_test_conf
for epoch = 1, n_epoch do
    local acc_train = exec_training(train, 1, epoch)
    
    if (epoch == n_epoch) then
        last_train_conf = confusion:__tostring__()
    end
    
    local acc_test = exec_testing(test, 1, epoch)
    
    if (epoch == n_epoch) then
        last_test_conf = confusion:__tostring__()
    end
    
    io.write(string.format('Epoch %3d: %.4f%% | %.4f%%\n', epoch, acc_train * 100, acc_test * 100))
end

print(last_train_conf)
print(last_test_conf)

Channel 1 mean: 25.509416422526 stdev: 70.180423838273	


Channel 1 mean: 229.12859375 stdev: 70.739298703524	


Epoch   1: 50.9522% | 0.8196%


Epoch   2: 91.0095% | 9.1345%


Epoch   3: 94.9684% | 10.8786%


Epoch   4: 96.3106% | 14.1486%


Epoch   5: 96.9947% | 16.8103%


Epoch   6: 97.4155% | 18.7987%


Epoch   7: 97.7032% | 20.3678%


Epoch   8: 97.9702% | 21.9936%


Epoch   9: 98.1568% | 23.7502%


Epoch  10: 98.3046% | 24.9325%


Epoch  11: 98.3956% | 25.9367%


Epoch  12: 98.5231% | 26.8109%


Epoch  13: 98.6360% | 27.7519%


Epoch  14: 98.7139% | 28.5954%


Epoch  15: 98.8025% | 29.2623%


Epoch  16: 98.8848% | 29.6898%


Epoch  17: 98.9554% | 30.2245%


Epoch  18: 99.0113% | 30.2899%


Epoch  19: 99.0798% | 30.7406%


Epoch  20: 99.1239% | 30.9558%


Epoch  21: 99.1661% | 31.0914%


Epoch  22: 99.2128% | 31.0630%


Epoch  23: 99.2636% | 31.1163%


Epoch  24: 99.3009% | 31.0727%


Epoch  25: 99.3539% | 31.1361%


Epoch  26: 99.3875% | 31.0664%


Epoch  27: 99.4239% | 31.0828%


Epoch  28: 99.4481% | 31.1699%


Epoch  29: 99.4830% | 31.1294%


Epoch  30: 99.5260% | 31.1702%


Epoch  31: 99.5546% | 31.2417%


Epoch  32: 99.5780% | 31.2405%
ConfusionMatrix:
[[    5914       1       2       0       0       0       2       0       1       3]   99.848% 	[class: 1]
 [       1    6721       2       0       1       0       2      10       4       1]   99.689% 	[class: 2]
 [       2       3    5946       1       1       0       0       4       1       0]   99.799% 	[class: 3]
 [       1       1       4    6101       0       7       0       6       6       5]   99.511% 	[class: 4]
 [       0       3       2       0    5816       0       5       2       1      13]   99.555% 	[class: 5]
 [       2       1       1       5       1    5395       8       0       5       3]   99.520% 	[class: 6]
 [       1       2       1       0       2       6    5902       0       4       0]   99.730% 	[class: 7]
 [       0       9       3       1       1       0       0    6238       6       7]   99.569% 	[class: 8]
 [       4       6       1       7       3       3       4       2    5813       8]   99.351% 	[class: 9

In [9]:
load_data(true, 'double')
normalize(train, 1)
normalize(test, 1)

for epoch = 1, n_epoch do
    local acc_train = exec_training(train, 1, epoch, true)
    
    if (epoch == n_epoch) then
        last_train_conf = confusion:__tostring__()
    end
    
    local acc_test = exec_testing(test, 1, epoch, 'both')
    
    if (epoch == n_epoch) then
        last_test_conf = confusion:__tostring__()
    end
    
    io.write(string.format('Epoch %3d: %.4f%% | %.4f%%\n', epoch, acc_train * 100, acc_test * 100))
end

print(last_train_conf)
print(last_test_conf)

Channel 1 mean: 127.5 stdev: 123.80376034192	


Channel 1 mean: 127.5 stdev: 123.82414832818	


Epoch   1: 96.6630% | 89.4073%


Epoch   2: 97.7767% | 96.1924%


Epoch   3: 98.3195% | 97.4384%


Epoch   4: 98.5649% | 97.9543%


Epoch   5: 98.7385% | 98.1426%


Epoch   6: 98.8817% | 98.3238%


Epoch   7: 98.9801% | 98.4296%


Epoch   8: 99.0576% | 98.5067%


Epoch   9: 99.1248% | 98.5269%


Epoch  10: 99.1900% | 98.5834%


Epoch  11: 99.2377% | 98.6414%


Epoch  12: 99.2827% | 98.7053%


Epoch  13: 99.3354% | 98.7055%


Epoch  14: 99.3780% | 98.7074%


Epoch  15: 99.4240% | 98.6732%


Epoch  16: 99.4662% | 98.6950%


Epoch  17: 99.4980% | 98.7153%


Epoch  18: 99.5326% | 98.7235%


Epoch  19: 99.5674% | 98.6929%


Epoch  20: 99.5927% | 98.6717%


Epoch  21: 99.6227% | 98.6675%


Epoch  22: 99.6442% | 98.6584%


Epoch  23: 99.6700% | 98.6687%


Epoch  24: 99.6967% | 98.6825%


Epoch  25: 99.7112% | 98.7095%


Epoch  26: 99.7381% | 98.6992%


Epoch  27: 99.7633% | 98.7197%


Epoch  28: 99.7865% | 98.7105%


Epoch  29: 99.8025% | 98.7209%


Epoch  30: 99.8210% | 98.7511%


Epoch  31: 99.8403% | 98.7521%


Epoch  32: 99.8549% | 98.7477%
ConfusionMatrix:
[[   11843       0       0       0       0       0       2       0       1       0]   99.975% 	[class: 1]
 [       0   13465       1       0       2       1       1      13       1       0]   99.859% 	[class: 2]
 [       0       3   11909       0       0       0       1       2       1       0]   99.941% 	[class: 3]
 [       0       1       1   12246       0       2       0       3       5       4]   99.870% 	[class: 4]
 [       0       3       0       0   11660       0       3       3       1      14]   99.795% 	[class: 5]
 [       0       0       0       5       0   10828       4       0       3       2]   99.871% 	[class: 6]
 [       3       4       2       0       2       6   11813       0       6       0]   99.806% 	[class: 7]
 [       0      12       1       1       0       0       0   12512       0       4]   99.856% 	[class: 8]
 [       0       5       1       3       0       2       3       1   11682       5]   99.829% 	[class: 9