In [1]:
%matplotlib inline


PyTorch: 定义可以autograd的函数
----------------------------------------

这里还是那个全连接网络的例子，不过这里我们不使用clamp来实现ReLU，而是我们自己来实现一个MyReLU的函数。




In [2]:
import torch


class MyReLU(torch.autograd.Function):
    """
    为了实现自定义的实现autograd的函数，我们需要基础torch.autograd.Function，
    然后再实现forward和backward两个函数。
    """

    @staticmethod
    def forward(ctx, input):
        """
        在forward函数，我们的输入是input，然后我们根据input计算输出。同时为了下面的backward，
        我们需要使用save\_for\_backward来保存用于反向计算的数据到ctx里，这里我们需要保存input。
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        从ctx.saved\_tensors里恢复input
        然后用input计算梯度
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # 为了调用我们自定义的函数，我们需要使用Function.apply方法，把它命名为'relu'
    relu = MyReLU.apply

    # 我们使用自定义的ReLU来进行Forward计算
    y_pred = relu(x.mm(w1)).mm(w2)
 
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
 
    loss.backward()
 
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
 
        w1.grad.zero_()
        w2.grad.zero_()

0 31219966.0
1 29228598.0
2 31855032.0
3 33837796.0
4 31309610.0
5 23219566.0
6 14017930.0
7 7351454.5
8 3824305.25
9 2181492.5
10 1429416.125
11 1053271.25
12 838279.4375
13 696292.375
14 591994.9375
15 509986.34375
16 443016.09375
17 387132.71875
18 339843.0
19 299528.5
20 264898.1875
21 235043.203125
22 209189.78125
23 186678.046875
24 167006.03125
25 149769.140625
26 134602.125
27 121227.1484375
28 109402.171875
29 98908.84375
30 89572.1796875
31 81244.7421875
32 73807.5859375
33 67140.8203125
34 61165.69140625
35 55794.71484375
36 50951.6796875
37 46584.84375
38 42636.40234375
39 39065.08984375
40 35830.0703125
41 32897.5859375
42 30235.6875
43 27815.34375
44 25608.77734375
45 23597.3984375
46 21764.626953125
47 20090.896484375
48 18560.94921875
49 17160.408203125
50 15876.46484375
51 14699.072265625
52 13618.0888671875
53 12626.189453125
54 11713.9921875
55 10874.2978515625
56 10101.07421875
57 9388.1953125
58 8730.314453125
59 8122.8173828125
60 7561.619140625
61 7042.642578125


475 9.529116505291313e-05
476 9.362120908917859e-05
477 9.179205517284572e-05
478 9.045530896401033e-05
479 8.876148058334365e-05
480 8.733809227123857e-05
481 8.590223296778277e-05
482 8.421840175287798e-05
483 8.267716475529596e-05
484 8.127454202622175e-05
485 7.992322207428515e-05
486 7.877988537074998e-05
487 7.754105899948627e-05
488 7.635498332092538e-05
489 7.502910011680797e-05
490 7.355350680882111e-05
491 7.20532116247341e-05
492 7.087919220793992e-05
493 6.966372893657535e-05
494 6.870318611618131e-05
495 6.756203947588801e-05
496 6.656658661086112e-05
497 6.57599521218799e-05
498 6.4848420151975e-05
499 6.390204362105578e-05
