# Learning Pytorch with Examples

In [1]:
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 26418617.82532145
1 22928774.017968252
2 23910229.04539955
3 25888322.92621596
4 26050398.921176262
5 22556652.915615927
6 16380222.002511023
7 10099573.50314493
8 5630988.258478143
9 3060805.3591068555
10 1742346.0146941298
11 1083296.1594924622
12 744339.024945023
13 556880.44742527
14 442802.65346295165
15 366242.7577075392
16 310365.8038042026
17 266951.7203512508
18 231897.5353759936
19 202859.98659602687
20 178395.58587090074
21 157536.17384305593
22 139617.67120714858
23 124117.49477646689
24 110645.42991350722
25 98891.07801249743
26 88605.11377357699
27 79585.84262327518
28 71626.67526190949
29 64582.54618429752
30 58337.27935706784
31 52787.79102130162
32 47839.192608782745
33 43419.78137961954
34 39463.894081725404
35 35915.63938769387
36 32722.2803222424
37 29849.662182240212
38 27260.39792957939
39 24923.871231499674
40 22818.239704223895
41 20915.825677000394
42 19194.336778504934
43 17630.68715686417
44 16210.817039188838
45 14918.618493164617
46 13741.16976234479
47 1

467 1.9945993002502287e-06
468 1.8968146271371027e-06
469 1.8039065091974904e-06
470 1.7154619970240227e-06
471 1.6313493116322023e-06
472 1.5513608729645666e-06
473 1.4752981186249975e-06
474 1.4029832009346496e-06
475 1.3342503080378626e-06
476 1.2688559201624182e-06
477 1.2066471551933468e-06
478 1.1474913925814599e-06
479 1.0912520773193756e-06
480 1.0377732376090902e-06
481 9.869147475231572e-07
482 9.38570141181169e-07
483 8.925585351602515e-07
484 8.487997861575097e-07
485 8.071941348276632e-07
486 7.676383446962927e-07
487 7.300105202380882e-07
488 6.942651960394839e-07
489 6.602327299302433e-07
490 6.278731013840427e-07
491 5.971021022335291e-07
492 5.678461292201028e-07
493 5.400122143921116e-07
494 5.135630352178945e-07
495 4.883993066332563e-07
496 4.6446008039550307e-07
497 4.4170239710283706e-07
498 4.2006535366444e-07
499 3.994797092725188e-07


In [4]:
import torch

dtype = torch.float
device = torch.device('cpu')
# device = torch.device('cuda:0') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 35808408.0
1 34572824.0
2 35920068.0
3 33378900.0
4 25074934.0
5 14924482.0
6 7696795.0
7 3946344.25
8 2266574.75
9 1504410.125
10 1118041.0
11 888476.3125
12 731997.9375
13 614981.4375
14 522673.375
15 447688.5625
16 385766.4375
17 334072.8125
18 290611.0625
19 253884.328125
20 222648.625
21 195915.90625
22 172942.421875
23 153109.625
24 135925.765625
25 120982.1015625
26 107952.734375
27 96555.1328125
28 86562.0546875
29 77789.46875
30 70041.2109375
31 63186.01953125
32 57100.31640625
33 51685.51953125
34 46858.52734375
35 42552.546875
36 38707.5703125
37 35270.43359375
38 32182.806640625
39 29404.228515625
40 26898.662109375
41 24634.875
42 22587.35546875
43 20731.796875
44 19048.708984375
45 17519.486328125
46 16128.4248046875
47 14861.0634765625
48 13705.615234375
49 12650.7236328125
50 11686.787109375
51 10804.8115234375
52 9997.2685546875
53 9256.7060546875
54 8577.1279296875
55 7952.94677734375
56 7379.326171875
57 6851.42724609375
58 6365.1689453125
59 5917.01123046875
60 55

