In [1]:
import numpy as np

N, D_in, H, D_out = 64, 1000, 100, 10

In [2]:
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

l_rate = 1e-6

In [4]:
for t in range(500):
    # Forward
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    # Loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # Backward
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    # Update
    w1 -= l_rate * grad_w1
    w2 -= l_rate * grad_w2

0 29928684.7193
1 23651400.2125
2 20892431.5575
3 18489305.8844
4 15349153.9665
5 11657482.2019
6 8166077.61782
7 5411953.00847
8 3520851.27881
9 2319312.71408
10 1584422.83709
11 1133534.68649
12 850730.748143
13 665827.447961
14 538727.829281
15 447005.020537
16 377700.823151
17 323267.571318
18 279288.377823
19 243046.487796
20 212658.598207
21 186888.747974
22 164854.751333
23 145889.121427
24 129487.026607
25 115238.902376
26 102803.245813
27 91907.7818601
28 82335.687053
29 73898.7034291
30 66444.5698257
31 59841.9229302
32 53984.6008692
33 48777.2147135
34 44135.403169
35 39989.0410925
36 36279.7789966
37 32957.2036253
38 29979.9042843
39 27301.1955294
40 24889.7672317
41 22716.5246469
42 20755.0636856
43 18982.789873
44 17378.5481321
45 15923.4433079
46 14603.0470788
47 13403.7523455
48 12312.6922025
49 11319.8324036
50 10415.7104512
51 9590.82157571
52 8838.01960849
53 8149.43998132
54 7520.16274753
55 6944.3361024
56 6416.24759963
57 5932.10421903
58 5487.73795087
59 5079.473

442 2.27576233782e-05
443 2.17677564322e-05
444 2.08211878565e-05
445 1.9915838225e-05
446 1.90498804442e-05
447 1.82216928312e-05
448 1.7429645385e-05
449 1.66723117514e-05
450 1.59478413222e-05
451 1.52550875273e-05
452 1.45924340005e-05
453 1.39586193509e-05
454 1.33523814882e-05
455 1.27726437368e-05
456 1.22181044965e-05
457 1.16877110285e-05
458 1.11804018291e-05
459 1.06952543632e-05
460 1.02312235296e-05
461 9.78729272031e-06
462 9.36273107736e-06
463 8.9566346859e-06
464 8.56815805387e-06
465 8.19666261216e-06
466 7.8413029937e-06
467 7.50137739164e-06
468 7.17631226211e-06
469 6.8652897579e-06
470 6.56779198838e-06
471 6.28325327954e-06
472 6.01104986936e-06
473 5.75065134444e-06
474 5.50161056446e-06
475 5.26334940317e-06
476 5.03547717911e-06
477 4.81751714937e-06
478 4.60898469863e-06
479 4.40949157878e-06
480 4.21866630517e-06
481 4.03611368384e-06
482 3.86148604983e-06
483 3.69443925566e-06
484 3.53462646648e-06
485 3.38180439739e-06
486 3.23559414938e-06
487 3.095690436

In [5]:
import torch

dtype = torch.float
device = torch.device('cuda')

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device = device, dtype = dtype)
y = torch.randn(N, D_out, device = device, dtype = dtype)

w1 = torch.randn(D_in, H, device = device, dtype = dtype)
w2 = torch.randn(H, D_out, device = device, dtype = dtype)

l_rate = 1e-6

In [7]:
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= l_rate * grad_w1
    w2 -= l_rate * grad_w2

0 30581170.0
1 25572624.0
2 24405876.0
3 23216742.0
4 20285164.0
5 15571814.0
6 10622676.0
7 6632684.0
8 4016641.5
9 2465819.5
10 1593426.0
11 1098646.5
12 807696.0625
13 625683.3125
14 503801.25
15 416677.96875
16 350901.8125
17 299209.5625
18 257406.96875
19 222859.09375
20 193917.0
21 169438.546875
22 148589.171875
23 130697.0
24 115263.4375
25 101944.3125
26 90383.15625
27 80304.328125
28 71501.0390625
29 63783.4765625
30 57001.9296875
31 51026.9765625
32 45752.41796875
33 41083.51171875
34 36943.48046875
35 33271.125
36 30010.046875
37 27101.33984375
38 24508.166015625
39 22188.158203125
40 20109.69921875
41 18246.35546875
42 16571.544921875
43 15065.8515625
44 13709.8017578125
45 12487.330078125
46 11386.1533203125
47 10390.8232421875
48 9490.669921875
49 8675.4873046875
50 7936.3115234375
51 7265.744140625
52 6656.6376953125
53 6102.58203125
54 5598.525390625
55 5139.4013671875
56 4720.84765625
57 4339.10546875
58 3990.673828125
59 3672.700927734375
60 3381.998046875
61 3115.887

