In [1]:
require 'cunn'

----------------------------------------------------------------------

tar = 'http://torch7.s3-website-us-east-1.amazonaws.com/data/mnist.t7.tgz'

if not paths.dirp('mnist.t7') then
   os.execute('wget ' .. tar)
   os.execute('tar xvf ' .. paths.basename(tar))
end

trainFile = 'mnist.t7/train_32x32.t7'
testFile = 'mnist.t7/test_32x32.t7'

----------------------------------------------------------------------

print '==> loading dataset'

trainData = nil
testData = nil
train = torch.load(trainFile,'ascii')
test = torch.load(testFile,'ascii')
----------------------------------------------------------------------

==> loading dataset	


In [2]:
trainData = {data = train.data[{ {1,6000}, {}, {} }],
             labels = train.labels[ { {1, 6000} } ]}
testData = {data = test.data[{ {1,1000}, {}, {} }],
             labels = test.labels[ { {1, 1000} } ]}

In [3]:
-- ignore setmetatable for now, it is a feature beyond the scope of this tutorial. It sets the index operator.
setmetatable(trainData, 
    {__index = function(t, i) 
                    return {t.data[i], t.labels[i]} 
                end}
);

function trainData:size() 
    return self.data:size(1)*0.1 
end

In [4]:
trainData.data = trainData.data:cuda():div(255)
mean = trainData.data:mean()
trainData.data = trainData.data:add(-mean)
trainData.labels = trainData.labels:cuda()

testData.data = testData.data:cuda():div(255)
testData.data = testData.data:add(-mean)
testData.labels = testData.labels:cuda()

trsize = trainData.data:size(1)
tesize = testData.data:size(1)

In [5]:
{testData}

{
  1 : 
    {
      data : CudaTensor - size: 1000x1x32x32
      labels : CudaTensor - size: 1000
    }
}


In [6]:
-- cnn structure
require 'nn'

net = nn.Sequential()
net:add(nn.SpatialConvolution(1, 6, 5, 5)) -- 1 input image channel, 6 output channels, 5x5 convolution kernel
net:add(nn.ReLU())                       -- non-linearity 
net:add(nn.SpatialMaxPooling(2,2,2,2))     -- A max-pooling operation that looks at 2x2 windows and finds the max.
net:add(nn.SpatialConvolution(6, 16, 5, 5))
net:add(nn.ReLU())                       -- non-linearity 
net:add(nn.SpatialMaxPooling(2,2,2,2))
net:add(nn.View(16*5*5))                    -- reshapes from a 3D tensor of 16x5x5 into 1D tensor of 16*5*5
net:add(nn.Linear(16*5*5, 120))             -- fully connected layer (matrix multiplication between input and weights)
net:add(nn.ReLU())                       -- non-linearity 
net:add(nn.Linear(120, 84))
net:add(nn.ReLU())                       -- non-linearity 
net:add(nn.Linear(84, 10))                   -- 10 is the number of outputs of the network (in this case, 10 digits)
net:add(nn.LogSoftMax())                     -- converts the output to a log-probability. Useful for classification problems
net = net:cuda()
-- print('Lenet5\n' .. net:__tostring());

In [7]:
criterion = nn.ClassNLLCriterion() -- a negative log-likelihood criterion for multi-class
criterion = criterion:cuda()

In [8]:
trainer = nn.StochasticGradient(net, criterion)
trainer.learningRate = 0.01
trainer.maxIteration = 200 -- just do 5 epochs of training.
trainer:train(trainData)

# StochasticGradient: training	


# current error = 2.2990066736937	


# current error = 2.1771078719695	


# current error = 1.3374640496572	


# current error = 0.65499614377817	


# current error = 0.36557671089967	


# current error = 0.20953761438529	


# current error = 0.14274736146132	


# current error = 0.14139907081922	


# current error = 0.090080716808637	


# current error = 0.049190447727839	


# current error = 0.014279152552287	


# current error = 0.0055411386489868	


# current error = 0.0020602695147196	


# current error = 0.0012853860855103	


# current error = 0.00095393260320028	


# current error = 0.00077550093332926	


# current error = 0.00064958969751994	


# current error = 0.00056095202763875	


# current error = 0.00049582401911418	


# current error = 0.00044140259424845	


# current error = 0.00039875348409017	


