In [1]:
import os
import pickle
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [22]:
# constants
in_dim = 7
out_dim = 6


In [23]:
# Hyper-params
batch_size = 32
hidden_dim = 128
learning_rate = 0.001
num_epochs = 200

In [24]:
model_dir = 'models/NN/'

## Step 1: Load data


In [25]:
data_path = 'dataset/pps_20_5_aug18/'

In [26]:
train_df = pd.read_pickle(os.path.join(data_path, 'train.pkl'))
val_df = pd.read_pickle(os.path.join(data_path, 'val.pkl'))
test_df = pd.read_pickle(os.path.join(data_path, 'test.pkl'))

In [27]:
with open(os.path.join(data_path, 'stats.pkl'), 'rb') as fp:
    target_stats = pickle.load(fp)

In [28]:
class CustomTrajDataset(Dataset):
    def __init__(self, traj_df):
        positions = torch.from_numpy(np.array(list(traj_df['position']))).type(torch.FloatTensor)
        orientations = torch.from_numpy(np.array(list(traj_df['orientation']))).type(torch.FloatTensor)
        self.x = torch.cat((positions, orientations), 2)
        
        forces = torch.from_numpy(np.array(list(traj_df['net_force']))).type(torch.FloatTensor)
        torques = torch.from_numpy(np.array(list(traj_df['net_torque']))).type(torch.FloatTensor)
                                   
        self.y = torch.cat((forces, torques), 2)

    
    def __len__(self):
        return len(self.x)
                                   
    def __getitem__(self, i):
        return self.x[i], self.y[i]
                                   

In [29]:
train_dataset = CustomTrajDataset(train_df)
valid_dataset = CustomTrajDataset(val_df)
test_dataset = CustomTrajDataset(test_df)

In [30]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

## Step 2: Build NN model

In [31]:
class NN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(NN, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        
        self.layers = nn.Sequential(
              nn.Linear(self.in_dim, self.hidden_dim),
              nn.ReLU(),
              nn.Linear(self.hidden_dim, self.hidden_dim),
              nn.ReLU(),
              nn.Linear(self.hidden_dim, self.out_dim)
        )
        
    def forward(self, x):
        return self.layers(x)


In [32]:
model = NN(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=out_dim)
model.to(device)

NN(
  (layers): Sequential(
    (0): Linear(in_features=7, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=6, bias=True)
  )
)

In [33]:
# Make sure model is on GPU
assert next(model.parameters()).is_cuda

## Step 3: Setting loss function,  optimizer , scheduler

In [34]:
mse_loss = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.7, patience=5, min_lr=0.00001)

In [35]:
def calculate_error(prediction, target):
    force_std = torch.from_numpy(target_stats['force_std']).type(torch.FloatTensor).to(device)
    predicted_force = prediction[:, :, :3] * force_std
    target_force = target[:, :, :3] * force_std
    force_error = (predicted_force - target_force).abs().sum()
    
    torque_std = torch.from_numpy(target_stats['torque_std']).type(torch.FloatTensor).to(device)
    predicted_torque = prediction[:, :, 3:] * torque_std
    target_torque = target[:, :, 3:] * torque_std
    
    torque_error = (predicted_torque - target_torque).abs().sum()
    
    return force_error.item(), torque_error.item()
    

## Step 4: Training & validation

In [36]:
def train():
    model.train()
    train_loss = 0.
    for i, (feature_tensor, target_tensor) in enumerate(train_dataloader):
        feature_tensor = feature_tensor.to(device)
        target_tensor = target_tensor.to(device)
        
        optimizer.zero_grad()
        prediction = model(feature_tensor)
        loss = mse_loss(prediction, target_tensor)
        train_loss += loss.item() * target_tensor.shape[0]
        
        loss.backward()
        optimizer.step()
            
    train_loss = train_loss / len(train_dataloader)
    
    return train_loss

In [37]:
def validation(model, data_loader):
    model.eval()
    with torch.no_grad():
        valid_loss = 0.
        valid_force_error = 0.
        valid_torque_error = 0.
        for i, (feature_tensor, target_tensor) in enumerate(data_loader):
            feature_tensor = feature_tensor.to(device)
            target_tensor = target_tensor.to(device)
            prediction = model(feature_tensor)
            loss = mse_loss(prediction, target_tensor)
            valid_loss += loss.item() * target_tensor.shape[0]
            force_error, torque_error = calculate_error(prediction, target_tensor)
            valid_force_error += force_error
            valid_torque_error += torque_error
        
        return valid_force_error/len(data_loader), valid_torque_error/len(data_loader)
            

