In [1]:
#import
import torch
from torch.autograd import Variable

In [2]:
#N is batch size; D_in is input dimension;
#H is hidden dimension; D is output Dimension
N, D_in, H , D_out = 64,1000,100,10

In [3]:
#Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N,D_in))
y = Variable(torch.randn(N,D_out),requires_grad=False)

In [4]:
#Use the nn package to define our model as a sequence of Layers. nn.Sequential
#is a Module which contains other Modules , and applies them in sequence to 
# produce its output. Each linear module computes output from input using a 
#Linear function, and holds internal Variables for its weight and bias
model = torch.nn.Sequential(
    torch.nn.Linear(D_in,H),
    torch.nn.ReLU(),
    torch.nn.Linear(H,D_out),
)

In [5]:
#NN package also contains definition of pupular loss function in this
# case we will use Mean squared Error (MSE) as out Loss funciton
loss_fn = torch.nn.MSELoss(size_average=False)

In [6]:
#learning Rate
learning_rate = 1e-4

In [8]:
for t in range(5000):
    #Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__operation so you can call them like function. when
    #doing so you pass a variable of input data to the module and its produces
    #a variable of output data
    y_pred = model(x)
    
    #Compute and print Loss, We pass Variables containing the predicted and true
    #values of y , and the loss funciton returns a Variable containing the 
    #LOSS.
    loss = loss_fn(y_pred,y)
    print(t, loss.data[0])
    
    #Zero the gradients before running the backward pass.
    model.zero_grad()
    
    #Backward pass: compute gradient of the Loss with respect to all the learnable
    #parameter of the model. Internally, the parameters of each Module are stored
    #in variables with requires_grad=True, so this call will compute gradient for
    #all Learning parameters in model.
    loss.backward()
    
    #Update the wieght using gradient descent. Each parameter is Variable, so 
    #we can acess its data and gradients like we did before
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data
    

0 654.5167846679688
1 605.8529663085938
2 563.2234497070312
3 525.3441162109375
4 491.6210021972656
5 461.2608642578125
6 433.6844482421875
7 408.57781982421875
8 385.2940368652344
9 363.69171142578125
10 343.46484375
11 324.5604248046875
12 306.83905029296875
13 290.1501770019531
14 274.4251403808594
15 259.5061340332031
16 245.35324096679688
17 231.89918518066406
18 219.07740783691406
19 206.93563842773438
20 195.4456024169922
21 184.5517120361328
22 174.21575927734375
23 164.3877716064453
24 155.0297088623047
25 146.13275146484375
26 137.66827392578125
27 129.6401824951172
28 121.99351501464844
29 114.75846099853516
30 107.88538360595703
31 101.39411926269531
32 95.27053833007812
33 89.50056457519531
34 84.07071685791016
35 78.95783233642578
36 74.15543365478516
37 69.64448547363281
38 65.40620422363281
39 61.426753997802734
40 57.69415283203125
41 54.186767578125
42 50.9011344909668
43 47.82767868041992
44 44.944034576416016
45 42.24866485595703
46 39.719669342041016
47 37.34419250

371 0.0004174032073933631
372 0.0004067549598403275
373 0.0003963976923841983
374 0.00038630544440820813
375 0.00037647594581358135
376 0.00036689190892502666
377 0.00035756570287048817
378 0.0003484733751975
379 0.0003396290121600032
380 0.00033100659493356943
381 0.0003226061526220292
382 0.00031441630562767386
383 0.0003064440388698131
384 0.0002986760518979281
385 0.000291117699816823
386 0.00028373877285048366
387 0.0002765553363133222
388 0.0002695524599403143
389 0.00026273398543708026
390 0.0002560920547693968
391 0.00024962457246147096
392 0.00024331797612830997
393 0.00023718128795735538
394 0.00023118896933738142
395 0.00022535797324962914
396 0.00021967300563119352
397 0.00021413627837318927
398 0.00020873959874734282
399 0.00020348530961200595
400 0.00019835980492644012
401 0.00019336405966896564
402 0.00018850491323973984
403 0.00018376408843323588
404 0.00017914464115165174
405 0.0001746439520502463
406 0.00017025355191435665
407 0.00016598220099695027
408 0.000161813833

