# Un exemple très simple : Implémentation complète d’un réseau à 2 couches

In [2]:
# imports
import numpy as np
from numpy.random import randn

In [3]:
# Create all variables
N = 64 # number of examples
D_in = 1000 # number of input features
H = 100 # number of neurones in the first layer
D_out = 10 # number of outputs

x, y = randn(N, D_in), randn(N, D_out)
w1, w2 = randn(D_in, H), randn(H, D_out)

In [4]:
# Train the network
for t in range(2000):
    
    # Compute the prediction and loss
    h = 1 / (1 + np.exp(-x.dot(w1)))
    y_pred = h.dot(w2)
    loss = np.square(y_pred - y).sum()
    print(t,loss)
    
    # Gradient descent
    grad_y_pred = 2 * (y_pred - y)
    grad_w2 = h.T.dot(grad_y_pred)
    grad_h = grad_y_pred.dot(w2.T)
    grad_w1 = x.T.dot(grad_h * h* (1-h))
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    

0 31042.536455843943
1 20716.743109497824
2 16129.478258889983
3 13855.089148288698
4 12597.88837492864
5 11771.8565009951
6 11180.670186846162
7 10720.990144982836
8 10324.33257901821
9 9967.371744807557
10 9631.5394538139
11 9299.587169788054
12 8977.241229998595
13 8703.316511979152
14 8465.397457892734
15 8253.652180846511
16 8059.494998140369
17 7879.238474982014
18 7707.468067728849
19 7540.395069912438
20 7376.703452693107
21 7215.809105338596
22 7066.470678846802
23 6925.314098747896
24 6785.333747050594
25 6650.970037709974
26 6524.161687203425
27 6402.713929194771
28 6287.179058437645
29 6176.211767805769
30 6068.76341792221
31 5966.244756014228
32 5869.245084933778
33 5776.899135290428
34 5688.140366691627
35 5601.858073070453
36 5517.218523414408
37 5433.798028599771
38 5351.415580941581
39 5269.657915268468
40 5188.0208051013215
41 5107.067645728459
42 5028.79503302651
43 4953.601078665239
44 4880.265235533487
45 4807.799241087127
46 4735.871288324018
47 4665.844671228306


463 231.3710189086602
464 230.45796600301887
465 229.54965461887372
466 228.64604175708365
467 227.747085183874
468 226.85274351185572
469 225.96297629160102
470 225.07774411104106
471 224.19700869911844
472 223.32073302927688
473 222.44888141752827
474 221.581419609049
475 220.71831484655763
476 219.8595359132002
477 219.00505314234707
478 218.15483838669348
479 217.30886493940042
480 216.46710740078228
481 215.62954148529477
482 214.79614376531163
483 213.96689135038645
484 213.1417615033203
485 212.3207311972808
486 211.50377662128224
487 210.69087264435393
488 209.88199225143669
489 209.07710596625418
490 208.27618127786127
491 207.4791820881207
492 206.6860681969005
493 205.89679484030893
494 205.1113122948837
495 204.32956555751883
496 203.55149410731093
497 202.77703175175606
498 202.00610655617018
499 201.23864085216186
500 200.47455131873483
501 199.7137491283426
502 198.95614015009602
503 198.2016252033658
504 197.4501003572165
505 196.7014572743485
506 195.95558360238675
507

932 50.13863719550345
933 50.00733570476418
934 49.87643828533133
935 49.745943633533244
936 49.615850427482066
937 49.4861573253615
938 49.35686296406775
939 49.227965958221745
940 49.099464899565625
941 48.97135835675435
942 48.843644875547184
943 48.71632297940009
944 48.58939117045669
945 48.462847930928305
946 48.33669172485206
947 48.21092100021036
948 48.08553419139052
949 47.96052972196144
950 47.83590600773876
951 47.71166146010815
952 47.58779448957405
953 47.46430350949799
954 47.341186939989214
955 47.218443211910504
956 47.09607077095933
957 46.974068081785816
958 46.852433632108564
959 46.73116593679005
960 46.6102635418348
961 46.48972502827468
962 46.36954901590745
963 46.24973416685694
964 46.1302791889256
965 46.011182838712955
966 45.89244392447533
967 45.77406130870675
968 45.656033910421776
969 45.53836070712619
970 45.42104073646311
971 45.3040730975252
972 45.18745695182736
973 45.07119152393672
974 44.95527610175909
975 44.839710036484675
976 44.72449274219764
9

1385 17.5396495868028
1386 17.502945888570125
1387 17.466332966989434
1388 17.429810541351372
1389 17.39337833240579
1390 17.357036062358688
1391 17.320783454868867
1392 17.284620235044223
1393 17.248546129437642
1394 17.212560866042782
1395 17.176664174289403
1396 17.14085578503827
1397 17.10513543057594
1398 17.069502844609207
1399 17.033957762259025
1400 16.998499920054186
1401 16.96312905592498
1402 16.927844909196125
1403 16.892647220579452
1404 16.857535732166625
1405 16.822510187421177
1406 16.787570331170336
1407 16.752715909596656
1408 16.71794667022934
1409 16.683262361935064
1410 16.648662734908843
1411 16.614147540664433
1412 16.57971653202434
1413 16.54536946310992
1414 16.51110608933082
1415 16.476926167374486
1416 16.44282945519521
1417 16.40881571200312
1418 16.374884698252686
1419 16.341036175631263
1420 16.30726990704739
1421 16.27358565661861
1422 16.239983189659476
1423 16.206462272669068
1424 16.173022673318446
1425 16.139664160437917
1426 16.10638650400425
1427 16

1825 7.402170261119686
1826 7.388090365413612
1827 7.3740358747899934
1828 7.3600067574891375
1829 7.346002983047983
1830 7.3320245223007685
1831 7.318071347378844
1832 7.30414343170942
1833 7.290240750013382
1834 7.276363278302071
1835 7.262510993872963
1836 7.248683875304297
1837 7.234881902448603
1838 7.221105056425042
1839 7.207353319610662
1840 7.193626675630441
1841 7.179925109346165
1842 7.166248606843982
1843 7.152597155420935
1844 7.138970743570079
1845 7.125369360964334
1846 7.111792998439309
1847 7.0982416479745885
1848 7.084715302673851
1849 7.0712139567438115
1850 7.05773760547172
1851 7.044286245201776
1852 7.030859873310112
1853 7.017458488178727
1854 7.004082089167919
1855 6.99073067658785
1856 6.977404251668546
1857 6.964102816529042
1858 6.9508263741451275
1859 6.93757492831615
1860 6.9243484836306575
1861 6.911147045430962
1862 6.897970619776742
1863 6.884819213407668
1864 6.87169283370506
1865 6.8585914886527455
1866 6.845515186796972
1867 6.832463937205708
1868 6.8

In [14]:
h = 1 / (1 + np.exp(-x.dot(w1)))
h.shape

(64, 100)