In [1]:
%matplotlib inline


PyTorch: Defining new autograd functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [2]:
import torch
from torch.autograd import Variable


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations on Variables; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

    # Manually zero the gradients after updating weights
    w1.grad.data.zero_()
    w2.grad.data.zero_()

0 31867192.0
1 27227074.0
2 25794266.0
3 23609850.0
4 19402774.0
5 13901236.0
6 8911919.0
7 5375584.5
8 3257291.25
9 2074466.5
10 1421328.125
11 1045540.4375
12 813898.9375
13 659212.5
14 548061.75
15 463634.9375
16 396671.9375
17 342065.15625
18 296833.84375
19 258782.765625
20 226495.5
21 198915.171875
22 175211.625
23 154749.03125
24 137010.671875
25 121591.171875
26 108148.6328125
27 96370.3125
28 86066.6171875
29 77004.8046875
30 69012.2109375
31 61960.94921875
32 55725.20703125
33 50191.078125
34 45267.80859375
35 40877.67578125
36 36957.859375
37 33452.7421875
38 30317.5859375
39 27505.119140625
40 24979.841796875
41 22708.080078125
42 20662.58984375
43 18817.2109375
44 17151.251953125
45 15649.208984375
46 14292.0615234375
47 13062.8427734375
48 11947.712890625
49 10935.822265625
50 10016.9775390625
51 9181.123046875
52 8420.3876953125
53 7728.90966796875
54 7099.77294921875
55 6526.20703125
56 6002.37939453125
57 5523.6552734375
58 5085.9677734375
59 4685.3544921875
60 4318.44

382 0.00011144750897074118
383 0.00010889928671531379
384 0.00010635921353241429
385 0.0001038106347550638
386 0.0001014333829516545
387 9.882374433800578e-05
388 9.671242878539488e-05
389 9.475171827943996e-05
390 9.235226025339216e-05
391 9.035936091095209e-05
392 8.829970465740189e-05
393 8.616990817245096e-05
394 8.431333844782785e-05
395 8.277186861960217e-05
396 8.112344949040562e-05
397 7.936661859275773e-05
398 7.770376396365464e-05
399 7.636867667315528e-05
400 7.460124470526353e-05
401 7.309101783903316e-05
402 7.15477071935311e-05
403 7.002995698712766e-05
404 6.851374200778082e-05
405 6.710779416607693e-05
406 6.616320024477318e-05
407 6.499900337075815e-05
408 6.374906661221758e-05
409 6.235305045265704e-05
410 6.145894440123811e-05
411 6.018417116138153e-05
412 5.914532084716484e-05
413 5.825841799378395e-05
414 5.714189319405705e-05
415 5.620365118375048e-05
416 5.5055632401490584e-05
417 5.388219142332673e-05
418 5.293450885801576e-05
419 5.21436522831209e-05
420 5.1219