756 3.74071547071253e-08
757 3.664120029611695e-08
758 3.582488972142528e-08
759 3.5031785472483534e-08
760 3.4345525534718035e-08
761 3.3570707103081077e-08
762 3.2874325484044675e-08
763 3.216122834714952e-08
764 3.148015181864139e-08
765 3.0843224863019714e-08
766 3.0157707442413084e-08
767 2.9522920996782887e-08
768 2.894478967618852e-08
769 2.831543177705953e-08
770 2.7672854230331723e-08
771 2.7119686052401448e-08
772 2.65726640691355e-08
773 2.600170923017231e-08
774 2.539512244936759e-08
775 2.4883281213305963e-08
776 2.4365627737665818e-08
777 2.38962929444142e-08
778 2.3410599681028543e-08
779 2.2870475291369985e-08
780 2.243628216547222e-08
781 2.195788795233966e-08
782 2.148760813724948e-08
783 2.1083000234511928e-08
784 2.0647034304488443e-08
785 2.025273282413309e-08
786 1.982176023318516e-08
787 1.9448108901087835e-08
788 1.902336421721884e-08
789 1.8647650534830973e-08
790 1.8291547831950083e-08
791 1.7921340855764356e-08
792 1.7544596886409636e-08
793 1.718426823060781

1099 6.878744240879087e-10
1100 6.862850288058553e-10
1101 6.848439593198918e-10
1102 6.796438412060013e-10
1103 6.78565759137939e-10
1104 6.736512458971333e-10
1105 6.706978861181767e-10
1106 6.700355270616853e-10
1107 6.684849895854938e-10
1108 6.611862723993056e-10
1109 6.604345403893319e-10
1110 6.593515733399613e-10
1111 6.542012487287252e-10
1112 6.496895244012535e-10
1113 6.467929525300065e-10
1114 6.434034416358259e-10
1115 6.381596362459163e-10
1116 6.348692682678347e-10
1117 6.318641165847794e-10
1118 6.286214881967567e-10
1119 6.264745944228878e-10
1120 6.299784027774535e-10
1121 6.27931207031196e-10
1122 6.217973913535957e-10
1123 6.174843414363806e-10
1124 6.177617861702345e-10
1125 6.193936474829798e-10
1126 6.100375760098586e-10
1127 6.072282676683471e-10
1128 6.053834655794788e-10
1129 6.008759045883494e-10
1130 6.006157238225285e-10
1131 5.955386184197664e-10
1132 5.939623792805548e-10
1133 5.938445291064909e-10
1134 5.932803137653764e-10
1135 5.869243424605486e-10
113

1450 2.372788632243328e-10
1451 2.360185380467783e-10
1452 2.3581864239119454e-10
1453 2.3440699381538366e-10
1454 2.3486179667742135e-10
1455 2.3386215186604886e-10
1456 2.3209743849061937e-10
1457 2.3092197598550968e-10
1458 2.3075037713926605e-10
1459 2.3145464711493702e-10
1460 2.3095135526229882e-10
1461 2.3072892207931517e-10
1462 2.3144250405060518e-10
1463 2.3228656498286426e-10
1464 2.312677965798926e-10
1465 2.3121693448757696e-10
1466 2.2977096614251735e-10
1467 2.2985495451433025e-10
1468 2.2969542934347942e-10
1469 2.2879914629569953e-10
1470 2.2930449206093328e-10
1471 2.271069859949293e-10
1472 2.2750282213657158e-10
1473 2.2682315747868387e-10
1474 2.2505772245828837e-10
1475 2.256031472747111e-10
1476 2.2515236897113766e-10
1477 2.239247398616584e-10
1478 2.246933472616064e-10
1479 2.2282538314488676e-10
1480 2.219430195182781e-10
1481 2.2124982401727777e-10
1482 2.216229699758543e-10
1483 2.2068184779566735e-10
1484 2.206608090693507e-10
1485 2.1979655595583125e-10
14

