## Simple Neural Network Using only Tensors

In [1]:
import torch

In [2]:
dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
N, D_in, H, D_out = 64, 1000, 100, 10
# Number of batch, input_dimensions, Hidden Unit, Output Unit

In [5]:
x = torch.randn(N,D_in,device=device,dtype=dtype)
y = torch.randn(N,D_out,device=device,dtype=dtype)

In [6]:
# Randomly Initialize the weights
w1 = torch.randn(D_in,H, device=device, dtype=dtype)
w2 = torch.randn(H,D_out,device=device,dtype=dtype)

In [9]:
learning_rate = 1e-6
for t in range(500):
    # Forward Pass
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    
    # Compute the loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t,loss)
    
    # Backprop the gradients
    
     # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 33051412.0
1 32553072.0
2 37769400.0
3 41279868.0
4 36832980.0
5 24174994.0
6 12142266.0
7 5362460.5
8 2590345.0
9 1530816.375
10 1083047.0
11 849522.125
12 700124.5
13 590487.75
14 504316.0
15 434119.15625
16 375969.03125
17 327352.8125
18 286324.0
19 251458.5
20 221691.390625
21 196201.171875
22 174239.53125
23 155224.046875
24 138686.4375
25 124244.46875
26 111549.15625
27 100379.2734375
28 90520.8671875
29 81790.7109375
30 74044.1328125
31 67151.34375
32 61002.8359375
33 55504.05078125
34 50572.63671875
35 46145.83984375
36 42167.75
37 38580.21484375
38 35340.4765625
39 32407.0
40 29747.90625
41 27335.84375
42 25143.90625
43 23134.12109375
44 21305.29296875
45 19638.291015625
46 18115.56640625
47 16723.533203125
48 15450.400390625
49 14284.0380859375
50 13215.73828125
51 12235.52734375
52 11334.44140625
53 10505.4814453125
54 9742.8095703125
55 9040.400390625
56 8392.638671875
57 7795.07568359375
58 7243.6826171875
59 6734.41845703125
60 6264.2783203125
61 5829.28662109375
62 542

In [13]:
(y_pred-y).pow(2).sum().item()

3.533364360919222e-05

## Same Code using the autograd

In [18]:
# This code will generate the run time error stating that the element zero doesnot requries gradient
# Because we have not set the requures_grad as the true parameter
for t in range(500):
    # Forward Pass
    
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    # Compute the loss
    loss = (y_pred - y).pow(2).sum()
    print(t,loss.item())
    
    # Backpropogation
    loss.backward()
    
    # Mannualy ipdate the weigts
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

In [31]:
# This code will run since the requres_grad parameter is True
w1 = torch.randn(D_in,H, device=device, dtype=dtype,requires_grad=True)
w2 = torch.randn(H,D_out,device=device,dtype=dtype,requires_grad=True)
for t in range(500):
    # Forward Pass
    
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    # Compute the loss
    loss = (y_pred - y).pow(2).sum()
    print(t,loss.item())
    
    # Backpropogation
    loss.backward()
    
    # Mannualy ipdate the weigts
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 33007128.0
1 30457684.0
2 30816770.0
3 29110346.0
4 23432378.0
5 15652928.0
6 9039271.0
7 4920465.0
8 2772875.5
9 1717094.875
10 1185766.375
11 894216.875
12 714992.5
13 592664.125
14 501612.8125
15 430216.875
16 372164.1875
17 324006.1875
18 283462.09375
19 249007.671875
20 219555.0
21 194230.546875
22 172338.375
23 153331.953125
24 136781.1875
25 122354.0390625
26 109688.0234375
27 98542.09375
28 88703.0078125
29 79999.4921875
30 72278.1875
31 65411.33203125
32 59297.21484375
33 53831.63671875
34 48937.11328125
35 44545.953125
36 40599.15625
37 37047.8828125
38 33846.8203125
39 30957.3515625
40 28345.3125
41 25979.81640625
42 23835.015625
43 21886.720703125
44 20115.48046875
45 18503.80078125
46 17034.60546875
47 15694.599609375
48 14471.0361328125
49 13352.404296875
50 12328.6767578125
51 11391.0341796875
52 10531.541015625
53 9742.5
54 9018.07421875
55 8352.3623046875
56 7740.2041015625
57 7176.947265625
58 6658.2109375
59 6180.07373046875
60 5739.3994140625
61 5332.79541015625
6

