[View in Colaboratory](https://colab.research.google.com/github/christopher-ell/Deep_Learning_Begin/blob/master/pytorch_tutorials/OPT2_Learning_Pytorch_with_Examples.ipynb)

Official Pytorch Tutorials - Learning Pytorch with Examples

Source: http://pytorch.org/tutorials/beginner/pytorch_with_examples.html

In [1]:
## File created in Google colaboratory so need to download libraries and data on begin 
!pip install torch



In [0]:
import numpy as np
import torch

**Tensors**

Warm-up: Numpy

- Create two layer network using only numpy and no PyTorch

In [3]:
# n is batch size; D_in is input dimensions
# H is hidden dimension; D_out is output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
## Create data that will be used to train the model randomly
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialise weights
## Initialise weights, which need to start somewhere
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

## Initialise the learning rate which is how much the parameter
## weights will change as the model trains
learning_rate = 1e-6

for t in range(500):
  # Forward pass: Compute predicted y
  ## Multiply data (x) by weights (w1) of first layer
  h = x.dot(w1)
  ## Apply a relu function by making all negative values zero
  h_relu = np.maximum(h,0)
  ## Calculate the second layer by multiplying the values 
  ## from the 2nd layer by the weights from that layer
  y_pred = h_relu.dot(w2)
  
  # Compute and print loss
  ## Calculate the squared loss of the error based on above 
  ## weights
  loss = np.square(y_pred - y)
  ## Output loss for each example and iteration number
  print(t, loss, "\n")
  
  # Backprop to compute gradient of w1 and w2 wrt loss
  ##   - Since there is no backpropagation for training
  ##     in Numpy it is done manually below
  
  ## CALCULATE WEIGHTS FOR LAYER 2
  ## Calculate the gradient of the loss function for each 
  ## example, by using the differential of the square error
  grad_y_pred = 2.0 * (y_pred - y)
  ## Takes the dot product of h_relu and the transpose (T)
  ## of the gradient (grad_y_pred)
  ## Apply the gradient of the loss function at the current
  ## weights to the results after the first layer
  grad_w2 = h_relu.T.dot(grad_y_pred)
  
  ## CALCULATE WEIGHTS FOR LAYER 1
  ## Takes the dot product of the gradient of the loss function
  ## at the current weights and multiply that by the transpose (T)
  ## of the 2nd layers weights.
  ## Takes the gradient back one level so they can be applied
  ## to second layer
  grad_h_relu = grad_y_pred.dot(w2.T)
  ## Copy above result
  grad_h = grad_h_relu.copy()
  ## Apply relu function by Making all values less than zero 
  ## equal to zero
  grad_h[h<0] = 0
  ## Multiply inputs for the model (x) by the above model value
  ## by the gradient of the loss function at layer 1 for the 
  ## current weights for the 1st layers gradient
  grad_w1 = x.T.dot(grad_h)
  
  # Update weights
  ## Use two above calculated gradients to adjust weights of 
  ## w1 and w2.
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2

(0, array([[3.16921061e+04, 5.58389901e+04, 1.01185053e+03, 2.02534563e+05,
        6.17028767e+04, 1.81585685e+02, 6.92076492e+04, 8.28405888e+03,
        8.83162750e+03, 7.87949086e-03],
       [7.00225570e+03, 4.59350219e+04, 1.41580439e+05, 1.24747748e+05,
        2.30882977e+04, 1.77174542e+04, 6.84227584e+03, 4.62522309e+05,
        1.09711669e+05, 1.30196259e+04],
       [8.73754375e+04, 2.40071947e+02, 1.75977519e+04, 2.17441112e+04,
        2.46688342e+05, 1.57854613e+05, 1.22497891e+03, 1.09034432e+05,
        1.19952852e+01, 3.28869700e+05],
       [1.81688206e+04, 3.38091670e+03, 6.52379671e+04, 9.25276304e+04,
        1.00012762e+04, 1.53964799e+05, 3.60611923e+03, 2.00756859e+05,
        4.75164092e+03, 1.06107541e+03],
       [2.01388157e+05, 3.11244249e+04, 1.07157047e+03, 7.74608081e+04,
        1.60262262e+02, 3.72335921e+04, 6.53549856e+04, 1.38594469e+02,
        1.89157641e+05, 9.00789388e+04],
       [2.88001558e+05, 3.89358035e+04, 1.78671290e+02, 2.09249825e+05,

(21, array([[8.07610417e+01, 5.19593918e+02, 9.96498466e-01, 4.09625028e+02,
        2.61235300e+02, 1.67463460e+02, 2.21811590e+02, 8.14976250e+01,
        6.50121918e+02, 5.11186647e+01],
       [1.55871159e+02, 5.69053999e+01, 2.79476791e+02, 1.16400094e+01,
        1.83286199e-01, 1.88940319e+02, 1.42281934e+02, 2.24712156e+02,
        6.74317580e+01, 1.39716937e+02],
       [1.38978142e+02, 3.63361007e+00, 6.69268268e+01, 4.25282102e+00,
        5.23875050e+01, 2.30711595e+02, 1.09053062e+01, 2.11996706e+02,
        5.94998285e+01, 3.52572957e+02],
       [1.77537408e+02, 1.39544763e+01, 1.24405885e+03, 1.98748382e+01,
        1.21261771e+02, 1.82858384e+02, 4.14564737e+01, 3.72623036e+02,
        2.33424698e+02, 1.41405562e+02],
       [1.51435029e+02, 6.33788339e+01, 6.12947627e+00, 6.68805234e+02,
        3.22143005e+02, 2.15283611e+02, 5.56534388e+02, 9.06141160e+02,
        9.27382132e-01, 3.62751164e+01],
       [1.26959050e+03, 1.43855521e+01, 9.64632638e+01, 4.70083694e+01

(36, array([[5.49689225e+01, 7.24502601e+01, 9.13526963e+00, 7.72055700e+00,
        9.05963215e+00, 6.91994652e+01, 1.88786052e+01, 4.55092772e+01,
        2.98934150e+02, 3.88157084e+01],
       [3.85727875e+00, 6.57852138e-01, 2.59882714e+01, 4.16046884e+00,
        1.99159909e+00, 1.05437587e+01, 1.51546963e+01, 1.03803216e+01,
        7.85123095e+00, 1.17495696e+01],
       [3.29874271e+01, 4.21664236e+00, 9.42895130e+00, 9.28806241e+00,
        2.18759914e-03, 2.58599843e+01, 1.73370548e+00, 2.96249628e+01,
        4.95470065e+00, 3.36236369e+01],
       [7.66495234e+01, 2.61065908e-01, 1.22797175e+02, 2.54052585e-02,
        2.75009911e+01, 4.19611645e+00, 9.94514004e+00, 1.14747309e+02,
        5.02914411e+01, 3.52722402e+01],
       [1.41960164e+01, 3.53187121e+00, 1.42910961e+00, 9.35360837e+01,
        6.47885311e+01, 2.16360256e+01, 2.88054692e+01, 2.12342691e+02,
        1.95543805e+01, 1.44808850e+00],
       [2.10362942e+02, 3.70962439e+00, 8.59961669e+00, 3.12609390e+00


(53, array([[3.06545905e+01, 9.17805634e+00, 5.77828305e+00, 1.77937074e+00,
        3.43095278e-01, 2.94723753e+01, 3.05398320e+00, 1.38371389e+01,
        1.12041380e+02, 2.32815783e+01],
       [5.72538292e-06, 8.28654712e-03, 1.89267336e+00, 6.64242892e-01,
        6.16204146e-01, 9.25972440e-01, 1.76089640e+00, 3.24315438e-01,
        7.02119707e-01, 5.67931980e-01],
       [6.75778028e+00, 1.00245457e+00, 1.12130039e+00, 4.83580448e+00,
        5.71275387e-01, 3.73688873e+00, 1.00488174e+00, 3.33448782e+00,
        1.15980208e+00, 5.21996835e+00],
       [2.28633332e+01, 2.30288709e-01, 1.15987230e+01, 2.58183368e-01,
        7.44961485e+00, 9.25562734e-02, 1.09762731e+00, 3.32444132e+01,
        9.75607021e+00, 1.02875318e+01],
       [1.71222060e+00, 2.32914489e-01, 1.34825231e+00, 1.66911549e+01,
        1.12809493e+01, 2.07541204e+00, 2.28529947e+00, 5.07481521e+01,
        8.27866691e+00, 3.43590497e+00],
       [3.59933102e+01, 6.57353492e-01, 4.72064494e-01, 5.13608097e+0

(64, array([[1.90045009e+01, 2.61006028e+00, 3.23452466e+00, 3.00426465e+00,
        1.11878369e+00, 1.64460833e+01, 1.24839255e+00, 6.41269011e+00,
        5.76265727e+01, 1.47416113e+01],
       [2.78981549e-02, 6.05245714e-03, 3.47113147e-01, 1.55418272e-01,
        2.86775235e-01, 3.00113811e-01, 4.36499037e-01, 5.02427085e-02,
        1.18428356e-01, 5.84561710e-02],
       [2.48735821e+00, 3.21993230e-01, 2.91484774e-01, 2.71054123e+00,
        4.49452843e-01, 1.33152111e+00, 6.41052247e-01, 8.24052439e-01,
        6.15124440e-01, 1.90760143e+00],
       [1.00783925e+01, 2.98851651e-01, 2.84797626e+00, 3.07073515e-01,
        3.67423976e+00, 2.74301192e-01, 2.45837131e-01, 1.60657080e+01,
        3.82312042e+00, 4.95649348e+00],
       [5.86526872e-01, 6.34359807e-02, 9.42601467e-01, 6.30150031e+00,
        3.73741666e+00, 5.29011583e-01, 8.07575344e-01, 2.05389210e+01,
        3.97629992e+00, 2.33158498e+00],
       [1.25654112e+01, 2.27151584e-01, 5.69331220e-02, 3.14340221e+00

(83, array([[7.37040363e+00, 3.22888229e-01, 1.03590514e+00, 2.09220548e+00,
        9.66920019e-01, 5.66758967e+00, 3.06973281e-01, 1.81138719e+00,
        1.78361841e+01, 5.89644555e+00],
       [1.53441953e-02, 5.73703442e-05, 1.76114787e-02, 1.75562034e-03,
        8.41902091e-02, 5.86913223e-02, 2.49739338e-02, 1.43110766e-02,
        2.22126591e-04, 1.98410957e-05],
       [4.57333967e-01, 3.83052802e-02, 3.00598466e-02, 8.80520491e-01,
        1.89522107e-01, 2.63056640e-01, 2.38181240e-01, 7.01780085e-02,
        2.30810386e-01, 4.02566303e-01],
       [2.47875801e+00, 1.71102830e-01, 3.13107362e-01, 2.20845328e-01,
        1.19725632e+00, 1.87374535e-01, 1.52423968e-02, 4.96621034e+00,
        9.31966289e-01, 1.45291959e+00],
       [1.35019902e-01, 1.19597136e-02, 3.36191234e-01, 1.40097938e+00,
        6.05487059e-01, 5.43181942e-02, 2.51843975e-01, 4.36045330e+00,
        1.03368830e+00, 9.11389217e-01],
       [2.30079612e+00, 3.59618572e-02, 2.69431985e-04, 1.01564472e+00

(100, array([[2.93614756e+00, 5.16276783e-02, 3.57783976e-01, 1.04430842e+00,
        5.11088568e-01, 2.10883318e+00, 8.56900672e-02, 6.18027900e-01,
        6.17626256e+00, 2.39853694e+00],
       [4.21435004e-03, 1.97847872e-04, 1.04068632e-03, 3.52624637e-03,
        3.13615427e-02, 1.62376169e-02, 1.81185808e-04, 1.28196037e-02,
        3.96749357e-03, 4.00077475e-04],
       [1.03226338e-01, 4.82657858e-03, 3.46745248e-03, 3.01955781e-01,
        7.31247684e-02, 6.83672790e-02, 8.42358300e-02, 5.72442425e-03,
        9.35035395e-02, 1.10062028e-01],
       [7.40045109e-01, 7.65739231e-02, 5.39473992e-02, 1.22674489e-01,
        4.55414316e-01, 8.47271781e-02, 4.53693292e-04, 1.82241785e+00,
        3.01853123e-01, 4.95355361e-01],
       [4.82965625e-02, 2.49045566e-03, 1.10210490e-01, 4.20862526e-01,
        1.27560650e-01, 6.83933462e-03, 1.03459275e-01, 1.09399825e+00,
        3.05912067e-01, 3.62734324e-01],
       [5.60698654e-01, 6.59358332e-03, 1.41363492e-04, 3.38355989e-0

(118, array([[1.05646892e+00, 8.24196713e-03, 1.17096880e-01, 4.28302928e-01,
        2.18576029e-01, 7.26492949e-01, 2.22048275e-02, 2.03050089e-01,
        2.01045453e+00, 8.83306704e-01],
       [8.04435428e-04, 3.27320830e-04, 4.35362556e-05, 6.07940559e-03,
        1.19057823e-02, 4.36129210e-03, 1.03463458e-03, 9.73502088e-03,
        5.07805374e-03, 3.09550201e-04],
       [2.17767294e-02, 3.74837781e-04, 2.44171246e-04, 9.39379837e-02,
        2.41717083e-02, 1.72177732e-02, 2.67200910e-02, 1.13315222e-04,
        3.36669729e-02, 2.88206402e-02],
       [2.19429532e-01, 2.96768606e-02, 1.06382375e-02, 5.67171236e-02,
        1.66824257e-01, 3.18535700e-02, 1.46222803e-04, 6.52035232e-01,
        9.97256788e-02, 1.62221583e-01],
       [1.97143634e-02, 2.29218258e-04, 3.24185535e-02, 1.27610382e-01,
        2.63146926e-02, 5.92060183e-04, 3.84618322e-02, 2.52399347e-01,
        8.64330505e-02, 1.32028247e-01],
       [1.35756767e-01, 1.11425928e-03, 8.28566829e-05, 1.01657867e-0

(128, array([[5.91455988e-01, 3.08805487e-03, 6.31620970e-02, 2.51530581e-01,
        1.30824517e-01, 4.00039687e-01, 1.03327324e-02, 1.10448156e-01,
        1.07934805e+00, 4.99948037e-01],
       [3.09400216e-04, 2.77923493e-04, 8.13376509e-06, 5.38876153e-03,
        7.03962270e-03, 2.05988263e-03, 1.34691213e-03, 7.67920819e-03,
        4.25170871e-03, 2.08269271e-04],
       [9.22019144e-03, 5.29196062e-05, 3.06711685e-05, 4.87098363e-02,
        1.28021308e-02, 8.10270034e-03, 1.40124211e-02, 4.75058677e-06,
        1.87543226e-02, 1.37929475e-02],
       [1.14597612e-01, 1.73370286e-02, 4.76657784e-03, 3.55242254e-02,
        9.58522687e-02, 1.81056074e-02, 3.24028301e-04, 3.72126645e-01,
        5.50978109e-02, 8.82151764e-02],
       [1.26396411e-02, 2.16125964e-05, 1.63569794e-02, 6.71099853e-02,
        1.11318343e-02, 1.04223883e-04, 2.18401594e-02, 1.11099561e-01,
        4.35334872e-02, 7.49141579e-02],
       [6.34216481e-02, 4.20379061e-04, 3.17304912e-05, 5.19473572e-0

(139, array([[3.10062791e-01, 1.07873949e-03, 3.20205662e-02, 1.37527819e-01,
        7.29310353e-02, 2.07080968e-01, 4.43192733e-03, 5.68400297e-02,
        5.45224340e-01, 2.65566574e-01],
       [1.04830467e-04, 1.92099985e-04, 3.97561516e-06, 4.02222823e-03,
        4.00915644e-03, 9.10531297e-04, 1.23424643e-03, 5.64713921e-03,
        3.09600746e-03, 1.45316354e-04],
       [3.56466478e-03, 5.43661266e-07, 8.23156507e-09, 2.35978824e-02,
        6.26595235e-03, 3.60692859e-03, 6.82809043e-03, 6.55775049e-05,
        9.81438447e-03, 6.16951786e-03],
       [5.71622815e-02, 9.51955962e-03, 2.10124591e-03, 2.05787554e-02,
        5.19658938e-02, 9.81301777e-03, 3.55858222e-04, 2.02013222e-01,
        2.86822642e-02, 4.57962479e-02],
       [7.88785592e-03, 1.63703783e-06, 7.62554831e-03, 3.39522418e-02,
        4.43095483e-03, 4.51827607e-06, 1.16503908e-02, 4.43628169e-02,
        2.06091597e-02, 4.04900642e-02],
       [2.78135412e-02, 1.32698066e-04, 1.22200565e-05, 2.50051393e-0

(148, array([[1.82257086e-01, 4.65095650e-04, 1.84111436e-02, 8.31152242e-02,
        4.46173067e-02, 1.20775713e-01, 2.18761259e-03, 3.31571263e-02,
        3.12108513e-01, 1.57520257e-01],
       [4.30714700e-05, 1.33867417e-04, 4.22992589e-06, 2.94291485e-03,
        2.54672830e-03, 4.70604992e-04, 9.98644239e-04, 4.25419903e-03,
        2.26638756e-03, 1.05365825e-04],
       [1.62393717e-03, 3.73414841e-06, 4.35887184e-06, 1.30500887e-02,
        3.48056047e-03, 1.87297773e-03, 3.79138274e-03, 8.93977312e-05,
        5.71602673e-03, 3.21047060e-03],
       [3.28507571e-02, 5.83250145e-03, 1.13845831e-03, 1.30320875e-02,
        3.15871956e-02, 5.87303790e-03, 3.17398935e-04, 1.23239187e-01,
        1.70336755e-02, 2.69542868e-02],
       [5.33732641e-03, 1.61355255e-05, 4.07033417e-03, 1.96036255e-02,
        2.08474596e-03, 4.47163076e-07, 6.84354323e-03, 2.06908531e-02,
        1.13625728e-02, 2.45685971e-02],
       [1.43758881e-02, 5.15723313e-05, 4.18492269e-06, 1.37915066e-0

(167, array([[5.88893107e-02, 8.82754143e-05, 5.75922966e-03, 2.81250108e-02,
        1.54556860e-02, 3.86675967e-02, 4.88361324e-04, 1.06940287e-02,
        9.64775344e-02, 5.18635190e-02],
       [8.10674520e-06, 6.11385800e-05, 4.53554471e-06, 1.35506541e-03,
        9.95917982e-04, 1.15517256e-04, 5.17377293e-04, 2.15932453e-03,
        1.06454326e-03, 4.90245551e-05],
       [2.98364487e-04, 1.28256742e-05, 1.10830980e-05, 3.70904495e-03,
        9.83362912e-04, 4.83997452e-04, 1.07672049e-03, 7.12014579e-05,
        1.80315744e-03, 8.17255147e-04],
       [1.05950374e-02, 2.09774440e-03, 3.38487124e-04, 4.85539189e-03,
        1.10836043e-02, 1.99551147e-03, 1.90143372e-04, 4.37661349e-02,
        5.75530901e-03, 8.95809197e-03],
       [2.34735949e-03, 3.91934390e-05, 1.11078151e-03, 6.25929666e-03,
        4.30265492e-04, 6.04918846e-06, 2.17162226e-03, 3.96564791e-03,
        3.33640907e-03, 8.57069207e-03],
       [3.70044669e-03, 7.98408976e-06, 1.17840639e-08, 3.97127617e-0

(184, array([[2.13088046e-02, 2.20576984e-05, 2.03997214e-03, 1.05107335e-02,
        5.87974713e-03, 1.39485955e-02, 1.27624577e-04, 3.90790312e-03,
        3.39384507e-02, 1.90685851e-02],
       [2.82219198e-06, 2.93619703e-05, 3.61134122e-06, 6.25608590e-04,
        4.35982049e-04, 3.18566023e-05, 2.48990382e-04, 1.10355629e-03,
        5.02875513e-04, 2.36544176e-05],
       [6.16247314e-05, 1.07617653e-05, 8.99739912e-06, 1.20106395e-03,
        3.13711975e-04, 1.48166193e-04, 3.46293312e-04, 3.94416313e-05,
        6.34363196e-04, 2.41946702e-04],
       [4.00035785e-03, 8.48826670e-04, 1.23674565e-04, 1.97613985e-03,
        4.36445987e-03, 7.61940612e-04, 1.01356645e-04, 1.74756976e-02,
        2.22832679e-03, 3.39939886e-03],
       [1.11882280e-03, 3.74015632e-05, 3.63938509e-04, 2.28808902e-03,
        1.06873752e-04, 5.69707072e-06, 7.67793613e-04, 8.39059918e-04,
        1.14512607e-03, 3.35996225e-03],
       [1.13466919e-03, 1.58467261e-06, 2.53431687e-07, 1.31776627e-0


(202, array([[7.23949024e-03, 5.60351967e-06, 6.79492155e-04, 3.67286577e-03,
        2.09074418e-03, 4.74289672e-03, 3.16825468e-05, 1.35071273e-03,
        1.12979356e-02, 6.58011346e-03],
       [1.41085864e-06, 1.33042356e-05, 2.30285337e-06, 2.63336862e-04,
        1.82768765e-04, 7.81870926e-06, 1.05599101e-04, 5.18688551e-04,
        2.19120543e-04, 1.09579354e-05],
       [1.02326617e-05, 6.35296433e-06, 5.26305718e-06, 3.64996226e-04,
        9.32894305e-05, 4.35770195e-05, 1.03080964e-04, 1.80442465e-05,
        2.09334000e-04, 6.70395435e-05],
       [1.46981171e-03, 3.28076101e-04, 4.37676726e-05, 7.54584885e-04,
        1.63384663e-03, 2.78649887e-04, 4.73295307e-05, 6.64333211e-03,
        8.25729925e-04, 1.23704705e-03],
       [5.01997060e-04, 2.59408035e-05, 1.16763378e-04, 7.98139622e-04,
        2.47767085e-05, 3.55592361e-06, 2.54410247e-04, 1.40910258e-04,
        3.79316334e-04, 1.25707175e-03],
       [3.33291551e-04, 2.99419279e-07, 3.56425111e-07, 4.13520061e-

(217, array([[2.94160412e-03, 1.84040084e-06, 2.71389834e-04, 1.52319014e-03,
        8.77348673e-04, 1.93314107e-03, 1.01689267e-05, 5.58502851e-04,
        4.53991819e-03, 2.70500113e-03],
       [9.62430480e-07, 6.73454109e-06, 1.42117255e-06, 1.24944062e-04,
        8.83454804e-05, 2.32909028e-06, 4.94180353e-05, 2.69861198e-04,
        1.07666886e-04, 5.78778744e-06],
       [1.88248369e-06, 3.67944453e-06, 3.04767887e-06, 1.36147027e-04,
        3.42676229e-05, 1.60233389e-05, 3.74268948e-05, 8.77414650e-06,
        8.26632350e-05, 2.30976903e-05],
       [6.49778385e-04, 1.48975790e-04, 1.87359052e-05, 3.36504353e-04,
        7.23762812e-04, 1.21440532e-04, 2.39174981e-05, 2.97670525e-03,
        3.64406324e-04, 5.38367498e-04],
       [2.53802789e-04, 1.67473591e-05, 4.67280628e-05, 3.34001342e-04,
        7.19748824e-06, 2.11264068e-06, 1.01227564e-04, 2.58169816e-05,
        1.53795492e-04, 5.57428037e-04],
       [1.22153502e-04, 6.99127491e-08, 2.72243016e-07, 1.58490661e-0

(227, array([[1.61349215e-03, 8.97281055e-07, 1.47237058e-04, 8.46141108e-04,
        4.90992270e-04, 1.06361684e-03, 4.87080850e-06, 3.10047302e-04,
        2.47696385e-03, 1.49417357e-03],
       [7.63097254e-07, 4.26422638e-06, 9.83618962e-07, 7.53852563e-05,
        5.44993671e-05, 1.00495433e-06, 2.93870511e-05, 1.72779211e-04,
        6.66509620e-05, 3.79306666e-06],
       [4.99861550e-07, 2.44259054e-06, 2.05081594e-06, 7.07046676e-05,
        1.75817876e-05, 8.30119855e-06, 1.89556162e-05, 5.31778527e-06,
        4.44970896e-05, 1.13622027e-05],
       [3.79787151e-04, 8.81186035e-05, 1.06676688e-05, 1.96073051e-04,
        4.21176672e-04, 7.00560432e-05, 1.48983315e-05, 1.74514835e-03,
        2.12002852e-04, 3.10510878e-04],
       [1.59619436e-04, 1.19276147e-05, 2.57975485e-05, 1.87564139e-04,
        3.14405783e-06, 1.42245225e-06, 5.48833843e-05, 6.61865233e-06,
        8.48829860e-05, 3.24660744e-04],
       [6.31106776e-05, 2.70295944e-08, 1.95160790e-07, 8.39064528e-0

(246, array([[5.15474580e-04, 2.35708978e-07, 4.60167407e-05, 2.76580928e-04,
        1.62561650e-04, 3.42605234e-04, 1.27859439e-06, 1.01374268e-04,
        7.87088827e-04, 4.83261674e-04],
       [4.87196706e-07, 1.77815837e-06, 4.52676261e-07, 2.85065788e-05,
        2.17278772e-05, 1.83764803e-07, 1.07192494e-05, 7.27553364e-05,
        2.65772663e-05, 1.69511679e-06],
       [7.27796503e-09, 1.06764757e-06, 9.10150573e-07, 2.03990462e-05,
        4.91227751e-06, 2.44084655e-06, 5.15161971e-06, 2.00188215e-06,
        1.37841039e-05, 2.94561822e-06],
       [1.38657897e-04, 3.25073505e-05, 3.66890607e-06, 7.02425742e-05,
        1.51240532e-04, 2.48177277e-05, 5.88521744e-06, 6.34128591e-04,
        7.61901391e-05, 1.10109661e-04],
       [6.51640555e-05, 5.82942202e-06, 8.48286914e-06, 6.31281896e-05,
        6.27000389e-07, 6.34054736e-07, 1.72494332e-05, 5.42675746e-08,
        2.79309645e-05, 1.16978774e-04],
       [1.82560697e-05, 3.49776499e-09, 9.44123799e-08, 2.52140674e-0

(263, array([[1.85777286e-04, 7.11752766e-08, 1.62233364e-05, 1.01780639e-04,
        6.02500161e-05, 1.24651549e-04, 4.11804292e-07, 3.72686079e-05,
        2.83458614e-04, 1.75903018e-04],
       [3.12078217e-07, 8.03751515e-07, 2.12904142e-07, 1.18214850e-05,
        9.50680578e-06, 3.29351477e-08, 4.27846364e-06, 3.30180033e-05,
        1.16067395e-05, 8.21035843e-07],
       [1.19456488e-08, 4.91856189e-07, 4.28631144e-07, 6.76743810e-06,
        1.58941552e-06, 8.31570555e-07, 1.58931335e-06, 8.21537230e-07,
        4.84452018e-06, 8.79586533e-07],
       [5.68351035e-05, 1.33320160e-05, 1.41547195e-06, 2.79903227e-05,
        6.07481074e-05, 9.86524386e-06, 2.50441044e-06, 2.56902382e-04,
        3.07306312e-05, 4.38985650e-05],
       [2.87844784e-05, 2.89821692e-06, 3.20420512e-06, 2.39431203e-05,
        1.28287532e-07, 2.95738569e-07, 6.15547925e-06, 2.46945850e-07,
        1.04672153e-05, 4.71323252e-05],
       [6.10899603e-06, 4.83561862e-10, 4.40910062e-08, 8.66934213e-0

(280, array([[6.69845124e-05, 2.15206932e-08, 5.71376350e-06, 3.75219404e-05,
        2.22984867e-05, 4.54580072e-05, 1.41130888e-07, 1.36769599e-05,
        1.02427174e-04, 6.40001874e-05],
       [1.89111208e-07, 3.60634809e-07, 9.59672000e-08, 4.87291981e-06,
        4.15087574e-06, 3.86120526e-09, 1.69469342e-06, 1.47961857e-05,
        5.05422635e-06, 3.95949479e-07],
       [2.51217894e-08, 2.20080369e-07, 1.97161716e-07, 2.25793215e-06,
        5.17131910e-07, 2.87846773e-07, 4.84908399e-07, 3.32825759e-07,
        1.71065819e-06, 2.60581935e-07],
       [2.34532403e-05, 5.47359932e-06, 5.44572364e-07, 1.11551418e-05,
        2.44885225e-05, 3.94202157e-06, 1.04786998e-06, 1.04265057e-04,
        1.24574685e-05, 1.76144120e-05],
       [1.25433037e-05, 1.38660062e-06, 1.23162026e-06, 9.13911831e-06,
        2.19861840e-08, 1.33092772e-07, 2.21710457e-06, 4.61594971e-07,
        3.97217271e-06, 1.90289249e-05],
       [2.07192318e-06, 1.99443679e-11, 1.92728152e-08, 2.99914063e-0

(297, array([[2.41592404e-05, 6.53297139e-09, 2.01154134e-06, 1.38693608e-05,
        8.24829655e-06, 1.66153242e-05, 5.14547217e-08, 5.00852607e-06,
        3.71126616e-05, 2.32776568e-05],
       [1.09242546e-07, 1.61227703e-07, 4.15661531e-08, 2.00163608e-06,
        1.80860523e-06, 6.95330023e-11, 6.69187146e-07, 6.56538913e-06,
        2.19754001e-06, 1.89240581e-07],
       [2.23964963e-08, 9.62762591e-08, 8.89147590e-08, 7.55362217e-07,
        1.66710939e-07, 1.01115860e-07, 1.45248693e-07, 1.34797251e-07,
        6.08833283e-07, 7.63376057e-08],
       [9.72751211e-06, 2.24571567e-06, 2.11227149e-07, 4.45171309e-06,
        9.91132045e-06, 1.57801937e-06, 4.32877347e-07, 4.24089125e-05,
        5.07986116e-06, 7.10428425e-06],
       [5.40886939e-06, 6.42281644e-07, 4.76929763e-07, 3.50915110e-06,
        2.36724905e-09, 5.86160354e-08, 8.07442880e-07, 4.05037130e-07,
        1.52068280e-06, 7.70350595e-06],
       [7.10697980e-07, 4.08595545e-12, 8.07005733e-09, 1.04576241e-0

(312, array([[9.82726442e-06, 2.17647561e-09, 7.98071643e-07, 5.78055332e-06,
        3.42164420e-06, 6.84811234e-06, 2.17694293e-08, 2.05854839e-06,
        1.51840947e-05, 9.53767443e-06],
       [6.51208920e-08, 7.85404111e-08, 1.95219386e-08, 9.10436394e-07,
        8.65129068e-07, 1.40762332e-10, 2.93789304e-07, 3.18270057e-06,
        1.05223707e-06, 9.80986153e-08],
       [1.58808856e-08, 4.64833855e-08, 4.37129233e-08, 2.89944363e-07,
        6.31667617e-08, 4.06197473e-08, 4.99621485e-08, 6.05265971e-08,
        2.46503817e-07, 2.57359455e-08],
       [4.48557625e-06, 1.02538486e-06, 8.98418576e-08, 1.97892607e-06,
        4.46856266e-06, 7.08520283e-07, 1.97147351e-07, 1.91846936e-05,
        2.30252026e-06, 3.20228610e-06],
       [2.55667096e-06, 3.21837703e-07, 2.08148773e-07, 1.51133104e-06,
        3.09471768e-11, 2.81008410e-08, 3.32492460e-07, 2.87058467e-07,
        6.55622449e-07, 3.47348039e-06],
       [2.78432976e-07, 1.99984983e-11, 3.73421832e-09, 4.15328127e-0

(323, array([[5.08095259e-06, 9.52789969e-10, 4.04831433e-07, 3.04924258e-06,
        1.79361698e-06, 3.57828787e-06, 1.18110557e-08, 1.06995239e-06,
        7.89146455e-06, 4.95742569e-06],
       [4.38096268e-08, 4.63028982e-08, 1.10699802e-08, 5.10557983e-07,
        5.03250532e-07, 3.25683603e-10, 1.60640745e-07, 1.86460704e-06,
        6.12751267e-07, 6.03384447e-08],
       [1.14549148e-08, 2.69667940e-08, 2.58775038e-08, 1.44359639e-07,
        3.10771323e-08, 2.08903519e-08, 2.26502496e-08, 3.36167113e-08,
        1.27331890e-07, 1.15028129e-08],
       [2.54645518e-06, 5.76700108e-07, 4.83102025e-08, 1.09312720e-06,
        2.49639003e-06, 3.93544831e-07, 1.10433360e-07, 1.07366941e-05,
        1.29298254e-06, 1.78830876e-06],
       [1.47086665e-06, 1.91642629e-07, 1.13802592e-07, 8.16482432e-07,
        1.01901495e-10, 1.62948508e-08, 1.74132551e-07, 2.07092366e-07,
        3.54750903e-07, 1.93832230e-06],
       [1.40677558e-07, 2.20422067e-11, 2.07577585e-09, 2.11712662e-0

(333, array([[2.78888366e-06, 4.44197286e-10, 2.18147336e-07, 1.70787108e-06,
        9.96377631e-07, 1.98458982e-06, 6.84097208e-09, 5.89195052e-07,
        4.35502992e-06, 2.73431262e-06],
       [3.01933932e-08, 2.86013452e-08, 6.55689766e-09, 3.01573182e-07,
        3.07135860e-07, 3.97486134e-10, 9.27858850e-08, 1.14435670e-06,
        3.74628377e-07, 3.86424141e-08],
       [8.16970931e-09, 1.64017054e-08, 1.59687953e-08, 7.67475100e-08,
        1.63067239e-08, 1.14878241e-08, 1.09354950e-08, 1.97899275e-08,
        7.01976096e-08, 5.50115122e-09],
       [1.52336577e-06, 3.41883826e-07, 2.74479375e-08, 6.37816474e-07,
        1.47209181e-06, 2.30899207e-07, 6.50786563e-08, 6.33822068e-06,
        7.65822741e-07, 1.05445749e-06],
       [8.87202501e-07, 1.19150716e-07, 6.56323202e-08, 4.67116814e-07,
        2.99405566e-10, 9.83830164e-09, 9.71526214e-08, 1.48397737e-07,
        2.03278620e-07, 1.14122104e-06],
       [7.58469147e-08, 2.26791206e-11, 1.22818795e-09, 1.15068576e-0

(338, array([[2.06605944e-06, 2.99729427e-10, 1.60068072e-07, 1.27882950e-06,
        7.42543042e-07, 1.47818209e-06, 5.22715563e-09, 4.36873618e-07,
        3.23571247e-06, 2.03065447e-06],
       [2.49749473e-08, 2.24692124e-08, 5.03672859e-09, 2.31779332e-07,
        2.39844684e-07, 4.03608198e-10, 7.05252381e-08, 8.95797272e-07,
        2.92926281e-07, 3.08977276e-08],
       [6.81205138e-09, 1.27766930e-08, 1.25243344e-08, 5.59335046e-08,
        1.18123188e-08, 8.54438817e-09, 7.59027647e-09, 1.52080397e-08,
        5.23364079e-08, 3.80316191e-09],
       [1.17877762e-06, 2.63234697e-07, 2.06685608e-08, 4.87535020e-07,
        1.13101656e-06, 1.76955997e-07, 4.99176412e-08, 4.87092182e-06,
        5.89253824e-07, 8.10123218e-07],
       [6.88615651e-07, 9.37069527e-08, 4.99043071e-08, 3.53704945e-07,
        3.77440505e-10, 7.64828350e-09, 7.25618250e-08, 1.24173151e-07,
        1.54054497e-07, 8.75882597e-07],
       [5.57848135e-08, 1.94815071e-11, 9.30501150e-10, 8.48520091e-0

(359, array([[5.85542300e-07, 5.29676236e-11, 4.34333504e-08, 3.81495029e-07,
        2.15550568e-07, 4.29632459e-07, 1.71964659e-09, 1.23531530e-07,
        9.29901796e-07, 5.81848088e-07],
       [1.09808950e-08, 8.11867296e-09, 1.63188019e-09, 7.66638367e-08,
        8.47258680e-08, 3.07398202e-10, 2.23248780e-08, 3.18747177e-07,
        1.04075730e-07, 1.19816686e-08],
       [3.03250617e-09, 4.44682972e-09, 4.51011898e-09, 1.50231899e-08,
        3.11075677e-09, 2.46759270e-09, 1.59310839e-09, 5.01727500e-09,
        1.51894502e-08, 7.83963114e-10],
       [4.01946663e-07, 8.78626240e-08, 6.30023745e-09, 1.57753667e-07,
        3.74591797e-07, 5.79474240e-08, 1.63486208e-08, 1.61468879e-06,
        1.96934678e-07, 2.68447254e-07],
       [2.36418054e-07, 3.38841185e-08, 1.58309799e-08, 1.10063035e-07,
        4.27894171e-10, 2.61451767e-09, 2.15347541e-08, 5.58919927e-08,
        4.81538621e-08, 2.88402617e-07],
       [1.53885307e-08, 1.35778543e-11, 3.04017831e-10, 2.38586741e-0

(376, array([[2.10683803e-07, 1.08405727e-11, 1.50165544e-08, 1.44224394e-07,
        7.90112136e-08, 1.58211714e-07, 7.10939004e-10, 4.40234657e-08,
        3.39067973e-07, 2.11393193e-07],
       [5.50469551e-09, 3.54441058e-09, 6.46110819e-10, 3.13029474e-08,
        3.63937081e-08, 1.98864121e-10, 8.81842282e-09, 1.37418467e-07,
        4.50026586e-08, 5.51956436e-09],
       [1.50094904e-09, 1.87907524e-09, 1.95701842e-09, 5.25003271e-09,
        1.07477649e-09, 9.10897215e-10, 4.33310909e-10, 2.05958155e-09,
        5.64187335e-09, 2.10969810e-10],
       [1.68507646e-07, 3.61852920e-08, 2.40043705e-09, 6.34055172e-08,
        1.53537660e-07, 2.35353397e-08, 6.60586416e-09, 6.61865319e-07,
        8.13042593e-08, 1.10132554e-07],
       [9.90855384e-08, 1.47110663e-08, 6.26599887e-09, 4.29160081e-08,
        3.24718950e-10, 1.08847666e-09, 8.12277770e-09, 2.78938935e-08,
        1.88372021e-08, 1.17510998e-07],
       [5.45257884e-09, 8.62559873e-12, 1.21324807e-10, 8.61745755e-0

(396, array([[6.31311183e-08, 1.05267774e-12, 4.26861561e-09, 4.62914297e-08,
        2.41753710e-08, 4.88977119e-08, 2.54392856e-10, 1.29089234e-08,
        1.03454658e-07, 6.41520736e-08],
       [2.38410080e-09, 1.33174447e-09, 2.13857423e-10, 1.09179306e-08,
        1.34314051e-08, 1.04483379e-10, 2.96740171e-09, 5.08540699e-08,
        1.67702398e-08, 2.19711010e-09],
       [6.28366818e-10, 6.80049015e-10, 7.31549769e-10, 1.53680764e-09,
        3.12268076e-10, 2.85241527e-10, 8.85917405e-11, 7.28684733e-10,
        1.78646069e-09, 4.31738198e-11],
       [6.07207320e-08, 1.27555499e-08, 7.69626187e-10, 2.17784892e-08,
        5.39508955e-08, 8.18073840e-09, 2.26955613e-09, 2.32328032e-07,
        2.87601841e-08, 3.87407928e-08],
       [3.54637241e-08, 5.46699577e-09, 2.10689239e-09, 1.42345332e-08,
        1.93550525e-10, 3.83170671e-10, 2.60313692e-09, 1.18461261e-08,
        6.25977994e-09, 4.09285365e-08],
       [1.61928332e-09, 4.49743223e-12, 4.10810315e-11, 2.62708971e-0

(413, array([[2.26028748e-08, 2.76995695e-14, 1.45069321e-09, 1.77452853e-08,
        8.80336327e-09, 1.80384822e-08, 1.06697079e-10, 4.48672483e-09,
        3.76913671e-08, 2.32488586e-08],
       [1.15077230e-09, 5.77507617e-10, 8.26790610e-11, 4.46297617e-09,
        5.74775618e-09, 5.63496267e-11, 1.17989592e-09, 2.17784594e-08,
        7.24088556e-09, 9.97389992e-10],
       [2.93061857e-10, 2.86314550e-10, 3.16663097e-10, 5.45965038e-10,
        1.10909170e-10, 1.07071762e-10, 2.14399249e-11, 3.01792935e-10,
        6.78654456e-10, 1.05039474e-11],
       [2.55383918e-08, 5.26357225e-09, 2.92373295e-10, 8.80794715e-09,
        2.22355761e-08, 3.34073242e-09, 9.14406845e-10, 9.56105291e-08,
        1.19074514e-08, 1.59815985e-08],
       [1.47790269e-08, 2.34443923e-09, 8.35369363e-10, 5.58956199e-09,
        1.12711149e-10, 1.56614539e-10, 9.96273867e-10, 5.58454956e-09,
        2.45792328e-09, 1.67212358e-08],
       [5.78395061e-10, 2.56199650e-12, 1.64427325e-11, 9.64945879e-1

(430, array([[8.06574246e-09, 4.42873527e-14, 4.87992100e-10, 6.84855440e-09,
        3.19402482e-09, 6.65507361e-09, 4.48835297e-11, 1.53549663e-09,
        1.37154093e-08, 8.40945741e-09],
       [5.48029731e-10, 2.49787957e-10, 3.16710592e-11, 1.82583074e-09,
        2.45581433e-09, 2.91081431e-11, 4.70562936e-10, 9.30648409e-09,
        3.12457615e-09, 4.50061054e-10],
       [1.34695025e-10, 1.19916228e-10, 1.36329847e-10, 1.97057317e-10,
        4.03897038e-11, 4.04337446e-11, 4.64484057e-12, 1.26061963e-10,
        2.59890890e-10, 2.34333339e-12],
       [1.07503303e-08, 2.17515172e-09, 1.10854970e-10, 3.56836747e-09,
        9.18201490e-09, 1.36662846e-09, 3.68590254e-10, 3.94190084e-08,
        4.93998858e-09, 6.60662459e-09],
       [6.15028686e-09, 9.99646106e-10, 3.32260914e-10, 2.20032592e-09,
        6.18607613e-11, 6.38339567e-11, 3.83885019e-10, 2.59406899e-09,
        9.66622096e-10, 6.83804862e-09],
       [2.06969110e-10, 1.38743613e-12, 6.55067660e-12, 3.56636561e-1

(448, array([[2.69714770e-09, 1.12406178e-13, 1.51437252e-10, 2.51789378e-09,
        1.08604332e-09, 2.31596962e-09, 1.79317486e-11, 4.82953178e-10,
        4.69444195e-09, 2.85842824e-09],
       [2.46726347e-10, 1.02655498e-10, 1.13757597e-11, 7.09566809e-10,
        9.96986731e-10, 1.39747593e-11, 1.78446173e-10, 3.77541993e-09,
        1.28248952e-09, 1.92794520e-10],
       [5.81961035e-11, 4.78118344e-11, 5.57522257e-11, 6.74162815e-11,
        1.40265888e-11, 1.45715192e-11, 7.50786150e-13, 5.02442912e-11,
        9.54757507e-11, 4.08342957e-13],
       [4.30726562e-09, 8.54264157e-10, 3.96048521e-11, 1.37579971e-09,
        3.60901748e-09, 5.31950609e-10, 1.40837966e-10, 1.54571151e-08,
        1.94708387e-09, 2.59905539e-09],
       [2.42611840e-09, 4.04898390e-10, 1.25024391e-10, 8.22784636e-10,
        3.10891475e-11, 2.43532146e-11, 1.40885230e-10, 1.13579189e-09,
        3.60165671e-10, 2.65659067e-09],
       [6.99010346e-11, 7.09955766e-13, 2.49549864e-12, 1.25326605e-1

(466, array([[8.96869452e-10, 1.13933903e-13, 4.61349454e-11, 9.32907224e-10,
        3.67043215e-10, 8.05640769e-10, 7.16021659e-12, 1.47734403e-10,
        1.60247885e-09, 9.68453033e-10],
       [1.09881674e-10, 4.21067064e-11, 4.05098265e-12, 2.76090300e-10,
        4.04340410e-10, 6.53345230e-12, 6.79058161e-11, 1.52927250e-09,
        5.26062839e-10, 8.21671492e-11],
       [2.49421201e-11, 1.90231804e-11, 2.27446926e-11, 2.34939392e-11,
        5.01267491e-12, 5.26853378e-12, 7.37116527e-14, 2.01263982e-11,
        3.52699154e-11, 4.90963837e-14],
       [1.72714635e-09, 3.36016411e-10, 1.41417348e-11, 5.31466868e-10,
        1.42121255e-09, 2.07392822e-10, 5.38955941e-11, 6.07422200e-09,
        7.69294378e-10, 1.02440813e-09],
       [9.57183150e-10, 1.63378666e-10, 4.70636522e-11, 3.08271740e-10,
        1.52370916e-11, 9.29576019e-12, 5.20364461e-11, 4.91540028e-10,
        1.34122743e-10, 1.03407002e-09],
       [2.35907085e-11, 3.54115990e-13, 9.50778201e-13, 4.44032949e-1

(484, array([[2.96215027e-10, 8.45413281e-14, 1.36980140e-11, 3.48396067e-10,
        1.23099755e-10, 2.80052524e-10, 2.85241197e-12, 4.35950115e-11,
        5.45248160e-10, 3.26817183e-10],
       [4.84700780e-11, 1.72372794e-11, 1.42913435e-12, 1.07573491e-10,
        1.63852606e-10, 2.99007446e-12, 2.59320155e-11, 6.18766411e-10,
        2.15658031e-10, 3.48527284e-11],
       [1.05859166e-11, 7.56371415e-12, 9.28361473e-12, 8.27929258e-12,
        1.82711904e-12, 1.92239332e-12, 3.27958838e-16, 8.09930021e-12,
        1.31655260e-11, 1.81817275e-15],
       [6.93383326e-10, 1.32389454e-10, 5.03868734e-12, 2.05918099e-10,
        5.60925834e-10, 8.10730498e-11, 2.06510321e-11, 2.39172611e-09,
        3.04278405e-10, 4.04662396e-10],
       [3.77461265e-10, 6.57857596e-11, 1.77604024e-11, 1.15860821e-10,
        7.25043083e-12, 3.51933372e-12, 1.93357540e-11, 2.10872137e-10,
        4.99508327e-11, 4.03021834e-10],
       [7.95280711e-12, 1.71717792e-13, 3.64440371e-13, 1.58642441e-1

Pytorch: Tensors
- Here we do the same as above except rather than using numpy arrays we use PyTorch Tensors

In [4]:
dtype = torch.float
device = torch.device("cpu")

# N is batch size; D_in is input dimension
# H is hidden dimension; D_out is output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
## Create training data by generating random numbers in
## PyTorch Tensors.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialise weights
## Randomly initialise weights for both layers as they
## need to begin somewhere for gradient descent to work
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

## Initialise learning rate which is rate at which 
## weights change per iteration
learning_rate = 1e-6

## Run 500 epochs over the data
for t in range(500):
  # Forward pass compute predicted y
  ## Run matrix multiplication (mm) of the input 
  ## values (x) against the weights of the first
  ## layer (w1)
  h = x.mm(w1)
  ## Apply the relu function by setting all values
  ## less than zero to zero
  h_relu = h.clamp(min=0)
  ## Make the prediction by multiplying the final
  ## layer weights by the above value.
  y_pred = h_relu.mm(w2)
  
  # Compute and print loss
  ## Compute the loss using the sum of squared errors
  ## per example (item())
  loss = (y_pred - y).pow(2).sum().item()
  ## print epoch and loss
  print(t, loss)
  
  # Backprop to compute gradient of w1 and w2 wrt the loss
  ## Calculate the gradient of the loss function at the 
  ## point where the weights are equal to w1 and w2, using
  ## the derivative of the loss function
  grad_y_pred = 2.0 * (y_pred - y)
  
  ## CALCULATE GRADIENT FOR LOSS FUNCTION OF LAYER 2
  ## Multiply gradient of loss function by transpose  
  ## of inputs to layer 2.
  grad_w2 = h_relu.t().mm(grad_y_pred)
  
  ## CALCULATE GRADIENT FOR LOSS FUNCTION OF LAYER 1
  ## Multiply gradient by transpose  (t()) of layer 2
  ## weights to get the gradient back to layer 1.
  grad_h_relu = grad_y_pred.mm(w2.t())
  ## Copy the gradient for layer 1
  grad_h = grad_h_relu.clone()
  ## Make all values less than 0 in layer 1 gradient
  ## equal to 0.
  grad_h[h < 0] = 0
  ## Multiply transpose of layer 1 inputs (x) by 
  ## above value to get the gradient of layer 1
  grad_w1 = x.t().mm(grad_h)
  
  # Update weights using gradient descent
  ## Subtract learning rate by each layers gradient
  ## to get updated layer weights.
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2

(0, 39136240.0)
(1, 43451088.0)
(2, 53348148.0)
(3, 55577556.0)
(4, 40911448.0)
(5, 19727364.0)
(6, 7353861.0)
(7, 3075230.0)
(8, 1801901.75)
(9, 1331428.875)
(10, 1080138.0)
(11, 906339.5625)
(12, 771484.875)
(13, 662808.0625)
(14, 573595.75)
(15, 499437.03125)
(16, 437312.65625)
(17, 384837.875)
(18, 340177.59375)
(19, 301925.53125)
(20, 268971.375)
(21, 240441.9375)
(22, 215698.703125)
(23, 194127.390625)
(24, 175220.40625)
(25, 158562.78125)
(26, 143857.265625)
(27, 130864.296875)
(28, 119306.796875)
(29, 108987.546875)
(30, 99741.0625)
(31, 91440.0703125)
(32, 83963.625)
(33, 77212.578125)
(34, 71108.015625)
(35, 65573.0859375)
(36, 60548.421875)
(37, 55975.33984375)
(38, 51808.48046875)
(39, 48004.08203125)
(40, 44526.671875)
(41, 41341.41796875)
(42, 38419.8984375)
(43, 35737.5)
(44, 33271.22265625)
(45, 30998.94140625)
(46, 28903.224609375)
(47, 26969.3125)
(48, 25181.802734375)
(49, 23529.833984375)
(50, 22001.640625)
(51, 20585.158203125)
(52, 19270.96484375)
(53, 18050.62890

**Autograd**

Pytorch: Tensors and autograd

In [5]:
dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # uncomment this to run on GPU

# N is batch size; D_in is input dimension
# H is hidden dimension; D_out is output dimension
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random tensors for weight
# Setting requires_grad=True indicates we want to compute gradients with 
# respect to Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: Compute predicted y using operations on Tensors; these
  # are exactly the same operations we used to compute the forward pass using
  # Tensors, but we do not need to keep references to intermediate values since
  # we are not implementing the backward pass by hand.
  ## This is exactly the same as above except its done on one line because the 
  ## intermediate values don't need to be preserved for the backpropagation.
  y_pred = x.mm(w1).clamp(min=0).mm(w2)
  
  # Compute and print loss using operations on Tensors; these
  # Now loss is a tensor of shape (1,)
  # Loss.item() gets a scalar value in the loss.
  ## COmpute the loss using the sum  of squared errors
  loss = (y_pred - y).pow(2).sum()
  ## print loss and epoch
  print(t, loss.item())
  
  # Use autograd to compute the backward pass. This call will compute the
  # gradient of loss wrt all tensors with requires_grad=True.
  # After this call w1.grad and w2.grad will be tensorsholding the gradient
  # of the loss wrt w1 and w2 respectively.
  ## This will perform the whole backpropagation calculating gradients for each
  ## layer. 
  loss.backward()
  
  # Manually update weights using gradient descent. Wrap in torch.no_grad()
  # because weights have requires_grad=True, but we don't need to track this
  # in autograd.
  # An alternative way is to operate on weight.data and weight.grad.data.tensor
  # but this doesn't track history.
  # You can also use torch.optim.SGD to achieve this.
  ## Only apply below operation to weights with requires_grad = True
  with torch.no_grad():
    ## Subtrack learning rate by weights gradient from each layer to adjust 
    ## weights for each layer
    w1 -= learning_rate * w1.grad
    w2 -= learning_rate * w2.grad
    
    # Manually zero the gradients after using weights
    ## gradients must be zeroed as they are cumulative
    w1.grad.zero_()
    w2.grad.zero_()

(0, 24904014.0)
(1, 19165894.0)
(2, 17861988.0)
(3, 18246090.0)
(4, 18705246.0)
(5, 17928222.0)
(6, 15525073.0)
(7, 11933672.0)
(8, 8285634.5)
(9, 5332510.5)
(10, 3323698.25)
(11, 2074352.25)
(12, 1339144.5)
(13, 908960.25)
(14, 653836.0625)
(15, 496364.75)
(16, 394135.84375)
(17, 323768.4375)
(18, 272529.96875)
(19, 233378.890625)
(20, 202321.5)
(21, 176880.5)
(22, 155609.140625)
(23, 137552.875)
(24, 122052.9609375)
(25, 108641.4140625)
(26, 96977.0703125)
(27, 86775.875)
(28, 77818.6015625)
(29, 69924.4609375)
(30, 62946.51171875)
(31, 56762.96484375)
(32, 51273.921875)
(33, 46385.9453125)
(34, 42025.50390625)
(35, 38133.56640625)
(36, 34649.359375)
(37, 31522.478515625)
(38, 28711.8515625)
(39, 26182.51171875)
(40, 23901.197265625)
(41, 21841.6796875)
(42, 19980.41015625)
(43, 18295.7734375)
(44, 16768.279296875)
(45, 15381.802734375)
(46, 14122.17578125)
(47, 12976.9169921875)
(48, 11935.1943359375)
(49, 10985.8720703125)
(50, 10120.033203125)
(51, 9328.8505859375)
(52, 8605.60156

Pytorch: Defining new Autograd Functions
- Each autograd operator is really two functions operating on the Tensor. A forward that computes the output. A backward, that receives the gradient of some output wrt some  scalar and calculates the gradient of the output.

In [6]:
## Create own custom ReLU PyTorch operator the sets all negative values to 0
class MyReLU(torch.autograd.Function):
  """
  We can implement our own custom autograd Functions by subclassing
  torch.autograd.Function and implementing the forward and backward passes
  which operate on Tensors.
  """  
  
  ## Override forward function with our own that calculate ReLU output
  @staticmethod
  def forward(ctx, input):
    """
    In the forward pass we receive a Tensor containing the input and return
    a Tensor containing the output. ctx is a context object that can be used
    to stash information for backward computation. You can cache arbitrarily
    objects for use in the backward pass using the cts.save_for_backward method. 
    """
    ctx.save_for_backward(input)
    return input.clamp(min=0)
  
  ## Override backward function with our own that receive gradient of output
  ## and calculates gradient of input.
  @staticmethod
  def backward(ctx, grad_output):
    """
    In the backward pass we receive a Tensor containing the gradient of the Loss
    wrt the output, and we need to compute the gradient of the loss wrt the 
    input.
    """
    
    input, = ctx.saved_tensors
    grad_input = grad_output.clone()
    grad_input[input < 0] = 0
    return grad_input
  
dtype = torch.float
## Store that this is being run on a CPU
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment if running on GPU
 
# N is batch size; D_in is input dimension
# H is hidden dimension; D_out is output dimension
N, D_in, H, D_out = 64, 1000, 100, 10
 
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
 
# Create random Tensors for weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
 
learning_rate = 1e-6
for t in range(500):
  # To apply our function, we use Function.apply method. We alias as 'relu'.
  ## Create variable to store methods
  relu = MyReLU.apply
  
  # Forward pass: compute predicted y using operations; we compute
  # ReLU using our custom autograd operation.
  ## Apply ReLU function to first layer inputs multiplied by weights using newly
  ## created operator and then apply second layer weights
  y_pred = relu(x.mm(w1)).mm(w2)
  
  # Compute and print loss
  ## Calculate loss using squared mean error
  loss = (y_pred - y).pow(2).sum()
  ## Print loss and epoch of that loss
  print(t, loss.item())
   
  # Use autograd to compute the backward pass
  ## Calculate gradient of loss function by applying the backward method to 
  ## loss
  loss.backward()
   
  # Update weights using gradient descent
  ## If the gradient is calculated apply weight adjustment
  with torch.no_grad():
    ## Subtraxct from weight the learning rate multiplied by gradient to adjust
    ## the weight closer to optimal value.
    w1 -= learning_rate * w1.grad
    w2 -= learning_rate * w2.grad
     
    # Manually zero the gradients after initialising weights
    ## Zero gradients as they are cumulative across epochs
    w1.grad.zero_()
    w2.grad.zero_()

(0, 24532838.0)
(1, 21624412.0)
(2, 24346328.0)
(3, 30079592.0)
(4, 34826724.0)
(5, 33591468.0)
(6, 24881554.0)
(7, 14043618.0)
(8, 6586287.5)
(9, 3008534.75)
(10, 1547691.875)
(11, 953449.1875)
(12, 684115.375)
(13, 538608.625)
(14, 444674.75)
(15, 375934.5)
(16, 321922.03125)
(17, 277890.84375)
(18, 241256.53125)
(19, 210468.984375)
(20, 184372.65625)
(21, 162118.46875)
(22, 143030.78125)
(23, 126592.859375)
(24, 112385.609375)
(25, 100045.25)
(26, 89294.59375)
(27, 79892.7421875)
(28, 71652.1484375)
(29, 64396.48828125)
(30, 58003.8046875)
(31, 52355.7109375)
(32, 47345.75390625)
(33, 42892.49609375)
(34, 38929.61328125)
(35, 35391.87890625)
(36, 32226.533203125)
(37, 29387.3515625)
(38, 26835.404296875)
(39, 24537.08203125)
(40, 22466.19140625)
(41, 20595.814453125)
(42, 18904.099609375)
(43, 17371.9609375)
(44, 15981.5859375)
(45, 14719.3232421875)
(46, 13570.734375)
(47, 12524.224609375)
(48, 11569.501953125)
(49, 10698.0498046875)
(50, 9901.6435546875)
(51, 9173.0498046875)
(52,

**nn module**

PyTorch: nn
- nn defines a set of modules equivalent to a neural network layer, receives input and output tensors and can also hold internal state parameters such as learnable parameters.
- nn package also defines a set of useful loss parameters

In [8]:
# N is batch size; D_in is input dimensions;
# H is hidden dimensions; D_out is output dimensions
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the NN package to define our model as a sequence of Layers. nn.Sequential
# is a module that contains other Modules, and applies them in sequence to 
# produce its output. Each Linear module computes output from input using a
# linear function and holds internal Tensors for its weight and bias.
## Defines a two layer neaural network
## nn.sequential enables you to build a neaural network by sequentially
## specifying the building blocks of the net.
model = torch.nn.Sequential(
    ## Specify a linear layer applying a linear transformation to the incoming
    ## D_in number of layers outputing H layers
    torch.nn.Linear(D_in, H), 
    ## Applies the rectified linear unit function element wise
    torch.nn.ReLU(),
    ## Specifies a linear layer applying a linear transformation to the incoming
    ## H layers and outputing D_out layers.
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
## The nn module also specifies a range of loss functions
## Specifies the Mean Squared Error loss function
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
  # Forward pass: Compute predicted y by passing x to the model. Module objects
  # override the __call__ operator so you can call them like functions. When
  # doing so you pass a Tensor of input data to the Module and it produces a 
  # Tensor of output data.
  ## Applies above created neaural network model going from input of x to 
  ## output of y_pred
  y_pred = model(x)
  
  # Compute and print loss. We pass Tensors computing the predicted and true
  # value of y, and the loss function returns a Tensor containing the loss.
  ## Apply loss function to predictions and actual values
  loss = loss_fn(y_pred, y)
  print(t, loss.item())
  
  # Zero the gradients before running the backward pass
  ## Zero gradients because gradient values are cumulative 
  model.zero_grad()
  
  # Backward pass: Compute gradient of the loss wrt all learnable  parameters 
  # of the model. Internally the parameters of each Module are stored in Tensors
  # with requires_grad=True, so this call will compute gradients for all 
  # learnable parameters in the model.
  ## Compute gradients of all learnable parameters wrt models learnable 
  ## parameters
  loss.backward()
  
  # Update the weights using gradient descent. Each parameter is a Tensor, so
  # we can access the gradients like we did before.
  ## Where the model has learnable parameters make adjustments to these 
  ## parameters
  with torch.no_grad():
    ## Loop through each of the models parameters
    for param in model.parameters():
      ## Adjust the weight of each learnable parameter in the model
      param -= learning_rate * param.grad

(0, 651.0094604492188)
(1, 599.4236450195312)
(2, 555.6410522460938)
(3, 518.0137329101562)
(4, 485.1589660644531)
(5, 456.1421203613281)
(6, 430.0504150390625)
(7, 406.13525390625)
(8, 384.2814636230469)
(9, 364.1424865722656)
(10, 345.3360595703125)
(11, 327.63519287109375)
(12, 310.99322509765625)
(13, 295.3150329589844)
(14, 280.5481872558594)
(15, 266.5860290527344)
(16, 253.32142639160156)
(17, 240.66139221191406)
(18, 228.6158447265625)
(19, 217.12298583984375)
(20, 206.11361694335938)
(21, 195.57254028320312)
(22, 185.47674560546875)
(23, 175.8717498779297)
(24, 166.70237731933594)
(25, 157.95936584472656)
(26, 149.5911865234375)
(27, 141.6107635498047)
(28, 133.98545837402344)
(29, 126.72557067871094)
(30, 119.81433868408203)
(31, 113.25624084472656)
(32, 107.03474426269531)
(33, 101.09329223632812)
(34, 95.45187377929688)
(35, 90.10154724121094)
(36, 85.04039764404297)
(37, 80.23994445800781)
(38, 75.69782257080078)
(39, 71.40235900878906)
(40, 67.34630584716797)
(41, 63.5142

PyTorch: Optim
- The optim module in PyTorch provides a set of optimisation algorithms that adjust the parameters much more efficiently than the manual adjustments we do above e.g. param -= learning_rate*grad

In [11]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Use the nn package to define our model and loss function
## Specify two layer linear model use sequential to store the sequential steps
## of the model
model = torch.nn.Sequential(
    ## First layer applies a linear transform to the data taking in D_in layers
    ## and outputting H layers
    torch.nn.Linear(D_in, H),
    ## The ReLU function turns every negative value into a 0 value elementwise
    torch.nn.ReLU(), 
    ## Applies linear transform for the second layer with H input layers and 
    ## D_out output layers.
    torch.nn.Linear(H, D_out), 
)

## Stores the Mean Squared Error loss function in a variable
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimiser that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimisation algorithms. The first argument to the Adam constructor tells the 
# optimiser which tensors it should update.
learning_rate = 1e-4
## Use the Adam optimisation algorithm from the optim package
## model.parameters tells the optimiser what it should be updating.
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
  # Forward pass: Compute predicted y by passing x to the model.
  ## Apply above calculated two layer model to input data to get y_pred output
  y_pred = model(x)
  
  # Compute and print loss
  ## Apply above specified loss function with the actual output and predicted 
  ## output to get the loss of the model wrt the current learnable parameters
  loss = loss_fn(y_pred, y)
  ## Output the loss value and current epoch
  print(t, loss.item())
  
  # Before the backward pass, use the optimiser object to zero all the
  # gradients for the variables it will update (which are the learnable weights
  # of the model). This is because by default, gradients are accumulated in 
  # buffers (i.e. not overwriten) whenever .backward() is called. Check out docs
  # of torch.autograd.backward for more details.
  ## Zero gradients because they are accumulated in the buffers at each epoch
  optimiser.zero_grad()
  
  # Backwards pass: Compute the gradient of the loss wrt model parameters
  ## Calculate gradients of loss functions using backward pass
  loss.backward()
  
  # Calling the step function on an Optimiser makes an update to its parameter
  ## Adjust parameter weights using above specified Adam algorithm
  optimiser.step()

(0, 624.0660400390625)
(1, 607.7557983398438)
(2, 591.9405517578125)
(3, 576.6320190429688)
(4, 561.7581787109375)
(5, 547.2724609375)
(6, 533.2161865234375)
(7, 519.6121215820312)
(8, 506.4575500488281)
(9, 493.75445556640625)
(10, 481.4327697753906)
(11, 469.512939453125)
(12, 458.0037536621094)
(13, 446.80096435546875)
(14, 435.96722412109375)
(15, 425.45361328125)
(16, 415.26776123046875)
(17, 405.4201354980469)
(18, 395.85797119140625)
(19, 386.5498046875)
(20, 377.4913330078125)
(21, 368.7260437011719)
(22, 360.2066955566406)
(23, 351.8802490234375)
(24, 343.72625732421875)
(25, 335.7688293457031)
(26, 328.0138854980469)
(27, 320.4402160644531)
(28, 313.03570556640625)
(29, 305.78472900390625)
(30, 298.7236022949219)
(31, 291.8276672363281)
(32, 285.06787109375)
(33, 278.4387512207031)
(34, 271.96661376953125)
(35, 265.63507080078125)
(36, 259.42364501953125)
(37, 253.34400939941406)
(38, 247.39735412597656)
(39, 241.57379150390625)
(40, 235.85243225097656)
(41, 230.2572784423828

PyTorch: Custom nn Modules
- When you want to specify models that are more complex than a sequence of existing modules by subclassing nn.Module and defining a forward.

In [13]:
## Two layer neural network as a class that inherits the torch.nn.Module class
class TwoLayerNet(torch.nn.Module):
  ## Initialise the models key parameters 
  def __init__(self, D_in, H, D_out):
    """
    In the constructor we instantiate two nn.Linear modules and assign them as
    member variables.
    """
    ## Inherit the __init__ variables from the super class torch.nn.Module
    super(TwoLayerNet, self).__init__()
    ## Specify layer 1 as a linear transform with D_in inputs and H outputs
    self.linear1 = torch.nn.Linear(D_in, H)
    ## Specify layer two as a linear transform with H inputs and D_out outputs
    self.linear2 = torch.nn.Linear(H, D_out)
    
  ## Calculate the modelled output given the model and input data
  def forward(self, x):
    """
    In the forward function we accept a Tensor of input data and we must return
    a Tensor of output data. We can use Modules defined in the constructor as 
    well as arbitrary operators on the Tensors.
    """
    ## Run input data through the first layer applying a ReLU function by 
    ## making every negative value 0.
    h_relu = self.linear1(x).clamp(min=0)
    ## Run the output of the first layer through the second layer to get the
    ## final prediction.
    y_pred = self.linear2(h_relu)
    return y_pred
  
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
## Specify model dimension parameters
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
## Instatiate model class to create a model
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimiser. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two 
# nn.linear modules which are members of the model.
## Specify the Mean Squared Error as the loss function 
criterion = torch.nn.MSELoss(size_average=False)
## Use the Stochastic Gradient Descent algorithm to adjust the model parameters
## each epoch
optimiser = torch.optim.SGD(model.parameters(), lr=1e-4)
## Run through the model 500 epochs
for t in range(500):
  
  ## Epoch Formula
  ## - Make model predictions
  ## - Calculate loss based on predictions
  ## - Zero gradients
  ## - Calculate gradients
  ## - Adjust learnable parameters using gradients
  
  # Forward pass: Compute predicted y by passing x to the model
  ## Make prediction using the model and its current learned parameters
  y_pred = model(x)
  
  # Compute and print loss
  ## Compute the loss using the model predictions and actual values
  loss = criterion(y_pred, y)
  ## Output the loss for each epoch
  print(t, loss.item())
  
  # zero gradients, perform a backward pass, and update the weights.
  ## Zero the gradients as they are cumulative on each epoch
  optimiser.zero_grad()
  ## Calculate the gradients using backpropogation
  loss.backward()
  ## Apply the optimiser to adjust the learnable parameters
  optimiser.step()

(0, 653.7261352539062)
(1, 606.1356201171875)
(2, 564.8939819335938)
(3, 528.5020751953125)
(4, 495.9006042480469)
(5, 466.3626708984375)
(6, 439.81182861328125)
(7, 415.3147888183594)
(8, 392.6358947753906)
(9, 371.58404541015625)
(10, 352.0029602050781)
(11, 333.5412902832031)
(12, 316.1431579589844)
(13, 299.7132263183594)
(14, 284.149658203125)
(15, 269.3353576660156)
(16, 255.28533935546875)
(17, 241.91165161132812)
(18, 229.1586456298828)
(19, 217.0066680908203)
(20, 205.45175170898438)
(21, 194.43385314941406)
(22, 183.95587158203125)
(23, 173.97352600097656)
(24, 164.4420928955078)
(25, 155.34762573242188)
(26, 146.65707397460938)
(27, 138.38970947265625)
(28, 130.55499267578125)
(29, 123.11859130859375)
(30, 116.06204223632812)
(31, 109.38883972167969)
(32, 103.06236267089844)
(33, 97.04481506347656)
(34, 91.34182739257812)
(35, 85.94685363769531)
(36, 80.84688568115234)
(37, 76.02531433105469)
(38, 71.46884155273438)
(39, 67.17565155029297)
(40, 63.13433074951172)
(41, 59.322

PyTorch: Control Flow + Weight Sharing
- Below the program shows how the same middle layer can be used to create an arbitrary number of layers with the same weights
- This is done by getting a random number and using it to create the middle layer that many times

In [14]:
import random

## Specify the dynamic class with potentially 2-5 layers and inheriting parent
## class torch.nn.Module
class DynamicNet(torch.nn.Module):
  ## Initialise parameters of class
  def __init__(self, D_in, H, D_out):
    """
    In the constructor we construct three nn.Linear instances that we will use 
    in the forward pass.
    """
    ## Inherit initialised variables from parent class
    super(DynamicNet, self).__init__()
    ## Create three layers that will be used to construct model, they will all
    ## apply a linear transform to their input.
    self.input_linear = torch.nn.Linear(D_in, H)
    ## The middle layer can potentially be used multiple times 
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)
    
  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute the hidden 
    layer representations.
    
    Since each forward pass builds a dynamioc computation graph, we can use the 
    normal python control-flow operations like loops or conditional statements
    when defining the forward pass of the model.
    
    Here we also see it is perfectly safe to reuse the same Module many times
    when defining a computational graph. This is a big improvement from Lua 
    Torch, where each module can be used only once.
    """
    
    ## Run the input through the first layer, a linear transform. Then apply
    ## a ReLU to turn all negative values to 0.
    h_relu = self.input_linear(x).clamp(min=0)
    ## Create a loop that will be run between 0-3 times randomly and generate
    ## another hidden (middle) layer each time.
    for _ in range(random.randint(0,3 )):
      ## Generate hidden layer for model with a ReLU function that makes each 
      ## negative value 0.
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    ## Run the final layer to get the models predictions  
    y_pred = self.output_linear(h_relu)
    return y_pred
  
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
## Specify model dimension parameters
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
## Generate data using random values
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
## Create a model by creating an instance of the above model class
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an optimiser. Training this strange model 
# with vanilla Stochastic Gradient Descent is tough, so we use momentum
## Specify a Mean Squared Error loss function
criterion = torch.nn.MSELoss(size_average=False)
## Will use Stochastic Gradient Descent algorithm to adjust the learnable 
## parameters of the model 
optimiser = torch.optim.SGD(model.parameters(), lr=1e-4, momentum = 0.9)
## Run 500 epochs to train the model
for t in range(500):
  ## Epoch Formula
  ## - Make predictions
  ## - Calculate loss based on predictions
  ## - Zero gradients
  ## - Calculate gradients
  ## - Adjust learnable parameters using gradients
  
  # Forward pass: Compute predicted y by passing x to the model
  ## Make predictions using the model
  y_pred = model(x)
  
  # Compute and print loss
  ## Calculate loss using the predicted y values and the actual y values
  loss = criterion(y_pred, y)
  ## print the loss function using the epoch number and loss
  print(t, loss.item())
  
  # zero gradients, perform a backward pass, and update the weights.
  ## Zero gradients because they are cumulative
  optimiser.zero_grad()
  ## Calculate gradeints using backpropagation
  loss.backward()
  ## Adjust learnable parameter weigths using Stochastic Gradient Descent
  optimiser.step()    

(0, 625.7408447265625)
(1, 653.1375732421875)
(2, 629.5321655273438)
(3, 601.989990234375)
(4, 589.5699462890625)
(5, 504.5060729980469)
(6, 624.5457153320312)
(7, 413.7421875)
(8, 621.9774169921875)
(9, 536.9540405273438)
(10, 624.7692260742188)
(11, 267.5934143066406)
(12, 619.1369018554688)
(13, 616.3870239257812)
(14, 614.5101928710938)
(15, 483.1741638183594)
(16, 160.67813110351562)
(17, 606.8574829101562)
(18, 603.2153930664062)
(19, 598.3306274414062)
(20, 591.9722900390625)
(21, 563.0183715820312)
(22, 92.56158447265625)
(23, 564.9708251953125)
(24, 520.9458618164062)
(25, 351.7651062011719)
(26, 328.8033447265625)
(27, 86.19957733154297)
(28, 275.37579345703125)
(29, 470.6893005371094)
(30, 223.30712890625)
(31, 91.8984146118164)
(32, 340.6063537597656)
(33, 376.065673828125)
(34, 150.76905822753906)
(35, 315.28515625)
(36, 127.87216186523438)
(37, 84.9565200805664)
(38, 244.17286682128906)
(39, 72.72901916503906)
(40, 61.81351089477539)
(41, 102.38778686523438)
(42, 263.1224