## Tensors

### Simple Two-layer Network

#### Numpy version

In [16]:
import numpy as np

In [2]:
# batch_size
N = 64
# input dimension
D_in = 1000
# hidden dimension
H = 100
# output dimension
D_out = 10

In [3]:
# random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

In [5]:
# randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

In [6]:
# learning rate
l_rate = 1e-6

In [14]:
for t in range(500):
    # Forward pass
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    # compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # Backprop to compute gradients
    dy_pred = 2.0 * (y_pred - y)
    #dw2/dloss
    dw2 = h_relu.T.dot(dy_pred)
    
    dh_relu = dy_pred.dot(w2.T)
    
    dh = dh_relu.copy()
    dh[h < 0] = 0
    
    #dw1/dloss
    dw1 = x.T.dot(dh)
    
    # update weights 
    w1 -= l_rate * dw1
    w2 -= l_rate * dw2

0 1.35740119993e-14
1 1.30322855795e-14
2 1.2512253469e-14
3 1.20130453784e-14
4 1.15336905293e-14
5 1.10735154852e-14
6 1.06315764249e-14
7 1.02073861085e-14
8 9.80018623413e-15
9 9.40908792307e-15
10 9.0337539647e-15
11 8.67334139097e-15
12 8.32728430302e-15
13 7.99516350986e-15
14 7.67627820874e-15
15 7.37005158691e-15
16 7.07614951654e-15
17 6.79385891218e-15
18 6.52286026909e-15
19 6.26266211429e-15
20 6.01290283038e-15
21 5.77310178e-15
22 5.54288765523e-15
23 5.32185042365e-15
24 5.10958948847e-15
25 4.90582077749e-15
26 4.710214722e-15
27 4.52243211205e-15
28 4.34215719741e-15
29 4.16907218759e-15
30 4.00289824361e-15
31 3.84329249702e-15
32 3.69009964495e-15
33 3.54302418243e-15
34 3.40177131132e-15
35 3.26612924084e-15
36 3.13591959389e-15
37 3.01091410892e-15
38 2.89095484629e-15
39 2.77574646937e-15
40 2.66509184e-15
41 2.5588625503e-15
42 2.45685615311e-15
43 2.35896800421e-15
44 2.2649907279e-15
45 2.17471383307e-15
46 2.08805930876e-15
47 2.00486261683e-15
48 1.924969943

459 5.65966336273e-22
460 5.56301493944e-22
461 5.4681224521e-22
462 5.35268197084e-22
463 5.27090498223e-22
464 5.19346075363e-22
465 5.09408731537e-22
466 4.99686572948e-22
467 4.90828183723e-22
468 4.82957225736e-22
469 4.7447716956e-22
470 4.66532723531e-22
471 4.61349856943e-22
472 4.51661460221e-22
473 4.45165686335e-22
474 4.37748657326e-22
475 4.28678921428e-22
476 4.20914151016e-22
477 4.14142404793e-22
478 4.07818053657e-22
479 4.01468221703e-22
480 3.94504150549e-22
481 3.90268306259e-22
482 3.84754139946e-22
483 3.77363330458e-22
484 3.72081121443e-22
485 3.65384646645e-22
486 3.58937409106e-22
487 3.52440214937e-22
488 3.478030985e-22
489 3.42656649021e-22
490 3.36482023816e-22
491 3.32003364349e-22
492 3.26795480303e-22
493 3.22849441593e-22
494 3.18630321356e-22
495 3.13384926063e-22
496 3.09332767044e-22
497 3.04809392302e-22
498 2.99871994939e-22
499 2.96734903423e-22


#### PyTorch version

In [15]:
import torch

In [17]:
# run on GPU
dtype = torch.cuda.FloatTensor

In [18]:
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

In [19]:
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

In [20]:
for t in range(500):
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)
    
    dy_pred = 2.0 * (y_pred - y)
    dw2 = h_relu.t().mm(dy_pred)
    dh_relu = dy_pred.mm(w2.t())
    dh = dh_relu.clone()
    dh[h < 0] = 0
    dw1 = x.t().mm(dh)
    
    w1 -= l_rate * dw1
    w2 -= l_rate * dw2