In [40]:
best_val_error = None
best_model_path = None
for epoch in range(num_epochs):
    
    train_loss = train()
    val_force_error, val_torque_error = validation(model, valid_dataloader)
    val_error = val_force_error + val_torque_error
    scheduler.step(val_error)
    
    print('epoch {}/{}: \n\t train_loss: {}, \n\t val_error: {}, \n\t val_force_error: {}, \n \t val_torque_error: {}'.
          format(epoch + 1, num_epochs, train_loss, val_error, val_force_error, val_torque_error))
    
    
    if best_val_error is None:
        best_val_error = val_error
        
    if val_error <= best_val_error:
        best_val_error = val_error
        best_model_path = os.path.join(model_dir, 'e_' + str(epoch) +'h_' + str(hidden_dim) + '.chkpnt')
        torch.save(model, best_model_path)
        print('best_val_error: {}, best_epoch: {}'.format(best_val_error, epoch))
    print('*********************************************************')    

epoch 1/200: 
	 train_loss: 1.0265701711177826, 
	 val_error: 1.5504864963239936e+23, 
	 val_force_error: 1.4557442126574695e+23, 
 	 val_torque_error: 9.474228366652402e+21
best_val_error: 1.5504864963239936e+23, best_epoch: 0
*********************************************************
epoch 2/200: 
	 train_loss: 0.9774841296672822, 
	 val_error: 1.589680674478017e+23, 
	 val_force_error: 1.4925240170828662e+23, 
 	 val_torque_error: 9.71566573951507e+21
*********************************************************
epoch 3/200: 
	 train_loss: 0.9473240923881531, 
	 val_error: 1.6220682657601484e+23, 
	 val_force_error: 1.522177190349923e+23, 
 	 val_torque_error: 9.989107541022536e+21
*********************************************************
epoch 4/200: 
	 train_loss: 0.9288430917263031, 
	 val_error: 1.6296926830021586e+23, 
	 val_force_error: 1.527427634770928e+23, 
 	 val_torque_error: 1.0226504823123055e+22
*********************************************************
epoch 5/200: 
	 train

epoch 37/200: 
	 train_loss: 0.6707087647914887, 
	 val_error: 1.9754371342812413e+23, 
	 val_force_error: 1.828161701988651e+23, 
 	 val_torque_error: 1.4727543229259037e+22
*********************************************************
epoch 38/200: 
	 train_loss: 0.6684429717063903, 
	 val_error: 1.981310079914383e+23, 
	 val_force_error: 1.834379783391879e+23, 
 	 val_torque_error: 1.4693029652250417e+22
*********************************************************
epoch 39/200: 
	 train_loss: 0.6675071620941162, 
	 val_error: 1.9837435501117515e+23, 
	 val_force_error: 1.8368859079102358e+23, 
 	 val_torque_error: 1.4685764220151562e+22
*********************************************************
epoch 40/200: 
	 train_loss: 0.6656546342372894, 
	 val_error: 1.9911585395163037e+23, 
	 val_force_error: 1.8437306845027534e+23, 
 	 val_torque_error: 1.474278550135503e+22
*********************************************************
epoch 41/200: 
	 train_loss: 0.664069402217865, 
	 val_error: 1.9890

epoch 73/200: 
	 train_loss: 0.6370391643047333, 
	 val_error: 2.0299468597056222e+23, 
	 val_force_error: 1.877710755448944e+23, 
 	 val_torque_error: 1.5223610425667825e+22
*********************************************************
epoch 74/200: 
	 train_loss: 0.6365615856647492, 
	 val_error: 2.0319280703862676e+23, 
	 val_force_error: 1.8796003500457412e+23, 
 	 val_torque_error: 1.5232772034052649e+22
*********************************************************
epoch 75/200: 
	 train_loss: 0.6360268211364746, 
	 val_error: 2.0333245823462555e+23, 
	 val_force_error: 1.88062392816905e+23, 
 	 val_torque_error: 1.527006541772055e+22