420 0.0002886331349145621
421 0.0002814277249854058
422 0.00027458221302367747
423 0.00026730456738732755
424 0.0002611597010400146
425 0.00025531003484502435
426 0.00024930783547461033
427 0.00024328894505742937
428 0.00023780703486409038
429 0.00023216735280584544
430 0.00022668125166092068
431 0.00022157878265716136
432 0.0002158473216695711
433 0.0002110401401296258
434 0.00020695064449682832
435 0.00020167429465800524
436 0.00019735301611945033
437 0.00019320531282573938
438 0.0001889762352220714
439 0.00018472298688720912
440 0.00018085262854583561
441 0.0001774163101799786
442 0.0001734765392029658
443 0.00016959589265752584
444 0.00016646957374177873
445 0.0001626902085263282
446 0.00015950726810842752
447 0.00015641677600797266
448 0.0001534663897473365
449 0.00015018851263448596
450 0.00014737635501660407
451 0.00014425971312448382
452 0.0001413488353136927
453 0.00013838674931321293
454 0.00013558636419475079
455 0.0001328536163782701
456 0.00013080413918942213
457 0.0001278

In [9]:
import torch

dtype = torch.float
device = torch.device('cpu')
# device = torch.device('cuda:0') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
# Setting requires_grad=False indicates that we do not to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights
# Setting requires_grad=True indicates that we want to compute gradients
# with respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
                            
    loss.backward()
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
                            
        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 31966784.0
1 28670176.0
2 27899198.0
3 25465188.0
4 20294044.0
5 13840972.0
6 8408997.0
7 4868049.0
8 2890029.25
9 1844872.375
10 1286383.75
11 967420.6875
12 768659.8125
13 632946.125
14 533216.8125
15 455972.625
16 393874.28125
17 342694.3125
18 299814.0625
19 263513.875
20 232600.984375
21 206134.0625
22 183288.078125
23 163480.859375
24 146228.984375
25 131144.453125
26 117906.734375
27 106279.046875
28 96023.5625
29 86947.03125
30 78882.9453125
31 71711.546875
32 65305.625
33 59569.01953125
34 54427.91015625
35 49821.625
36 45666.609375
37 41912.29296875
38 38515.76953125
39 35435.9296875
40 32639.9609375
41 30096.65234375
42 27782.908203125
43 25676.548828125
44 23753.35546875
45 21994.927734375
46 20384.052734375
47 18906.70703125
48 17552.021484375
49 16309.0810546875
50 15166.7880859375
51 14114.103515625
52 13144.8427734375
53 12250.65234375
54 11426.1162109375
55 10663.37109375
56 9957.525390625
57 9303.7392578125
58 8697.837890625
59 8135.7734375
60 7613.85791015625
61 71

441 0.0008549016201868653
442 0.0008329160627909005
443 0.0008107459871098399
444 0.0007889915723353624
445 0.0007682730793021619
446 0.0007483436493203044
447 0.0007292699301615357
448 0.0007099871872924268
449 0.000691087800078094
450 0.0006742451805621386
451 0.0006581321940757334
452 0.0006410822388716042
453 0.0006249093567021191
454 0.0006108076777309179
455 0.0005951786879450083
456 0.0005805519176647067
457 0.0005664811469614506
458 0.0005531718488782644
459 0.0005403075483627617
460 0.000526582938618958
461 0.0005154077662155032
462 0.0005024208803661168
463 0.0004912837175652385
464 0.0004792911931872368
465 0.00046909847878850996
466 0.0004582099209073931
467 0.0004483349621295929
468 0.000437986251199618
469 0.00042731547728180885
470 0.0004180881951469928
471 0.0004088474961463362
472 0.00040026206988841295
473 0.0003909956431016326
474 0.00038245003088377416
475 0.00037434790283441544
476 0.0003661430673673749
477 0.0003587337560020387
478 0.000350799469742924
479 0.00034

In [15]:
import torch


class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

dtype = torch.float
device = torch.device('cpu')
# device = torch.device('cuda:0') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
# Setting requires_grad=False indicates that we do not to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights
# Setting requires_grad=True indicates that we want to compute gradients
# with respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    relu = MyReLU.apply
    
    y_pred = relu(x.mm(w1)).mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
                            
    loss.backward()
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
                            
        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 33782052.0