0 37481752.0
1 35900668.0
2 35247400.0
3 30252724.0
4 21056786.0
5 12072196.0
6 6298048.0
7 3407320.0
8 2075449.25
9 1435406.0
10 1088348.625
11 871592.0
12 719859.5
13 605315.0
14 514988.0625
15 441841.40625
16 381608.96875
17 331415.03125
18 289237.6875
19 253528.78125
20 223126.609375
21 197077.609375
22 174652.828125
23 155273.84375
24 138434.03125
25 123743.96875
26 110889.734375
27 99593.6171875
28 89650.6640625
29 80869.4453125
30 73095.2734375
31 66188.9140625
32 60031.92578125
33 54530.21875
34 49610.171875
35 45197.93359375
36 41236.4375
37 37667.3203125
38 34446.671875
39 31535.474609375
40 28898.640625
41 26509.45703125
42 24341.59765625
43 22370.515625
44 20575.9921875
45 18940.884765625
46 17448.595703125
47 16085.0966796875
48 14838.916015625
49 13698.37890625
50 12653.3369140625
51 11694.9189453125
52 10815.2431640625
53 10007.2119140625
54 9264.7919921875
55 8582.439453125
56 7954.10302734375
57 7374.9990234375
58 6841.16943359375
59 6348.56640625
60 5893.85400390625
6

471 3.400672721909359e-05
472 3.347924939589575e-05
473 3.315187495900318e-05
474 3.2804782676976174e-05
475 3.2227289921138436e-05
476 3.187752372468822e-05
477 3.15300676447805e-05
478 3.122558700852096e-05
479 3.0961367883719504e-05
480 3.074560663662851e-05
481 3.0309931389638223e-05
482 2.9964530767756514e-05
483 2.9684622859349474e-05
484 2.9206499675638042e-05
485 2.899366882047616e-05
486 2.8616799681913108e-05
487 2.836798012140207e-05
488 2.818137545546051e-05
489 2.796404078253545e-05
490 2.774213135126047e-05
491 2.7414056603447534e-05
492 2.7121288439957425e-05
493 2.686037441890221e-05
494 2.665920510480646e-05
495 2.6381158022559248e-05
496 2.6050627639051527e-05
497 2.5725832529133186e-05
498 2.5666971851023845e-05
499 2.533271981519647e-05


## Autograd

### Variables and autograd

In [21]:
from torch.autograd import Variable

In [22]:
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

In [23]:
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

In [25]:
for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    loss.backward()
    
    w1.data -= l_rate * w1.grad.data
    w2.data -= l_rate * w2.grad.data
    
    # manually zero gradients after updating weights
    w1.grad.data.zero_()
    w2.grad.data.zero_()

0 0.0008918086532503366
1 0.0008707644301466644
2 0.000849952339194715
3 0.0008299434557557106
4 0.0008082665153779089
5 0.0007902406505309045
6 0.0007714335806667805
7 0.0007531735463999212
8 0.000736358342692256
9 0.0007198908133432269
10 0.000702892430126667
11 0.0006873729871585965
12 0.0006711854366585612
13 0.0006555626168847084
14 0.0006407683831639588
15 0.0006260634982027113
16 0.0006126239313744009
17 0.0005981423310004175
18 0.0005851310561411083
19 0.0005719797336496413
20 0.0005599421565420926
21 0.0005478112725540996
22 0.0005366306868381798
23 0.0005249556852504611
24 0.0005138044944033027
25 0.0005020102835260332
26 0.0004923974629491568
27 0.000480982125736773
28 0.0004706983454525471
29 0.00046181498328223825
30 0.0004523643001448363
31 0.00044246422476135194
32 0.00043421966256573796
33 0.00042526950710453093
34 0.0004167452279943973
35 0.0004076504847034812
36 0.000400115386582911
37 0.0003921192546840757
38 0.0003840464341919869
39 0.00037674664054065943
40 0.00036

330 1.624300239200238e-05
331 1.611813968338538e-05
332 1.6080552086350508e-05
333 1.6023361240513623e-05
334 1.5937144780764356e-05
335 1.589963176229503e-05
336 1.5849607734708115e-05
337 1.5785528376000002e-05
338 1.5638521290384233e-05
339 1.550542401673738e-05
340 1.5368790627690032e-05
341 1.5390924090752378e-05
342 1.5208646800601855e-05
343 1.508814330009045e-05
344 1.505184354755329e-05
345 1.49479174069711e-05
346 1.4911550351826008e-05
347 1.4792484762438107e-05
348 1.469411017751554e-05
349 1.4621769878431223e-05
350 1.4588833437301219e-05
351 1.4468089830188546e-05
352 1.4404271496459842e-05
353 1.4375848877534736e-05
354 1.4310137885331642e-05
355 1.4232571629690938e-05
356 1.4147745787340682e-05
357 1.4058779015613254e-05
358 1.403526130161481e-05
359 1.4003062460687943e-05
360 1.3840744941262528e-05
361 1.3777588719676714e-05
362 1.3723176380153745e-05
363 1.3615400348498952e-05
364 1.3561004379880615e-05
365 1.3533030141843483e-05
366 1.3467341887007933e-05
367 1.33587