426 4.101195008843206e-05
427 4.0311111661139876e-05
428 3.999546242994256e-05
429 3.928838123101741e-05
430 3.877200651913881e-05
431 3.823169754468836e-05
432 3.765963629120961e-05
433 3.72283611795865e-05
434 3.6600191378965974e-05
435 3.585304148145951e-05
436 3.538408054737374e-05
437 3.484236367512494e-05
438 3.444786852924153e-05
439 3.411024226807058e-05
440 3.3416836231481284e-05
441 3.3220429031644017e-05
442 3.288083098595962e-05
443 3.23158128594514e-05
444 3.191797441104427e-05
445 3.1463263439945877e-05
446 3.0919723940314725e-05
447 3.0632891139248386e-05
448 3.0105071346042678e-05
449 2.9709795853705145e-05
450 2.9309381716302596e-05
451 2.8999053029110655e-05
452 2.8373200620990247e-05
453 2.8108963306294754e-05
454 2.784532080113422e-05
455 2.7446374588180333e-05
456 2.7004254661733285e-05
457 2.6713791157817468e-05
458 2.640163802425377e-05
459 2.6001158403232694e-05
460 2.5612929675844498e-05
461 2.54006044997368e-05
462 2.5019959139171988e-05
463 2.467962804075796e

In [9]:
import torch

dtype = torch.float
device = torch.device("cuda")

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device = device, dtype = dtype)
y = torch.randn(N, D_out, device = device, dtype = dtype)

w1 = torch.randn(D_in, H, device = device, dtype = dtype, requires_grad = True)
w2 = torch.randn(H, D_out, device = device, dtype = dtype, requires_grad = True)
l_rate = 1e-6

