In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 666.1155395507812
1 614.9331665039062
2 571.3746337890625
3 534.205078125
4 501.3841247558594
5 472.0885925292969
6 445.50714111328125
7 421.3489074707031
8 399.1278076171875
9 378.44097900390625
10 359.18505859375
11 341.1101379394531
12 324.0578308105469
13 307.8788757324219
14 292.4938049316406
15 277.8373718261719
16 263.9089660644531
17 250.6515350341797
18 238.04371643066406
19 226.02825927734375
20 214.54881286621094
21 203.56126403808594
22 193.07276916503906
23 183.07032775878906
24 173.50726318359375
25 164.3750457763672
26 155.64903259277344
27 147.30453491210938
28 139.35311889648438
29 131.7988739013672
30 124.63439178466797
31 117.83165740966797
32 111.35284423828125
33 105.1778335571289
34 99.32102966308594
35 93.77847290039062
36 88.52814483642578
37 83.56310272216797
38 78.86182403564453
39 74.41251373291016
40 70.20599365234375
41 66.24284362792969
42 62.50136947631836
43 58.969913482666016
44 55.632301330566406
45 52.48615264892578
46 49.52914810180664
47 46.749217

495 2.17585784412222e-06
496 2.1165692487556953e-06
497 2.0592478904291056e-06
498 2.003363988478668e-06
499 1.9491694729367737e-06