### Custom autograd functions

In [26]:
class MyReLU(torch.autograd.Function):
    def forward(self, input):
        self.save_for_backward(input)
        return input.clamp(min=0)
    
    def backward(self, grad_output):
        inp, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[inp < 0] = 0
        return grad_input

In [27]:
for t in range(500):
    relu = MyReLU()
    
    y_pred = relu(x.mm(w1)).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    
    loss.backward()
    
    w1.data -= l_rate * w1.grad.data
    w2.data -= l_rate * w2.grad.data
    
    w1.grad.data.zero_()
    w2.grad.data.zero_()

0 7.877531061240006e-06
1 7.856249794713221e-06
2 7.810985152900685e-06
3 7.790823474351782e-06
4 7.746594747004565e-06
5 7.74355521571124e-06
6 7.720849680481479e-06
7 7.710377758485265e-06
8 7.673384970985353e-06
9 7.662734788027592e-06
10 7.641466254426632e-06
11 7.601269317092374e-06
12 7.56986719352426e-06
13 7.57312454879866e-06
14 7.573107723146677e-06
15 7.555689990113024e-06
16 7.493642442568671e-06
17 7.464402187906671e-06
18 7.4198260335833766e-06
19 7.415371328534093e-06
20 7.407283646898577e-06
21 7.407648354273988e-06
22 7.356275546044344e-06
23 7.339926014537923e-06
24 7.340760021179449e-06
25 7.296692274394445e-06
26 7.239095793920569e-06
27 7.2105003710021265e-06
28 7.17982038622722e-06
29 7.147733867896022e-06
30 7.136844942579046e-06
31 7.102513791323872e-06
32 7.077278951328481e-06
33 7.017298230493907e-06
34 7.010160516074393e-06
35 7.024937076494098e-06
36 6.9935604187776335e-06
37 6.93881247570971e-06
38 6.900280823174398e-06
39 6.8995914261904545e-06
40 6.890888

370 3.5231580568506615e-06
371 3.529979039740283e-06
372 3.5275811569590587e-06
373 3.5319808375788853e-06
374 3.5133680285071023e-06
375 3.5092946291115368e-06
376 3.489018354230211e-06
377 3.490544486339786e-06
378 3.485079560050508e-06
379 3.4750826216622954e-06
380 3.4648958262550877e-06
381 3.4613647130754543e-06
382 3.4506633710407186e-06
383 3.4300585411983775e-06
384 3.4283596050954657e-06
385 3.4169438549724873e-06
386 3.402037236810429e-06
387 3.4040494938381016e-06
388 3.382293698450667e-06
389 3.3687751965771895e-06
390 3.374095740582561e-06
391 3.3717490168783115e-06
392 3.37333085553837e-06
393 3.36882158080698e-06
394 3.374185553184361e-06
395 3.374560947122518e-06
396 3.366476221344783e-06
397 3.3604405871301424e-06
398 3.3560786505404394e-06
399 3.339800741741783e-06
400 3.3300038921879604e-06
401 3.3388955671398435e-06
402 3.350563247295213e-06
403 3.351321993250167e-06
404 3.342101990710944e-06
405 3.3242113204323687e-06
406 3.324488034195383e-06
407 3.32170725414471

## Pytorch modules

### nn

In [33]:
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)
l_rate = 1e-4

In [34]:
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out)
)

loss_fn = torch.nn.MSELoss(size_average=False)

In [35]:
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    # zero gradients before running backprop
    model.zero_grad()
    
    loss.backward()
    
    #update weights using gradient descent
    for param in model.parameters():
        param.data -= l_rate * param.grad.data

