In [1]:
import torch
from torch import Tensor

import dlc_practical_prologue as prologue

In [2]:
# ex1
'''
that take as input a float tensor and returns a tensor of same size, obtained by applying component-wise
respectively tanh, and the first derivative of tanh.
'''

def sigma(x):
    return x.tanh()
    
    
def dsigma(x):
    return 1 - x.tanh().pow(2)


In [3]:
# ex2: loss
'''
v: predicted tensor
t: target tensor
loss: compute the cost with LSE
dloss: derivative of loss
'''

def loss(v, t):
    return (v - t).pow(2).sum()
    
def dloss(v, t):
    return 2 * (v - t)
    

In [4]:
# ex3: forward and backward passes

def forward_pass(w1, b1, w2, b2, x):
    '''
    whose arguments correspond to an input vector to the network, and the weight and bias of the two
    layers, and returns a tuple composed of the corresponding x (0) , s (1) , x (1) , s (2) , and x (2) .
    
    x0: tensor of 
    '''
    x0 = x
    s1 = (w1.mv(x0) + b1)
    x1 = sigma(s1)
    
    s2 = (w2.mv(x1) + b2)
    x2 = sigma(s2)
    
    return x0, s1, x1, s2, x2
    
    
def backward_pass(w1, b1, w2, b2, t, x, s1, x1, s2, x2, dl_dw1, dl_db1, dl_dw2, dl_db2):
    '''
    whose arguments correspond to the target vector, the quantities computed by the forward pass, and
    the tensors used to store the cumulated sums of the gradient on individual samples, and update the
    latters according to the formula of the backward pass.
    '''
    x0 = x
    dl_dx2 = dloss(x2, t)
    dl_ds2 = dl_dx2 * dsigma(s2)
    dl_dx1 = w2.t().mv(dl_ds2)
    dl_ds1 = dl_dx1 * dsigma(s1)
    
    dl_dw2.add_(dl_ds2.view(-1,1).mm(x1.view(1, -1)))
    dl_db2.add_(dl_ds2)
    dl_dw1.add_(dl_ds1.view(-1,1).mm(x0.view(1, -1)))
    dl_db1.add_(dl_ds1)

In [5]:
# driver
train_input, train_target, test_input, test_target = prologue.load_data(one_hot_labels = True, normalize = True)

nb_hidden = 50
nb_classes = train_target.size(1)
nb_train_examples = train_input.size(0)

zeta = 0.9

train_input *= zeta
test_input *= zeta

epsilon = 1e-6

w1 = Tensor(nb_hidden, train_input.size(1)).normal_(0, epsilon)
b1 = Tensor(nb_hidden).normal_(0, epsilon)

w2 = Tensor(nb_classes, nb_hidden).normal_(0, epsilon)
b2 = Tensor(nb_classes).normal_(0, epsilon)

dl_dw1 = Tensor(w1.size())
dl_db1 = Tensor(b1.size())
dl_dw2 = Tensor(w2.size())
dl_db2 = Tensor(b2.size())

iteration_count = 1000

step_size = 0.1 / nb_train_examples

for n in range(iteration_count):
    dl_dw1.zero_()
    dl_db1.zero_()
    dl_dw2.zero_()
    dl_db2.zero_()
    
    nb_train_errors = 0
    # accumulate loss
    acc_loss = 0
    
    for i in range(nb_train_examples):
        x0, s1, x1, s2, x2 = forward_pass(w1, b1, w2, b2, train_input[i])
        
        # check error
        predicted = x2.max(0)[1]
        if train_target[i, predicted] < 0:
            nb_train_errors += 1
        acc_loss += loss(x2, train_target[i])
        
        backward_pass(w1, b1, w2, b2, train_target[i], x0, s1, x1, s2, x2, dl_dw1, dl_db1, dl_dw2, dl_db2)
        
    # Gradient descent
    w1 -= step_size*dl_dw1
    b1 -= step_size*dl_db1
    w2 -= step_size*dl_dw2
    b2 -= step_size*dl_db2
    
    nb_test_errors = 0
    
    for i in range(test_input.size(0)):
        _, _, _, _, x2 = forward_pass(w1, b1, w2, b2, test_input[i])
        
        predicted = x2.max(0)[1]
        if test_target[i, predicted] < 0:
            nb_test_errors += 1
            
    print ('{:d} acc_loss: {:.02f} train_errors {:.02f} test_errors: {:.02f}'
          .format(n, acc_loss, nb_train_errors*100/train_input.size(0), nb_test_errors*100/test_input.size(0)))
        

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
0 acc_loss: 10000.00 train_errors 90.30 test_errors: 90.10
1 acc_loss: 7712.08 train_errors 88.30 test_errors: 90.10
2 acc_loss: 6327.60 train_errors 88.30 test_errors: 90.10
3 acc_loss: 5499.01 train_errors 88.30 test_errors: 90.10
4 acc_loss: 4982.97 train_errors 88.30 test_errors: 90.10
5 acc_loss: 4638.62 train_errors 88.30 test_errors: 90.10
6 acc_loss: 4350.50 train_errors 88.30 test_errors: 90.10
7 acc_loss: 3952.47 train_errors 88.30 test_errors: 90.10
8 acc_loss: 3649.57 train_errors 88.30 test_errors: 90.10
9 acc_loss: 3610.92 train_errors 88.30 test_errors: 90.10
10 acc_loss: 3605.49 train_errors 88.30 test_errors: 90.10
11 acc_loss: 3603.86 train_errors 88.30 test_errors: 90.10
12 acc_loss: 3603.02 train_errors 88.40 test_errors: 90.10
13 acc_loss: 3602.42 train_errors 88.40 test_errors: 90.10
14 acc_loss: 3601.89 train_errors 88.40 test_errors: 90.20
15 acc_loss: 36