In [10]:
for t in range(500):
    y_pred = x.mm(w1).clamp(min = 0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    loss.backward()
    
    with torch.no_grad():
        w1 -= l_rate * w1.grad
        w2 -= l_rate * w2.grad
        
        w1.grad.zero_()
        w2.grad.zero_()

0 40267672.0
1 39017248.0
2 36020840.0
3 27253648.0
4 16663354.0
5 8751938.0
6 4635003.0
7 2747921.5
8 1879619.125
9 1424785.5
10 1146111.0
11 951351.625
12 803265.4375
13 685496.5
14 589517.25
15 510148.09375
16 443779.65625
17 387881.5
18 340462.3125
19 299983.0625
20 265231.8125
21 235297.609375
22 209459.96875
23 186999.875
24 167387.8125
25 150200.953125
26 135097.96875
27 121782.140625
28 110009.21875
29 99567.640625
30 90290.0
31 82016.7421875
32 74619.390625
33 67994.484375
34 62046.796875
35 56699.3984375
36 51882.4140625
37 47542.9296875
38 43620.1875
39 40067.60546875
40 36846.28125
41 33920.59375
42 31257.7109375
43 28830.8125
44 26617.62890625
45 24596.609375
46 22748.30859375
47 21055.109375
48 19501.5546875
49 18075.53515625
50 16765.9609375
51 15562.666015625
52 14455.646484375
53 13435.57421875
54 12494.994140625
55 11627.263671875
56 10825.533203125
57 10084.4814453125
58 9398.767578125
59 8764.14453125
60 8176.1435546875
61 7631.07275390625
62 7125.47802734375
63 665

457 9.573072020430118e-05
458 9.37798322411254e-05
459 9.211566793965176e-05
460 9.065726771950722e-05
461 8.922030974645168e-05
462 8.785425598034635e-05
463 8.638050348963588e-05
464 8.477733354084194e-05
465 8.33106751088053e-05
466 8.188329957192764e-05
467 8.092919597402215e-05
468 7.942497177282348e-05
469 7.827102672308683e-05
470 7.68477184465155e-05
471 7.566295244032517e-05
472 7.459369953721762e-05
473 7.33578999643214e-05
474 7.236692181322724e-05
475 7.098913920344785e-05
476 6.97816867614165e-05
477 6.88207583152689e-05
478 6.78742362651974e-05
479 6.680437218165025e-05
480 6.558316817972809e-05
481 6.470384687418118e-05
482 6.36976765235886e-05
483 6.286619463935494e-05
484 6.16505931247957e-05
485 6.094136551837437e-05
486 6.000285429763608e-05
487 5.909325409447774e-05
488 5.8552686823531985e-05
489 5.7706576626515016e-05
490 5.6933382438728586e-05
491 5.630909436149523e-05
492 5.544571467908099e-05
493 5.470530595630407e-05
494 5.387675264501013e-05
495 5.323209916241

In [11]:
import torch

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [14]:
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()


0 25689308.0
1 22780176.0
2 23468494.0
3 24785984.0
4 24322230.0
5 20705014.0
6 15025802.0
7 9457044.0
8 5487216.5
9 3143423.25
10 1892876.875
11 1238898.5
12 885462.875
13 679977.4375
14 549055.375
15 457782.9375
16 389422.3125
17 335449.46875
18 291350.0625
19 254533.09375
20 223381.546875
21 196753.3125
22 173874.90625
23 154109.90625
24 136951.015625
25 121997.890625
26 108937.328125
27 97480.125
28 87405.453125
29 78528.03125
30 70677.171875
31 63718.671875
32 57552.0234375
33 52063.5390625
34 47169.55078125
35 42795.84765625
36 38880.8203125
37 35367.7578125
38 32212.666015625
39 29374.50390625
40 26816.837890625
41 24506.884765625
42 22419.681640625
43 20530.583984375
44 18818.919921875
45 17265.85546875
46 15854.6484375
47 14571.5087890625
48 13403.3740234375
49 12339.9501953125
50 11371.0634765625
51 10486.80078125
52 9678.6923828125
53 8939.234375
54 8262.0859375
55 7641.326171875
56 7071.6708984375
57 6548.72607421875
58 6068.46923828125
59 5626.810546875
60 5220.224609375
6

392 0.0005779537605121732
393 0.0005608471110463142
394 0.0005444571143016219
395 0.0005291588604450226
396 0.000513773353304714
397 0.0004980055382475257
398 0.00048317149048671126
399 0.0004690426867455244
400 0.000457625777926296
401 0.00044469430577009916
402 0.0004330166557338089
403 0.0004210165934637189
404 0.0004089346039108932
405 0.00039827218279242516
406 0.00038809419493190944
407 0.00037760345730930567
408 0.0003682305687107146
409 0.00035738665610551834
410 0.0003483816981315613
411 0.0003401674621272832
412 0.00033086526673287153
413 0.00032329955138266087
414 0.00031448996742255986
415 0.0003063718613702804
416 0.00029961540712974966
417 0.0002918649697676301
418 0.0002842186077032238
419 0.00027793156914412975
420 0.0002712691784836352
421 0.0002645842032507062
422 0.00025751543580554426
423 0.0002525824820622802
424 0.00024655324523337185
425 0.00024084869073703885
426 0.0002358261845074594
427 0.0002308440743945539
428 0.0002248440869152546
429 0.0002192479296354577


In [16]:
import torch
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

l_rate = 1e-4

In [17]:
for t in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    model.zero_grad()
    
    loss.backward()
    
    with torch.no_grad():
        for param in model.parameters():
            param -= l_rate * param.grad

0 737.5913696289062
1 681.0422973632812
2 632.0503540039062
3 589.1618041992188
4 551.1578979492188
5 517.0707397460938
6 486.11907958984375
7 457.5793762207031
8 431.1181335449219
9 406.43731689453125
10 383.2294921875
11 361.359130859375
12 340.69049072265625
13 321.04376220703125
14 302.37811279296875
15 284.60565185546875
16 267.617431640625
17 251.42074584960938
18 235.99681091308594
19 221.30380249023438
20 207.32667541503906
21 194.07000732421875
22 181.4579315185547
23 169.50401306152344
24 158.20655822753906
25 147.50579833984375
26 137.38954162597656
27 127.8740234375
28 118.9213638305664
29 110.493408203125
30 102.59172821044922
31 95.2114028930664
32 88.30426788330078
33 81.8650894165039
34 75.87013244628906
35 70.2989501953125
36 65.11898040771484
37 60.30846405029297
38 55.83930587768555
39 51.68503952026367
40 47.83652114868164
41 44.277896881103516
42 40.99211502075195
43 37.95561599731445
44 35.15298080444336
45 32.55710983276367
46 30.17137336730957
47 27.970451354980

380 3.84576924261637e-05
381 3.7331279600039124e-05
382 3.624145392677747e-05
383 3.5183846193831414e-05
384 3.415368701098487e-05
385 3.3160224120365456e-05
386 3.2192252547247335e-05
387 3.1253275665221736e-05
388 3.0344001061166637e-05
389 2.9460063160513528e-05
390 2.8605529223568738e-05
391 2.7774216505349614e-05
392 2.6967714802594855e-05
393 2.6186318791587837e-05
394 2.542666879890021e-05
395 2.469177525199484e-05
396 2.3975724616320804e-05
397 2.3282789697987027e-05
398 2.261142071802169e-05
399 2.1957996068522334e-05
400 2.1322954125935212e-05
401 2.0707346266135573e-05
402 2.0110832338104956e-05
403 1.9530831195879728e-05
404 1.8969050870509818e-05
405 1.8424929294269532e-05
406 1.789587076928001e-05
407 1.7381005818606354e-05
408 1.688295196800027e-05
409 1.6398353182012215e-05
410 1.592809167050291e-05
411 1.5472045561182313e-05
412 1.5028254892968107e-05
413 1.4598104826291092e-05
414 1.4181306141836103e-05
415 1.3775475963484496e-05
416 1.3381827557168435e-05
417 1.30002

In [18]:
import torch
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

l_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr = l_rate)

