In [1]:
%matplotlib inline


PyTorch: Defining new autograd functions
----------------------------------------

A **fully-connected ReLU network** with **one hidden layer** and **no biases**, 

Trained to predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch Variables, 

And uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform the ReLU function.



In [2]:
import torch

In [4]:
# MyReLU is a child class of Function

# ***
# Function can be regarded as the Function Operation
# With Forward Method
# and Backward Method
# ***

class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions 
    by subclassing `torch.autograd.Function` 
    and implementing the forward and backward passes
    which operate on Tensors.
    """

    # staticmethod do not need to take parameter: self
    
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass 
        We receive a Tensor containing the input 
        And return a Tensor containing the output. 
        
        `ctx` is a context object that can be used to stash information 
        for backward computation. 
        You can cache arbitrary objects for use in the backward pass 
        using the ctx.save_for_backward method.
        """
        # ctx is context
        # the method .save_for_backward
        # is acutally from Function
        ctx.save_for_backward(input)
        # operate ReLU Operation
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass 
        We receive a Tensor containing the gradient of the loss
        with respect to the output, 
        And we need to compute the gradient of the loss
        with respect to the input.
        """
        # This is just a step of ReLU
        # d(ReLu(x)) = 1 if x >= 0
        
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [6]:
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)


In [7]:

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    
    # Here x.mm(w1) is the input
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 22418600.0
1 16678010.0
2 14591717.0
3 14099033.0
4 14131016.0
5 13886398.0
6 12943339.0
7 11197784.0
8 8978464.0
9 6718944.5
10 4786429.0
11 3311236.25
12 2274950.25
13 1578423.5
14 1120292.75
15 819530.125
16 619658.75
17 484166.75
18 389663.59375
19 321647.1875
20 271023.0
21 232176.4375
22 201498.84375
23 176675.640625
24 156143.890625
25 138853.65625
26 124104.3203125
27 111371.0078125
28 100273.2265625
29 90539.59375
30 81953.6015625
31 74335.875
32 67565.6953125
33 61519.14453125
34 56102.29296875
35 51236.6640625
36 46856.88671875
37 42907.890625
38 39340.89453125
39 36108.5703125
40 33177.359375
41 30516.291015625
42 28094.908203125
43 25890.908203125
44 23880.337890625
45 22044.197265625
46 20364.58984375
47 18827.078125
48 17418.388671875
49 16126.783203125
50 14941.0888671875
51 13851.0380859375
52 12848.49609375
53 11925.3759765625
54 11075.6083984375
55 10293.1923828125
56 9573.5419921875
57 8909.46484375
58 8295.40234375
59 7727.3671875
60 7201.416015625
61 6713.895996

374 0.0007555692573077977
375 0.0007296799449250102
376 0.0007056138711050153
377 0.0006819795817136765
378 0.0006615443853661418
379 0.0006406618049368262
380 0.0006192965665832162
381 0.0006002160371281207
382 0.0005818617646582425
383 0.0005641267052851617
384 0.0005468495655804873
385 0.0005308804102241993
386 0.0005147046176716685
387 0.0004991893656551838
388 0.00048419839004054666
389 0.0004693936789408326
390 0.000456712645245716
391 0.00044297732529230416
392 0.00043103029020130634
393 0.00041796392179094255
394 0.0004064012027811259
395 0.00039509052294306457
396 0.0003832665388472378
397 0.00037309995968826115
398 0.00036284918314777315
399 0.00035287835635244846
400 0.0003438949934206903
401 0.0003343108110129833
402 0.0003250027948524803
403 0.0003160472260788083
404 0.00030731240985915065
405 0.0002995036484207958
406 0.0002924600266851485
407 0.0002841871464625001
408 0.00027804626733995974
409 0.00027078884886577725
410 0.0002635517157614231
411 0.000256700674071908
412