# PyTorch Tutorials

In [31]:
import random
import numpy
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms

# Two-Layer NN (Numpy, manual backprop)

In [3]:
# No bias
N, D_in, H, D_out = 64, 1000, 100, 10

x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

lr = 1e-6
for epoch in range(500):
    # Forward pass: Compute y_pred
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(epoch, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred) # 1st
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h) # 2nd
    
    # Update weights
    w1 -= lr * grad_w1
    w2 -= lr * grad_w2
    

0 29051412.78
1 25778200.9373
2 27801575.4696
3 30631359.0911
4 30189862.9337
5 24318934.4186
6 15690693.6742
7 8490167.47034
8 4333038.79163
9 2333242.00835
10 1422269.3316
11 986611.311105
12 753693.199667
13 610760.285869
14 511713.763276
15 436780.872796
16 377042.87078
17 327939.217337
18 286821.461158
19 252026.073787
20 222345.818044
21 196809.3067
22 174742.119049
23 155582.50284
24 138874.006021
25 124257.721054
26 111435.911581
27 100146.096208
28 90178.1628341
29 81352.1085126
30 73534.0642047
31 66596.4783123
32 60418.9772622
33 54902.969382
34 49963.4426331
35 45534.0223826
36 41553.5725267
37 37968.7954383
38 34733.2742746
39 31810.7331945
40 29166.9431656
41 26773.7912577
42 24599.0689338
43 22622.2167162
44 20823.2029576
45 19184.1603524
46 17689.3102236
47 16324.2333392
48 15076.7346635
49 13934.7112174
50 12888.8747547
51 11931.4854807
52 11052.1734501
53 10244.5085843
54 9502.16815225
55 8819.50038322
56 8191.2384892
57 7612.30855884
58 7078.35193273
59 6585.36563141

473 1.05300798387e-05
474 1.0080525074e-05
475 9.65015639877e-06
476 9.23818875253e-06
477 8.8440305308e-06
478 8.46655861972e-06
479 8.10523680189e-06
480 7.75939722503e-06
481 7.42837123565e-06
482 7.11153364861e-06
483 6.80816675689e-06
484 6.51778616579e-06
485 6.23981468682e-06
486 5.97374745042e-06
487 5.71908347678e-06
488 5.47523733492e-06
489 5.24181910834e-06
490 5.01834486509e-06
491 4.80450843148e-06
492 4.5997793479e-06
493 4.40374779735e-06
494 4.21608950988e-06
495 4.03646789686e-06
496 3.86453341487e-06
497 3.69991244514e-06
498 3.54232137165e-06
499 3.39143665558e-06


# Two-Layer NN (PyTorch, manual backprop)

In [6]:
# No bias
N, D_in, H, D_out = 64, 1000, 100, 10

device = torch.device('cpu')
# device = torch.device('gpu')

x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)

lr = 1e-6
for epoch in range(500):
    # Forward pass: Compute y_pred
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # Compute and print loss (loss is scalar and stored in a PyTorch 
    # tensor, we can get its value as a Python number with loss.item()
    loss = (y_pred - y).pow(2).sum()
    print(epoch, loss.item())
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred) # 1st
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h) # 2nd
    
    # Update weights
    w1 -= lr * grad_w1
    w2 -= lr * grad_w2

0 34918032.0
1 32124314.0
2 33087986.0
3 32080922.0
4 26621100.0
5 17988800.0
6 10348984.0
7 5504402.5
8 3036208.0
9 1855520.75
10 1278017.125
11 966402.75
12 775822.1875
13 644949.4375
14 546931.375
15 469690.5625
16 406735.59375
17 354485.0
18 310630.15625
19 273464.0625
20 241687.1875
21 214372.484375
22 190774.734375
23 170284.984375
24 152413.265625
25 136752.375
26 122977.1484375
27 110821.7578125
28 100074.90625
29 90536.0859375
30 82046.515625
31 74461.4296875
32 67681.046875
33 61607.34375
34 56155.51171875
35 51250.40234375
36 46836.01953125
37 42853.34375
38 39248.80859375
39 35980.9140625
40 33019.4921875
41 30329.009765625
42 27879.9140625
43 25648.841796875
44 23614.23828125
45 21756.74609375
46 20061.140625
47 18508.744140625
48 17087.24609375
49 15784.095703125
50 14588.7138671875
51 13491.1953125
52 12482.990234375
53 11556.5302734375
54 10703.7724609375
55 9918.369140625
56 9195.1162109375
57 8528.513671875
58 7913.64013671875
59 7346.07275390625
60 6822.08837890625
6

