In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [2]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 680.4923706054688
1 628.7054443359375
2 583.7244873046875
3 543.9805297851562
4 508.8058776855469
5 477.3746337890625
6 448.9125061035156
7 422.97064208984375
8 399.0166015625
9 376.7982482910156
10 356.11578369140625
11 336.7657470703125
12 318.4747314453125
13 301.2025146484375
14 284.9186706542969
15 269.3970031738281
16 254.66268920898438
17 240.52073669433594
18 227.05731201171875
19 214.20358276367188
20 201.99855041503906
21 190.4016876220703
22 179.37213134765625
23 168.91091918945312
24 158.99951171875
25 149.59127807617188
26 140.64991760253906
27 132.18324279785156
28 124.17243957519531
29 116.5859146118164
30 109.40998840332031
31 102.6418228149414
32 96.27250671386719
33 90.27861022949219
34 84.64252471923828
35 79.33430480957031
36 74.34005737304688
37 69.64608001708984
38 65.23111724853516
39 61.1014289855957
40 57.23769760131836
41 53.61716079711914
42 50.23176574707031
43 47.06723403930664
44 44.112125396728516
45 41.34918212890625
46 38.77133560180664
47 36.36462020

377 0.00020123121794313192
378 0.0001961381931323558
379 0.00019117635383736342
380 0.0001863546931417659
381 0.0001816537551349029
382 0.000177073321538046
383 0.00017261115135625005
384 0.00016826213686726987
385 0.00016403294284828007
386 0.0001599079550942406
387 0.00015589233953505754
388 0.00015198146866168827
389 0.00014817265036981553
390 0.00014445683336816728
391 0.00014084168651606888
392 0.00013731194485444576
393 0.0001338821602985263
394 0.00013054395094513893
395 0.00012728488945867866
396 0.00012411194620653987
397 0.00012101800530217588
398 0.00011800177890108898
399 0.00011506797454785556
400 0.00011220949818380177
401 0.00010942331573460251
402 0.00010670533811207861
403 0.00010406126239104196
404 0.00010147484135814011
405 9.896038682200015e-05
406 9.650989522924647e-05
407 9.412135113961995e-05
408 9.179361950373277e-05
409 8.952522330218926e-05
410 8.731526759220287e-05
411 8.516465459251776e-05
412 8.306097879540175e-05
413 8.101583080133423e-05
414 7.90179474279