In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 714.5966796875
1 660.20751953125
2 613.3257446289062
3 572.3291625976562
4 536.3473510742188
5 503.8334655761719
6 474.3138732910156
7 447.28802490234375
8 422.50494384765625
9 399.3360900878906
10 377.5850524902344
11 357.0859680175781
12 337.6787414550781
13 319.30908203125
14 301.9495849609375
15 285.3998718261719
16 269.61663818359375
17 254.57537841796875
18 240.2565460205078
19 226.5531005859375
20 213.5130157470703
21 201.05709838867188
22 189.21868896484375
23 177.98974609375
24 167.34657287597656
25 157.2512969970703
26 147.66236877441406
27 138.57728576660156
28 129.95602416992188
29 121.80867004394531
30 114.09290313720703
31 106.8165054321289
32 99.95207214355469
33 93.50264739990234
34 87.41598510742188
35 81.71761322021484
36 76.37152099609375
37 71.3697738647461
38 66.68040466308594
39 62.297203063964844
40 58.198116302490234
41 54.37871170043945
42 50.81940841674805
43 47.493309020996094
44 44.392906188964844
45 41.50539779663086
46 38.815494537353516
47 36.3060531616

366 0.00017065179417841136
367 0.000165465273312293
368 0.00016044446965679526
369 0.00015557120786979795
370 0.00015085277846083045
371 0.00014627586642745882
372 0.000141848242492415
373 0.00013754426618106663
374 0.0001333788241026923
375 0.0001293342502322048
376 0.00012541987234726548
377 0.00012161985068814829
378 0.00011793510930147022
379 0.0001143710978794843
380 0.00011090745829278603
381 0.00010755359835457057
382 0.00010430270049255341
383 0.00010115351324202493
384 9.80954137048684e-05
385 9.513088298263028e-05
386 9.226117981597781e-05
387 8.94733821041882e-05
388 8.677432197146118e-05
389 8.415147749474272e-05
390 8.161572623066604e-05
391 7.915588503237814e-05
392 7.676835230085999e-05
393 7.44534918339923e-05
394 7.220983388833702e-05
395 7.003436621744186e-05
396 6.792481144657359e-05
397 6.587823736481369e-05
398 6.389475311152637e-05
399 6.197148468345404e-05
400 6.010639845044352e-05
401 5.8297529903938994e-05
402 5.654422420775518e-05
403 5.484209759742953e-05
404