# Test path integrator RNN

In [1]:
import numpy as np
import torch

from motion import load_batch

In [2]:
class PathRNN(torch.nn.Module):
    
    def __init__(self, n_units):

        super(PathRNN, self).__init__()

        self.n_units = n_units

        # RNN Layer
        self.rnn = torch.nn.RNN(input_size=2, hidden_size=n_units, num_layers=1, nonlinearity='tanh', batch_first=True)

        # Output layer
        self.output = torch.nn.Linear(n_units, 2)
    
    def forward(self, vel):
        
        # Run RNN on velocity sequences to get hidden unit values
        u_vals, _ = self.rnn(vel)
        
        # Apply output weights to get estimated position
        pos_est = self.output(u_vals)
        
        return pos_est, u_vals

In [3]:
# Load position and velocity data from file
pos_arr, vel_arr = load_batch('data/test_batch.npz')

# Move data into PyTorch tensors
pos = torch.from_numpy(pos_arr).float()
vel = torch.from_numpy(vel_arr).float()

In [4]:
# Instantiate model
model = PathRNN(n_units=100)

# We'll also set the model to the device that we defined earlier (default is CPU)
#model = model.to(device)

# Define hyperparameters
n_epochs = 100
lr=0.01

# Define Loss, Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [5]:
# Training Run
for epoch in range(1, n_epochs + 1):

    # Clear gradients from previous epoch
    optimizer.zero_grad()

    # Compute loss
    pos_est, u_vals = model(vel)
    loss = criterion(pos_est, pos)

    # Compute gradient via backprop
    loss.backward()

    # Update model parameters
    optimizer.step()
    
    if epoch % 10 == 0:
        print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ')
        print("Loss: {:.4f}".format(loss.item()))

Epoch: 10/100............. Loss: 0.3377
Epoch: 20/100............. Loss: 0.3992
Epoch: 30/100............. Loss: 0.3337
Epoch: 40/100............. Loss: 0.3338
Epoch: 50/100............. Loss: 0.3317
Epoch: 60/100............. Loss: 0.3296
Epoch: 70/100............. Loss: 0.3299
Epoch: 80/100............. Loss: 0.3324
Epoch: 90/100............. Loss: 0.3299
Epoch: 100/100............. Loss: 0.3307
