In [1]:
%matplotlib inline


PyTorch: nn模块
-----------


我们接下来使用nn模块来实现这个简单的全连接网络。前面我们通过用Tensor和Operation等low-level API来创建
动态的计算图。这里我们使用更简单的high-level API。




In [2]:
import torch
print(torch.__version__)

# N是batch size；D_in是输入大小
# H是隐层的大小；D_out是输出大小。
N, D_in, H, D_out = 64, 1000, 100, 10

# 创建随机的Tensor作为输入和输出
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# 使用nn包来定义网络。nn.Sequential是一个包含其它模块(Module)的模块。每个Linear模块使用线性函数
# 来计算，它会内部创建需要的weight和bias。
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# 常见的损失函数在nn包里也有，不需要我们自己实现
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # 前向计算：通过x来计算y。Module对象会重写__call__函数，因此我们可以把它当成函数来调用。
    y_pred = model(x)

    # 计算loss 
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # 梯度清空，调用Sequential对象的zero_grad后所有里面的变量都会清零梯度
    model.zero_grad()

    # 反向计算梯度。我们通过Module定义的变量都会计算梯度。
    loss.backward()

    # 更新参数，所有的参数都在model.paramenters()里
    
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0.5.0a0+a24163a
0 665.8291015625
1 614.0447998046875
2 569.7244262695312
3 530.7206420898438
4 495.89105224609375
5 464.5169982910156
6 436.1329650878906
7 410.459228515625
8 386.9529724121094
9 365.2677307128906
10 344.98291015625
11 325.9330749511719
12 308.0951843261719
13 291.2874755859375
14 275.3282165527344
15 260.17352294921875
16 245.74899291992188
17 231.9853515625
18 218.92092895507812
19 206.46051025390625
20 194.60479736328125
21 183.24606323242188
22 172.43678283691406
23 162.18336486816406
24 152.42430114746094
25 143.16773986816406
26 134.38656616210938
27 126.10890197753906
28 118.2713623046875
29 110.8513412475586
30 103.85079956054688
31 97.25940704345703




32 91.06820678710938
33 85.24884796142578
34 79.78056335449219
35 74.64293670654297
36 69.823974609375
37 65.31278991699219
38 61.07954025268555
39 57.120540618896484
40 53.41852569580078
41 49.95926284790039
42 46.72844314575195
43 43.7166748046875
44 40.90652847290039
45 38.291927337646484
46 35.860286712646484
47 33.591575622558594
48 31.47673225402832
49 29.512805938720703
50 27.68549346923828
51 25.98240852355957
52 24.395750045776367
53 22.916954040527344
54 21.536808013916016
55 20.245407104492188
56 19.041471481323242
57 17.920501708984375
58 16.871746063232422
59 15.892293930053711
60 14.976529121398926
61 14.120612144470215
62 13.319745063781738
63 12.570850372314453
64 11.870187759399414
65 11.21413516998291
66 10.5996732711792
67 10.024039268493652
68 9.484028816223145
69 8.977757453918457
70 8.502097129821777
71 8.053671836853027
72 7.6325201988220215
73 7.236522197723389
74 6.864195346832275
75 6.513209342956543
76 6.182273864746094
77 5.870550632476807
78 5.5760774612426

425 3.1406325433636084e-05
426 3.0463617804343812e-05
427 2.954726187454071e-05
428 2.865969690901693e-05
429 2.7798820156021975e-05
430 2.6964484277414158e-05
431 2.615593621158041e-05
432 2.5366945919813588e-05
433 2.460663927195128e-05
434 2.3869777578511275e-05
435 2.3153837901190855e-05
436 2.2457608793047257e-05
437 2.178503382310737e-05
438 2.113391565217171e-05
439 2.0499797756201588e-05
440 1.9884930225089192e-05
441 1.928942401718814e-05
442 1.8710994481807575e-05
443 1.8150161849916913e-05
444 1.7608286725590006e-05
445 1.7079788449336775e-05
446 1.6568703358643688e-05
447 1.6071722711785696e-05
448 1.5590912880725227e-05
449 1.5125535355764441e-05
450 1.4674023987026885e-05
451 1.4234493391995784e-05
452 1.3808922631142195e-05
453 1.3395805581239983e-05
454 1.2994831195101142e-05
455 1.2606441487150732e-05
456 1.2229614185343962e-05
457 1.1865588930959348e-05
458 1.1510438525874633e-05
459 1.1165910109411925e-05
460 1.083310507965507e-05
461 1.050994706019992e-05
462 1.0196