0 677.706298828125
1 628.9736938476562
2 586.2559204101562
3 548.93115234375
4 515.4884643554688
5 485.2983093261719
6 457.73486328125
7 432.3388671875
8 408.7256774902344
9 386.5161437988281
10 365.687255859375
11 346.1318054199219
12 327.6876525878906
13 310.21624755859375
14 293.6852111816406
15 277.93975830078125
16 262.9082336425781
17 248.56341552734375
18 234.9402313232422
19 221.99424743652344
20 209.7062225341797
21 197.9618682861328
22 186.72474670410156
23 176.01364135742188
24 165.84043884277344
25 156.1518096923828
26 146.9368438720703
27 138.1688690185547
28 129.84854125976562
29 121.96017456054688
30 114.51398468017578
31 107.46424102783203
32 100.80532836914062
33 94.5455322265625
34 88.63758087158203
35 83.07050323486328
36 77.82722473144531
37 72.90787506103516
38 68.28022766113281
39 63.9354133605957
40 59.86166000366211
41 56.04523468017578
42 52.47010803222656
43 49.122554779052734
44 45.989566802978516
45 43.063655853271484
46 40.32310104370117
47 37.7552261352539

402 0.00012057117419317365
403 0.00011742235801648349
404 0.0001143579647759907
405 0.00011137212277390063
406 0.00010846913210116327
407 0.00010564154217718169
408 0.00010289011697750539
409 0.00010020809713751078
410 9.76032461039722e-05
411 9.506835340289399e-05
412 9.259585203835741e-05
413 9.01893072295934e-05
414 8.784539386397228e-05
415 8.556349348509684e-05
416 8.334479935001582e-05
417 8.117983816191554e-05
418 7.907625695224851e-05
419 7.702804578002542e-05
420 7.503369852202013e-05
421 7.309295324375853e-05
422 7.119898509699851e-05
423 6.935634155524895e-05
424 6.756274524377659e-05
425 6.581558409379795e-05
426 6.41138685750775e-05
427 6.245783879421651e-05
428 6.0846650740131736e-05
429 5.927741221967153e-05
430 5.77474529563915e-05
431 5.625821358989924e-05
432 5.4809115681564435e-05
433 5.3397448937175795e-05
434 5.20211506227497e-05
435 5.068674363428727e-05
436 4.9381094868294895e-05
437 4.810952668776736e-05
438 4.687574983108789e-05
439 4.566995266941376e-05
440 4.

### optim

In [36]:
optimizer = torch.optim.Adam(model.parameters(), lr=l_rate)

In [38]:
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

0 9.522657819616143e-06
1 0.04931570217013359
2 0.03335961699485779
3 0.023293310776352882
4 0.02058095484972
5 0.017825264483690262
6 0.016609590500593185
7 0.014975471422076225
8 0.011265886016190052
9 0.008477889001369476
10 0.007836366072297096
11 0.007485910318791866
12 0.006875071674585342
13 0.006385214626789093
14 0.005870559718459845
15 0.0053136409260332584
16 0.0047590541653335094
17 0.004097468685358763
18 0.0034378813579678535
19 0.0029390035197138786
20 0.002657304983586073
21 0.002578489016741514
22 0.002547033131122589
23 0.0024029735941439867
24 0.0021550192032009363
25 0.0018651385325938463
26 0.0015419587725773454
27 0.0012677102349698544
28 0.0011317543685436249
29 0.0011208205251023173
30 0.0011639350559562445
31 0.0011490958277136087
32 0.0009865507017821074
33 0.0007577264332212508
34 0.0006146169034764171
35 0.0005787246627733111
36 0.0005818786448799074
37 0.0005730177508667111
38 0.0005330974236130714
39 0.00045983868767507374
40 0.00038508576108142734
41 0.00

338 3.0667906685266644e-05
339 2.1918815036769956e-05
340 1.9223702111048624e-05
341 2.48059459408978e-05
342 3.5290508094476536e-05
343 4.5645505451830104e-05
344 5.3689447668148205e-05
345 6.201536598382518e-05
346 7.7054399298504e-05
347 0.00010715977987274528
348 0.00016202150436583906
349 0.0002541264984756708
350 0.00040067537338472903
351 0.0006205260287970304
352 0.0009167782845906913
353 0.0012359624961391091
354 0.0014303637435659766
355 0.0013171805767342448
356 0.0008946536108851433
357 0.0004664189473260194
358 0.0003788005269598216
359 0.0006439123535528779
360 0.0009282921673730016
361 0.0008881844696588814
362 0.0005291436682455242
363 0.0002035155484918505
364 0.00018626372911967337
365 0.0003784166765399277
366 0.0004857259336858988
367 0.00037985146627761424
368 0.0002132190129486844
369 0.00018597418966237456
370 0.0002958255645353347
371 0.00037898513255640864
372 0.00033566015190444887
373 0.00023971677001100034
374 0.00021849205950275064
375 0.0002872566110454499

