In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

### Create Data Loader

In [2]:
n_samples = 100
input_dim = 3
output_dim = 3
batch_size = 10

x = torch.randn(n_samples, input_dim)
y = torch.randn(n_samples, output_dim)

# Create a TensorDataset
dataset = TensorDataset(x, y)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### 2. Define a Neural Network



In [3]:
class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, output_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# initialize the network
net = Net(input_dim, output_dim)

### 3. Define a Loss function and optimizer
Let's use a Classification Cross-Entropy loss and SGD with momentum.



In [4]:
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

### 4. Train the network

This is when things start to get interesting.
We simply have to loop over our data iterator, and feed the inputs to the
network and optimize.



In [5]:
max_epochs = 100

for epoch in range(max_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      x_batch, y_true_batch = data

      # zero the parameter gradients
      optimizer.zero_grad()

      # forward + backward + optimize
      y_pred_batch = net(x_batch)
      loss = criterion(y_pred_batch, y_true_batch)
      loss.backward()
      optimizer.step()

      # print statistics
      running_loss += loss.item()*batch_size


    ave_loss = running_loss/n_samples
    # print statistics
    print('iter: %d,' %epoch, ' loss: %f' % ave_loss)

print('Finished Training')

iter: 0,  loss: 1.076159
iter: 1,  loss: 1.062353
iter: 2,  loss: 1.052567
iter: 3,  loss: 1.047058
iter: 4,  loss: 1.039210
iter: 5,  loss: 1.034401
iter: 6,  loss: 1.029787
iter: 7,  loss: 1.026330
iter: 8,  loss: 1.022052
iter: 9,  loss: 1.019423
iter: 10,  loss: 1.016291
iter: 11,  loss: 1.013432
iter: 12,  loss: 1.010991
iter: 13,  loss: 1.008025
iter: 14,  loss: 1.005992
iter: 15,  loss: 1.003428
iter: 16,  loss: 1.001189
iter: 17,  loss: 0.998564
iter: 18,  loss: 0.996285
iter: 19,  loss: 0.994500
iter: 20,  loss: 0.992896
iter: 21,  loss: 0.990385
iter: 22,  loss: 0.989453
iter: 23,  loss: 0.986186
iter: 24,  loss: 0.984480
iter: 25,  loss: 0.982497
iter: 26,  loss: 0.980455
iter: 27,  loss: 0.977989
iter: 28,  loss: 0.975705
iter: 29,  loss: 0.973445
iter: 30,  loss: 0.971399
iter: 31,  loss: 0.969911
iter: 32,  loss: 0.966612
iter: 33,  loss: 0.965049
iter: 34,  loss: 0.963404
iter: 35,  loss: 0.960482
iter: 36,  loss: 0.958250
iter: 37,  loss: 0.956955
iter: 38,  loss: 0.954