# PyTorch MLP example

In this notebook we will show an example of how to implement an MLP in PyTorch. Firstly, import PyTorch:

In [None]:
import torch

Define the training data with `torch.tensor`. Most often you have to use `.float()` to avoid error messages when training.

In [None]:
# Data for XOR
x = torch.tensor([[0, 0],
                  [0, 1],
                  [1, 0],
                  [1, 1]]).float()
t = torch.tensor([[0],
                  [1],
                  [1],
                  [0]]).float()

A network structure is created by defining a class which extends the `torch.nn.Module` class. In the initialization, the layers should be defined. The linear layer is the one you are familiar with from previous exercises. It can be initialized with `torch.nn.Linear(n_inputs, n_units)`. Here, we create a network with a single hidden layer. The `forward` function defines how the forward propagation, that is how to compute the network output given the input and layers. PyTorch will automatically derive the gradients for backpropagation using this forward definition.

In [None]:
# Define neural network
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = torch.sigmoid(self.hidden(x))
        x = self.predict(x)
        return x

Now we can instantiate our model using the class defined above:

In [None]:
model = Net(n_feature=2, n_hidden=4, n_output=1)

We define the loss function to use, in this case the Mean Square Error (MSE).

In [None]:
loss_func = torch.nn.MSELoss()

We also have to define the optimizer. We will use the stochastic gradient descent as in our own implementation, with learning rate 0.1.

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Next we define a loop to train the network. This consists of 5 steps, each explained in the comments below:

In [None]:
num_epochs = 100

for _ in range(num_epochs):
    prediction = model(x) # Forward pass prediction. Saves intermediary values required for backwards pass
    loss = loss_func(prediction, t) # Computes the loss for each example, using the loss function defined above
    optimizer.zero_grad() # Clears gradients from previous iteration
    loss.backward() # Backpropagation of errors through the network
    optimizer.step() # Updating weights

    #print("Prediction: ", prediction.detach().numpy())
    print("Loss: ", loss.detach().numpy())


## Additional optimizer options
For faster convergence, the Adam optimizer can be useful. It employs adaptive learning rates specific to each weight. In this exampe you can see that the loss decreases much faster when using Adam:

In [None]:
model = Net(n_feature=2, n_hidden=4, n_output=1)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)

num_epochs = 100

for _ in range(num_epochs):
    prediction = model(x)
    loss = loss_func(prediction, t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("Loss: ", loss.detach().numpy())

# Saving and loading models
To save the model weights

In [None]:
torch.save(model.state_dict(), 'filename.pth')

To load into a moedl:

In [None]:
model.load_state_dict(torch.load('filename.pth'))

In [2]:
import pickle

with open('motor_angles.pickle', 'rb') as data:
    knownPos = pickle.load(data)
print(knownPos)

with open('xy_pixel.pickle', 'rb') as data:
    xy_pixel = pickle.load(data)
print(xy_pixel)

[(-90, 0), (-75, 0), (-60, 0), (-45, 0), (-30, 0), (-15, 0), (0, 0), (15, 0), (30, 0), (45, 0), (60, 0), (75, 0), (90, 0), (90, -15), (75, -15), (60, -15), (45, -15), (30, -15), (15, -15), (0, -15), (-15, -15), (-30, -15), (-45, -15), (-60, -15), (-75, -15), (-90, -15), (-90, -30), (-75, -30), (-60, -30), (-45, -30), (-30, -30), (-15, -30), (0, -30), (15, -30), (30, -30), (45, -30), (60, -30), (75, -30), (90, -30), (90, -45), (75, -45), (60, -45), (45, -45), (30, -45), (15, -45), (0, -45), (-15, -45), (-30, -45), (-45, -45), (-60, -45), (-75, -45), (-90, -45), (-90, -60), (-75, -60), (-60, -60), (-45, -60), (-30, -60), (-15, -60), (0, -60), (15, -60), (30, -60), (45, -60), (60, -60), (75, -60), (90, -60), (90, -75), (75, -75), (60, -75), (45, -75), (30, -75), (15, -75), (0, -75), (-15, -75), (-30, -75), (-45, -75), (-60, -75), (-75, -75), (-90, -75), (-90, -90), (-75, -90), (-60, -90), (-45, -90), (-30, -90), (-15, -90), (0, -90), (15, -90), (30, -90), (45, -90), (60, -90), (75, -90), 