In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in,  device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H,  device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)


# w1 and w2 are the parameters to learn

print(x.shape)
print(y.shape)
print(w1.shape)
print(w2.shape)

torch.Size([64, 1000])
torch.Size([64, 10])
torch.Size([1000, 100])
torch.Size([100, 10])


In [3]:
learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1) # mm: matrix multiple
    h_relu = h.clamp(min=0) # interesting
    y_pred = h_relu.mm(w2)  
    # Here the Computational Graph is ready

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item() # only scalar can get the item.
    
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 35708764.0
1 32332090.0
2 32699736.0
3 30881566.0
4 24825464.0
5 16349859.0
6 9271131.0
7 4938620.5
8 2732648.5
9 1667830.5
10 1136995.625
11 847166.625
12 669519.75
13 548001.0625
14 458182.90625
15 388183.8125
16 331961.75
17 285753.59375
18 247277.1875
19 214958.8125
20 187656.0
21 164422.9375
22 144542.140625
23 127456.6640625
24 112716.265625
25 99961.25
26 88875.8984375
27 79210.6484375
28 70763.4609375
29 63361.9140625
30 56852.9375
31 51112.9921875
32 46040.21484375
33 41544.28515625
34 37553.75390625
35 34010.23046875
36 30863.4296875
37 28049.958984375
38 25530.400390625
39 23268.86328125
40 21239.21875
41 19412.439453125
42 17763.892578125
43 16274.515625
44 14928.44921875
45 13710.603515625
46 12605.083984375
47 11600.353515625
48 10686.53515625
49 9853.748046875
50 9094.2080078125
51 8400.734375
52 7766.4013671875
53 7185.8037109375
54 6654.00244140625
55 6166.03369140625
56 5718.3740234375
57 5307.07421875
58 4929.17333984375
59 4581.14013671875
60 4260.4287109375
61 39

421 0.00011855332559207454
422 0.0001162412590929307
423 0.00011404048564145342
424 0.00011192861711606383
425 0.0001097627027775161
426 0.00010747944907052442
427 0.00010511450818739831
428 0.00010320331057300791
429 0.0001014163572108373
430 9.953259723260999e-05
431 9.742514521349221e-05
432 9.589003457222134e-05
433 9.37196018639952e-05
434 9.197401232086122e-05
435 9.020479046739638e-05
436 8.817666093818843e-05
437 8.659202285343781e-05
438 8.523035648977384e-05
439 8.334487938554958e-05
440 8.212003012886271e-05
441 8.067401358857751e-05
442 7.904919038992375e-05
443 7.773414836265147e-05
444 7.647225720575079e-05
445 7.50350154703483e-05
446 7.363859185716137e-05
447 7.222645945148543e-05
448 7.097073103068396e-05
449 6.973822019062936e-05
450 6.858133565401658e-05
451 6.725304410792887e-05
452 6.603878136957064e-05
453 6.525038043037057e-05
454 6.402781582437456e-05
455 6.293533806456253e-05
456 6.188894622027874e-05
457 6.10432107350789e-05
458 5.9787864302052185e-05
459 5.91