In [1]:
# -*- coding: utf-8 -*-
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 718.9425048828125
1 664.4103393554688
2 617.6167602539062
3 576.3728637695312
4 540.0093383789062
5 507.81072998046875
6 478.4939880371094
7 451.6659240722656
8 427.25128173828125
9 404.3513488769531
10 382.84344482421875
11 362.62799072265625
12 343.5057373046875
13 325.40948486328125
14 308.1457214355469
15 291.6558837890625
16 275.8414306640625
17 260.78240966796875
18 246.4379425048828
19 232.77313232421875
20 219.7017822265625
21 207.1685791015625
22 195.2937469482422
23 183.9872589111328
24 173.1946258544922
25 162.96463012695312
26 153.26315307617188
27 144.05044555664062
28 135.34197998046875
29 127.1148910522461
30 119.33745574951172
31 111.97636413574219
32 105.01868438720703
33 98.42767333984375
34 92.21357727050781
35 86.36959838867188
36 80.87628936767578
37 75.7200698852539
38 70.86865997314453
39 66.31451416015625
40 62.04060363769531
41 58.0314826965332
42 54.275875091552734
43 50.7597770690918
44 47.47849655151367
45 44.402896881103516
46 41.53129577636719
47 38.8487