In [187]:
require 'torch';
require 'nn';
require 'optim';

### Prepair data

In [188]:
local matio = require 'matio'
data = matio.load('ex4data1.mat')

In [189]:
trainset = {}
trainset.data = data.X
trainset.label = data.y[{ {}, 1}]

In [190]:
trainset

In [191]:
setmetatable(trainset,
    {__index = function(t,i)
                return {t.data[i], t.label[i]}
        end}
);
 
function trainset:size()
    return self.data:size(1)
end

In [192]:
mean = {}
stdv = {}
for i=1,400 do
    mean[i] = trainset.data[{ {},{i} }]:mean()
    stdv[i] = trainset.data[{ {}, {i} }]:std()
    --print(i .. 'th mean: ' .. mean[i])
    --print(i .. 'th std dev: ' .. stdv[i])
    trainset.data[{ {},{i} }]:add(-mean[i])
    if stdv[i] ~= 0 then
        trainset.data[{ {},{i} }]:div(stdv[i])
    end
end

### define model

In [193]:
n_train_data = trainset:size(1) -- number of training data
n_inputs = trainset:size(2)     -- number of cols = number of dims of input
n_outputs = 10   -- highest label = # of classes

In [194]:
net = nn.Sequential()
net:add(nn.Linear(400,25))
net:add(nn.Sigmoid())
net:add(nn.Linear(25,10))
net:add(nn.Sigmoid())
net:add(nn.LogSoftMax())

### define loss function

In [202]:
opt = {
    optimization = 'sgd',
    batch_size = 5000,
    train_size = 5000,  -- set to 0 or 60000 to use all 60000 training data
    test_size = 0,      -- 0 means load all data
    epochs = 1e3,         -- **approximate** number of passes through the training data (see below for the `iterations` variable, which is calculated from this)
}         -- these options are used throughout


In [196]:
criterion = nn.ClassNLLCriterion()

In [197]:
parameters, gradParameters = net:getParameters()

In [198]:
counter = 0
feval = function(x)
  if x ~= parameters then
    parameters:copy(x)
  end

  gradParameters:zero()
    
  local batch_inputs = trainset.data[{{}, {}}]
  local batch_targets = trainset.label[{{}}]

  batch_outputs = net:forward(batch_inputs)
  batch_loss = criterion:forward(batch_outputs, batch_targets)
  dloss_doutput = criterion:backward(batch_outputs, batch_targets) 
  net:backward(batch_inputs, dloss_doutput)

  return batch_loss, gradParameters
end  

### train

In [203]:
optimState = {
    learningRate = 1e-1,
    weightDecay = 0,
    momentum = 0,
    learningRateDecay = 1e-4
}
optimMethod = optim.sgd

In [None]:
losses = {}          -- training losses for each iteration/minibatch
epochs = opt.epochs  -- number of full passes over all the training data
iterations = epochs * math.ceil(n_train_data / opt.batch_size) -- integer number of minibatches to process