1798 1.2458747922217128e-10
1799 1.2425566131568644e-10
1800 1.241437092014408e-10
1801 1.24751778352028e-10
1802 1.245546582540058e-10
1803 1.2481139732845037e-10
1804 1.2478938715698717e-10
1805 1.2429515749978748e-10
1806 1.244958580670641e-10
1807 1.2531903292867241e-10
1808 1.2501111257279263e-10
1809 1.248643966000884e-10
1810 1.2459891451932492e-10
1811 1.246344139005373e-10
1812 1.2434359097923675e-10
1813 1.2428172380118951e-10
1814 1.24257354405799e-10
1815 1.2432116447413932e-10
1816 1.234390090143478e-10
1817 1.2272807770052907e-10
1818 1.2153125727998315e-10
1819 1.2079992561808695e-10
1820 1.2100563606676218e-10
1821 1.2131490256805932e-10
1822 1.2103848479050328e-10
1823 1.213347478046245e-10
1824 1.208129707386263e-10
1825 1.2061537879581863e-10
1826 1.2054628129032352e-10
1827 1.2077419619949126e-10
1828 1.2019439610266858e-10
1829 1.2002369931263246e-10
1830 1.1990541892714646e-10
1831 1.2029899298937607e-10
1832 1.200662208544756e-10
1833 1.1967646318389313e-10
1834 

2170 8.288228436903466e-11
2171 8.29267349233831e-11
2172 8.267565798636412e-11
2173 8.256129113703992e-11
2174 8.258152495166371e-11
2175 8.278065732891804e-11
2176 8.289917363679677e-11
2177 8.321464350924401e-11
2178 8.29837448756976e-11
2179 8.245147620211668e-11
2180 8.237192872240229e-11
2181 8.221641423222792e-11
2182 8.233670689694605e-11
2183 8.228843995095048e-11
2184 8.210836177635628e-11
2185 8.215832181246441e-11
2186 8.248231264662564e-11
2187 8.233065618146185e-11
2188 8.17901718574987e-11
2189 8.199718681822787e-11
2190 8.198060980069144e-11
2191 8.20847834148708e-11
2192 8.189177114203972e-11
2193 8.190048639278302e-11
2194 8.183587141274984e-11
2195 8.194266098993097e-11
2196 8.200010115366752e-11
2197 8.130968121022875e-11
2198 8.14449341302037e-11
2199 8.140831064817888e-11
2200 8.143262453241817e-11
2201 8.10210093460384e-11
2202 8.113014426935905e-11
2203 8.10710942822368e-11
2204 8.083247959866924e-11
2205 8.064071632674086e-11
2206 8.131029183289229e-11
2207 8.1

2492 6.323553486398126e-11
2493 6.316697859221065e-11
2494 6.245379213787317e-11
2495 6.232400706629448e-11
2496 6.238845551287397e-11
2497 6.241518413219183e-11
2498 6.25950402621811e-11
2499 6.206124503194133e-11
2500 6.157591103672644e-11
2501 6.149271369881859e-11
2502 6.142610031734108e-11
2503 6.15879153231802e-11
2504 6.175455979917643e-11
2505 6.182275524846403e-11
2506 6.224103177299156e-11
2507 6.179201594846973e-11
2508 6.186073875369402e-11
2509 6.184724954394483e-11
2510 6.157035992160331e-11
2511 6.154859955032066e-11
2512 6.137695907071361e-11
2513 6.119918460889551e-11
2514 6.119220408162818e-11
2515 6.127395812960401e-11
2516 6.115338790912972e-11
2517 6.100966953859199e-11
2518 6.097089499945696e-11
2519 6.12149775314208e-11
2520 6.145234321408566e-11
2521 6.135150720787408e-11
2522 6.115966066921885e-11
2523 6.055758672296463e-11
2524 6.041567246484192e-11
2525 6.024719612085505e-11
2526 6.047714412593663e-11
2527 6.052377349297089e-11
2528 6.056185414271553e-11
2529

2810 4.9515509747966036e-11
2811 4.943082748676275e-11
2812 4.924986113374885e-11
2813 4.928205760146298e-11
2814 4.941450720830076e-11
2815 4.9157823645007426e-11
2816 4.934545133616908e-11
2817 5.0346154328861914e-11
2818 5.0114894872832494e-11
2819 4.959887708877453e-11
2820 4.9486009040533574e-11
2821 4.9137010432742656e-11
2822 4.922760463155207e-11
2823 4.933776304172355e-11
2824 4.927381419550514e-11
2825 4.88871235160282e-11
2826 4.904699563157422e-11
2827 4.912743128970831e-11
2828 4.918860457836516e-11
2829 4.9085187303621325e-11
2830 4.895612387700865e-11
2831 4.91409066216697e-11
2832 4.9299946069947254e-11
2833 4.919158830274384e-11
2834 4.99111238450034e-11
2835 4.976901529785138e-11
2836 4.9751529285213536e-11
2837 4.974736594887119e-11
2838 4.968563754870203e-11
2839 4.9822972136848165e-11
2840 4.976701689640706e-11
2841 4.9754263209411675e-11
2842 4.9708855087704507e-11
2843 4.9624034048623145e-11
2844 4.951306725731186e-11
2845 4.956580978987546e-11
2846 4.96041679953