139 acc_loss: 312.86 train_errors 2.60 test_errors: 15.40
140 acc_loss: 308.45 train_errors 2.40 test_errors: 16.00
141 acc_loss: 304.20 train_errors 2.50 test_errors: 15.40
142 acc_loss: 300.04 train_errors 2.40 test_errors: 16.00
143 acc_loss: 296.00 train_errors 2.40 test_errors: 15.40
144 acc_loss: 292.06 train_errors 2.40 test_errors: 15.60
145 acc_loss: 288.22 train_errors 2.40 test_errors: 15.60
146 acc_loss: 284.47 train_errors 2.40 test_errors: 15.70
147 acc_loss: 280.81 train_errors 2.40 test_errors: 15.60
148 acc_loss: 277.23 train_errors 2.40 test_errors: 15.70
149 acc_loss: 273.73 train_errors 2.30 test_errors: 15.90
150 acc_loss: 270.31 train_errors 2.30 test_errors: 15.70
151 acc_loss: 266.96 train_errors 2.30 test_errors: 15.80
152 acc_loss: 263.69 train_errors 2.20 test_errors: 15.70
153 acc_loss: 260.48 train_errors 2.10 test_errors: 15.70
154 acc_loss: 257.34 train_errors 2.10 test_errors: 15.70
155 acc_loss: 254.27 train_errors 2.00 test_errors: 15.70
156 acc_loss: 

281 acc_loss: 101.69 train_errors 1.20 test_errors: 15.10
282 acc_loss: 101.27 train_errors 1.20 test_errors: 15.10
283 acc_loss: 100.85 train_errors 1.20 test_errors: 15.10
284 acc_loss: 100.43 train_errors 1.20 test_errors: 15.10
285 acc_loss: 100.02 train_errors 1.20 test_errors: 15.10
286 acc_loss: 99.62 train_errors 1.20 test_errors: 15.10
287 acc_loss: 99.22 train_errors 1.20 test_errors: 15.10
288 acc_loss: 98.83 train_errors 1.20 test_errors: 15.10
289 acc_loss: 98.44 train_errors 1.20 test_errors: 15.10
290 acc_loss: 98.06 train_errors 1.20 test_errors: 15.20
291 acc_loss: 97.68 train_errors 1.20 test_errors: 15.20
292 acc_loss: 97.30 train_errors 1.20 test_errors: 15.20
293 acc_loss: 96.93 train_errors 1.20 test_errors: 15.20
294 acc_loss: 96.56 train_errors 1.20 test_errors: 15.20
295 acc_loss: 96.20 train_errors 1.20 test_errors: 15.20
296 acc_loss: 95.85 train_errors 1.20 test_errors: 15.20
297 acc_loss: 95.49 train_errors 1.20 test_errors: 15.20
298 acc_loss: 95.14 train_

425 acc_loss: 67.61 train_errors 1.10 test_errors: 14.90
426 acc_loss: 67.49 train_errors 1.10 test_errors: 14.90
427 acc_loss: 67.38 train_errors 1.10 test_errors: 14.90
428 acc_loss: 67.27 train_errors 1.10 test_errors: 14.90
429 acc_loss: 67.15 train_errors 1.10 test_errors: 14.90
430 acc_loss: 67.04 train_errors 1.10 test_errors: 14.90
431 acc_loss: 66.93 train_errors 1.10 test_errors: 14.90
432 acc_loss: 66.81 train_errors 1.10 test_errors: 14.90
433 acc_loss: 66.70 train_errors 1.10 test_errors: 14.90
434 acc_loss: 66.59 train_errors 1.10 test_errors: 14.90
435 acc_loss: 66.48 train_errors 1.10 test_errors: 14.90
436 acc_loss: 66.37 train_errors 1.10 test_errors: 14.90
437 acc_loss: 66.25 train_errors 1.10 test_errors: 14.90
438 acc_loss: 66.14 train_errors 1.10 test_errors: 14.90
439 acc_loss: 66.03 train_errors 1.10 test_errors: 14.90
440 acc_loss: 65.92 train_errors 1.10 test_errors: 14.90
441 acc_loss: 65.81 train_errors 1.10 test_errors: 14.90
442 acc_loss: 65.70 train_error

