In [15]:
# numpy version

import numpy as np

N, D_in, H, D_out = 64, 1000, 100, 10

x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6

for t in range(500):
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    grad_y_pred = 2 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2


0 36026851.75371924
1 33255284.559068296
2 32535812.30143962
3 28773373.555607535
4 21333569.16442948
5 13168478.228453849
6 7281314.606478389
7 3981889.7573746275
8 2358034.3314947654
9 1560014.7730299286
10 1140006.9488510676
11 891776.0575408603
12 727233.2189031846
13 607511.0671494878
14 515088.48566382437
15 441216.22265056917
16 380725.64467581816
17 330434.96872855606
18 288120.5754391176
19 252279.98365180465
20 221761.20254236914
21 195616.24525817262
22 173098.0820043576
23 153629.4703651574
24 136721.42131871622
25 121987.28096591536
26 109099.78136910601
27 97795.53377799135
28 87832.56332122174
29 79048.76474559892
30 71284.36741817987
31 64401.863303759805
32 58304.85517759934
33 52878.418093523615
34 48030.15279126602
35 43689.737091280724
36 39796.24485022797
37 36299.11586584238
38 33153.243177688244
39 30315.633014775594
40 27752.43151527014
41 25433.912590902706
42 23333.56236033306
43 21428.746675502458
44 19698.37234722608
45 18124.77130682864
46 16691.755794411
4

425 4.139313819863701e-05
426 3.953997438002541e-05
427 3.777006203337061e-05
428 3.607970433549256e-05
429 3.446518145760625e-05
430 3.292297479079137e-05
431 3.1449993575812245e-05
432 3.0043081816264487e-05
433 2.8699283530464532e-05
434 2.7415871769477723e-05
435 2.6189884901858565e-05
436 2.5018924640197592e-05
437 2.3900405598331984e-05
438 2.2831966513014132e-05
439 2.181141998023885e-05
440 2.0836637765377176e-05
441 1.9905489064861634e-05
442 1.9016029475273346e-05
443 1.8166449936645398e-05
444 1.7354959955990065e-05
445 1.6579770313515535e-05
446 1.583930449583496e-05
447 1.5131963788846231e-05
448 1.4456234267741597e-05
449 1.3810762537292212e-05
450 1.3194175070920864e-05
451 1.2605259182103546e-05
452 1.2042648715129965e-05
453 1.1505175140603483e-05
454 1.0991775647638016e-05
455 1.0501290895680055e-05
456 1.0032735906316782e-05
457 9.585144193316332e-06
458 9.157617835113452e-06
459 8.749153654219538e-06
460 8.358940484789509e-06
461 7.986159149680665e-06
462 7.63004870

In [6]:
# torch version 0
import torch

dtype = torch.float
device = torch.device("cuda:0")
# print(torch.cuda.get_device_name(0))

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6

for t in range(500):
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    grad_y_pred = 2 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2


GeForce GTX 1080 Ti
0 37252260.0
1 30495010.0
2 24724340.0
3 17973014.0
4 11712066.0
5 7122838.0
6 4356665.5
7 2813657.0
8 1961588.625
9 1463899.25
10 1149751.75
11 934587.375
12 776841.75
13 655463.125
14 558832.25
15 480350.03125
16 415561.21875
17 361499.84375
18 316071.34375
19 277626.53125
20 244829.328125
21 216681.578125
22 192586.96875
23 171776.21875
24 153663.859375
25 137823.9375
26 123933.1875
27 111710.7890625
28 100912.8984375
29 91348.515625
30 82852.015625
31 75283.0
32 68519.15625
33 62461.19921875
34 57025.05859375
35 52133.1484375
36 47723.21484375
37 43742.2265625
38 40144.00390625
39 36882.5078125
40 33920.87109375
41 31226.982421875
42 28774.474609375
43 26539.001953125
44 24498.431640625
45 22631.728515625
46 20923.869140625
47 19358.26171875
48 17922.2265625
49 16603.564453125
50 15391.201171875
51 14275.8291015625
52 13251.0498046875
53 12308.0478515625
54 11438.5703125
55 10635.8994140625
56 9894.1279296875
57 9208.654296875
58 8574.044921875
59 7988.90234375


438 0.00010157780343433842
439 9.981735638575628e-05
440 9.751777542987838e-05
441 9.584626241121441e-05
442 9.380749543197453e-05
443 9.1886380687356e-05
444 9.011435759020969e-05
445 8.828358841128647e-05
446 8.694266580278054e-05
447 8.539969712728634e-05
448 8.387487469008192e-05
449 8.200555748771876e-05
450 8.054218778852373e-05
451 7.896155875641853e-05
452 7.785906200297177e-05
453 7.639429531991482e-05
454 7.512605952797458e-05
455 7.388700032606721e-05
456 7.248899055412039e-05
457 7.132780592655763e-05
458 7.014444418018684e-05
459 6.898058200022206e-05
460 6.782348646083847e-05
461 6.659739301539958e-05
462 6.537822628160939e-05
463 6.460803706431761e-05
464 6.375138036673889e-05
465 6.263935938477516e-05
466 6.123733328422531e-05
467 6.0492973716463894e-05
468 5.965030504739843e-05
469 5.85495145060122e-05
470 5.783137021353468e-05
471 5.703520218958147e-05
472 5.6041029893094674e-05
473 5.5377862736349925e-05
474 5.4733551223762333e-05
475 5.3800566092832014e-05
476 5.290