# current error = 0.00036344687143962	


# current error = 0.00033293008804321	


# current error = 0.00030826965967814	


# current error = 0.00028576294581095	


# current error = 0.00026652415593465	


# current error = 0.0002497394879659	


# current error = 0.00023494402567546	


# current error = 0.00022150675455729	


# current error = 0.00020989576975505	


# current error = 0.00019898414611816	


# current error = 0.00018956661224365	


# current error = 0.00018049399058024	


# current error = 0.00017237345377604	


# current error = 0.00016546885172526	


# current error = 0.00015831470489502	


# current error = 0.0001521635055542	


# current error = 0.00014603932698568	


# current error = 0.0001408592859904	


# current error = 0.00013573328653971	


# current error = 0.00013092199961344	


# current error = 0.00012655893961589	


# current error = 0.00012246290842692	


# current error = 0.00011843999226888	


# current error = 0.00011483033498128	


# current error = 0.00011136690775553	


# current error = 0.0001080592473348	


# current error = 0.00010498205820719	


# current error = 0.00010202725728353	


# current error = 9.9236170450846e-05	


# current error = 9.6529324849447e-05	


# current error = 9.4067255655924e-05	


# current error = 9.1582934061686e-05	


# current error = 8.9348157246908e-05	


# current error = 8.7118148803711e-05	


# current error = 8.5037549336751e-05	


# current error = 8.3049138387044e-05	


# current error = 8.1156094868978e-05	


# current error = 7.9278945922852e-05	


# current error = 7.7540079752604e-05	


# current error = 7.5910886128743e-05	


# current error = 7.4222882588704e-05	


# current error = 7.2673161824544e-05	


# current error = 7.1199735005697e-05	


# current error = 6.9737434387207e-05	


# current error = 6.8402290344238e-05	


# current error = 6.7009925842285e-05	


# current error = 6.5784454345703e-05	


# current error = 6.4527193705241e-05	




# current error = 6.3387552897135e-05	


# current error = 6.2198638916016e-05	


# current error = 6.1100323994954e-05	


# current error = 6.005605061849e-05	


# current error = 5.8979988098145e-05	


# current error = 5.8013598124186e-05	


# current error = 5.7031313578288e-05	


# current error = 5.6093533833822e-05	


# current error = 5.5162111918132e-05	


# current error = 5.4262479146322e-05	


# current error = 5.3407351175944e-05	


# current error = 5.2598317464193e-05	


# current error = 5.1786104838053e-05	


# current error = 5.0959587097168e-05	


# current error = 5.0209363301595e-05	


# current error = 4.948616027832e-05	


# current error = 4.8732757568359e-05	


# current error = 4.8027038574219e-05	


# current error = 4.728635152181e-05	


# current error = 4.662831624349e-05	


# current error = 4.5933723449707e-05	


# current error = 4.5326550801595e-05	


# current error = 4.4655799865723e-05	


# current error = 4.4045448303223e-05	


# current error = 4.3444633483887e-05	


# current error = 4.2870839436849e-05	


# current error = 4.2312939961751e-05	


# current error = 4.1743914286296e-05	


# current error = 4.1213035583496e-05	


# current error = 4.0710767110189e-05	


# current error = 4.0179888407389e-05	


# current error = 3.9698282877604e-05	


# current error = 3.9167404174805e-05	


# current error = 3.8711229960124e-05	


# current error = 3.8270950317383e-05	


# current error = 3.7811597188314e-05	


# current error = 3.7353833516439e-05	


# current error = 3.6929448445638e-05	


# current error = 3.6524136861165e-05	


# current error = 3.6071141560872e-05	


# current error = 3.569761912028e-05	


# current error = 3.5268465677897e-05	


# current error = 3.4890174865723e-05	


# current error = 3.4521420796712e-05	


# current error = 3.4151077270508e-05	


# current error = 3.3755302429199e-05	


# current error = 3.3427874247233e-05	


# current error = 3.3055941263835e-05	


# current error = 3.2720565795898e-05	


# current error = 3.23486328125e-05	


# current error = 3.2013257344564e-05	


# current error = 3.1685829162598e-05	


# current error = 3.1382242838542e-05	


# current error = 3.1061172485352e-05	


# current error = 3.0752817789714e-05	


# current error = 3.0431747436523e-05	


# current error = 3.0144055684408e-05	


