In [None]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [1]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 23077336.0
1 15462882.0
2 12008316.0
3 10520315.0
4 10063504.0
5 10057845.0
6 10123327.0
7 9912577.0
8 9292999.0
9 8204142.5
10 6839601.5
11 5387828.0
12 4070945.75
13 2976114.5
14 2140889.25
15 1527435.375
16 1094436.5
17 792098.4375
18 583272.5625
19 438168.9375
20 336767.09375
21 264796.5
22 212923.90625
23 174681.796875
24 145904.34375
25 123716.390625
26 106234.546875
27 92185.9453125
28 80676.2421875
29 71084.84375
30 62979.515625
31 56051.34765625
32 50093.1796875
33 44907.4296875
34 40362.98046875
35 36362.06640625
36 32818.2421875
37 29668.8515625
38 26860.486328125
39 24351.08203125
40 22101.607421875
41 20081.62109375
42 18264.587890625
43 16627.146484375
44 15150.0029296875
45 13817.853515625
46 12613.0634765625
47 11521.822265625
48 10532.3935546875
49 9635.2763671875
50 8822.1689453125
51 8082.70068359375
52 7409.8896484375
53 6797.228515625
54 6241.36376953125
55 5734.130859375
56 5270.99560546875
57 4847.6015625
58 4460.6767578125
59 4106.7314453125
60 3782.4565429687

378 6.610759737668559e-05
379 6.489635416073725e-05
380 6.372555799316615e-05
381 6.238464266061783e-05
382 6.114694406278431e-05
383 5.9978050558129326e-05
384 5.8597630413714796e-05
385 5.787896952824667e-05
386 5.679362948285416e-05
387 5.563180457102135e-05
388 5.460943066282198e-05
389 5.373091698857024e-05
390 5.275693547446281e-05
391 5.184466499486007e-05
392 5.0878647016361356e-05
393 5.004566992283799e-05
394 4.9028290959540755e-05
395 4.834568972000852e-05
396 4.76756613352336e-05
397 4.669433837989345e-05
398 4.61639319837559e-05
399 4.529565921984613e-05
400 4.4457949115894735e-05
401 4.387640001368709e-05
402 4.298028943594545e-05
403 4.233920481055975e-05
404 4.146279752603732e-05
405 4.1074436012422666e-05
406 4.034720768686384e-05
407 3.9803417166695e-05
408 3.9143033063737676e-05
409 3.858171112369746e-05
410 3.8166828744579107e-05
411 3.7507557863136753e-05
412 3.690965240821242e-05
413 3.6393015761859715e-05
414 3.591132190194912e-05
415 3.549364919308573e-05
416 3.