for i = 1, iterations do
  local _, minibatch_loss = optimMethod(feval, parameters, optimState)

  if i % 10 == 1 then -- don't print *every* iteration, this is enough to get the gist
      print(string.format("minibatches processed: %6s, loss = %6.6f", i, minibatch_loss[1]))
  end
  losses[#losses + 1] = minibatch_loss[1] -- append the new loss
end


minibatches processed:      1, loss = 2.242163	




minibatches processed:     11, loss = 2.235990	


minibatches processed:     21, loss = 2.229758	


minibatches processed:     31, loss = 2.223466	


minibatches processed:     41, loss = 2.217114	


minibatches processed:     51, loss = 2.210704	


minibatches processed:     61, loss = 2.204238	


minibatches processed:     71, loss = 2.197722	


minibatches processed:     81, loss = 2.191161	


minibatches processed:     91, loss = 2.184560	


minibatches processed:    101, loss = 2.177927	


minibatches processed:    111, loss = 2.171268	


minibatches processed:    121, loss = 2.164592	


minibatches processed:    131, loss = 2.157905	


minibatches processed:    141, loss = 2.151216	


minibatches processed:    151, loss = 2.144533	


minibatches processed:    161, loss = 2.137862	


minibatches processed:    171, loss = 2.131213	


minibatches processed:    181, loss = 2.124591	


minibatches processed:    191, loss = 2.118005	


minibatches processed:    201, loss = 2.111460	


minibatches processed:    211, loss = 2.104963	


minibatches processed:    221, loss = 2.098520	


minibatches processed:    231, loss = 2.092136	


minibatches processed:    241, loss = 2.085817	


minibatches processed:    251, loss = 2.079566	


minibatches processed:    261, loss = 2.073388	


minibatches processed:    271, loss = 2.067286	


minibatches processed:    281, loss = 2.061264	


minibatches processed:    291, loss = 2.055325	


minibatches processed:    301, loss = 2.049470	


minibatches processed:    311, loss = 2.043703	


minibatches processed:    321, loss = 2.038024	


minibatches processed:    331, loss = 2.032434	


minibatches processed:    341, loss = 2.026936	


minibatches processed:    351, loss = 2.021529	


minibatches processed:    361, loss = 2.016213	


minibatches processed:    371, loss = 2.010989	


minibatches processed:    381, loss = 2.005857	


minibatches processed:    391, loss = 2.000815	


minibatches processed:    401, loss = 1.995864	


minibatches processed:    411, loss = 1.991003	


minibatches processed:    421, loss = 1.986231	


minibatches processed:    431, loss = 1.981547	


minibatches processed:    441, loss = 1.976949	


minibatches processed:    451, loss = 1.972436	


minibatches processed:    461, loss = 1.968007	


minibatches processed:    471, loss = 1.963661	


minibatches processed:    481, loss = 1.959395	


minibatches processed:    491, loss = 1.955209	


minibatches processed:    501, loss = 1.951101	


minibatches processed:    511, loss = 1.947069	


minibatches processed:    521, loss = 1.943112	


minibatches processed:    531, loss = 1.939229	


minibatches processed:    541, loss = 1.935416	


minibatches processed:    551, loss = 1.931674	


minibatches processed:    561, loss = 1.928000	


minibatches processed:    571, loss = 1.924393	


minibatches processed:    581, loss = 1.920851	


minibatches processed:    591, loss = 1.917372	


minibatches processed:    601, loss = 1.913956	


minibatches processed:    611, loss = 1.910601	


minibatches processed:    621, loss = 1.907304	


minibatches processed:    631, loss = 1.904066	


minibatches processed:    641, loss = 1.900883	


minibatches processed:    651, loss = 1.897756	


minibatches processed:    661, loss = 1.894682	


minibatches processed:    671, loss = 1.891661	


minibatches processed:    681, loss = 1.888690	


minibatches processed:    691, loss = 1.885770	


minibatches processed:    701, loss = 1.882898	


minibatches processed:    711, loss = 1.880073	


minibatches processed:    721, loss = 1.877294	


minibatches processed:    731, loss = 1.874561	


minibatches processed:    741, loss = 1.871871	


minibatches processed:    751, loss = 1.869225	


minibatches processed:    761, loss = 1.866620	


minibatches processed:    771, loss = 1.864057	


minibatches processed:    781, loss = 1.861533	


minibatches processed:    791, loss = 1.859048	


minibatches processed:    801, loss = 1.856601	


minibatches processed:    811, loss = 1.854192	


minibatches processed:    821, loss = 1.851819	


### test

In [None]:
correction = 0
for i=1,trainset:size() do
    local answer = trainset.label[i]
    local prediction = net:forward(trainset.data[i])
    local confidences, indices = torch.sort(prediction, true)
    if (answer == indices[1]) then
        correction = correction + 1
    end
end
print(correction, 100*correction/trainset:size() .. '%')