In [None]:
require 'torch'
require 'nn'
require 'nngraph'
require 'optim'
model_utils = require 'third_party.char-rnn.util.model_utils'

In [None]:
-- Whether to use CUDA, -1: use CPU, >=0: use corresponding GPU
gpuid = 0
if gpuid >= 0 then
    print('using CUDA on GPU ' .. gpuid .. '...')
    require 'cutorch'
    require 'cunn'
    cutorch.setDevice(gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
end


In [None]:
trainFile = 'data/mnist.t7/train_32x32.t7'
testFile = 'data/mnist.t7/test_32x32.t7'
trainData = torch.load(trainFile,'ascii')
testData = torch.load(testFile,'ascii')

In [None]:
print('Train data:')
print(trainData.labels[{{1, 6}}])
print("size: ", trainData.data:size(), trainData.labels:size())
itorch.image(trainData.data[{{1, 6}}])
print()

In [None]:
print('Test data:')
print(testData.data:size())
print(testData.labels[{{1, 6}}])
itorch.image(testData.data[{{1, 6}}])
print()

In [None]:
inputSize = 32*32
layerSize = 200
numLabels = 10
gradClip = 5
mlp = nn.Sequential()
mlp:add(nn.Linear(inputSize, layerSize))
mlp:add(nn.ReLU(false))
mlp:add(nn.Linear(layerSize, numLabels))
mlp:add(nn.LogSoftMax())
criterion = nn.ClassNLLCriterion()

if gpuid >= 0 then
    mlp:cuda()
    criterion:cuda()
end

-- Flatten params
params, gradParams = model_utils.combine_all_parameters(mlp)
print('params: ', params:size(), params:type())
print('gradParams: ', gradParams:size(), params:type())

In [None]:
batchSize = 100
maxBatch = 50000 / batchSize -- hardcoded value for MNIST
curBatch = 1

function feval(x)
    if x ~= params then
        params:copy(x)
    end
    gradParams:zero()
    ------------------ get minibatch -------------------
    local batchStart = (curBatch-1)*batchSize + 1
    local batchEnd = batchStart + batchSize - 1
    curBatch = curBatch + 1
    if curBatch > maxBatch then
        curBatch = 1
    end
    local x = torch.reshape(trainData.data[{{batchStart, batchEnd}}], batchSize, inputSize)
    x = x:float()/127.5 - 1
    -- print('feval, x: ', x:size(), x:type())
    local y = trainData.labels[{{batchStart, batchEnd}}]
    if gpuid >= 0 then
        x = x:float():cuda()
        y = y:float():cuda()
    end

    ------------------- forward pass -------------------
    prediction = mlp:forward(x)
    loss = criterion:forward(prediction, y)

    ------------------ backward pass -------------------
    dprediction = criterion:backward(prediction, y)
    mlp:backward(x, dprediction)
    
    gradParams:clamp(-gradClip, gradClip)
    return loss, gradParams
end

loss, _ = feval(params)
print('loss: ', loss)


In [None]:
local optimState = {learningRate = 0.00001, alpha = 0.95}
iterations = 100000

for i = 1, iterations do
    local _, loss = optim.rmsprop(feval, params, optimState)
    trainLoss = loss[1]
    if i == 1 or i % 5000 == 0 then
        print('i=', i, ' train loss: ', trainLoss)
    end
end

In [None]:
function predict(input)
    --print ('input: ', input:size())
    local x = torch.reshape(input, input:size(1), inputSize)
    x = x:float()/127.5 - 1
    if gpuid >= 0 then
        x = x:float():cuda()
    end
    --print ('x: ', x:size(), x:type())
    local prediction = mlp:forward(x)
    local _, classes = prediction:max(2)
    return classes
end
classes = predict(trainData.data[{{1, 2}}])
print("predicted classes: ", classes)
print("ground truth: ", trainData.labels[{{1, 2}}])

In [None]:
function evalAccuracy(input, labels)
    local matches = 0
    local batchSize = 1000
    local from = 1
    for i = 1, input:size(1) do
        if i - from + 1 >= batchSize or i == input:size(1) then
            --print ('i=', i, ' from: ', from)
            local curLabels = labels[{{from, i}}]
            local predictions = predict(input[{{from, i}}], curLabels):float()
            --print ('predictions: ', predictions:size(), predictions:type())
            curLabels:map(predictions, function(xx, yy) if xx == yy then matches = matches + 1 end end)
            from = i+1
        end
    end
    
    return matches / labels:size(1)
end

In [None]:
valAcc = evalAccuracy(trainData.data[{{50001, 60000}}], trainData.labels[{{50001, 60000}}])
print('validation accuracy: ', valAcc)

In [None]:
testAcc = evalAccuracy(testData.data, testData.labels)
print('test accuracy: ', testAcc)