In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    # .mm matrix multiply I guess! 
    h = x.mm(w1)
    # .clamp put all values between min and max (below effectively ReLU)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 23326258.0
1 16143872.0
2 13022788.0
3 11671908.0
4 11029584.0
5 10514736.0
6 9796171.0
7 8739341.0
8 7410247.5
9 5972921.0
10 4616528.0
11 3453905.0
12 2532668.75
13 1839864.75
14 1337836.25
15 981140.75
16 730009.6875
17 553191.375
18 427885.375
19 338165.1875
20 272818.78125
21 224383.34375
22 187760.96875
23 159538.90625
24 137290.03125
25 119388.859375
26 104721.890625
27 92510.828125
28 82198.078125
29 73386.1640625
30 65782.296875
31 59163.5234375
32 53360.296875
33 48242.328125
34 43709.6953125
35 39679.9609375
36 36081.6640625
37 32859.77734375
38 29967.912109375
39 27365.10546875
40 25017.54296875
41 22897.490234375
42 20979.9375
43 19241.875
44 17664.57421875
45 16230.7265625
46 14925.4453125
47 13736.34765625
48 12651.7705078125
49 11661.818359375
50 10757.2119140625
51 9931.298828125
52 9176.580078125
53 8484.849609375
54 7850.38818359375
55 7267.482421875
56 6731.7861328125
57 6239.2861328125
58 5785.95458984375
59 5368.26220703125
60 4983.37451171875
61 4628.4306640625

387 0.00024423195281997323
388 0.0002381547965342179
389 0.00023174899979494512
390 0.00022475192963611335
391 0.00021933838434051722
392 0.0002136503899237141
393 0.00020775805751327425
394 0.0002024164132308215
395 0.00019731096108444035
396 0.0001925286342157051
397 0.00018757829093374312
398 0.00018334665219299495
399 0.00017903264961205423
400 0.0001749873481458053
401 0.000170141996932216
402 0.0001665024901740253
403 0.0001624705910217017
404 0.00015850427735131234
405 0.00015456909022759646
406 0.00015101867029443383
407 0.0001481139042880386
408 0.0001447455579182133
409 0.00014131466741673648
410 0.00013798782310914248
411 0.0001345923519693315
412 0.00013157790817786008
413 0.00012914705439470708
414 0.00012656653416343033
415 0.00012342911213636398
416 0.00012089009396731853
417 0.00011853745672851801
418 0.00011577295663300902
419 0.00011393283057259396
420 0.00011144752352265641
421 0.00010917444160440937
422 0.00010694429511204362
423 0.00010479355842107907
424 0.0001026