In [1]:
%matplotlib inline


PyTorch的Tensor
----------------

和前面一样，我们还是实现一个全连接的Relu激活的网络，它只有一个隐层并且没有bias。loss是预测与真实值的欧氏距离。


之前我们用Numpy实现，自己手动前向计算loss，反向计算梯度。这里还是一样，只不过把numpy数组换成了PyTorch的Tensor。

但是使用PyTorch的好处是我们可以利用GPU来加速计算，如果想用GPU计算，我们值需要在创建tensor的时候指定device为gpu。




In [2]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # 如果想在GPU上运算，把这行注释掉。
 
N, D_in, H, D_out = 64, 1000, 100, 10
 
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
 
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500): 
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
 
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
 
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
 
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 27160548.0
1 23449182.0
2 26120808.0
3 32149890.0
4 37190440.0
5 35832544.0
6 26425754.0
7 14859924.0
8 6947733.0
9 3219258.5
10 1716543.5
11 1111947.25
12 835275.8125
13 681324.5
14 578198.625
15 500179.78125
16 437220.53125
17 384775.625
18 340300.59375
19 302186.21875
20 269241.375
21 240663.921875
22 215757.671875
23 193945.390625
24 174762.15625
25 157839.625
26 142858.71875
27 129555.2890625
28 117708.453125
29 107117.5234375
30 97642.5078125
31 89144.1484375
32 81506.3515625
33 74628.2109375
34 68425.6328125
35 62819.69140625
36 57745.23046875
37 53145.54296875
38 48966.77734375
39 45171.5234375
40 41716.015625
41 38566.4140625
42 35691.43359375
43 33060.91796875
44 30652.283203125
45 28444.80078125
46 26417.853515625
47 24555.509765625
48 22842.537109375
49 21265.677734375
50 19814.068359375
51 18478.091796875
52 17244.171875
53 16104.5859375
54 15050.564453125
55 14074.8291015625
56 13172.9384765625
57 12337.578125
58 11562.255859375
59 10842.37109375
60 10173.2822265625
61 

432 0.00935486238449812
433 0.00907206628471613
434 0.008797461166977882
435 0.008530347608029842
436 0.008267566561698914
437 0.008022594265639782
438 0.007777847815304995
439 0.007544893771409988
440 0.007315889932215214
441 0.007104233838617802
442 0.006887681782245636
443 0.006679466459900141
444 0.006482238415628672
445 0.006291826721280813
446 0.006104573607444763
447 0.005923294462263584
448 0.005745504982769489
449 0.005578520707786083
450 0.005415928550064564
451 0.005256837699562311
452 0.005104752257466316
453 0.004957721568644047
454 0.004813564009964466
455 0.004674824420362711
456 0.00453700078651309
457 0.004409389570355415
458 0.004281625151634216
459 0.004159301985055208
460 0.00403966661542654
461 0.003928312100470066
462 0.0038173189386725426
463 0.003707451745867729
464 0.003601861884817481
465 0.003503216663375497
466 0.003407241078093648
467 0.0033099823631346226
468 0.003218397730961442
469 0.003128425218164921
470 0.0030428029131144285
471 0.00296075944788754
47