In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [4]:
import torch


#dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 38696376.0
1 37189804.0
2 35755600.0
3 29463194.0
4 19699762.0
5 11028149.0
6 5836564.5
7 3305322.5
8 2129290.0
9 1540972.75
10 1204123.25
11 982495.8125
12 821221.125
13 696676.5625
14 596738.1875
15 514883.5625
16 446979.46875
17 390068.28125
18 341898.21875
19 300927.46875
20 265914.65625
21 235771.25
22 209720.203125
23 187101.890625
24 167360.734375
25 150065.8125
26 134862.78125
27 121456.6640625
28 109600.3203125
29 99081.171875
30 89710.0703125
31 81371.4765625
32 73923.765625
33 67256.1875
34 61275.91015625
35 55900.703125
36 51057.7890625
37 46689.16796875
38 42742.20703125
39 39167.71875
40 35937.62890625
41 33009.55078125
42 30347.04296875
43 27923.865234375
44 25714.2734375
45 23697.154296875
46 21854.123046875
47 20169.408203125
48 18626.39453125
49 17212.44921875
50 15915.154296875
51 14724.27734375
52 13630.71875
53 12624.666015625
54 11699.095703125
55 10847.3037109375
56 10062.1123046875
57 9338.041015625
58 8669.8603515625
59 8053.14501953125
60 7483.8623046875
61 