In [19]:
for t in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 670.0675659179688
1 653.05224609375
2 636.5446166992188
3 620.4712524414062
4 604.8778076171875
5 589.7322998046875
6 575.0533447265625
7 560.7755126953125
8 546.9067993164062
9 533.4485473632812
10 520.3580932617188
11 507.6657409667969
12 495.3187561035156
13 483.3780822753906
14 471.8153991699219
15 460.54119873046875
16 449.51837158203125
17 438.7836608886719
18 428.39788818359375
19 418.36279296875
20 408.6347961425781
21 399.2022399902344
22 390.0315246582031
23 381.1014099121094
24 372.3504333496094
25 363.7919921875
26 355.4184875488281
27 347.2419128417969
28 339.2574768066406
29 331.4222106933594
30 323.76055908203125
31 316.2690734863281
32 308.9320373535156
33 301.7555847167969
34 294.730712890625
35 287.8853759765625
36 281.18011474609375
37 274.62152099609375
38 268.2050476074219
39 261.908447265625
40 255.73062133789062
41 249.6732177734375
42 243.73187255859375
43 237.90484619140625
44 232.1938018798828
45 226.5972442626953
46 221.11148071289062
47 215.7300567626953
4

359 0.00012464376050047576
360 0.0001174215431092307
361 0.00011060736869694665
362 0.00010417057637823746
363 9.811169729800895e-05
364 9.239042992703617e-05
365 8.700090256752446e-05
366 8.191104279831052e-05
367 7.712458318565041e-05
368 7.260217535076663e-05
369 6.834043597336859e-05
370 6.432898226194084e-05
371 6.0548227338586e-05
372 5.698519817087799e-05
373 5.362609954318032e-05
374 5.0463349907658994e-05
375 4.748904029838741e-05
376 4.4679294660454616e-05
377 4.2039468098664656e-05
378 3.954977000830695e-05
379 3.7207290006335825e-05
380 3.4998523915419355e-05
381 3.2923479011515155e-05
382 3.0968672945164144e-05
383 2.9127129892003722e-05
384 2.7393651180318557e-05
385 2.5765404643607326e-05
386 2.4231332645285875e-05
387 2.2787353373132646e-05
388 2.1426491002785042e-05
389 2.0148203475400805e-05
390 1.8945835108752362e-05
391 1.781329592631664e-05
392 1.6748248526710086e-05
393 1.5747549696243368e-05
394 1.4804768397880252e-05
395 1.3919106095272582e-05
396 1.308494483964