In [None]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [1]:
import torch


dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 31611660.0
1 25878570.0
2 23227854.0
3 20424506.0
4 16688835.0
5 12318376.0
6 8426835.0
7 5493042.0
8 3580632.75
9 2401347.0
10 1693632.875
11 1259054.375
12 981637.8125
13 794009.125
14 659672.3125
15 558487.4375
16 479256.40625
17 415384.6875
18 362660.65625
19 318470.28125
20 280998.21875
21 248950.328125
22 221371.53125
23 197473.515625
24 176728.984375
25 158585.28125
26 142659.046875
27 128654.7109375
28 116288.2578125
29 105387.7421875
30 95711.390625
31 87090.7421875
32 79385.625
33 72484.1796875
34 66289.59375
35 60715.3359375
36 55696.21484375
37 51164.6171875
38 47065.734375
39 43353.36328125
40 39987.75390625
41 36927.4921875
42 34141.55859375
43 31601.6796875
44 29280.73828125
45 27158.48046875
46 25216.59765625
47 23435.71484375
48 21800.71484375
49 20299.248046875
50 18918.05078125
51 17644.837890625
52 16471.751953125
53 15389.0517578125
54 14388.26953125
55 13464.18359375
56 12610.7197265625
57 11819.9521484375
58 11086.6064453125
59 10405.951171875
60 9773.8671875
6

384 0.03898973762989044
385 0.03767462074756622
386 0.03640568628907204
387 0.03516550362110138
388 0.03399255871772766
389 0.032846204936504364
390 0.03174227103590965
391 0.03067142702639103
392 0.029646070674061775
393 0.02865086868405342
394 0.027681272476911545
395 0.026754969730973244
396 0.025853291153907776
397 0.024982944130897522
398 0.024146348237991333
399 0.023339372128248215
400 0.022568203508853912
401 0.02180424891412258
402 0.02107037603855133
403 0.020375985652208328
404 0.01969251222908497
405 0.019042039290070534
406 0.018416130915284157
407 0.01780114881694317
408 0.017206551507115364
409 0.016637304797768593
410 0.016080889850854874
411 0.015545572154223919
412 0.0150283919647336
413 0.014531505294144154
414 0.014055818319320679
415 0.013587349094450474
416 0.013142485171556473
417 0.012707649730145931
418 0.012287919409573078
419 0.011884867213666439
420 0.011493369936943054
421 0.011118670925498009
422 0.01075703650712967
423 0.010403238236904144
424 0.010066032