1 34346496.0
2 37580396.0
3 36776236.0
4 29346716.0
5 18218966.0
6 9450016.0
7 4643156.0
8 2494767.0
9 1558922.25
10 1117045.625
11 873864.8125
12 717568.875
13 604781.375
14 517137.03125
15 446357.0
16 387877.6875
17 338882.5625
18 297472.75
19 262244.5625
20 232081.109375
21 206105.984375
22 183606.75
23 164033.265625
24 146953.90625
25 131994.203125
26 118807.1015625
27 107178.8671875
28 96876.328125
29 87730.1796875
30 79591.1796875
31 72334.484375
32 65877.3515625
33 60088.65625
34 54890.3984375
35 50218.546875
36 46009.9375
37 42209.33203125
38 38771.73046875
39 35656.328125
40 32827.921875
41 30257.841796875
42 27918.052734375
43 25785.8046875
44 23839.755859375
45 22061.478515625
46 20434.82421875
47 18945.8828125
48 17580.60546875
49 16326.033203125
50 15172.44140625
51 14110.591796875
52 13132.1796875
53 12229.6767578125
54 11396.322265625
55 10626.19140625
56 9914.4599609375
57 9255.5947265625
58 8645.12890625
59 8079.66650390625
60 7555.6796875
61 7069.05224609

428 0.000501803238876164
429 0.0004877421597484499
430 0.0004757973656523973
431 0.00046301327529363334
432 0.0004524414543993771
433 0.00044119585072621703
434 0.0004297650302760303
435 0.000419158284785226
436 0.0004082326777279377
437 0.00039913671207614243
438 0.0003886091581080109
439 0.0003788176109082997
440 0.00037018151488155127
441 0.00036176806315779686
442 0.0003526692744344473
443 0.00034366996260359883
444 0.0003356722882017493
445 0.00032836952595971525
446 0.0003204709792044014
447 0.0003133289283141494
448 0.00030608600354753435
449 0.00029870972502976656
450 0.00029212867957539856
451 0.0002859053201973438
452 0.00027957020211033523
453 0.0002740583731792867
454 0.0002673472627066076
455 0.00026164305745624006
456 0.0002561573637649417
457 0.00025041954359039664
458 0.0002450781757943332
459 0.00024002778809517622
460 0.0002349150163354352
461 0.00023043990950100124
462 0.00022561970399692655
463 0.00022076285677030683
464 0.00021630038099829108
465 0.0002115416427841

## TensorFlow: Static Graph

In [20]:
import tensorflow as tf
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs
x = tf.placeholder(tf.float32, shape=(None, D_in))
y = tf.placeholder(tf.float32, shape=(None, D_out))

# Create random Tensors for weights
w1 = tf.Variable(tf.random_normal((D_in, H)))
w2 = tf.Variable(tf.random_normal((H, D_out)))

h = tf.matmul(x, w1)
h_relu = tf.maximum(h, tf.zeros(1))
y_pred = tf.matmul(h_relu, w2)

loss = tf.reduce_sum((y - y_pred) ** 2.0)

grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])

learning_rate = 1e-6
new_w1 = w1.assign(w1 - learning_rate * grad_w1)
new_w2 = w2.assign(w2 - learning_rate * grad_w2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    x_value = np.random.randn(N, D_in)
    y_value = np.random.randn(N, D_out)
    for _ in range(500):
        loss_value, _, _ = sess.run([loss, new_w1, new_w2],
                                    feed_dict={x: x_value, y: y_value})
        print(loss_value)

29804144.0
29075380.0
32276444.0
34110956.0
30560214.0
21492880.0
12209178.0
6088504.5
3101149.2
1784820.5
1197995.8
904397.0
732693.9
616529.44
529581.4
460475.8
403608.62
355959.8
315546.44
280976.56
251169.67
225406.28
202963.58
183345.19
166085.9
150832.28
137306.08
125268.7
114527.14
104912.09
96284.67
88511.75
81494.97
75148.85
69396.75
64159.94
59396.746
55055.523
51084.89
47449.883
44118.37
41061.062
38248.68
35657.71
33267.28
31061.027
29021.4
27133.406
25383.52
23761.857
22256.934
20859.031
19558.703
18350.654
17227.363
16179.809
15202.452
14289.79
13437.871
12642.5
11898.32
11201.9
10549.8
9938.816
9366.203
8828.914
8324.973
7851.804
7407.4663
6990.0767
6597.8965
6229.2334
5883.1367
5558.299
5252.6494
4965.025
4694.232
4439.1963
4198.8745
3972.4597
3758.9053
3557.6511
3367.7874
3188.5903
3019.485
2859.8723
2709.142
2566.7988
2432.351
2305.3372
2185.3816
2071.9497
1964.6824
1863.2645
1767.3293
1676.593
1590.8193
1509.7855
1433.0725
1360.4763
1291.728
1226.6118
1164.9362
1106.

## PyTorch: nn

In [22]:
import torch

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out)
)

