In [4]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 32096646.0
1 27609032.0
2 25178188.0
3 21615508.0
4 16556869.0
5 11244575.0
6 7040286.0
7 4292120.5
8 2685276.75
9 1780115.625
10 1263432.75
11 953287.4375
12 753568.5
13 615714.625
14 514281.03125
15 436035.0
16 373550.59375
17 322472.8125
18 280155.03125
19 244596.578125
20 214429.96875
21 188725.359375
22 166663.515625
23 147621.4375
24 131130.28125
25 116788.4921875
26 104279.828125
27 93317.7578125
28 83688.0703125
29 75202.8125
30 67705.1796875
31 61063.578125
32 55181.5703125
33 49949.93359375
34 45282.19140625
35 41112.6796875
36 37378.12890625
37 34029.64453125
38 31021.404296875
39 28313.576171875
40 25873.728515625
41 23671.248046875
42 21680.033203125
43 19876.685546875
44 18243.1875
45 16762.109375
46 15416.5859375
47 14191.87109375
48 13075.8828125
49 12058.5517578125
50 11130.64453125
51 10281.3916015625
52 9503.7373046875
53 8791.787109375
54 8139.39892578125
55 7540.80322265625
56 6990.599609375
57 6484.77099609375
58 6019.40625
59 5590.89013671875
60 5196.1416015625