## trainer.ipynb

This notebook is used to train the model described in "Deep Imitation Learning for Complex Manipulation Tasks from Virtual Reality Teleoperation" (https://arxiv.org/abs/1710.04615)

Make sure the /data/ dataset folder is in the same directory as this notebook.  Dataset format:

/data/
 - /{runNumber}/
   - /depth/ - contains 1 channel depth images
      - fileNames: depth{stepNum}.png
   - /rgb/ - contains 3 channel rgb images
      - fileNames: rgb{stepNum}.png
   - /states/ - contains csv files of the format:
      - endEffectorPt1X, endEffectorPt1Y, endEffectorPt2X, endEffectorPt2Y, endEffectorPt3X, endEffectorPt3Y, isOpen (boolean: {0 = closed, 1 = open})
      - endEffectorX, endEffectorY, endEffectorZ, endEffectorRoll, endEffectorPitch, endEffectorYaw
      - fileNames: states{stepNum}.csv

Note that for simplicity in this implementation, the network is not trained using auxilary points.  Further, the network only considers the linear movement of the end effector (no roll / pitch / yaw control.)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import torchvision
import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
from VRNet import VRNet
from VRNet import VRDataLoader

In [None]:
dataloader = VRDataLoader('data', 2, 160, batch_size=128)

states = dataloader.states

#set outliers to 0 (where |x| > 1)
for i in range(len(states)):
    if abs(states[i][0]) > 1:
        #interpolate between two adjacent states
        #if either value adjacent is out of bounds, use the other
        if i == 0:
            states[i][0] = states[i+1][0]
        elif i == states.shape[0] - 1:
            states[i][0] = states[i-1][0]
        else:
            states[i][0] = (states[i-1][0] + states[i+1][0]) / 2
    if abs(states[i][1]) > 1:
        if i == 0:
            states[i][1] = states[i+1][1]
        elif i == states.shape[0] - 1:
            states[i][1] = states[i-1][1]
        else:
            states[i][1] = (states[i-1][1] + states[i+1][1]) / 2
    if abs(states[i][2]) > 1:
        if i == 0:
            states[i][2] = states[i+1][2]
        elif i == states.shape[0] - 1:
            states[i][2] = states[i-1][2]
        else:
            states[i][2] = (states[i-1][2] + states[i+1][2]) / 2

#apply a gaussian filter to smooth out the data
gaussian_filter = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1])
gaussian_filter = gaussian_filter / np.sum(gaussian_filter)

states = states.cpu().numpy()
for i in range(6):
    states[:, i] = np.convolve(states[:, i], gaussian_filter, mode='same')
states = torch.tensor(states).to('cuda')

#plot x velocities
plt.plot([state[0].cpu() for state in states][0:500], 'bo')
plt.show()

#plot y velocities
plt.plot([state[1].cpu() for state in states][0:500], 'bo')
plt.show()

#plot z velocities
plt.plot([state[2].cpu() for state in states][0:500], 'bo')
plt.show()

#display image
plt.imshow(dataloader.rgb_images[0].permute(1,2,0))
plt.show()
#display histogram of image
plt.hist(dataloader.rgb_images[0].permute(1,2,0).flatten().cpu().numpy())
plt.show()

In [None]:

#show a random img from the dataset
import matplotlib.pyplot as plt
import random

idx = 0 #random.randint(0, len(dataloader))
data = dataloader[idx]
rgb_img, depth_img, state = data[0][0], data[1][0], data[2][0]

plt.imshow(rgb_img.permute(1, 2, 0))
plt.show()

plt.imshow(depth_img.permute(1, 2, 0))
plt.show()

print(state)

print(len(dataloader))

In [None]:
#add data augmentation
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomAffine(0, shear=3),
    torchvision.transforms.RandomAffine(0, scale=(0.98, 1.02)),
    # torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05),
    # torchvision.transforms.RandomApply([torchvision.transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5),
])

def applyTransforms(rgb_img, depth_img):
    rngstate = torch.random.get_rng_state()
    rgb_img = transform(rgb_img)
    torch.random.set_rng_state(rngstate)
    depth_img = transform(depth_img)
    return rgb_img, depth_img