loss_fn = torch.nn.MSELoss(reduction='sum')

leanring_rate = 1e-4
for t in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    model.zero_grad()
    
    loss.backward()
    
    with torch.no_grad():
        for param in model.parameters():
            param -= leanring_rate * param.grad

0 634.2044067382812
1 587.6046142578125
2 546.87255859375
3 510.763916015625
4 478.4909362792969
5 449.4096374511719
6 422.9550476074219
7 398.5513000488281
8 375.77471923828125
9 354.7318420410156
10 335.1690673828125
11 316.86676025390625
12 299.65081787109375
13 283.35296630859375
14 267.8443908691406
15 253.1413116455078
16 239.23989868164062
17 226.00816345214844
18 213.42022705078125
19 201.435302734375
20 189.9920654296875
21 179.15765380859375
22 168.86610412597656
23 159.09786987304688
24 149.8577117919922
25 141.1049041748047
26 132.7996826171875
27 124.9355697631836
28 117.50630950927734
29 110.45330047607422
30 103.8097915649414
31 97.554931640625
32 91.66276550292969
33 86.10077667236328
34 80.86578369140625
35 75.94041442871094
36 71.30916595458984
37 66.96902465820312
38 62.903533935546875
39 59.08238983154297
40 55.49288558959961
41 52.13631820678711
42 48.98802185058594
43 46.035728454589844
44 43.271507263183594
45 40.67979431152344
46 38.24748229980469
47 35.96559524

373 0.00033805842394940555
374 0.0003295188071206212
375 0.0003211937437299639
376 0.00031309318728744984
377 0.0003051979292649776
378 0.00029750209068879485
379 0.0002900110848713666
380 0.00028270683833397925
381 0.00027559237787500024
382 0.0002686608349904418
383 0.00026190545759163797
384 0.0002553286321926862
385 0.00024891606881283224
386 0.00024266888794954866
387 0.00023658189456909895
388 0.00023064782726578414
389 0.0002248746168334037
390 0.00021923820895608515
391 0.00021375331562012434
392 0.00020840656361542642
393 0.000203193339984864
394 0.00019811210222542286
395 0.00019315969257149845
396 0.0001883372460724786
397 0.00018364116840530187
398 0.00017906010907609016
399 0.00017459654191043228
400 0.00017025155830197036
401 0.0001660097623243928
402 0.00016187447181437165
403 0.00015785846335347742
404 0.00015393133799079806
405 0.00015010160859674215
406 0.00014637422282248735
407 0.00014273636043071747
408 0.00013919557386543602
409 0.00013574736658483744
410 0.000132

## Pytorch: optim

In [25]:
import torch

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out)
)

loss_fn = torch.nn.MSELoss(reduction='sum')

leanring_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    model.zero_grad()
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 661.7081909179688
1 661.5352172851562
2 661.3621826171875
3 661.1892700195312
4 661.016357421875
5 660.8434448242188
6 660.6705932617188
7 660.4977416992188
8 660.3248901367188
9 660.1521606445312
10 659.9793701171875
11 659.8067016601562
12 659.6339721679688
13 659.4613647460938
14 659.288818359375
15 659.1162719726562
16 658.9437255859375
17 658.7713012695312
18 658.5988159179688
19 658.4264526367188
20 658.2540893554688
21 658.0817260742188
22 657.9094848632812
23 657.7372436523438
24 657.5651245117188
25 657.3930053710938
26 657.2210693359375
27 657.0491333007812
28 656.877197265625
29 656.705322265625
30 656.5335083007812
31 656.3617553710938
32 656.18994140625
33 656.0182495117188
34 655.8465576171875
35 655.6749877929688
36 655.5035400390625
37 655.3320922851562
38 655.16064453125
39 654.9892578125
40 654.81787109375
41 654.6465454101562
42 654.4752197265625
43 654.303955078125
44 654.1327514648438
45 653.9615478515625
46 653.7904663085938
47 653.6195068359375
48 653.448608398

