In [1]:
require 'nn'
require 'optim'

torch.manualSeed(0)
torch.setnumthreads(4)

In [2]:
function setup() 
    classes = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }
    geometry = { 32, 32 }

    net = nn.Sequential()

    net:add(nn.SpatialConvolution(1, 6, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.SpatialConvolution(6, 16, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.View(16*5*5))
    net:add(nn.Linear(16*5*5, 120))
    net:add(nn.ReLU())
    net:add(nn.Linear(120, 84))
    net:add(nn.ReLU())
    net:add(nn.Linear(84, #classes))
    net:add(nn.LogSoftMax())
    
    parameters, gradParameters = net:getParameters()
    criterion = nn.ClassNLLCriterion()
    confusion = optim.ConfusionMatrix(classes)
end

In [3]:
function normalize(input, n_channels)
    local mean = {}
    local stdev = {}
    
    for channel = 1, n_channels do
        mean[channel] = input.data[{ {}, {channel}, {}, {} }]:mean()
        stdev[channel] = input.data[{ {}, {channel}, {}, {} }]:std()
        
        print('Channel ' .. channel .. ' mean: ' .. mean[channel] .. ' stdev: ' .. stdev[channel])
        
        input.data[{ {}, {channel}, {}, {} }]:add(-mean[channel])
        input.data[{ {}, {channel}, {}, {} }]:div(stdev[channel])
    end
end

In [4]:
function load_data()    
    train = torch.load('mnist-p1b-train.t7')
    test = torch.load('mnist-p1b-test.t7')
    
    n_train = train.data:size()[1]
    n_test = test.data:size()[1]
    
    train.data = train.data:double()
    test.data = test.data:double()
end

In [5]:
function exec_training(obj, n_channels, current_epoch) 
    confusion:zero()
    
    for t = 1, n_train, batch_size do
        local limit = math.min(t + batch_size - 1, n_train)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local feval = function(x)
            collectgarbage()
            
            if x ~= parameters then
                parameters:copy(x)
            end
            gradParameters:zero()
            
            local outputs = net:forward(inputs)
            local f = criterion:forward(outputs, targets)
            local df_do = criterion:backward(outputs, targets)
            net:backward(inputs, df_do)
            
            for i = 1, limited_batch_size do
                confusion:add(outputs[i], targets[i])
            end
            
            return f, gradParameters
        end
        
        sgd_state = sgd_state or {
            -- default parameter
        }
        optim.sgd(feval, parameters, sgd_state)
    end
    
    confusion:updateValids()
    return confusion.averageValid
end

In [6]:
function exec_testing(obj, n_channels, current_epoch)
    confusion:zero()
    
    for t = 1, n_test, batch_size do
        local limit = math.min(t + batch_size - 1, n_test)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local preds = net:forward(inputs)
        
        for i = 1, limited_batch_size do
            confusion:add(preds[i], targets[i])
        end
    end 
    
    confusion:updateValids()
    return confusion.averageValid
end

In [7]:
setup()
load_data()
normalize(train, 1)
normalize(test, 1)

n_epoch = 32
batch_size = 256
local last_train_conf
local last_test_conf
for epoch = 1, n_epoch do
    local acc_train = exec_training(train, 1, epoch)
    
    if (epoch == n_epoch) then
        last_train_conf = confusion:__tostring__()
    end
    
    local acc_test = exec_testing(test, 1, epoch)
    
    if (epoch == n_epoch) then
        last_test_conf = confusion:__tostring__()
    end
    
    io.write(string.format('Epoch %3d: %.4f%% | %.4f%%\n', epoch, acc_train * 100, acc_test * 100))
end

print(last_train_conf)
print(last_test_conf)

Channel 1 mean: 25.509416422526 stdev: 70.180423838273	


Channel 1 mean: 25.87140625 stdev: 70.739298703524	


Epoch   1: 10.2643% | 10.8571%


Epoch   2: 11.6346% | 12.7210%


Epoch   3: 14.1223% | 17.3499%


Epoch   4: 20.4657% | 24.3337%


Epoch   5: 26.0021% | 27.9034%


Epoch   6: 29.2847% | 30.9846%


Epoch   7: 33.2610% | 36.0922%


Epoch   8: 39.1443% | 43.0193%


Epoch   9: 46.6054% | 51.4180%


Epoch  10: 55.8640% | 61.8424%


Epoch  11: 65.2546% | 70.9439%


Epoch  12: 73.7239% | 78.8196%


Epoch  13: 78.9270% | 82.0469%


Epoch  14: 81.7246% | 83.9573%


Epoch  15: 83.5813% | 85.1635%


Epoch  16: 84.8827% | 86.2330%


Epoch  17: 85.9751% | 87.1243%


Epoch  18: 86.7504% | 87.7464%


Epoch  19: 87.4073% | 88.4504%


Epoch  20: 88.0097% | 88.9167%


Epoch  21: 88.5033% | 89.4261%


Epoch  22: 88.9600% | 89.8560%


Epoch  23: 89.4188% | 90.2861%


Epoch  24: 89.7933% | 90.6134%


Epoch  25: 90.0960% | 90.9752%


Epoch  26: 90.3860% | 91.2019%


Epoch  27: 90.6327% | 91.4346%


Epoch  28: 90.9199% | 91.7120%


Epoch  29: 91.1404% | 91.9135%


Epoch  30: 91.3843% | 92.1653%


Epoch  31: 91.6019% | 92.3194%


Epoch  32: 91.8077% | 92.5503%
ConfusionMatrix:
[[    5692       1      18       9       6      43      56       9      71      18]   96.100% 	[class: 1]
 [       0    6540      41      36       6      22       9       9      66      13]   97.004% 	[class: 2]
 [      50      40    5326      96     104      29      79      93     104      37]   89.392% 	[class: 3]
 [      15      46     128    5492       2     191      14      72     106      65]   89.578% 	[class: 4]
 [       7      24      36       3    5325       2     103      14      41     287]   91.150% 	[class: 5]
 [      43      39      38     151      29    4914      74      10      81      42]   90.647% 	[class: 6]
 [      49      39      53       0      62      72    5611       0      31       1]   94.812% 	[class: 7]
 [      27      43      94      20      49      12       0    5751      14     255]   91.796% 	[class: 8]
 [      25     131      59     130      32      78      45      25    5210     116]   89.045% 	[class: 9