In [7]:
# torch version with autograd
import torch

dtype = torch.float
device = torch.device("cuda:0")
# print(torch.cuda.get_device_name(0))

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6

for t in range(500):
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()


0 30004208.0
1 24382408.0
2 22424054.0
3 20867300.0
4 18112156.0
5 14240100.0
6 10092217.0
7 6672046.0
8 4276254.0
9 2772914.0
10 1869940.375
11 1330267.125
12 998438.3125
13 784621.6875
14 638665.3125
15 533186.375
16 453190.9375
17 390163.90625
18 338966.21875
19 296468.9375
20 260658.984375
21 230235.796875
22 204091.359375
23 181463.203125
24 161776.046875
25 144582.421875
26 129526.8125
27 116293.8046875
28 104666.0625
29 94399.578125
30 85302.609375
31 77218.5859375
32 70015.8515625
33 63586.75390625
34 57836.2265625
35 52684.04296875
36 48055.87890625
37 43895.73046875
38 40150.34375
39 36773.44921875
40 33723.8359375
41 30964.447265625
42 28462.23046875
43 26193.376953125
44 24133.466796875
45 22262.2578125
46 20556.787109375
47 18999.732421875
48 17577.125
49 16275.5498046875
50 15086.6796875
51 13996.7705078125
52 12996.9169921875
53 12078.2177734375
54 11233.544921875
55 10456.599609375
56 9740.9775390625
57 9080.7587890625
58 8471.263671875
59 7908.384765625
60 7388.1660156

498 0.00024342755204997957
499 0.00023903213150333613


In [20]:
# torch version with autograd and nn
import torch

dtype = torch.float
device = torch.device("cuda:0")
# print(torch.cuda.get_device_name(0))

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out)
)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(500):
    y_pred = model.to(device)(x)
    
    loss = loss_fn(y_pred, y)
#     print(loss.size())
    print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()

0 660.8055419921875
1 644.6014404296875
2 628.8951416015625
3 613.5344848632812
4 598.6131591796875
5 584.0795288085938
6 570.02197265625
7 556.4755859375
8 543.3504028320312
9 530.570556640625
10 518.106201171875
11 505.9735107421875
12 494.2369689941406
13 482.8907775878906
14 471.8343811035156
15 461.0474853515625
16 450.5661315917969
17 440.3694763183594
18 430.3622131347656
19 420.5999450683594
20 411.0989685058594
21 401.85870361328125
22 392.86444091796875
23 384.0890808105469
24 375.53680419921875
25 367.2034606933594
26 359.0461730957031
27 351.07958984375
28 343.27130126953125
29 335.6396484375
30 328.20599365234375
31 320.9803771972656
32 313.9300842285156
33 307.0394592285156
34 300.310791015625
35 293.762451171875
36 287.349853515625
37 281.05816650390625
38 274.9222717285156
39 268.9276123046875
40 263.0546875
41 257.2928466796875
42 251.64990234375
43 246.09974670410156
44 240.64620971679688
45 235.27810668945312
46 230.02008056640625
47 224.8814239501953
48 219.84661865

379 0.00027861641137860715
380 0.00026403836091049016
381 0.0002502004208508879
382 0.0002370573638472706
383 0.00022457409068010747
384 0.00021273409947752953
385 0.00020149486954323947
386 0.00019082297512795776
387 0.00018069971702061594
388 0.0001710915967123583
389 0.00016198451339732856
390 0.0001533411123091355
391 0.00014514068607240915
392 0.0001373663399135694
393 0.0001299881550949067
394 0.00012299635272938758
395 0.0001163654014817439
396 0.00011008107685483992
397 0.00010412433766759932
398 9.847782348515466e-05
399 9.312792826676741e-05
400 8.806020923657343e-05
401 8.32536825328134e-05
402 7.870498666306958e-05
403 7.439147884724662e-05
404 7.031393761280924e-05
405 6.644702079938725e-05
406 6.278751243371516e-05
407 5.931983105256222e-05
408 5.604184843832627e-05
409 5.2935778512619436e-05
410 4.999502198188566e-05
411 4.721618097391911e-05
412 4.4583615817828104e-05
413 4.209541293676011e-05
414 3.9741953514749184e-05
415 3.7513949791900814e-05
416 3.540548641467467e-