453 588.3358764648438
454 588.1845092773438
455 588.033203125
456 587.8818359375
457 587.7305908203125
458 587.579345703125
459 587.427978515625
460 587.2767333984375
461 587.125732421875
462 586.9747314453125
463 586.8236694335938
464 586.6726684570312
465 586.5217895507812
466 586.3707885742188
467 586.2198486328125
468 586.069091796875
469 585.9183349609375
470 585.767578125
471 585.6168212890625
472 585.4662475585938
473 585.3155517578125
474 585.1650390625
475 585.0145263671875
476 584.864013671875
477 584.713623046875
478 584.5631103515625
479 584.4127197265625
480 584.2623901367188
481 584.1121215820312
482 583.9618530273438
483 583.8116455078125
484 583.66162109375
485 583.5115356445312
486 583.3614501953125
487 583.2115478515625
488 583.0615844726562
489 582.9116821289062
490 582.76171875
491 582.6119995117188
492 582.4622192382812
493 582.3123779296875
494 582.16259765625
495 582.012939453125
496 581.86328125
497 581.713623046875
498 581.56396484375
499 581.4144287109375


## Pytorch: Custom nn Moduels

In [31]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
    
    
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 703.5922241210938
1 651.6735229492188
2 566.9246215820312
3 472.3230285644531
4 381.13531494140625
5 300.6065979003906
6 231.46157836914062
7 172.7969970703125
8 125.92198181152344
9 91.85741424560547
10 70.827392578125
11 61.763427734375
12 61.76057052612305
13 65.22379302978516
14 65.95622253417969
15 60.65507507324219
16 50.827613830566406
17 40.72055435180664
18 33.31852340698242
19 28.647382736206055
20 25.382436752319336
21 23.005775451660156
22 21.830934524536133
23 21.774818420410156
24 21.78876304626465
25 20.782732009887695
26 18.59795379638672
27 16.005409240722656
28 13.878814697265625
29 12.429287910461426
30 11.266716003417969
31 10.03994083404541
32 8.811214447021484
33 7.8201751708984375
34 7.058350563049316
35 6.268895149230957
36 5.322849273681641
37 4.416608810424805
38 3.830085515975952
39 3.5991649627685547
40 3.4997880458831787
41 3.3199880123138428
42 3.0496230125427246
43 2.7939021587371826
44 2.5888733863830566
45 2.36767315864563
46 2.0776047706604004
47 1.7

494 1.7429449810507647e-12
495 1.7643722854260302e-12
496 1.591311500973025e-12
497 1.825568559168933e-12
498 1.9209909270928582e-12
499 1.794455426265551e-12


## PyTorch: Control Flow + Weight sharing

In [30]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
    
    
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 613.8414916992188
1 653.973876953125
2 604.682373046875
3 609.8871459960938
4 466.8778381347656
5 604.8815307617188
6 579.6489868164062
7 603.9298095703125
8 596.2282104492188
9 592.130615234375
10 552.2302856445312
11 250.0496063232422
12 577.1481323242188
13 518.5153198242188
14 499.8061218261719
15 177.79685974121094
16 152.01919555664062
17 586.3853149414062
18 103.75989532470703
19 580.1700439453125
20 381.44384765625
21 359.7198791503906
22 330.310546875
23 550.1314086914062
24 263.08544921875
25 433.1102294921875
26 202.12844848632812
27 168.98866271972656
28 357.010009765625
29 149.64122009277344
30 409.73095703125
31 178.24819946289062
32 149.55484008789062
33 111.43412780761719
34 367.756103515625
35 230.65811157226562
36 103.37852478027344
37 92.09828186035156
38 66.74730682373047
39 54.2598991394043
40 49.81157684326172
41 302.8063659667969
42 145.90037536621094
43 53.539791107177734
44 120.4664077758789
45 30.6953125
46 280.3010559082031
47 187.47216796875
48 177.3464965

471 0.5882894992828369
472 0.5832046270370483
473 0.17954297363758087
474 0.17752380669116974
475 0.08164795488119125
476 0.07422234117984772
477 0.42409443855285645
478 0.7175725698471069
479 0.2310369312763214
480 0.07523498684167862
481 0.08567391335964203
482 0.6826462745666504
483 0.3984709084033966
484 0.319871187210083
485 0.45393693447113037
486 0.3985227644443512
487 0.42174264788627625
488 0.33844682574272156
489 0.21404844522476196
490 0.5939167737960815
491 0.13132597506046295
492 0.4686504006385803
493 0.1660091131925583
494 0.42131394147872925
495 0.38440465927124023
496 0.1285063624382019
497 0.12443900853395462
498 0.2895931601524353
499 0.8113041520118713