3162 3.9269955343090857e-11
3163 3.923392860594177e-11
3164 3.8736507057546277e-11
3165 3.877804674590202e-11
3166 3.877602405832903e-11
3167 3.8783601330472095e-11
3168 3.892749664280437e-11
3169 3.872164741625106e-11
3170 3.8762701382033526e-11
3171 3.875810783426914e-11
3172 3.8745701091968954e-11
3173 3.8755026965375805e-11
3174 3.876191034812848e-11
3175 3.8855522965786093e-11
3176 3.87773910204281e-11
3177 3.869545656121076e-11
3178 3.8828905368770705e-11
3179 3.88638773940464e-11
3180 3.8837731641816475e-11
3181 3.884001106846391e-11
3182 3.87929376122198e-11
3183 3.7956717630072134e-11
3184 3.810693080530392e-11
3185 3.817898427960209e-11
3186 3.8021610165861475e-11
3187 3.800029388378867e-11
3188 3.8125041318393116e-11
3189 3.761985514771915e-11
3190 3.7597595176075416e-11
3191 3.758965014255544e-11
3192 3.7593227142362906e-11
3193 3.7677493069931955e-11
3194 3.7737056535203095e-11
3195 3.778882068372624e-11
3196 3.802485409876155e-11
3197 3.8003034746880715e-11
3198 3.7939141

3532 3.2316156695078035e-11
3533 3.238141699224428e-11
3534 3.230947454024857e-11
3535 3.228504963370682e-11
3536 3.233335127417192e-11
3537 3.232691198062909e-11
3538 3.226901384989489e-11
3539 3.240534923731886e-11
3540 3.257499131548158e-11
3541 3.257216024676879e-11
3542 3.252059038727495e-11
3543 3.243818408327215e-11
3544 3.245150675956765e-11
3545 3.246299756787252e-11
3546 3.239359475104564e-11
3547 3.186801517118809e-11
3548 3.178516477797544e-11
3549 3.1719661619522554e-11
3550 3.171905099685901e-11
3551 3.171077983532555e-11
3552 3.1621295859540766e-11
3553 3.161414185992584e-11
3554 3.159970896060571e-11
3555 3.1585775661646665e-11
3556 3.1587746307515374e-11
3557 3.1642036213419544e-11
3558 3.172569151832505e-11
3559 3.173299123471196e-11
3560 3.16614928719261e-11
3561 3.169191298280083e-11
3562 3.177673402188219e-11
3563 3.1901356556396365e-11
3564 3.209586763031069e-11
3565 3.196425069074138e-11
3566 3.193649511512575e-11
3567 3.198800946346836e-11
3568 3.203316778499499

3890 2.7369823421552475e-11
3891 2.7370690783290463e-11
3892 2.7376463943018514e-11
3893 2.7447164333005425e-11
3894 2.7420796536170577e-11
3895 2.7437491514903378e-11
3896 2.7464140336941334e-11
3897 2.742067857497421e-11
3898 2.735627349648162e-11
3899 2.740551882651765e-11
3900 2.7441023411900467e-11
3901 2.7512103706328617e-11
3902 2.7557851833837077e-11
3903 2.7539821118027774e-11
3904 2.7647995004542736e-11
3905 2.759277181740849e-11
3906 2.7595249002532185e-11
3907 2.7547773090441652e-11
3908 2.7529787477442724e-11
3909 2.759591513634696e-11
3910 2.7639161792603062e-11
3911 2.7637928404211642e-11
3912 2.7564487151132688e-11
3913 2.7628880086560947e-11
3914 2.7624772261369834e-11
3915 2.745246564794801e-11
3916 2.7545120698246883e-11
3917 2.7467628865851523e-11
3918 2.749577301952577e-11
3919 2.7452391054838543e-11
3920 2.7416419828840688e-11
3921 2.7407482533492455e-11
3922 2.753515818132435e-11
3923 2.7428965349018952e-11
3924 2.7419084364099788e-11
3925 2.7411729136561647e-11


