In [3]:
%matplotlib inline


PyTorch: 流程控制和参数共享
--------------------------------------

为了展示PyTorch的动态图的能力，我们这里会实现一个很奇怪模型：这个全连接的网络的隐层个数是个1到4之间的随机数，
而且这些网络层的参数是共享的。



In [4]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        构造3个nn.Linear实例。
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        输入和输出层是固定的，但是中间层的个数是随机的(0,1,2)，并且中间层的参数是共享的。
        
        因为每次计算的计算图是动态(实时)构造的，所以我们可以使用普通的Python流程控制代码比如for循环
        来实现。读者可以尝试一下怎么用TensorFlow来实现。另外一点就是一个Module可以多次使用，这样就
        可以实现参数共享。
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

 
N, D_in, H, D_out = 64, 1000, 100, 10
 
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
 
model = DynamicNet(D_in, H, D_out)
 
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500): 
    y_pred = model(x)
 
    loss = criterion(y_pred, y)
    print(t, loss.item())
 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 667.2325439453125
1 693.5020141601562
2 661.39453125
3 657.8023071289062
4 629.8822021484375
5 519.1797485351562
6 653.5802612304688
7 642.0776977539062
8 637.75048828125
9 584.5928344726562
10 627.3572998046875
11 643.9644775390625
12 556.0167236328125
13 638.9312744140625
14 325.15899658203125
15 632.3028564453125
16 590.1881713867188
17 501.169677734375
18 483.6795959472656
19 609.3463745117188
20 438.00732421875
21 210.1587677001953
22 188.89439392089844
23 366.6983642578125
24 341.95196533203125
25 560.5984497070312
26 289.00738525390625
27 450.74407958984375
28 126.8094482421875
29 491.33477783203125
30 464.3042297363281
31 431.2166442871094
32 199.5053253173828
33 371.03472900390625
34 119.70623779296875
35 191.9884490966797
36 300.1385803222656
37 273.79559326171875
38 221.24205017089844
39 233.88214111328125
40 97.59255981445312
41 191.1996307373047
42 163.98350524902344
43 80.37187194824219
44 70.52588653564453
45 122.30460357666016
46 211.8968505859375
47 104.5218429565429

386 0.371745765209198
387 0.49127039313316345
388 0.5159963965415955
389 0.4905443489551544
390 0.3951423168182373
391 0.5412530899047852
392 0.4557349681854248
393 0.44451403617858887
394 0.39588019251823425
395 0.5645430088043213
396 0.2031405121088028
397 0.19988282024860382
398 0.16355444490909576
399 0.1183689758181572
400 0.08736192435026169
401 0.07636848092079163
402 0.5534483194351196
403 0.4135391116142273
404 0.410937637090683
405 0.3865385353565216
406 0.2663421928882599
407 0.25469735264778137
408 0.3201303780078888
409 0.2853842079639435
410 0.493245393037796
411 0.1930563747882843
412 0.16763263940811157
413 1.0437934398651123
414 0.11094801127910614
415 0.3258151710033417
416 0.11869552731513977
417 0.17231100797653198
418 0.6704407930374146
419 0.7999902963638306
420 0.12791769206523895
421 0.4975522458553314
422 0.6035558581352234
423 0.4231157600879669
424 0.40610167384147644
425 0.463609904050827
426 0.35453495383262634
427 0.46370208263397217
428 0.2505448460578918