
PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Get reproducible results
torch.manual_seed(0)

# Define the model
class MLP(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden_layer_nodes, num_outputs):
        # Initialize super class
        super().__init__()

        # Build model using Sequential container
        self.model = nn.Sequential(
            # Add hidden layer 
            nn.Linear(num_inputs, num_hidden_layer_nodes),
            # Add ReLU activation
            nn.ReLU(),
            # Add output layer
            nn.Linear(num_hidden_layer_nodes, num_outputs)
        )

    def forward(self, x):
        # Forward pass
        return self.model(x)

# Num data points
num_data = 1000

# Network parameters
num_inputs = 1000
num_hidden_layer_nodes = 100
num_outputs = 10

# Training parameters
num_epochs = 100 

# Create input and output tensors
x = torch.randn(num_data, num_inputs)
y = torch.randn(num_data, num_outputs)

# Construct model
model = MLP(num_inputs, num_hidden_layer_nodes, num_outputs)

# Define loss function
loss_function = nn.MSELoss(reduction='sum')

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=1e-4)


for t in range(num_epochs):

    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = loss_function(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()

    # Calculate gradient using backward pass
    loss.backward()

    # Update model parameters (weights)
    optimizer.step()

0 10581.46484375
1 9755.71875
2 9161.30078125
3 8637.4951171875
4 8135.08154296875
5 7634.916015625
6 7127.3212890625
7 6610.97021484375
8 6087.3720703125
9 5563.39599609375
10 5048.94873046875
11 4551.427734375
12 4078.5126953125
13 3634.7451171875
14 3225.67626953125
15 2852.95361328125
16 2516.31201171875
17 2214.47265625
18 1946.53515625
19 1709.1043701171875
20 1499.0308837890625
21 1313.60205078125
22 1151.3802490234375
23 1009.0332641601562
24 883.9778442382812
25 774.8359375
26 679.2198486328125
27 596.10205078125
28 524.0062255859375
29 462.0333251953125
30 410.55902099609375
31 370.9833984375
32 349.6308288574219
33 361.6944580078125
34 444.9813232421875
35 685.0106811523438
36 1272.0538330078125
37 2519.06689453125
38 4693.54833984375
39 6812.6015625
40 6350.90185546875
41 3024.81298828125
42 1011.5364990234375
43 448.4580078125
44 291.2738037109375
45 218.35183715820312
46 172.94004821777344
47 140.86444091796875
48 116.9203872680664
49 98.4088134765625
50 83.75098419189453