In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 32732327.14804079
1 31959849.174433872
2 33936128.95347166
3 32954808.560527187
4 26440663.680798486
5 16918710.633705214
6 9155871.351095974
7 4690412.676198886
8 2588810.785600813
9 1637358.5130011663
10 1178166.9129206857
11 924067.7060414937
12 760996.2238812279
13 643509.8325722283
14 552218.5898687763
15 478246.5307192489
16 416968.17320154374
17 365507.105288386
18 321983.8503787474
19 284797.96655355196
20 252874.499491791
21 225322.14042809757
22 201390.4226730317
23 180597.93072277558
24 162430.00910892233
25 146519.5826552016
26 132487.61915250658
27 120068.70731719502
28 109046.40392197177
29 99247.55397811331
30 90503.39171972875
31 82678.34353545523
32 75669.9902013536
33 69369.3828096719
34 63690.44936084296
35 58561.311304256196
36 53922.6346183978
37 49718.66824318476
38 45899.90550542042
39 42431.17726553301
40 39269.21245037114
41 36383.54002453934
42 33750.63795243988
43 31340.057483040146
44 29130.001493400312
45 27101.53477228889
46 25235.29695970792
47 23518.12

442 0.000906189043997821
443 0.0008815020188945311
444 0.0008568177077218908
445 0.0008317075397029638
446 0.0008082270085273402
447 0.0007867985437337022
448 0.0007662101772507174
449 0.0007468557313156654
450 0.0007259505342097627
451 0.0007071580933376009
452 0.0006864159854645391
453 0.0006690529060634082
454 0.0006531885991841319
455 0.0006361500237617862
456 0.0006198657938561031
457 0.0006056695198673051
458 0.0005890168922289274
459 0.0005761538157770946
460 0.0005605204578594813
461 0.0005464689488728497
462 0.000533508742929234
463 0.0005205800015035122
464 0.0005074438991273938
465 0.0004964117449616223
466 0.00048485595957969974
467 0.0004737099801234712
468 0.00046156150474185864
469 0.00045113438865374
470 0.00044074937760059385
471 0.00043115126558851413
472 0.0004209850311474722
473 0.00041180737286490576
474 0.00040255537443119327
475 0.00039401766333631294
476 0.0003846544311268546
477 0.00037611653391023125
478 0.00036763323860806996
479 0.00036005606460470796
480 0.