In [1]:
%matplotlib inline


复习numpy
--------------

我们这里实现一个全连接的激活为ReLU的网络，它只有一个隐层，没有bias，用于回归预测一个值，loss是计算实际值和预测值的欧氏距离。

我们这里完全使用numpy手动的进行前向和后向计算。

numpy数组就是一个n维的数值，它并不知道任何关于深度学习、梯度下降或者计算图的东西，它只是进行数值运算。




In [2]:
import numpy as np

# N是batch size；D_in是输入大小
# H是隐层的大小；D_out是输出大小。
N, D_in, H, D_out = 64, 1000, 100, 10

# 随机产生输入与输出
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# 随机初始化参数
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # 前向计算y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # 计算loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # 反向计算梯度 
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # 更新参数
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 33646753.85632531
1 27453271.08056476
2 24080728.07197475
3 20012098.00318867
4 15020097.40268158
5 10091813.961896304
6 6341133.404644929
7 3910062.1217965647
8 2488779.0392809846
9 1676342.0555654871
10 1205096.8024708242
11 915995.4516851017
12 726284.396605152
13 593105.7328345232
14 494089.0852438557
15 417139.10831823596
16 355578.76236067386
17 305310.7595813564
18 263623.60571311344
19 228668.0907029406
20 199128.49472756893
21 174049.81959830096
22 152653.4961341839
23 134268.9237992906
24 118418.86725722818
25 104720.75201967722
26 92834.60862786917
27 82481.04671131056
28 73439.96087784009
29 65528.99758426403
30 58583.28115482667
31 52465.886487346346
32 47063.94844026916
33 42284.91975667222
34 38047.84042800703
35 34285.54548707235
36 30937.511824526584
37 27954.474727795743
38 25296.420806012582
39 22919.451085092278
40 20789.087979882876
41 18877.836070129975
42 17160.75784321333
43 15615.917956880401
44 14223.92132894742
45 12967.805813690054
46 11834.965045754612
47

428 1.451992367164433e-06
429 1.3779210821161447e-06
430 1.3076173662682819e-06
431 1.240924839728358e-06
432 1.1776450252452056e-06
433 1.117618407562209e-06
434 1.0606845366603025e-06
435 1.0066465953464284e-06
436 9.553852296428742e-07
437 9.06760319941981e-07
438 8.605999561195762e-07
439 8.168207355867012e-07
440 7.752650309969766e-07
441 7.358335325904279e-07
442 6.984182556189208e-07
443 6.629139398230205e-07
444 6.292344437323959e-07
445 5.972736160506238e-07
446 5.669366937675013e-07
447 5.381468211347913e-07
448 5.108301346590006e-07
449 4.849100521774883e-07
450 4.60313435199732e-07
451 4.3696401610213727e-07
452 4.1480617322114424e-07
453 3.9377687780501087e-07
454 3.7382564511017e-07
455 3.5488954234910913e-07
456 3.3691317307948826e-07
457 3.1985133777856254e-07
458 3.0365884476756654e-07
459 2.8829064135645517e-07
460 2.737087106467845e-07
461 2.598621944334112e-07
462 2.4672072982228823e-07
463 2.3424726777596949e-07
464 2.224080388842777e-07
465 2.1117328135298994e-07
