In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [2]:
import torch


# construction a Computational Graph
# Take D_in, H, D_out as the inputs
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor 
        we instantiate two nn.Linear modules 
        and assign them as member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data 
        and we must return a Tensor of output data. 
        We can use Modules defined in the constructor 
        as well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred



In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [4]:
# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)
model
# notice, here is a bias

TwoLayerNet(
  (linear1): Linear(in_features=1000, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=10, bias=True)
)

In [5]:


# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.

# criterion is a loss function.
# In essence, it is a Function Operation.
criterion = torch.nn.MSELoss() # reduction='sum'

optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)


for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad() # prepare to calculate the grad
    loss.backward()       # backward to get the grad
    optimizer.step()      # add the grad to weight

0 1.0750528573989868
1 1.0749151706695557
2 1.0747783184051514
3 1.074641227722168
4 1.0745043754577637
5 1.0743672847747803
6 1.0742309093475342
7 1.0740950107574463
8 1.0739580392837524
9 1.0738216638565063
10 1.073686122894287
11 1.0735502243041992
12 1.0734131336212158
13 1.0732777118682861
14 1.0731408596038818
15 1.0730042457580566
16 1.0728687047958374
17 1.072731375694275
18 1.0725959539413452
19 1.0724600553512573
20 1.0723241567611694
21 1.0721882581710815
22 1.0720518827438354
23 1.0719163417816162
24 1.0717798471450806
25 1.0716444253921509
26 1.0715082883834839
27 1.0713728666305542
28 1.0712363719940186
29 1.071101427078247
30 1.070965051651001
31 1.070828914642334
32 1.070693850517273
33 1.0705584287643433
34 1.0704232454299927
35 1.0702866315841675
36 1.0701513290405273
37 1.0700159072875977
38 1.0698808431625366
39 1.069745421409607
40 1.0696094036102295
41 1.069474458694458
42 1.069339394569397
43 1.0692040920257568
44 1.0690690279006958
45 1.0689338445663452
46 1.068

429 1.019552230834961
430 1.019429326057434
431 1.0193067789077759
432 1.0191847085952759
433 1.0190613269805908
434 1.0189383029937744
435 1.0188158750534058
436 1.018693208694458
437 1.0185710191726685
438 1.0184481143951416
439 1.018325924873352
440 1.0182030200958252
441 1.0180809497833252
442 1.0179591178894043
443 1.0178356170654297
444 1.017714262008667
445 1.0175920724868774
446 1.0174691677093506
447 1.0173470973968506
448 1.0172253847122192
449 1.017102599143982
450 1.0169798135757446
451 1.0168578624725342
452 1.0167365074157715
453 1.0166137218475342
454 1.0164918899536133
455 1.0163695812225342
456 1.0162477493286133
457 1.0161244869232178
458 1.0160032510757446
459 1.0158816576004028
460 1.0157593488693237
461 1.0156375169754028
462 1.0155150890350342
463 1.015393614768982
464 1.0152720212936401
465 1.0151498317718506
466 1.0150272846221924
467 1.0149056911468506
468 1.0147838592529297
469 1.0146623849868774
470 1.014540195465088
471 1.0144197940826416
472 1.0142974853515