445 6.138425669632852e-05
446 6.041454616934061e-05
447 5.950756894890219e-05
448 5.8511308452580124e-05
449 5.763873195974156e-05
450 5.661117029376328e-05
451 5.548699846258387e-05
452 5.4439686209661886e-05
453 5.3661166020901874e-05
454 5.279388278722763e-05
455 5.19423047080636e-05
456 5.115211024531163e-05
457 5.0374310376355425e-05
458 4.947253182763234e-05
459 4.86635253764689e-05
460 4.793733387487009e-05
461 4.707926927949302e-05
462 4.6610173740191385e-05
463 4.6000735892448574e-05
464 4.493154483498074e-05
465 4.424808867042884e-05
466 4.353362965048291e-05
467 4.2888957977993414e-05
468 4.234831430949271e-05
469 4.166871076449752e-05
470 4.11224682466127e-05
471 4.05208847951144e-05
472 4.001010529464111e-05
473 3.947801451431587e-05
474 3.894445762853138e-05
475 3.852220834232867e-05
476 3.7772821087855846e-05
477 3.728881347342394e-05
478 3.685180490720086e-05
479 3.634230961324647e-05
480 3.58016332029365e-05
481 3.519692472764291e-05
482 3.481325984466821e-05
483 3.441

# Two-Layer NN (PyTorch, autograd)

When using **autograd**, the forward pass of your network will define a **computational graph**; *nodes* in the graph will be Tensors, and *edges* will be functions that produce output Tensors from input Tensors. Backpropagating through this graph then allows you to easily compute gradients.

If we want to compute gradients with respect to some Tensor, then we set **requires_grad=True** when constructing that Tensor. Any PyTorch operations on that Tensor will cause a computational graph to be constructed, allowing us to later perform backpropagation through the graph. If x is a Tensor with requires_grad=True, then after backpropagation **x.grad** will be another Tensor holding the gradient of x with respect to some scalar value. On the other hand, we usually don't want to backpropagate through the weight update steps when training a neural network. In such scenarios we can use the **torch.no_grad()** context manager to prevent the construction of a computational graph.

In [7]:
# No bias
N, D_in, H, D_out = 64, 1000, 100, 10

device = torch.device('cpu')
# device = torch.device('gpu')

x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

lr = 1e-6
for epoch in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(epoch, loss.item())
    
    # This call will compute the gradient of loss with respect to 
    # all Tensors with requires_grad=True. 
    # After this call w1.grad and w2.grad will be Tensors holding 
    # the gradient of the loss with respect to w1 and w2 respectively.
    loss.backward()
    
    # Update weights using gradient descent. For this step we just want 
    # to mutate the values of w1 and w2 in-place; we don't want to build 
    # up a computational graph for the update steps, so we use the 
    # torch.no_grad() context manager to prevent PyTorch from building a 
    # computational graph for the updates
    with torch.no_grad():
        w1 -= lr * w1.grad
        w2 -= lr * w2.grad
        
        # Manually zero the gradients after running the backward pass
        w1.grad.zero_()
        w2.grad.zero_()

0 31326422.0
1 28625564.0
2 29273698.0
3 28508714.0
4 23866528.0
5 16603359.0
6 9874511.0
7 5445775.0
8 3073711.5
9 1902177.375
10 1314993.0
11 996134.0625
12 801942.5625
13 669692.9375
14 571395.6875
15 493943.75
16 430661.78125
17 377924.3125
18 333325.75
19 295225.375
20 262470.3125
21 234134.9375
22 209507.4375
23 188001.796875
24 169132.0625
25 152522.8125
26 137860.53125
27 124872.0390625
28 113358.6796875
29 103117.8359375
30 93973.3984375
31 85793.3671875
32 78450.734375
33 71846.03125
34 65893.7265625
35 60517.375
36 55658.6953125
37 51256.7890625
38 47262.2421875
39 43631.2578125
40 40323.078125
41 37306.171875
42 34554.734375
43 32035.01171875
44 29728.58203125
45 27614.287109375
46 25673.958984375
47 23892.435546875
48 22255.748046875
49 20748.24609375
50 19358.328125
51 18076.53125
52 16892.37890625
53 15797.296875
54 14784.529296875
55 13846.4501953125
56 12976.4267578125
57 12169.2880859375
58 11420.1220703125
59 10723.837890625
60 10076.1513671875
61 9473.4345703125
62 