### Custom nn modules

In [39]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [40]:
model = TwoLayerNet(D_in, H, D_out)

In [41]:
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [42]:
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 680.0205688476562
1 630.1744995117188
2 586.8239135742188
3 548.3598022460938
4 513.8623046875
5 482.53118896484375
6 453.8783264160156
7 427.51629638671875
8 403.0627136230469
9 380.3773498535156
10 359.1781921386719
11 339.4368591308594
12 321.051025390625
13 303.6702575683594
14 287.284912109375
15 271.6813659667969
16 256.89898681640625
17 242.8348388671875
18 229.48631286621094
19 216.7857666015625
20 204.75912475585938
21 193.33641052246094
22 182.4900360107422
23 172.21437072753906
24 162.46385192871094
25 153.17462158203125
26 144.3374786376953
27 135.9636993408203
28 128.03106689453125
29 120.51023864746094
30 113.38441467285156
31 106.66439056396484
32 100.32189178466797
33 94.33175659179688
34 88.67887878417969
35 83.34928131103516
36 78.32959747314453
37 73.60810852050781
38 69.17536926269531
39 65.0117416381836
40 61.097312927246094
41 57.41884231567383
42 53.96588897705078
43 50.719032287597656
44 47.67274475097656
45 44.814144134521484
46 42.13180923461914
47 39.617794

## Control FLow + Weight Sharing

In [43]:
import random

In [44]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        
        y_pred = self.output_linear(h_relu)
        return y_pred

In [45]:
model = DynamicNet(D_in, H, D_out)

In [46]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [47]:
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 667.3516235351562
1 657.2692260742188
2 647.53515625
3 653.5028686523438
4 650.7179565429688
5 604.3248901367188
6 640.3576049804688
7 652.4511108398438
8 632.6267700195312
9 556.6858520507812
10 635.7134399414062
11 511.6328125
12 462.05572509765625
13 616.6563720703125
14 519.0269165039062
15 506.7500915527344
16 487.2482604980469
17 248.26937866210938
18 436.56805419921875
19 405.7793273925781
20 562.5972900390625
21 601.907470703125
22 152.41273498535156
23 142.45083618164062
24 484.270263671875
25 250.39117431640625
26 108.4697494506836
27 95.09793853759766
28 201.49586486816406
29 66.140869140625
30 361.9932556152344
31 158.037841796875
32 315.7113342285156
33 402.5297546386719
34 358.3771667480469
35 309.770751953125
36 259.4389343261719
37 217.58103942871094
38 195.78826904296875
39 219.44308471679688
40 204.60133361816406
41 179.23680114746094
42 742.9996337890625
43 370.3631896972656
44 272.4890441894531
45 982.8380126953125
46 212.8589324951172
47 467.8999328613281
48 227.

417 1.7168396711349487
418 0.9909673929214478
419 0.1306210458278656
420 1.545019268989563
421 1.1572035551071167
422 0.9889354705810547
423 0.11437994986772537
424 0.9452104568481445
425 0.9604310393333435
426 0.9277898669242859
427 1.384697675704956
428 0.20066851377487183
429 0.2401019036769867
430 0.9576117992401123
431 0.2278365045785904
432 1.8125834465026855
433 0.14241166412830353
434 0.5586472749710083
435 1.0710761547088623
436 0.6599811911582947
437 0.8032755255699158
438 0.901730477809906
439 0.7978502511978149
440 1.3673429489135742
441 0.43701261281967163
442 1.1252562999725342
443 0.2057044953107834
444 1.084328055381775
445 1.004183292388916
446 0.9524351358413696
447 0.12499898672103882
448 0.8037099242210388
449 1.552160382270813
450 1.0779005289077759
451 0.11325645446777344
452 1.0666747093200684
453 0.8537534475326538
454 1.2321127653121948
455 0.7232624292373657
456 0.9656876921653748
457 0.11947818845510483
458 0.6103894710540771
459 0.6500977277755737
460 1.2470