idx = random.randint(0, len(dataloader))
data = dataloader[idx]
rgb_img, depth_img, state = data[0][1], data[1][1], data[2][1]

rgb_img, depth_img = applyTransforms(rgb_img, depth_img)

#apply same transform to both depth and rgb image
plt.imshow(rgb_img.permute(1, 2, 0))
plt.show()

plt.imshow(depth_img.permute(1, 2, 0))
plt.show()

In [None]:
#create the custom loss functions used by the paper

#create custom Lc loss function
class LcLoss(nn.Module):
    def __init__(self):
        super(LcLoss, self).__init__()

    def forward(self, pred, target):
        loss = torch.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            p = pred[i]
            t = target[i]
            loss[i] = torch.arccos(torch.dot(t, p) / (torch.norm(t) * torch.norm(p)))
        
        return torch.sum(loss)

#create custom Lg loss function
class LgLoss(nn.Module):
    def __init__(self):
        super(LgLoss, self).__init__()

    def forward(self, pred, target):
        loss = torch.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            p = pred[i]
            t = target[i]
            loss[i] = p * torch.log(t) - (1 - p) * torch.log(1 - t)
        return torch.sum(loss)

In [None]:
loss_weights = [1, 0.01, 0.005, 0.0001]

def train(model, data_loader, num_epochs, learning_rate, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001)
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    L1_loss = nn.L1Loss()
    L2_loss = nn.MSELoss()
    L_c_loss = LcLoss()
    L_g_loss = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        for i in range(len(data_loader)):
            rgb_img, depth_img, state = data_loader[i]
            rgb_img = rgb_img
            depth_img = depth_img

            #get rgb and depth images
            rgb_img = rgb_img.to(device).float()
            depth_img = depth_img.to(device).float()

            
            #apply data augmentation
            # rgb_img, depth_img = applyTransforms(rgb_img, depth_img)
            
            #add batch dimension
            state = state.to(device).float()

            optimizer.zero_grad()
            output = model(rgb_img, depth_img)
            
            #calculate combined loss
            # loss = L1_loss(output[0:3], state[0:3]) * loss_weights[0]
            #combine 0:3 and 6
            important_output = torch.cat((output[:, 0:3], output[:, 6].unsqueeze(1)), dim=1)
            important_state = torch.cat((state[:, 0:3], state[:, 6].unsqueeze(1)), dim=1)
            
            loss = L1_loss(important_output, important_state) # * loss_weights[1]
            # loss += L_c_loss(output[:, 0:6], state[:, 0:6]) * loss_weights[2]
            # loss += L_g_loss(output[:, 6], state[:, 6]) * loss_weights[3]
            

            loss.backward()
            optimizer.step()

        print(f'Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}')
    

In [None]:
#train the model
device = torch.device('cuda')
model = VRNet().to(device)

In [None]:
train(model, dataloader, 500, 0.0005, device)

In [None]:
#test model on an example image
import numpy as np

idx = random.randint(0, len(dataloader))
rgb_img, depth_img, state = dataloader[idx][0][0], dataloader[idx][1][0], dataloader[idx][2][0]

model.eval()
rgb_img = rgb_img.unsqueeze(0).to(device).float() 
depth_img = depth_img.unsqueeze(0).to(device).float()

rgb_img = rgb_img.permute(0, 1, 2, 3)
depth_img = depth_img.permute(0, 1, 2, 3)

print(rgb_img.shape, depth_img.shape)
output = model(rgb_img, depth_img)

#show output (no scientific notation)
print('output: {} {} {} {}'.format(output[0][0].item(), output[0][1].item(), output[0][2].item(), output[0][6].item(),))
print('mse: ', np.mean((output.detach().cpu().numpy()[0:2] - state.detach().cpu().numpy())[0:2] ** 2))
print(state)

print(rgb_img.shape)
plt.imshow(rgb_img.permute(0, 2, 3, 1).cpu()[0])
plt.show()

plt.imshow(depth_img.permute(0, 2, 3, 1).cpu()[0])
plt.show()

print(state)


In [None]:
#save the model
torch.save(model.state_dict(), 'model.pt')

#load the model
model = VRNet().to(device)
model.load_state_dict(torch.load('model.pt'))
