In [1]:
require 'nn'
require 'optim'

torch.manualSeed(0)
torch.setnumthreads(4)

In [2]:
function setup() 
    classes = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }
    geometry = { 32, 32 }

    net = nn.Sequential()

    net:add(nn.SpatialConvolution(3, 6, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.SpatialConvolution(6, 16, 5, 5))
    net:add(nn.ReLU())
    net:add(nn.SpatialMaxPooling(2, 2, 2, 2))

    net:add(nn.View(16*5*5))
    net:add(nn.Linear(16*5*5, 120))
    net:add(nn.ReLU())
    net:add(nn.Linear(120, 84))
    net:add(nn.ReLU())
    net:add(nn.Linear(84, #classes))
    net:add(nn.LogSoftMax())
    
    parameters, gradParameters = net:getParameters()
    criterion = nn.ClassNLLCriterion()
    confusion = optim.ConfusionMatrix(classes)
end

In [3]:
function normalize(input, n_channels)
    local mean = {}
    local stdev = {}
    
    for channel = 1, n_channels do
        mean[channel] = input.data[{ {}, {channel}, {}, {} }]:mean()
        stdev[channel] = input.data[{ {}, {channel}, {}, {} }]:std()
        
        print('Channel ' .. channel .. ' mean: ' .. mean[channel] .. ' stdev: ' .. stdev[channel])
        
        input.data[{ {}, {channel}, {}, {} }]:add(-mean[channel])
        input.data[{ {}, {channel}, {}, {} }]:div(stdev[channel])
    end
end

In [4]:
function load_data()    
    train = torch.load('mnist-p2b-train.t7')
    test = torch.load('mnist-p2b-test.t7')
    
    n_train = train.data:size()[1]
    n_test = test.data:size()[1]
    
    train.data = train.data:double()
    test.data = test.data:double()
end

In [5]:
function exec_training(obj, n_channels, current_epoch) 
    confusion:zero()
    
    for t = 1, n_train, batch_size do
        local limit = math.min(t + batch_size - 1, n_train)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local feval = function(x)
            collectgarbage()
            
            if x ~= parameters then
                parameters:copy(x)
            end
            gradParameters:zero()
            
            local outputs = net:forward(inputs)
            local f = criterion:forward(outputs, targets)
            local df_do = criterion:backward(outputs, targets)
            net:backward(inputs, df_do)
            
            for i = 1, limited_batch_size do
                confusion:add(outputs[i], targets[i])
            end
            
            return f, gradParameters
        end
        
        sgd_state = sgd_state or {
            learningRate = 0.03,
            learningRateDecay = 1e-7,
            momentum = 0.5
        }
        optim.sgd(feval, parameters, sgd_state)
    end
    
    confusion:updateValids()
    return confusion.averageValid
end

In [6]:
function exec_testing(obj, n_channels, current_epoch)
    confusion:zero()
    
    for t = 1, n_test, batch_size do
        local limit = math.min(t + batch_size - 1, n_test)
        
        local limited_batch_size = limit - t + 1
        local inputs = torch.Tensor(limited_batch_size, n_channels, geometry[1], geometry[2])
        local targets = torch.Tensor(limited_batch_size)
        local k = 1
        
        for i = t, limit do
            local input = obj.data[i]:clone()
            local target = obj.label[i]
            
            inputs[k] = input
            targets[k] = target
            k = k + 1
        end
        
        local preds = net:forward(inputs)
        
        for i = 1, limited_batch_size do
            confusion:add(preds[i], targets[i])
        end
    end 
    
    confusion:updateValids()
    return confusion.averageValid
end

In [7]:
setup()
load_data()
normalize(train, 3)
normalize(test, 3)

n_epoch = 256
batch_size = 256
local last_train_conf
local last_test_conf
for epoch = 1, n_epoch do
    local acc_train = exec_training(train, 3, epoch)
    
    if (epoch == n_epoch) then
        last_train_conf = confusion:__tostring__()
    end
    
    local acc_test = exec_testing(test, 3, epoch)
    
    if (epoch == n_epoch) then
        last_test_conf = confusion:__tostring__()
    end
    
    io.write(string.format('Epoch %3d: %.4f%% | %.4f%%\n', epoch, acc_train * 100, acc_test * 100))
end

print(last_train_conf)
print(last_test_conf)

Channel 1 mean: 113.22781756185 stdev: 67.791216435084	


Channel 2 mean: 107.1339985026 stdev: 64.198819382131	


Channel 3 mean: 93.089492203776 stdev: 69.493239574806	


Channel 1 mean: 112.923084375 stdev: 67.446258384048	


Channel 2 mean: 106.68813339844 stdev: 63.735822418245	




Channel 3 mean: 92.559736328125 stdev: 69.015380306021	


Epoch   1: 9.3933% | 10.3465%


Epoch   2: 9.2686% | 9.2073%


Epoch   3: 11.0037% | 16.6152%


Epoch   4: 35.0955% | 50.6401%


Epoch   5: 59.4007% | 65.1927%


Epoch   6: 69.7084% | 71.8779%


Epoch   7: 74.4332% | 75.0387%


Epoch   8: 76.7416% | 76.7231%


Epoch   9: 78.1652% | 78.2265%


Epoch  10: 79.3385% | 79.3215%


Epoch  11: 80.2173% | 80.4840%


Epoch  12: 80.9366% | 81.3075%


Epoch  13: 81.6457% | 81.8282%


Epoch  14: 82.3293% | 82.2724%


Epoch  15: 82.8602% | 82.4444%


Epoch  16: 83.4239% | 82.7229%


Epoch  17: 83.9239% | 82.8442%


Epoch  18: 84.3653% | 82.9576%


Epoch  19: 84.8183% | 82.9443%


Epoch  20: 85.2779% | 83.0067%


Epoch  21: 85.6471% | 83.3202%


Epoch  22: 86.0492% | 83.3770%


Epoch  23: 86.3733% | 83.5700%


Epoch  24: 86.7435% | 83.5425%


Epoch  25: 87.0762% | 83.4834%


Epoch  26: 87.3900% | 83.5130%


Epoch  27: 87.6869% | 83.6301%


Epoch  28: 87.9974% | 83.6473%


Epoch  29: 88.2677% | 83.6509%


Epoch  30: 88.5824% | 83.5346%


Epoch  31: 88.8128% | 83.4160%


Epoch  32: 89.1252% | 83.5461%


Epoch  33: 89.4017% | 83.4997%


Epoch  34: 89.7864% | 83.4810%


Epoch  35: 90.0678% | 83.5594%


Epoch  36: 90.3289% | 83.4964%


Epoch  37: 90.6008% | 83.5335%


Epoch  38: 90.9064% | 83.3595%


Epoch  39: 91.1830% | 83.5003%


Epoch  40: 91.4817% | 83.3717%


Epoch  41: 91.7214% | 83.3668%


Epoch  42: 92.0252% | 83.2839%


Epoch  43: 92.3082% | 83.1279%


Epoch  44: 92.5789% | 83.1062%


Epoch  45: 92.8611% | 83.1046%


Epoch  46: 93.1344% | 82.9438%


Epoch  47: 93.3712% | 82.8321%


Epoch  48: 93.6075% | 82.7306%


Epoch  49: 93.8969% | 82.6073%


Epoch  50: 94.1591% | 82.4534%


Epoch  51: 94.4156% | 82.5265%


Epoch  52: 94.6765% | 82.4821%


Epoch  53: 94.9754% | 82.3279%


Epoch  54: 95.1601% | 82.4129%


Epoch  55: 95.4405% | 82.2399%


Epoch  56: 95.6558% | 82.2444%


Epoch  57: 95.9020% | 82.2653%


Epoch  58: 96.0638% | 81.9005%


Epoch  59: 96.3553% | 82.0465%


Epoch  60: 96.3424% | 82.0174%


Epoch  61: 96.4384% | 81.8618%


Epoch  62: 96.6502% | 81.5325%


Epoch  63: 96.7152% | 81.8863%


Epoch  64: 96.7458% | 81.6176%


Epoch  65: 96.5710% | 81.1337%


Epoch  66: 96.6720% | 81.7378%


Epoch  67: 96.7626% | 81.8583%


Epoch  68: 96.6607% | 81.7753%


Epoch  69: 96.9876% | 81.3569%


Epoch  70: 96.9723% | 81.2368%


Epoch  71: 97.1927% | 81.6894%


Epoch  72: 97.1459% | 81.8403%


Epoch  73: 97.1823% | 81.4901%


Epoch  74: 97.2841% | 81.5608%


Epoch  75: 97.2628% | 81.4927%


Epoch  76: 97.4545% | 81.7068%


Epoch  77: 97.4075% | 81.9553%


Epoch  78: 97.6193% | 81.3637%


Epoch  79: 97.6557% | 81.0150%


Epoch  80: 97.9048% | 81.3337%


Epoch  81: 98.0409% | 80.9581%


Epoch  82: 97.9942% | 80.9295%


Epoch  83: 97.8505% | 81.6929%


Epoch  84: 97.8290% | 81.5694%


Epoch  85: 98.1359% | 81.7087%


Epoch  86: 98.4174% | 80.8660%


Epoch  87: 98.4462% | 80.6553%


Epoch  88: 98.4617% | 80.8551%


Epoch  89: 98.0218% | 81.0875%


Epoch  90: 97.9410% | 81.2338%


Epoch  91: 98.2297% | 81.3910%


Epoch  92: 98.6463% | 81.3475%


Epoch  93: 98.8176% | 81.0438%


Epoch  94: 98.7407% | 81.8777%


Epoch  95: 98.8287% | 80.8815%


Epoch  96: 98.9036% | 81.6060%


Epoch  97: 98.9696% | 82.0456%


Epoch  98: 98.8185% | 81.0467%


Epoch  99: 98.8723% | 81.4086%


Epoch 100: 98.9358% | 81.7760%


Epoch 101: 98.9997% | 81.6653%


Epoch 102: 98.9489% | 81.7148%


Epoch 103: 98.6728% | 81.5827%


Epoch 104: 98.4303% | 80.2812%


Epoch 105: 98.0851% | 81.5526%


Epoch 106: 98.1437% | 81.5463%


Epoch 107: 98.5847% | 81.3273%


Epoch 108: 98.8706% | 81.8569%


Epoch 109: 98.6548% | 81.5282%


Epoch 110: 98.9667% | 81.6053%


Epoch 111: 99.1337% | 81.9670%


Epoch 112: 99.2942% | 81.8819%


Epoch 113: 99.3702% | 81.7915%


Epoch 114: 99.3817% | 81.9591%


Epoch 115: 99.4400% | 81.6566%


Epoch 116: 99.3124% | 81.6182%


Epoch 117: 99.5341% | 82.1324%


Epoch 118: 99.6266% | 82.2078%


Epoch 119: 99.6724% | 82.2938%


Epoch 120: 99.6903% | 82.0067%


Epoch 121: 99.7164% | 82.2698%


Epoch 122: 99.7868% | 82.4040%


Epoch 123: 99.7759% | 82.2554%


Epoch 124: 99.7896% | 82.2483%


Epoch 125: 99.8245% | 82.3170%


Epoch 126: 99.8314% | 82.2528%


Epoch 127: 99.8465% | 82.3583%


Epoch 128: 99.8531% | 82.3857%


Epoch 129: 99.8648% | 82.2586%


Epoch 130: 99.8549% | 82.2521%


Epoch 131: 99.8616% | 82.4276%


Epoch 132: 99.8631% | 82.3613%


Epoch 133: 99.8593% | 82.2623%


Epoch 134: 99.8732% | 82.3179%


Epoch 135: 99.8785% | 82.3595%


Epoch 136: 99.8749% | 82.3647%


Epoch 137: 99.8699% | 82.3700%


Epoch 138: 99.8766% | 82.3287%


Epoch 139: 99.8800% | 82.2915%


Epoch 140: 99.8818% | 82.3067%


Epoch 141: 99.8784% | 82.3256%


Epoch 142: 99.8850% | 82.2837%


Epoch 143: 99.8801% | 82.2571%


Epoch 144: 98.4419% | 79.4952%


Epoch 145: 94.0638% | 81.5974%


Epoch 146: 96.6854% | 81.4756%


Epoch 147: 97.7361% | 81.3494%


Epoch 148: 98.6147% | 82.0590%


Epoch 149: 98.9142% | 80.5397%


Epoch 150: 99.2461% | 82.3763%


Epoch 151: 99.4672% | 82.1521%


Epoch 152: 99.6473% | 82.0497%


Epoch 153: 99.7725% | 82.0615%


Epoch 154: 99.8378% | 82.1330%


Epoch 155: 99.8598% | 82.2194%


Epoch 156: 99.8684% | 82.1108%


Epoch 157: 99.8653% | 82.2821%


Epoch 158: 99.8769% | 82.1083%


Epoch 159: 99.8865% | 82.2499%


Epoch 160: 99.8900% | 82.2429%


Epoch 161: 99.8917% | 82.2083%


Epoch 162: 99.8800% | 82.1330%


Epoch 163: 99.8887% | 82.2918%


Epoch 164: 99.8919% | 82.2360%


Epoch 165: 99.8969% | 82.1025%


Epoch 166: 99.8898% | 82.2542%


Epoch 167: 99.9001% | 82.1763%


Epoch 168: 99.9034% | 82.1347%


Epoch 169: 99.8984% | 82.0562%


Epoch 170: 99.9088% | 82.0753%


Epoch 171: 99.9053% | 82.1471%


Epoch 172: 99.9052% | 82.0542%


Epoch 173: 99.9070% | 82.0335%


Epoch 174: 99.9169% | 82.0545%


Epoch 175: 99.9138% | 82.0672%


Epoch 176: 99.9102% | 82.1364%


Epoch 177: 99.9136% | 82.1097%


Epoch 178: 99.9219% | 82.2825%


Epoch 179: 99.9170% | 82.1331%


Epoch 180: 99.9217% | 82.1669%


Epoch 181: 99.9233% | 82.0780%


Epoch 182: 99.9166% | 82.1140%


Epoch 183: 99.9286% | 82.1032%


Epoch 184: 99.9213% | 82.0939%


Epoch 185: 99.9215% | 82.1210%


Epoch 186: 99.9230% | 82.1106%


Epoch 187: 99.9182% | 82.1657%


Epoch 188: 99.9035% | 82.1049%


Epoch 189: 99.9183% | 82.1197%


Epoch 190: 99.9169% | 82.1872%


Epoch 191: 99.9116% | 82.1698%


Epoch 192: 99.9204% | 82.0810%


Epoch 193: 99.9266% | 82.0799%


Epoch 194: 99.9168% | 82.0768%


Epoch 195: 99.9218% | 82.0941%


Epoch 196: 99.9284% | 82.1197%


Epoch 197: 99.9268% | 82.1165%


Epoch 198: 99.9217% | 82.1559%


Epoch 199: 99.9350% | 82.1280%


Epoch 200: 99.9313% | 82.1893%


Epoch 201: 99.9219% | 82.1730%


Epoch 202: 99.9334% | 82.1439%


Epoch 203: 99.5925% | 81.1380%


Epoch 204: 95.7599% | 81.0773%


Epoch 205: 96.3561% | 80.9835%


Epoch 206: 97.6217% | 79.8839%


Epoch 207: 97.7989% | 81.0918%


Epoch 208: 98.6238% | 81.9189%


Epoch 209: 98.9974% | 82.2505%


Epoch 210: 99.3146% | 82.0730%


Epoch 211: 99.6177% | 81.8812%


Epoch 212: 99.8074% | 82.1682%


Epoch 213: 99.8516% | 82.2742%


Epoch 214: 99.8886% | 82.2949%


Epoch 215: 99.9021% | 82.3541%


Epoch 216: 99.9021% | 82.3382%


Epoch 217: 99.8954% | 82.2409%


Epoch 218: 99.9037% | 82.1866%


Epoch 219: 99.9153% | 82.2577%


Epoch 220: 99.9059% | 82.2717%


Epoch 221: 99.9186% | 82.2632%


Epoch 222: 99.9169% | 82.2377%


Epoch 223: 99.9135% | 82.2460%


Epoch 224: 99.9203% | 82.2943%


Epoch 225: 99.9236% | 82.2711%


Epoch 226: 99.9182% | 82.3038%


Epoch 227: 99.9204% | 82.2021%


Epoch 228: 99.9300% | 82.3174%


Epoch 229: 99.9249% | 82.2611%


Epoch 230: 99.9284% | 82.2920%


Epoch 231: 99.9283% | 82.2203%


Epoch 232: 99.9350% | 82.2732%


Epoch 233: 99.9365% | 82.2850%


Epoch 234: 99.9252% | 82.2548%


Epoch 235: 99.9285% | 82.1496%


Epoch 236: 99.9317% | 82.1574%


Epoch 237: 99.9251% | 82.1408%


Epoch 238: 99.9353% | 82.2341%


Epoch 239: 99.9370% | 82.2171%


Epoch 240: 99.9318% | 82.1847%


Epoch 241: 99.9268% | 82.1597%


Epoch 242: 99.9383% | 82.0933%


Epoch 243: 99.9383% | 82.1208%


Epoch 244: 99.9319% | 82.1741%


Epoch 245: 99.9384% | 82.1573%


Epoch 246: 99.9349% | 82.0571%


Epoch 247: 99.9370% | 82.1475%


Epoch 248: 99.9316% | 82.1547%


Epoch 249: 99.9399% | 82.0947%


Epoch 250: 99.9400% | 82.1464%


Epoch 251: 99.9384% | 82.1021%


Epoch 252: 99.9384% | 82.0681%


Epoch 253: 99.9385% | 82.0553%


Epoch 254: 99.9399% | 82.0473%


Epoch 255: 99.9351% | 82.0644%


Epoch 256: 99.9468% | 82.0541%
ConfusionMatrix:
[[    5357       0       0       0       0       0       0       0       0       0       2]   99.963% 	[class: 1]
 [       0    6134       0       0       0       0       0       0       0       1       3]   99.935% 	[class: 2]
 [       0       0    5414       0       0       0       0       0       0       0       3]   99.945% 	[class: 3]
 [       1       0       1    5591       0       0       0       0       0       0       1]   99.946% 	[class: 4]
 [       0       0       0       0    5328       0       0       0       0       1       2]   99.944% 	[class: 5]
 [       0       0       0       0       0    4917       1       0       1       0       1]   99.939% 	[class: 6]
 [       0       0       0       0       1       0    5383       0       0       0       1]   99.963% 	[class: 7]
 [       0       0       0       0       0       0       0    5655       0       0       3]   99.947% 	[class: 8]
 [       0       1       0       0      