In [None]:
require 'torch'
require 'nn'
require 'image'
mnist = require 'mnist'

In [None]:
-- Whether to use CUDA, -1: use CPU, >=0: use corresponding GPU
gpuid = 0
if gpuid >= 0 then
    use_cuda = true
    print('using CUDA on GPU ' .. gpuid .. '...')
    require 'cutorch'
    require 'cunn'
    cutorch.setDevice(gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
end

In [None]:
state = mnist.load('checkpoints/mnist-2015-6-18-17-38-0.9939.nn', use_cuda)
print ('params: ', state.dog.params:type(), state.dog.params:storage():size())
print (string.format('mean: %s, std: %s', state.mean, state.std))

In [None]:
digits = image.load('data/digits.png', 1, 'byte')
digits = image.scale(digits, 320, 320)
digits = digits:float()
digits:add(-state.mean)
digits:div(state.std)
print(string.format('digits mean: %s', digits:mean()))
print(string.format('digits std: %s', digits:std()))
itorch.image(digits)

In [None]:
good_cnt = 0
bad_cnt = 0
for i = 1, 10 do
  for j = 1, 10 do
    local x = digits:narrow(2, (i-1)*32 + 1, 32):narrow(3, (j-1)*32 + 1, 32)
    local num = state:predict(x:reshape(1, 1, 32, 32))[1][1]-1
    if num == i % 10 then
        good_cnt = good_cnt + 1
        -- print(string.format('%s - OK', num))
    else
        bad_cnt = bad_cnt + 1
        itorch.image(x)
        print(string.format('%s - error!', num))
    end
  end
end
print(string.format('Accuracy: %s%%', good_cnt))
print(string.format('Error: %s%%', bad_cnt))

In [None]:
state2 = mnist.State(use_cuda)
state2:create_new('mnist_conv', nil)
state2:load_data('data/mnist.t7')
print('mean: ', state2.mean)
print('std: ', state2.std)
print('cur mean: ', state2.test_data.data:mean())
print('cur std: ', state2.test_data.data:std())

In [None]:
fours = {}
for i = 1, 100 do
    if state2.train_data.labels[i] == 5 then -- digit 4
        table.insert(fours, state2.train_data.data[i])
    end
end
itorch.image(fours)

In [None]:
itorch.image(state2.train_data.data[1])

In [None]:
itorch.image(image.rotate(state2.train_data.data[1], 0.0, 'bilinear'))

In [None]:
state = nil
state2 = nil
collectgarbage()

In [None]:
state:load_data('data/mnist.t7')

In [None]:
print('train acc:', state:train_accuracy())
print('val acc: ', state:val_accuracy())

In [None]:
freeMemory, totalMemory = cutorch.getMemoryUsage(1)
print('total GPU memory: ', totalMemory)
print('free GPU memory: ', freeMemory)

In [None]:
print('test acc: ', state:test_accuracy())