In [1]:
import torch

In [2]:
dtype = torch.FloatTensor

In [3]:
batch_size = 64  # batch size
input_dim = 1000  # input dimension
hidden_dim = 100  # hidden dimension
output_dim = 10   # output dimension

In [4]:
# input and output data
x = torch.randn(batch_size, input_dim).type(dtype)
y = torch.randn(batch_size, output_dim).type(dtype)

# Randomly initialize weights
w1 = torch.randn(input_dim, hidden_dim).type(dtype)
w2 = torch.randn(hidden_dim, output_dim).type(dtype)

In [5]:
learning_rate = 1e-6
for t in range(500):
    # Forward pass: predict y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 32629818.697378688
1 29639069.575577572
2 31565690.009640925
3 32405940.646661058
4 28353118.857581124
5 19716795.435692035
6 11201645.904863745
7 5685191.090651594
8 2966191.8320646994
9 1738024.3777328655
10 1172917.1127420184
11 880926.3565512658
12 706738.9071125507
13 587910.9585375916
14 498969.72646291554
15 428615.83280366845
16 371242.8594270777
17 323790.5308198726
18 283923.64515687595
19 250095.35423040146
20 221167.2744523422
21 196276.01938981423
22 174746.74711279152
23 156047.52482570393
24 139733.42658157414
25 125445.28035887978
26 112878.81285783146
27 101798.48348675203
28 92003.75664738257
29 83311.31713879586
30 75575.60022745095
31 68666.66589825135
32 62489.17666272074
33 56951.80910727063
34 51980.75553461498
35 47510.75894754102
36 43479.82513540902
37 39837.380138760855
38 36542.52530213431
39 33555.43675281806
40 30842.561831531537
41 28375.731322248335
42 26129.751894926565
43 24084.823317346883
44 22218.837466398836
45 20513.572041687992
46 18953.3333833