439 0.00019459838222246617
440 0.00019114142924081534
441 0.00018765253480523825
442 0.0001840300828916952
443 0.00018020438437815756
444 0.00017718967865221202
445 0.00017354327428620309
446 0.00017061854305211455
447 0.0001677732652751729
448 0.0001637759996810928
449 0.00016083387890830636
450 0.0001587642473168671
451 0.0001554312475491315
452 0.0001530339359305799
453 0.0001499862555647269
454 0.00014761692727915943
455 0.0001453368313377723
456 0.0001427654642611742
457 0.0001404343347530812
458 0.0001380049216095358
459 0.00013576305354945362
460 0.00013348502398002893
461 0.0001312407257501036
462 0.00012938778672832996
463 0.00012726856220979244
464 0.00012518961739260703
465 0.00012289235019125044
466 0.0001214246149174869
467 0.00011937804811168462
468 0.00011731660924851894
469 0.000114947855763603
470 0.0001137601284426637
471 0.00011158105917274952
472 0.00011033198825316504
473 0.00010869177640415728
474 0.00010689064220059663
475 0.00010551226296229288
476 0.00010358652

### Defining new autograd functions

In [24]:
class MyRelu(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx,input):
        """
        Input is the tensor
        ctx is a context object that is used to stash the information 
        for backward computation
        """
        
        ctx.save_for_backward(input)
        return input.clamp(min =0)
    
    
    @staticmethod
    
    def backward(ctx, grad_output):
        """Derivative of Relu is 0 if x < 0 else 1. The derivative is step function"""
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    

In [26]:
# This code will run since the requres_grad parameter is True
w1 = torch.randn(D_in,H, device=device, dtype=dtype,requires_grad=True)
w2 = torch.randn(H,D_out,device=device,dtype=dtype,requires_grad=True)
for t in range(500):
    
    relu = MyRelu.apply
    # Forward Pass
    
    y_pred = relu(x.mm(w1)).mm(w2)
    
    # Compute the loss
    loss = (y_pred - y).pow(2).sum()
    print(t,loss.item())
    
    # Backpropogation
    loss.backward()
    
    # Mannualy ipdate the weigts
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 25754042.0
1 21067474.0
2 21649362.0
3 24709188.0
4 27508894.0
5 27185820.0
6 22333338.0
7 15028934.0
8 8583111.0
9 4535551.0
10 2425255.75
11 1412119.375
12 922474.25
13 669883.375
14 524736.5625
15 430845.5625
16 363630.21875
17 311803.15625
18 269982.375
19 235301.828125
20 206077.984375
21 181204.5625
22 159918.875
23 141616.375
24 125750.6015625
25 111973.359375
26 99962.1328125
27 89457.609375
28 80243.2421875
29 72118.28125
30 64952.03515625
31 58621.2734375
32 53017.5625
33 48047.890625
34 43638.5546875
35 39701.39453125
36 36177.984375
37 33017.71484375
38 30178.99609375
39 27623.42578125
40 25318.41796875
41 23237.154296875
42 21354.9453125
43 19649.345703125
44 18101.677734375
45 16695.1796875
46 15415.3984375
47 14249.400390625
48 13184.9384765625
49 12212.1953125
50 11321.65625
51 10505.9365234375
52 9757.7080078125
53 9070.4306640625
54 8438.4326171875
55 7856.96923828125
56 7321.16064453125
57 6826.7861328125
58 6370.48974609375
59 5949.0478515625
60 5559.10009765625
6

436 0.0003870144719257951
437 0.0003768503956962377
438 0.00036765806726180017
439 0.0003583166399039328
440 0.0003497657016851008
441 0.00034097046591341496
442 0.0003322034317534417
443 0.0003236506599932909
444 0.0003168244147673249
445 0.00030987319769337773
446 0.00030237872852012515
447 0.0002956635144073516
448 0.0002881543477997184
449 0.00028210689197294414
450 0.0002753455482888967
451 0.00026944425189867616
452 0.00026317170704714954
453 0.0002567792544141412
454 0.000250667188083753
455 0.00024484857567586005
456 0.0002397948846919462
457 0.00023482643882744014
458 0.0002295688318554312
459 0.00022460798209067434
460 0.00022010192333254963
461 0.00021541690512094647
462 0.00021068072237540036
463 0.0002068852772936225
464 0.0002021842956310138
465 0.0001978733198484406
466 0.00019349539070390165
467 0.00018920090224128217
468 0.00018536209245212376
469 0.0001816108706407249
470 0.0001779973244993016
471 0.00017426052363589406
472 0.00017092717462219298
473 0.000167612393852