# current error = 2.9854774475098e-05	


# current error = 2.9559135437012e-05	


# current error = 2.9300053914388e-05	


# current error = 2.9021898905436e-05	


# current error = 2.8740564982096e-05	


# current error = 2.8492609659831e-05	


# current error = 2.8214454650879e-05	


# current error = 2.7982393900553e-05	


# current error = 2.7713775634766e-05	


# current error = 2.7464230855306e-05	


# current error = 2.7190844217936e-05	


# current error = 2.6942888895671e-05	


# current error = 2.6731491088867e-05	


# current error = 2.6483535766602e-05	


# current error = 2.626101175944e-05	


# current error = 2.6024182637533e-05	


# current error = 2.5820732116699e-05	


# current error = 2.5575955708822e-05	


# current error = 2.536137898763e-05	


# current error = 2.5161107381185e-05	


# current error = 2.497673034668e-05	


# current error = 2.47589747111e-05	


# current error = 2.4545987447103e-05	


# current error = 2.4353663126628e-05	


# current error = 2.4145444234212e-05	


# current error = 2.3951530456543e-05	


# current error = 2.375602722168e-05	


# current error = 2.3579597473145e-05	


# current error = 2.3369789123535e-05	


# current error = 2.3190180460612e-05	


# current error = 2.3010571797689e-05	


# current error = 2.2821426391602e-05	


# current error = 2.264658610026e-05	


# current error = 2.2487640380859e-05	


# current error = 2.2323926289876e-05	


# current error = 2.2144317626953e-05	


# current error = 2.1971066792806e-05	


# current error = 2.1800994873047e-05	


# current error = 2.1653175354004e-05	




# current error = 2.1486282348633e-05	


# current error = 2.1328926086426e-05	


# current error = 2.1165211995443e-05	


# current error = 2.1006266276042e-05	


# current error = 2.0874341328939e-05	


# current error = 2.0707448323568e-05	


# current error = 2.0577112833659e-05	


# current error = 2.0418167114258e-05	




# current error = 2.0273526509603e-05	


# current error = 2.0128885904948e-05	


# current error = 2.0006497701009e-05	


# current error = 1.986821492513e-05	


# current error = 1.9720395406087e-05	


# current error = 1.9598007202148e-05	


# current error = 1.9461313883464e-05	


# current error = 1.9327799479167e-05	


# current error = 1.9195874532064e-05	


# current error = 1.9070307413737e-05	


# current error = 1.8963813781738e-05	


# current error = 1.8817583719889e-05	


# current error = 1.8690427144368e-05	


# current error = 1.8555323282878e-05	


# current error = 1.8444061279297e-05	


# current error = 1.8307367960612e-05	


# current error = 1.8194516499837e-05	


# current error = 1.8091201782227e-05	


# current error = 1.7970403035482e-05	


# current error = 1.7848014831543e-05	


# current error = 1.7743110656738e-05	


# current error = 1.7627080281576e-05	


# current error = 1.7526944478353e-05	


# current error = 1.7414093017578e-05	


# current error = 1.7321904500326e-05	


# current error = 1.7201105753581e-05	


# current error = 1.7086664835612e-05	
# StochasticGradient: you have reached the maximum number of iterations	
# training error = 1.7086664835612e-05	


In [9]:
correct = 0
for i=1,1000 do
    local groundtruth = testData.labels[i]
    local prediction = net:forward(testData.data[i])
    local confidences, indices = torch.sort(prediction, true)  -- true means sort in descending order
    if groundtruth == indices[1] then
        correct = correct + 1
    end
end

print(correct, 100*correct/1000 .. ' % ')

910	91 % 	


In [15]:
class_performance = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
for i=1,1000 do
    local groundtruth = testData.labels[i]
    local prediction = net:forward(testData.data[i])
    local confidences, indices = torch.sort(prediction, true)  -- true means sort in descending order
    if groundtruth == indices[1] then
        class_performance[groundtruth] = class_performance[groundtruth] + 1
    end
end

In [17]:
classes = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
for i=1,#classes do
    print(classes[i], 100*class_performance[i]/100 .. ' %')
end

0	82 %	
1	124 %	
2	104 %	
3	94 %	
4	98 %	
5	80 %	
6	80 %	
7	86 %	
8	76 %	
9	86 %	