4257 2.3530278420169637e-11
4258 2.351851005610861e-11
4259 2.349769337439689e-11
4260 2.351412467516134e-11
4261 2.360607889717592e-11
4262 2.3587926750723298e-11
4263 2.354213005095751e-11
4264 2.352625039225842e-11
4265 2.3316584774057958e-11
4266 2.3351223732426263e-11
4267 2.354928752001939e-11
4268 2.34340463700633e-11
4269 2.3412209670947703e-11
4270 2.320060116245415e-11
4271 2.3360834100483174e-11
4272 2.3308431573720867e-11
4273 2.329488685282044e-11
4274 2.3150890926526557e-11
4275 2.293789463925222e-11
4276 2.3164491158578215e-11
4277 2.3081252187306944e-11
4278 2.3125883152896876e-11
4279 2.293764483907168e-11
4280 2.31198879485639e-11
4281 2.3053829678598703e-11
4282 2.3054717857018403e-11
4283 2.3270017857068837e-11
4284 2.3219169642541004e-11
4285 2.3161299267382418e-11
4286 2.3096739798500465e-11
4287 2.3014333494497663e-11
4288 2.314311936535418e-11
4289 2.3174039076589992e-11
4290 2.3238050372853536e-11
4291 2.3196597420671594e-11
4292 2.320392489263412e-11
4293 2.32

4557 2.0560228866650654e-11
4558 2.072076711601145e-11
4559 2.0899179956068714e-11
4560 2.097439756598707e-11
4561 2.0983851808931142e-11
4562 2.0687311239053763e-11
4563 2.080366261203448e-11
4564 2.0616312476628984e-11
4565 2.0626804084211692e-11
4566 2.0597494196361588e-11
4567 2.0551419940839644e-11
4568 2.0397876096533984e-11
4569 2.0490690741392648e-11
4570 2.0352467974826816e-11
4571 2.038977146845422e-11
4572 2.0317162882643736e-11
4573 2.0497858618795384e-11
4574 2.0437906575465625e-11
4575 2.0618706395025832e-11
4576 2.0575574230519145e-11
4577 2.0666612518538408e-11
4578 2.0545265141946878e-11
4579 2.063763569759569e-11
4580 2.0548595811020753e-11
4581 2.058194066567598e-11
4582 2.0495801236752875e-11
4583 2.069997125098144e-11
4584 2.049122503622325e-11
4585 2.064438030247029e-11
4586 2.0674439590862015e-11
4587 2.06021363163833e-11
4588 2.0661810803956904e-11
4589 2.0422224675242795e-11
4590 2.048683965527598e-11
4591 2.0387363672269565e-11
4592 2.059541946708432e-11
4593 

4862 1.9002584228378083e-11
4863 1.9088237934727914e-11
4864 1.909473273942197e-11
4865 1.8866415374407808e-11
4866 1.8948349833625144e-11
4867 1.8909152021961972e-11
4868 1.8842427618182e-11
4869 1.8812229551912196e-11
4870 1.881077932308628e-11
4871 1.8754775510387844e-11
4872 1.8766765919053796e-11
4873 1.8790122235934348e-11
4874 1.8751930563887242e-11
4875 1.867188348381177e-11
4876 1.8629250919666163e-11
4877 1.854309761295525e-11
4878 1.854925935074192e-11
4879 1.855968156938559e-11
4880 1.857300424568109e-11
4881 1.858362422280102e-11
4882 1.8658314476782678e-11
4883 1.8652097227844777e-11
4884 1.863588797168525e-11
4885 1.8534635631839436e-11
4886 1.8627672321303024e-11
4887 1.8646285904200255e-11
4888 1.8610037122446244e-11
4889 1.8566516379880937e-11
4890 1.8493297171406908e-11
4891 1.8567682114056794e-11
4892 1.8555844361056728e-11
4893 1.8596769957301973e-11
4894 1.8619251973550632e-11
4895 1.8448693961392593e-11
4896 1.8417607716703088e-11
4897 1.850853498241989e-11
4898 

# Pytorch: Custom nn Module

In [None]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self,D_in,H,Dout):
        '''
        In the constructor we instantiate two nn.Linear modules and assign then as member variable
        '''
        super(TwoLayerNet,self).__init__()
        self.linear1 = torch.nn.Linear(D_in,H)
        self.linera2 = torch.nn.Linear(H,D_out)
        
    def forward(self,x):
        '''
        In the forward function we accept a Variable of input data and we must return 
        a Variable of output data. we can use Modules defined in the constructor as wee
        as arbitrary operators on Variables
        '''
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
        
        
        
        