479 0.0021785253193229437
480 0.002119706943631172
481 0.0020604718010872602
482 0.0020072595216333866
483 0.0019527877448126674
484 0.0018999500898644328
485 0.0018471170915290713
486 0.001800810219720006
487 0.001752884709276259
488 0.001706977840512991
489 0.0016628953162580729
490 0.0016196384094655514
491 0.0015752729959785938
492 0.0015352274058386683
493 0.0014975386438891292
494 0.0014596175169572234
495 0.0014223044272512197
496 0.0013842078624293208
497 0.0013506714021787047
498 0.00131758744828403
499 0.0012843938311561942


# PyTorch: Defining new autograd functions

In [9]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """
    @staticmethod
    def forward(ctx, x):
        """
        In the forward pass we receive a context object and a Tensor containing the
        input; we must return a Tensor containing the output, and we can use the
        context object to cache objects for use in the backward pass.
        """
        ctx.save_for_backward(x)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive the context object and a Tensor containing
        the gradient of the loss with respect to the output produced during the
        forward pass. We can retrieve cached data from the context object, and must
        compute and return the gradient of the loss with respect to the input to the
        forward function.
        """
        x, = ctx.saved_tensors
        grad_x = grad_output.clone()
        grad_x[x < 0] = 0
        return grad_x

dtype = torch.float
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and output
x = torch.randn(N, D_in, device=device, dtype = dtype)
y = torch.randn(N, D_out, device=device, dtype = dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype = dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype = dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y using operations on Tensors; we call our
    # custom ReLU implementation using the MyReLU.apply function
    y_pred = MyReLU.apply(x.mm(w1)).mm(w2)
 
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    with torch.no_grad():
        # Update weights using gradient descent
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after running the backward pass
        w1.grad.zero_()
        w2.grad.zero_()


0 31899396.0
1 27617632.0
2 27153348.0
3 26163906.0
4 22556914.0
5 16746497.0
6 10848735.0
7 6449658.0
8 3771219.25
9 2303944.75
10 1522605.375
11 1094335.125
12 843099.25
13 682586.5
14 570584.4375
15 486853.5625
16 420816.90625
17 366950.5625
18 322126.8125
19 284263.53125
20 251895.90625
21 224018.546875
22 199888.359375
23 178906.3125
24 160560.65625
25 144465.953125
26 130300.6015625
27 117807.4921875
28 106726.5546875
29 96879.2578125
30 88105.546875
31 80274.421875
32 73262.8359375
33 66968.2265625
34 61308.4375
35 56216.140625
36 51627.80078125
37 47478.0703125
38 43719.109375
39 40303.37109375
40 37194.58984375
41 34362.28125
42 31780.23828125
43 29422.6328125
44 27266.67578125
45 25291.546875
46 23480.126953125
47 21816.00390625
48 20287.248046875
49 18880.62109375
50 17584.0625
51 16388.623046875
52 15284.7177734375
53 14264.912109375
54 13322.2080078125
55 12449.3505859375
56 11640.5302734375
57 10890.904296875
58 10195.5322265625
59 9549.84765625
60 8949.556640625
61 8391.

423 0.0032636660616844893
424 0.0031628040596842766
425 0.0030665684025734663
426 0.0029743097256869078
427 0.0028835353441536427
428 0.0027930899523198605
429 0.0027107412461191416
430 0.002628244459629059
431 0.002553538652136922
432 0.0024768647272139788
433 0.002399270888417959
434 0.0023285404313355684
435 0.0022599902004003525
436 0.002194951055571437
437 0.002128880936652422
438 0.0020655388943850994
439 0.00200473191216588
440 0.0019476930610835552
441 0.0018893115920946002
442 0.0018355398206040263
443 0.0017850252334028482
444 0.0017341901548206806
445 0.0016853944398462772
446 0.0016395312268286943
447 0.0015934089897200465
448 0.001551691791974008
449 0.0015069692162796855
450 0.0014636326814070344
451 0.0014257528819143772
452 0.0013867653906345367
453 0.0013488081749528646
454 0.001313907327130437
455 0.0012794842477887869
456 0.0012460962170735002
457 0.0012100809253752232
458 0.0011788943083956838
459 0.0011467678705230355
460 0.0011163947638124228
461 0.001087639946490

# PyTorch: nn

When building neural networks we frequently think of arranging the computation into **layers**, some of which have **learnable parameters** which will be optimized during learning.

In PyTorch, the **nn** package serves this purpose. The nn package defines a set of **Modules**, which are roughly equivalent to *neural network layers*. A Module receives input Tensors and computes output Tensors, but may also hold internal state such as Tensors containing learnable parameters. The nn package also defines a set of useful loss functions that are commonly used when training neural networks.

In [11]:
device = torch.device('cpu')

N, D_in, H, D_out = 64, 1000, 100, 10
dtype = torch.float

x = torch.randn(N, D_in, device=device, dtype = dtype)
y = torch.randn(N, D_out, device=device, dtype = dtype)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    ).to(device)

loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    # Zero the gradients before running the backward pass.
    model.zero_grad()
    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()
    with torch.no_grad():
        for param in model.parameters():
            param.data -= learning_rate * param.grad

0 684.3977661132812
1 630.6233520507812
2 584.8399047851562
3 545.417236328125
4 510.755859375
5 479.97943115234375
6 452.16815185546875
7 426.7303161621094
8 403.4039611816406
9 381.9970397949219
10 362.0254821777344
11 343.37371826171875
12 325.68609619140625
13 309.110595703125
14 293.5045166015625
15 278.7313232421875
16 264.6668701171875
17 251.2382354736328
18 238.4595184326172
19 226.25018310546875
20 214.5805206298828
21 203.4050750732422
22 192.7078857421875
23 182.54283142089844
24 172.84783935546875
25 163.5914306640625
26 154.750732421875
27 146.29690551757812
28 138.2255859375
29 130.53326416015625
30 123.21186828613281
31 116.25303649902344
32 109.64311218261719
33 103.36421203613281
34 97.39875793457031
35 91.74606323242188
36 86.38350677490234
37 81.30864715576172
38 76.50733184814453
39 71.97425842285156
40 67.68302154541016
41 63.630184173583984
42 59.8081169128418
43 56.203956604003906
44 52.81985855102539
45 49.638519287109375
46 46.64120101928711
47 43.819381713867

372 0.0001576106733409688
373 0.00015363056445494294
374 0.0001497528573963791
375 0.00014598784036934376
376 0.00014231790555641055
377 0.00013874594878870994
378 0.00013527055853046477
379 0.00013188490993343294
380 0.00012858898844569921
381 0.0001253804366569966
382 0.0001222524733748287
383 0.00011921197437914088
384 0.00011625051411101595
385 0.00011336465831845999
386 0.00011055084905819967
387 0.00010781348828459159
388 0.00010514898895053193
389 0.00010255022789351642
390 0.0001000207630568184
391 9.75562070379965e-05
392 9.515850979369134e-05
393 9.282132668886334e-05
394 9.054211841430515e-05
395 8.832311141304672e-05
396 8.61580774653703e-05
397 8.40540014905855e-05
398 8.200037700589746e-05
399 7.999984518392012e-05
400 7.804857159499079e-05
401 7.61500486987643e-05
402 7.430143159581348e-05
403 7.249343616422266e-05
404 7.073766755638644e-05
405 6.902222230564803e-05
406 6.735320494044572e-05
407 6.572420534212142e-05
408 6.413731898646802e-05
409 6.259329529711977e-05
41

In [28]:
for param in model.parameters():
    print(param.size(), '\n')

torch.Size([100, 1000]) 

torch.Size([100]) 

torch.Size([10, 100]) 

torch.Size([10]) 



# PyTorch: optim

In [29]:
N, D_in, H, D_out = 64, 1000, 100, 10

