In [None]:
require 'torch'
require 'nn'
require 'nngraph'
require 'optim'
model_utils = require 'third_party.char-rnn.util.model_utils'

In [None]:
trainFile = 'data/mnist.t7/train_32x32.t7'
testFile = 'data/mnist.t7/test_32x32.t7'
trainData = torch.load(trainFile,'ascii')
testData = torch.load(testFile,'ascii')

In [None]:
print('Train data:')
print(trainData.labels[{{1, 6}}])
print("size: ", trainData.data:size(), trainData.labels:size())
itorch.image(trainData.data[{{1, 6}}])
print()

In [None]:
print('Test data:')
print(testData.data:size())
print(testData.labels[{{1, 6}}])
itorch.image(testData.data[{{1, 6}}])
print()

In [None]:
inputSize = 32*32
layerSize = 200
numLabels = 10
gradClip = 5
mlp = nn.Sequential()
mlp:add(nn.Linear(inputSize, layerSize))
mlp:add(nn.ReLU(false))
mlp:add(nn.Linear(layerSize, numLabels))
mlp:add(nn.LogSoftMax())
criterion = nn.ClassNLLCriterion()

-- Flatten params
params, gradParams = model_utils.combine_all_parameters(mlp)
print('params: ', params:size())
print('gradParams: ', gradParams:size())

In [None]:
function feval(x)
    if x ~= params then
        params:copy(x)
    end
    gradParams:zero()
    ------------------ get minibatch -------------------
    local x = torch.reshape(trainData.data[{{1, 50000}}], 50000, inputSize)
    x = x:double()/127.5 - 1
    -- print('feval, x: ', x:size(), x:type())
    local y = trainData.labels[{{1, 50000}}]

    ------------------- forward pass -------------------
    prediction = mlp:forward(x)
    loss = criterion:forward(prediction, y)

    ------------------ backward pass -------------------
    dprediction = criterion:backward(prediction, y)
    mlp:backward(x, dprediction)
    
    gradParams:clamp(-gradClip, gradClip)
    return loss, gradParams
end

loss, _ = feval(params)
print('loss: ', loss)


In [None]:
local optimState = {learningRate = 0.0001, alpha = 0.95}
iterations = 1000

for i = 1, iterations do
    local _, loss = optim.rmsprop(feval, params, optimState)
    trainLoss = loss[1]
    if i == 1 or i % 10 == 0 then
        print('i=', i, ' train loss: ', trainLoss)
    end
end

In [None]:
function predict(input)
    print ('input: ', input:size())
    local x = torch.reshape(input, input:size(1), inputSize)
    x = x:double()/127.5 - 1
    --print ('x: ', x:size(), x:type())
    local prediction = mlp:forward(x)
    local _, classes = prediction:max(2)
    return classes
end
classes = predict(trainData.data[{{1, 2}}])
print("predicted classes: ", classes)
print("ground truth: ", trainData.labels[{{1, 2}}])

In [None]:
function evalAccuracy(input, labels)
    local predictions = predict(input)
    local matches = 0
    labels:map(predictions, function(xx, yy) if xx == yy then matches = matches + 1 end end)
    return matches / labels:size(1)
end

In [None]:
valAcc = evalAccuracy(trainData.data[{{50001, 60000}}], trainData.labels[{{50001, 60000}}])
print('validation accuracy: ', valAcc)

In [None]:
testAcc = evalAccuracy(testData.data, testData.labels)
print('test accuracy: ', testAcc)