In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from ravens import tasks
from ravens.environments import environment
from ravens.utils import utils

import pybullet as p
import copy

np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

In [None]:
ASSETS_PATH = '../ravens/environments/assets/'

In [None]:
# Helper methods
def plot(obs):
    plt.figure()
    plt.imshow(obs)
    plt.axis('off')
    plt.show()
    plt.close()
    
from scipy.interpolate import interp1d
from scipy.spatial.transform import Rotation, RotationSpline

class PathPlanner:
    def __init__(self, start_pose, pick_pose, place_pose, times, height=0.32):
        self.times = times
        self.height = height

        self.start_pose = start_pose
        self.pick_pose = pick_pose
        self.place_pose = place_pose

        # Slightly lower pick_pose z-coord to make contact code register.
        # mod = ((0, 0, -0.005), (0, 0, 0, 1))
        # self.pick_pose = utils.multiply(pick_pose, mod)
        
        # Generate intermediate waypoints.
        prepick_to_pick = ((0, 0, height), (0, 0, 0, 1))
        self.prepick_pose = utils.multiply(pick_pose, prepick_to_pick)
        postpick_to_pick = ((0, 0, height), (0, 0, 0, 1))
        self.postpick_pose = utils.multiply(pick_pose, postpick_to_pick)
        preplace_to_place = ((0, 0, height), (0, 0, 0, 1))
        self.preplace_pose = utils.multiply(place_pose, preplace_to_place)
        postplace_to_place = ((0, 0, 0.32), (0, 0, 0, 1))
        self.postplace_pose = utils.multiply(place_pose, postplace_to_place)
        
        self.rots = None

    def _interpolate_position(self):
        f = lambda x: np.asarray(x).astype(np.float64)
        xyzs = [
            self.start_pose[0],
            self.prepick_pose[0],
            self.pick_pose[0],
            self.postpick_pose[0],
            self.preplace_pose[0],
            self.place_pose[0],
            self.postplace_pose[0],
        ]
        xyzs = [f(i) for i in xyzs]
        self.pos_spline = [interp1d(self.times, [p[i] for p in xyzs], kind='linear') for i in range(3)]
        
    def _interpolate_rotation(self):
        f = lambda x: np.asarray(x).astype(np.float64)
        quats = [
            self.start_pose[-1],
            self.prepick_pose[-1],
            self.pick_pose[-1],
            self.postpick_pose[-1],
            self.preplace_pose[-1],
            self.place_pose[-1],
            self.postplace_pose[-1],
        ]
        quats = [f(i) for i in quats]
        self.rots = Rotation.from_quat(quats)
        self.rot_spline = RotationSpline(self.times, self.rots)
        
    def plan(self):
        self._interpolate_position()
        self._interpolate_rotation()
        
    def __call__(self, times=None):
        self.plan()
        if times is None:
            times = self.times
        xyz = np.vstack([self.pos_spline[i](times) for i in range(3)]).T
        quats = self.rot_spline(times).as_quat()
        poses = [(tuple(p), tuple(q)) for p, q in zip(xyz, quats)]
        return poses

In [None]:
class ContinuousEnvironment(environment.Environment):
    def reset(self):
        self.suction_on = False
        self.patience = 0
        return super().reset()
    
    def get_ee_pose(self):
        return p.getLinkState(self.ur5, self.ee_tip)[0:2]
        
    def step(self, action=None):
        if action is not None:
            timeout = self.movep(action)
            if timeout:
                obs = self._get_obs()
                return obs, 0.0, True, self.info
            
            if self.ee.detect_contact():
                print("Detected contact.")
                if not self.suction_on:
                    print("\tActivating suction.")
                    self.ee.activate()
                else:
                    print("\tDeactivating suction.")
                    self.ee.release()
                self.suction_on = not self.suction_on
        
        # Step simulator asynchronously until objects settle.
        while not self.is_static:
            p.stepSimulation()

        # Get task rewards.
        reward, info = self.task.reward() if action is not None else (0, {})
        done = self.task.done()

        # Add ground truth robot state into info.
        info.update(self.info)

        obs = self._get_obs()

        return obs, reward, done, info

