In [129]:
require 'nn'
require 'optim'

torch.manualSeed(0)
torch.setnumthreads(4)

In [130]:
function setup() 
    classes = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }
    geometry = { 32, 32 }

    net = nn.Sequential()

    net:add(nn.SpatialConvolution(1, 6, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(3, 3, 3, 3))

    net:add(nn.SpatialConvolution(6, 16, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(3, 3, 3, 3))

    -- net:add(nn.View(16*5*5))
    net:add(nn.View(16*5*5))
    net:add(nn.Linear(16*5*5, 120))
    net:add(nn.ReLU())
    net:add(nn.Linear(120, 84))
    net:add(nn.ReLU())
    net:add(nn.Linear(84, #classes))
    net:add(nn.LogSoftMax())
    
    parameters, gradParameters = net:getParameters()
    criterion = nn.ClassNLLCriterion()
    confusion = optim.ConfusionMatrix(classes)
end

In [131]:
function random_translate(obj) 
    n = torch.Tensor(1, 40, 40):zero()
    n[{ {1}, {5, 36}, {5, 36 }}] = obj
    
    rand_x = torch.random(-6, 6)
    rand_y = torch.random(-6, 6)
    
    n = image.translate(n, rand_x, rand_y)
    n = image.crop(n, 4, 4, 36, 36)
    return n
end

In [132]:
function normalize(input, n_channels)
    local mean = {}
    local stdev = {}
    
    for channel = 1, n_channels do
        mean[channel] = input.data[{ {}, {channel}, {}, {} }]:mean()
        stdev[channel] = input.data[{ {}, {channel}, {}, {} }]:std()
        
        print('Channel ' .. channel .. ' mean: ' .. mean[channel] .. ' stdev: ' .. stdev[channel])
        
        input.data[{ {}, {channel}, {}, {} }]:add(-mean[channel])
        input.data[{ {}, {channel}, {}, {} }]:div(stdev[channel])
    end
end

In [133]:
function load_data()    
    train = torch.load('mnist-p1b-train.t7')
    test = torch.load('mnist-p1b-test.t7')
    
    n_train = train.data:size()[1]
    n_test = test.data:size()[1]
    
    for i = 1, n_test do
        test.data[i] = random_translate(test.data[i])
    end
    
    train.data = train.data:double()
    test.data = test.data:double()
    
    n_train = 1000
    n_test = 1000
end

In [134]:
function exec_training(obj, n_channels, current_epoch) 
    confusion:zero()
    
    for t = 1, n_train, batch_size do
        local limit = math.min(t + batch_size - 1, n_train)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local feval = function(x)
            collectgarbage()
            
            if x ~= parameters then
                parameters:copy(x)
            end
            gradParameters:zero()
            
            local outputs = net:forward(inputs)
            local f = criterion:forward(outputs, targets)
            local df_do = criterion:backward(outputs, targets)
            net:backward(inputs, df_do)
            
            for i = 1, limited_batch_size do
                confusion:add(outputs[i], targets[i])
            end
            
            return f, gradParameters
        end
        
        sgd_state = sgd_state or {
            learningRate = 0.03,
            learningRateDecay = 1e-7,
            momentum = 0.5
        }
        optim.sgd(feval, parameters, sgd_state)
    end
    
    confusion:updateValids()
    return confusion.averageValid
end

In [135]:
function exec_testing(obj, n_channels, current_epoch)
    confusion:zero()
    
    for t = 1, n_test, batch_size do
        local limit = math.min(t + batch_size - 1, n_test)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local preds = net:forward(inputs)
        
        for i = 1, limited_batch_size do
            confusion:add(preds[i], targets[i])
        end
    end 
    
    confusion:updateValids()
    return confusion.averageValid
end

In [136]:
setup()
load_data()
normalize(train, 1)
normalize(test, 1)

n_epoch = 1
batch_size = 256
local last_train_conf
local last_test_conf
for epoch = 1, n_epoch do
    local acc_train = exec_training(train, 1, epoch)
    
    if (epoch == n_epoch) then
        last_train_conf = confusion:__tostring__()
    end
    
    local acc_test = exec_testing(test, 1, epoch)
    
    if (epoch == n_epoch) then
        last_test_conf = confusion:__tostring__()
    end
    
    io.write(string.format('Epoch %3d: %.4f%% | %.4f%%\n', epoch, acc_train * 100, acc_test * 100))
end

print(last_train_conf)
print(last_test_conf)

Channel 1 mean: 25.509416422526 stdev: 70.180423838273	


Channel 1 mean: 25.781723046875 stdev: 70.645559945625	


...rs/gbudiman/torch/install/share/lua/5.1/nn/Container.lua:67: 
In 7 module of nn.Sequential:
/Users/gbudiman/torch/install/share/lua/5.1/nn/View.lua:47: input view (256x16x1x1) and desired view (400) do not match
stack traceback:
	[C]: in function 'error'
	/Users/gbudiman/torch/install/share/lua/5.1/nn/View.lua:47: in function 'batchsize'
	/Users/gbudiman/torch/install/share/lua/5.1/nn/View.lua:79: in function </Users/gbudiman/torch/install/share/lua/5.1/nn/View.lua:77>
	[C]: in function 'xpcall'
	...rs/gbudiman/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
	...s/gbudiman/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
	[string "function exec_training(obj, n_channels, curre..."]:29: in function 'opfunc'
	/Users/gbudiman/torch/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd'
	[string "function exec_training(obj, n_channels, curre..."]:46: in function 'exec_training'
	[string "setup()..."]:11: in main chunk
	[C]: in function 'xpcall'
	/Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:210: in function </Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:174>
	/Users/gbudiman/torch/install/share/lua/5.1/lzmq/poller.lua:75: in function 'poll'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:307: in function 'poll'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:325: in function 'sleep_ex'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:370: in function 'start'
	/Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:389: in main chunk
	[C]: in function 'require'
	(command line):1: in main chunk
	[C]: at 0x01041b2d00

WARNING: If you see a stack trace below, it doesn't point to the place where this error occurred. Please use only the one above.
stack traceback:
	[C]: in function 'error'
	...rs/gbudiman/torch/install/share/lua/5.1/nn/Container.lua:67: in function 'rethrowErrors'
	...s/gbudiman/torch/install/share/lua/5.1/nn/Sequential.lua:44: in function 'forward'
	[string "function exec_training(obj, n_channels, curre..."]:29: in function 'opfunc'
	/Users/gbudiman/torch/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd'
	[string "function exec_training(obj, n_channels, curre..."]:46: in function 'exec_training'
	[string "setup()..."]:11: in main chunk
	[C]: in function 'xpcall'
	/Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:210: in function </Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:174>
	/Users/gbudiman/torch/install/share/lua/5.1/lzmq/poller.lua:75: in function 'poll'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:307: in function 'poll'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:325: in function 'sleep_ex'
	.../gbudiman/torch/install/share/lua/5.1/lzmq/impl/loop.lua:370: in function 'start'
	/Users/gbudiman/torch/install/share/lua/5.1/itorch/main.lua:389: in main chunk
	[C]: in function 'require'
	(command line):1: in main chunk
	[C]: at 0x01041b2d00: 