569 acc_loss: 51.56 train_errors 0.90 test_errors: 15.00
570 acc_loss: 51.50 train_errors 0.90 test_errors: 15.00
571 acc_loss: 51.44 train_errors 0.90 test_errors: 15.00
572 acc_loss: 51.39 train_errors 0.90 test_errors: 15.00
573 acc_loss: 51.33 train_errors 0.90 test_errors: 15.00
574 acc_loss: 51.28 train_errors 0.90 test_errors: 15.00
575 acc_loss: 51.23 train_errors 0.90 test_errors: 15.00
576 acc_loss: 51.18 train_errors 0.90 test_errors: 15.00
577 acc_loss: 51.12 train_errors 0.90 test_errors: 15.00
578 acc_loss: 51.07 train_errors 0.90 test_errors: 15.00
579 acc_loss: 51.02 train_errors 0.90 test_errors: 15.10
580 acc_loss: 50.97 train_errors 0.90 test_errors: 15.10
581 acc_loss: 50.92 train_errors 0.90 test_errors: 15.10
582 acc_loss: 50.87 train_errors 0.90 test_errors: 15.10
583 acc_loss: 50.82 train_errors 0.90 test_errors: 15.10
584 acc_loss: 50.77 train_errors 0.90 test_errors: 15.10
585 acc_loss: 50.72 train_errors 0.90 test_errors: 15.10
586 acc_loss: 50.67 train_error

713 acc_loss: 46.36 train_errors 0.90 test_errors: 15.00
714 acc_loss: 46.33 train_errors 0.90 test_errors: 15.00
715 acc_loss: 46.31 train_errors 0.90 test_errors: 15.00
716 acc_loss: 46.28 train_errors 0.90 test_errors: 15.00
717 acc_loss: 46.26 train_errors 0.90 test_errors: 15.00
718 acc_loss: 46.23 train_errors 0.90 test_errors: 15.00
719 acc_loss: 46.21 train_errors 0.90 test_errors: 15.00
720 acc_loss: 46.18 train_errors 0.90 test_errors: 15.00
721 acc_loss: 46.16 train_errors 0.90 test_errors: 15.00
722 acc_loss: 46.13 train_errors 0.90 test_errors: 15.00
723 acc_loss: 46.11 train_errors 0.90 test_errors: 15.00
724 acc_loss: 46.08 train_errors 0.90 test_errors: 15.00
725 acc_loss: 46.06 train_errors 0.90 test_errors: 15.00
726 acc_loss: 46.04 train_errors 0.90 test_errors: 15.00
727 acc_loss: 46.01 train_errors 0.90 test_errors: 15.00
728 acc_loss: 45.99 train_errors 0.90 test_errors: 15.00
729 acc_loss: 45.96 train_errors 0.90 test_errors: 15.00
730 acc_loss: 45.94 train_error

857 acc_loss: 42.99 train_errors 0.80 test_errors: 15.10
858 acc_loss: 42.96 train_errors 0.80 test_errors: 15.10
859 acc_loss: 42.93 train_errors 0.80 test_errors: 15.30
860 acc_loss: 42.90 train_errors 0.80 test_errors: 15.30
861 acc_loss: 42.87 train_errors 0.80 test_errors: 15.30
862 acc_loss: 42.85 train_errors 0.80 test_errors: 15.30
863 acc_loss: 42.82 train_errors 0.80 test_errors: 15.30
864 acc_loss: 42.79 train_errors 0.80 test_errors: 15.30
865 acc_loss: 42.76 train_errors 0.80 test_errors: 15.30
866 acc_loss: 42.73 train_errors 0.80 test_errors: 15.30
867 acc_loss: 42.70 train_errors 0.80 test_errors: 15.30
868 acc_loss: 42.67 train_errors 0.80 test_errors: 15.30
869 acc_loss: 42.65 train_errors 0.80 test_errors: 15.30
870 acc_loss: 42.62 train_errors 0.80 test_errors: 15.30
871 acc_loss: 42.59 train_errors 0.80 test_errors: 15.20
872 acc_loss: 42.56 train_errors 0.80 test_errors: 15.20
873 acc_loss: 42.53 train_errors 0.80 test_errors: 15.20
874 acc_loss: 42.50 train_error