*********************************************************
epoch 76/200: 
	 train_loss: 0.6357109129428864, 
	 val_error: 2.0327400408494776e+23, 
	 val_force_error: 1.8803710381825454e+23, 
 	 val_torque_error: 1.523690026669322e+22
*********************************************************
epoch 77/200: 
	 train_loss: 0.6355567848682404, 
	 val_error: 2.0334

epoch 109/200: 
	 train_loss: 0.622808289527893, 
	 val_error: 2.053762832248937e+23, 
	 val_force_error: 1.8977838651549447e+23, 
 	 val_torque_error: 1.559789670939924e+22
*********************************************************
epoch 110/200: 
	 train_loss: 0.6223305702209473, 
	 val_error: 2.0534909716532167e+23, 
	 val_force_error: 1.8974219944911718e+23, 
 	 val_torque_error: 1.5606897716204493e+22
*********************************************************
epoch 111/200: 
	 train_loss: 0.6219807499647141, 
	 val_error: 2.054600769582963e+23, 
	 val_force_error: 1.8987257930170092e+23, 
 	 val_torque_error: 1.5587497656595375e+22
*********************************************************
epoch 112/200: 
	 train_loss: 0.6215588831901551, 
	 val_error: 2.0561454362990495e+23, 
	 val_force_error: 1.899453761294491e+23, 
 	 val_torque_error: 1.566916750045584e+22
*********************************************************
epoch 113/200: 
	 train_loss: 0.6212853002548218, 
	 val_error: 2.

epoch 145/200: 
	 train_loss: 0.6079902279376984, 
	 val_error: 2.0749476808557018e+23, 
	 val_force_error: 1.912086416152689e+23, 
 	 val_torque_error: 1.6286126470301282e+22
*********************************************************
epoch 146/200: 
	 train_loss: 0.6075341516733169, 
	 val_error: 2.0743962701700042e+23, 
	 val_force_error: 1.9114160167458724e+23, 
 	 val_torque_error: 1.6298025342413195e+22
*********************************************************
epoch 147/200: 
	 train_loss: 0.607253890633583, 
	 val_error: 2.0762502887423878e+23, 
	 val_force_error: 1.9127665047324182e+23, 
 	 val_torque_error: 1.6348378400996964e+22
*********************************************************
epoch 148/200: 
	 train_loss: 0.6067491376399994, 
	 val_error: 2.079248861372535e+23, 
	 val_force_error: 1.9152464826380812e+23, 
 	 val_torque_error: 1.6400237873445376e+22
*********************************************************
epoch 149/200: 
	 train_loss: 0.6064093935489655, 
	 val_error:

epoch 180/200: 
	 train_loss: 0.5940491759777069, 
	 val_error: 2.1032319576727363e+23, 
	 val_force_error: 1.932687206193315e+23, 
 	 val_torque_error: 1.7054475147942113e+22
*********************************************************
epoch 181/200: 
	 train_loss: 0.5941293787956238, 
	 val_error: 2.106180576134552e+23, 
	 val_force_error: 1.9354967896882776e+23, 
 	 val_torque_error: 1.7068378644627447e+22
*********************************************************
epoch 182/200: 
	 train_loss: 0.5935939329862595, 
	 val_error: 2.106564689755199e+23, 
	 val_force_error: 1.935882156274678e+23, 
 	 val_torque_error: 1.70682533480521e+22
*********************************************************
epoch 183/200: 
	 train_loss: 0.5932133537530899, 
	 val_error: 2.1056495784162727e+23, 
	 val_force_error: 1.9349248132344604e+23, 
 	 val_torque_error: 1.7072476518181245e+22
*********************************************************
epoch 184/200: 
	 train_loss: 0.5926132190227509, 
	 val_error: 2.

## Step 5: Testing

In [42]:
if best_model_path:
    model = torch.load(best_model_path, map_location=device)

test_force_error, test_torque_error = validation(model, test_dataloader)
test_error = test_force_error + test_torque_error
print('Testing \n\t test error: {}, \n\t test_force_error: {}, \n\t test_torque_error: {}'.
      format(test_error, test_force_error, test_torque_error))

Testing 
	 test error: 2.49087401351966e+23, 
	 test_force_error: 2.107556073574886e+23, 
	 test_torque_error: 3.833179399447743e+22