In [None]:
# env = environment.Environment(ASSETS_PATH)
env = ContinuousEnvironment(ASSETS_PATH)
task = tasks.BlockInsertionEasy()
env.set_task(task)
env.seed(1)
agent = task.oracle(env)
obs = env.reset()
plot(obs['color'][0])
info = None
act = agent.act(obs, info)  # Expert action.

In [None]:
MIN_DELTA = 0.001
HORIZON = 200  # max steps it should take to solve this task.

In [None]:
times = np.linspace(0, 10, 7, endpoint=True)
# times = np.linspace(0, 10, HORIZON, endpoint=True)
planner = PathPlanner(
    env.get_ee_pose(),
    act['pose0'],
    act['pose1'],
    times,
)
times = np.linspace(0, 10, HORIZON, endpoint=True)
plan = planner(times)
print(len(plan))

In [None]:
obses = []
for i, act in enumerate(plan):
    if not i % 10:
        print(i)
    obs, _, done, info = env.step(act)
    obses.append(obs['color'][0])
#     plot(obs['color'][0])
    if done:
        print("Done, exiting.")
        break

In [None]:
import imageio
from IPython.display import Video

imageio.mimsave("demo.mp4", obses, fps=FPS)
Video("demo.mp4")

In [None]:
from scipy.interpolate import interp1d
from scipy.spatial.transform import Rotation, RotationSpline

class PathPlanner:
    def __init__(self, start_pose, pick_pose, place_pose, total_time, height=0.32):
        self.total_time = total_time
        self.height = height

        self.start_pose = start_pose
        self.pick_pose = pick_pose
        self.place_pose = place_pose
        
        # Generate intermediate waypoints.
        prepick_to_pick = ((0, 0, height), (0, 0, 0, 1))
        self.prepick_pose = utils.multiply(pick_pose, prepick_to_pick)
        postpick_to_pick = ((0, 0, height), (0, 0, 0, 1))
        self.postpick_pose = utils.multiply(pick_pose, postpick_to_pick)
        
        self.times = np.linspace(0, total_time, 4, endpoint=True, dtype=float)
        self.rots = None
    
    def _interpolate_position(self):
        xyzs = [
            self.start_pose[0],
            self.prepick_pose[0],
            self.pick_pose[0],
            self.postpick_pose[0],
        ]
        self.pos_spline = [interp1d(self.times, [p[i] for p in xyzs], kind='cubic') for i in range(3)]
        
    def _interpolate_rotation(self):
        quats = [
            self.start_pose[-1],
            self.prepick_pose[-1],
            self.pick_pose[-1],
            self.postpick_pose[-1],
        ]
        self.rots = Rotation.from_quat(quats)
        self.rot_spline = RotationSpline(self.times, self.rots)
        
    def plan(self):
        self._interpolate_position()
        self._interpolate_rotation()
        
    def __call__(self, times):
        self.plan()
        xyz = np.vstack([self.pos_spline[i](times) for i in range(3)]).T
        quats = self.rot_spline(times).as_quat()
        return [(tuple(p), tuple(q)) for p, q in zip(xyz, quats)]

In [None]:
planner = PathPlanner(
    info['ee_pose'],
    info['pick_pose'],
    info['place_pose'],
    3,
)

In [None]:
times_plot = np.linspace(planner.times[0], planner.times[-1], 5)
plan = planner(times_plot)

In [None]:
times_plot = np.linspace(planner.times[0], planner.times[-1], 100)

names = ['x-pos', 'y-pos', 'z-pos']
fig, axes = plt.subplots(1, 3, figsize=(20, 4))
for i, ax in enumerate(axes):
    ax.plot(times_plot, planner.pos_spline[i](times_plot))
    ax.scatter(planner.times, planner.pos_spline[i](planner.times), c='black')
    ax.set_title(names[i])
plt.show()

In [None]:
angles = planner.rots.as_euler('XYZ', degrees=True)

# Euler angles.
angles_plot = planner.rot_spline(times_plot).as_euler('XYZ', degrees=True)
plt.plot(times_plot, angles_plot)
plt.plot(planner.times, angles, 'x')
plt.title("Euler angles")
plt.show()

# Angular rates.
angular_rate_plot = np.rad2deg(planner.rot_spline(times_plot, 1))
plt.plot(times_plot, angular_rate_plot)
plt.plot(times, angular_rate, 'x')
plt.title("Angular rate")
plt.show()