In [None]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [1]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 28755626.0
1 24108876.0
2 25181006.0
3 28319676.0
4 29905792.0
5 26958088.0
6 19843236.0
7 11983835.0
8 6411629.0
9 3358278.25
10 1898358.375
11 1211776.25
12 870356.125
13 680365.9375
14 559919.5
15 474362.875
16 408684.28125
17 355752.03125
18 311835.03125
19 274745.40625
20 243146.765625
21 216004.046875
22 192520.578125
23 172111.0625
24 154293.828125
25 138695.828125
26 124974.75
27 112871.5
28 102164.0859375
29 92651.6484375
30 84182.90625
31 76615.7109375
32 69841.3203125
33 63763.359375
34 58297.30078125
35 53370.3046875
36 48914.3671875
37 44882.73828125
38 41230.0546875
39 37917.2109375
40 34907.5
41 32167.654296875
42 29670.33203125
43 27390.8515625
44 25307.439453125
45 23401.396484375
46 21655.7265625
47 20057.564453125
48 18591.240234375
49 17243.884765625
50 16004.1806640625
51 14862.640625
52 13810.779296875
53 12840.6884765625
54 11944.9765625
55 11120.037109375
56 10357.7216796875
57 9652.40625
58 8999.396484375
59 8394.322265625
60 7833.341796875
61 7313.03125
62 6

386 0.0008756075403653085
387 0.0008480591350235045
388 0.0008225216879509389
389 0.0007953006424941123
390 0.0007709946949034929
391 0.0007488494738936424
392 0.0007255630334839225
393 0.0007043033838272095
394 0.0006823848816566169
395 0.0006618263432756066
396 0.0006423038430511951
397 0.0006253602914512157
398 0.0006076639983803034
399 0.0005907275481149554
400 0.0005733777652494609
401 0.0005572675145231187
402 0.0005424026167020202
403 0.0005270944093354046
404 0.0005128629854880273
405 0.0004992678295820951
406 0.0004849306424148381
407 0.0004710914217866957
408 0.00045912101631984115
409 0.000447734200861305
410 0.00043565864325501025
411 0.00042440718971192837
412 0.0004136148781981319
413 0.0004023742221761495
414 0.00039307604311034083
415 0.0003830235218629241
416 0.0003737788356374949
417 0.0003635796019807458
418 0.00035523989936336875
419 0.0003464616893325001
420 0.0003369877231307328
421 0.00032964497222565114
422 0.0003219510253984481
423 0.00031377363484352827
424 0.

In [3]:
import torch


dtype = torch.float
device = torch.device("cuda")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 41995528.0
1 44494432.0
2 46319912.0
3 37742540.0
4 22449208.0
5 10110632.0
6 4463806.0
7 2385295.5
8 1597030.5
9 1219603.625
10 988099.25
11 821667.625
12 693087.375
13 590160.25
14 506095.1875
15 436805.875
16 379021.15625
17 330495.96875
18 289478.03125
19 254603.1875
20 224818.03125
21 199205.328125
22 177103.921875
23 157963.34375
24 141292.34375
25 126726.046875
26 113968.59375
27 102736.09375
28 92813.953125
29 84026.1328125
30 76226.328125
31 69286.7890625
32 63101.4375
33 57564.5546875
34 52601.66796875
35 48137.1796875
36 44117.10546875
37 40508.25
38 37250.921875
39 34300.9375
40 31625.84765625
41 29194.3125
42 26980.68359375
43 24960.830078125
44 23116.166015625
45 21429.62109375
46 19886.3046875
47 18470.98828125
48 17174.408203125
49 15982.712890625
50 14886.068359375
51 13876.35546875
52 12944.73046875
53 12084.255859375
54 11289.0244140625
55 10553.205078125
56 9872.052734375
57 9240.818359375
58 8655.0927734375
59 8111.41650390625
60 7605.86474609375
61 7135.79003906

400 0.0019212213810533285
401 0.0018581750337034464
402 0.0017986635211855173
403 0.0017410878790542483
404 0.0016865080688148737
405 0.0016307660844177008
406 0.0015792391495779157
407 0.0015293298056349158
408 0.0014811926521360874
409 0.001434591133147478
410 0.0013911665882915258
411 0.001347595825791359
412 0.00130530446767807
413 0.001264880527742207
414 0.0012273427564650774
415 0.001190938288345933
416 0.0011552071664482355
417 0.0011216626735404134
418 0.0010871641570702195
419 0.001055314438417554
420 0.0010250461054965854
421 0.0009949738159775734
422 0.0009671626030467451
423 0.0009397784015163779
424 0.0009124545031227171
425 0.0008857547654770315
426 0.0008612402016296983
427 0.0008367304690182209
428 0.0008134546224027872
429 0.0007908549159765244
430 0.0007702450966462493
431 0.0007486121030524373
432 0.0007280783611349761
433 0.000707754457835108
434 0.0006891511147841811
435 0.0006694962503388524
436 0.0006522070616483688
437 0.0006350856274366379
438 0.00061891961377