dtype = torch.float
x = torch.randn(N, D_in, dtype=dtype)
y = torch.randn(N, D_out, dtype=dtype)
device = torch.device('cpu')

model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H),
        torch.nn.ReLU(),
        torch.nn.Linear(H, D_out)
        ).to(device)

loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(500):
    y_pred = model(x) # Forward pass
    
    loss = loss_fn(y_pred, y)
    print(epoch, loss.item())
    
    optimizer.zero_grad()
    
    loss.backward() # Backprop (Backward pass)
    
    optimizer.step()

0 674.7896728515625
1 657.497802734375
2 640.7019653320312
3 624.3692016601562
4 608.525146484375
5 593.1459350585938
6 578.2470092773438
7 563.7449340820312
8 549.7078247070312
9 536.0301513671875
10 522.7205810546875
11 509.83123779296875
12 497.3307800292969
13 485.2369079589844
14 473.5286865234375
15 462.1333923339844
16 451.0416564941406
17 440.2685241699219
18 429.7782287597656
19 419.59405517578125
20 409.7344055175781
21 400.174072265625
22 390.85308837890625
23 381.7866516113281
24 373.00115966796875
25 364.47210693359375
26 356.1755065917969
27 348.0770568847656
28 340.25701904296875
29 332.73968505859375
30 325.3472595214844
31 318.10906982421875
32 311.0151062011719
33 304.0462646484375
34 297.21734619140625
35 290.5299072265625
36 284.0111083984375
37 277.65673828125
38 271.4178466796875
39 265.3093566894531
40 259.3272399902344
41 253.47862243652344
42 247.7609100341797
43 242.16700744628906
44 236.6873779296875
45 231.33023071289062
46 226.06214904785156
47 220.91255187

383 0.0002773124142549932
384 0.00026221232837997377
385 0.000247920339461416
386 0.0002343969390494749
387 0.00022158610227052122
388 0.00020945884170942008
389 0.00019798232824541628
390 0.00018710379663389176
391 0.0001768125221133232
392 0.00016707654867786914
393 0.00015786885342095047
394 0.00014914733765181154
395 0.00014089983596932143
396 0.00013309378118719906
397 0.00012571272964123636
398 0.00011872866161866114
399 0.00011212610843358561
400 0.00010587689030217007
401 9.997258894145489e-05
402 9.43797203944996e-05
403 8.909960888558999e-05
404 8.410928421653807e-05
405 7.938402995932847e-05
406 7.492578879464418e-05
407 7.070764695527032e-05
408 6.672004383290187e-05
409 6.295695493463427e-05
410 5.939501352258958e-05
411 5.603269164566882e-05
412 5.285347288008779e-05
413 4.985680789104663e-05
414 4.701704892795533e-05
415 4.4342392357066274e-05
416 4.181148688076064e-05
417 3.9422098780050874e-05
418 3.716950232046656e-05
419 3.503640982671641e-05
420 3.302521872683428e-0

# PyTorch: Custom nn Modules

In [30]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
    
N, D_in, H, D_out = 64, 1000, 100, 10
dtype = torch.float
device = torch.device('cpu')

x = torch.randn(N, D_in, dtype=dtype)
y = torch.randn(N, D_out, dtype=dtype)

model = TwoLayerNet(D_in, H, D_out)

loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred, y)
    print(epoch, loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 741.2385864257812
1 723.5322875976562
2 706.2889404296875
3 689.5057983398438
4 673.1799926757812
5 657.3681030273438
6 642.031494140625
7 627.0531616210938
8 612.4808959960938
9 598.3419799804688
10 584.5706787109375
11 571.1962280273438
12 558.2335205078125
13 545.6686401367188
14 533.4560546875
15 521.5656127929688
16 509.9730224609375
17 498.7438049316406
18 487.8216552734375
19 477.1598205566406
20 466.762451171875
21 456.62158203125
22 446.797119140625
23 437.27734375
24 427.9966735839844
25 418.9244689941406
26 410.0148620605469
27 401.3161926269531
28 392.8056640625
29 384.4886474609375
30 376.36639404296875
31 368.411376953125
32 360.6177062988281
33 352.9718933105469
34 345.4583740234375
35 338.0610046386719
36 330.7948303222656
37 323.6539306640625
38 316.6589660644531
39 309.8097839355469
40 303.08642578125
41 296.4962463378906
42 290.025146484375
43 283.6714782714844
44 277.4303894042969
45 271.29345703125
46 265.2447204589844
47 259.2868347167969
48 253.42575073242188
4

374 0.0001561605604365468
375 0.0001491261791670695
376 0.00014241502503864467
377 0.00013601100363302976
378 0.0001299079303862527
379 0.0001240887213498354
380 0.00011852701572934166
381 0.00011322968930471689
382 0.00010816362919285893
383 0.0001033312946674414
384 9.87238236120902e-05
385 9.432495426153764e-05
386 9.01198509382084e-05
387 8.611012890469283e-05
388 8.227946818806231e-05
389 7.86228192737326e-05
390 7.513109449064359e-05
391 7.179170643212274e-05
392 6.86145358486101e-05
393 6.55680923955515e-05
394 6.266170385060832e-05
395 5.988304837956093e-05
396 5.723603317164816e-05
397 5.470384348882362e-05
398 5.228301961324178e-05
399 4.99686757393647e-05
400 4.776086279889569e-05
401 4.564904884318821e-05
402 4.362952313385904e-05
403 4.170315514784306e-05
404 3.9858299714978784e-05
405 3.809448389802128e-05
406 3.64117331628222e-05
407 3.4802094887709245e-05
408 3.32623822032474e-05
409 3.1791125365998596e-05
410 3.038307659153361e-05
411 2.904094617406372e-05
412 2.775391

# PyTorch: Control Flow + Weight Sharing

In [35]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
    
N, D_in, H, D_out = 64, 1000, 100, 10
dtype = torch.float
device = torch.device('cpu')

x = torch.randn(N, D_in, dtype=dtype)
y = torch.randn(N, D_out, dtype=dtype)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

for epoch in range(500):
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    print(epoch, loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 645.0264282226562
1 635.6986083984375
2 668.4109497070312
3 605.0932006835938
4 587.5155639648438
5 547.5601806640625
6 644.6016845703125
7 445.21600341796875
8 532.16748046875
9 340.1044921875
10 636.3773803710938
11 631.947509765625
12 216.30519104003906
13 468.6410217285156
14 447.1859130859375
15 634.4487915039062
16 585.002685546875
17 128.39662170410156
18 120.14215850830078
19 324.0041809082031
20 96.63548278808594
21 588.6312866210938
22 260.7765808105469
23 74.43215942382812
24 543.7608642578125
25 420.6781005859375
26 62.53512191772461
27 463.0618591308594
28 344.4100036621094
29 319.8106994628906
30 290.9010009765625
31 261.0998229980469
32 314.343505859375
33 277.7725830078125
34 214.62258911132812
35 221.85061645507812
36 120.69895935058594
37 107.068359375
38 282.34393310546875
39 57.499183654785156
40 49.56890106201172
41 403.3985900878906
42 277.38604736328125
43 230.16119384765625
44 372.1958923339844
45 167.9290313720703
46 124.01951599121094
47 75.01535034179688
48

455 0.8101961016654968
456 0.20547448098659515
457 0.47768574953079224
458 0.46305638551712036
459 0.6589791178703308
460 0.6462878584861755
461 0.08529990166425705
462 0.6077752709388733
463 0.39241823554039
464 0.08763162791728973
465 0.08111317455768585
466 0.06987768411636353
467 0.40240201354026794
468 0.32480597496032715
469 0.5442531704902649
470 0.2962420582771301
471 0.039496030658483505
472 0.03862518444657326
473 0.036267444491386414
474 0.23205174505710602
475 0.20748737454414368
476 0.04033225029706955
477 0.6617398262023926
478 0.41631442308425903
479 0.06875152140855789
480 0.07062743604183197
481 0.05899016559123993
482 0.040962621569633484
483 0.02651415579020977
484 0.021035211160779
485 0.46564170718193054
486 0.7916224598884583
487 0.2361866980791092
488 0.20484541356563568
489 0.3714633285999298
490 1.0481152534484863
491 0.22264577448368073
492 0.19921651482582092
493 0.04326990246772766
494 0.7270206809043884
495 0.6710115671157837
496 0.036094240844249725
497 0.