In [2]:
# -*- coding: utf-8 -*-
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()


0 33006048.0
1 31294934.0
2 30084180.0
3 25971114.0
4 18854720.0
5 11722679.0
6 6630625.0
7 3769277.25
8 2304013.5
9 1559606.375
10 1153827.5
11 908913.5625
12 744591.8125
13 624528.1875
14 531627.5625
15 457011.71875
16 395815.90625
17 344800.1875
18 301826.90625
19 265371.78125
20 234325.25
21 207703.28125
22 184707.359375
23 164743.46875
24 147339.3125
25 132109.75
26 118738.2890625
27 106949.796875
28 96536.0390625
29 87314.484375
30 79108.3359375
31 71794.484375
32 65258.140625
33 59402.09375
34 54144.59375
35 49414.80078125
36 45160.41015625
37 41335.625
38 37886.45703125
39 34761.140625
40 31926.40625
41 29348.408203125
42 27002.4140625
43 24863.8515625
44 22912.982421875
45 21131.5625
46 19503.837890625
47 18013.435546875
48 16647.6953125
49 15395.1259765625
50 14245.140625
51 13188.060546875
52 12216.1201171875
53 11322.134765625
54 10498.5419921875
55 9739.384765625
56 9039.2392578125
57 8393.486328125
58 7796.81298828125
59 7245.3349609375
60 6735.48681640625
61 6263.7470703