In [12]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
from mpl_toolkits.mplot3d import Axes3D
%matplotlib qt

In [2]:
#loading data
data_dicts = pickle.load(open('pe_data_dicts', 'r'))

In [42]:
def get_target_actual_pairs(data, desired_states):    
    all_pairs = {
        'targets':[],
        'actuals':[],
        'successes':[]
    }

    for obj_name in data:
        for grasp_id, grasp in data[obj_name]['grasps'].items():
            n = len(grasp['target_states'][desired_states[0]])
            targets = [{} for _ in range(n)]
            actuals = [{} for _ in range(n)]
            
            for state in desired_states:
                target_vals = grasp['target_states'][state]
                actual_vals = grasp['actual_states'][state]
                for i in range(n):
                    targets[i][state] = target_vals[i]
                    actuals[i][state] = actual_vals[i]
            
            all_pairs['targets'].extend(targets)
            all_pairs['actuals'].extend(actuals)
            all_pairs['successes'].extend(grasp['grasp_output']['success'])
    
    all_pairs['successes'] = np.array(all_pairs['successes'])
    return all_pairs

In [43]:
desired_states = ['arm_ext', 'arm_elev', 'gripper_rot', 'arm_rot']
pairs = get_target_actual_pairs(data_dicts, desired_states)

In [44]:
def forward_kinematics(state):
    z = state['arm_elev']
    r = state['arm_ext']
    t = state['arm_rot']
    
    x = r*np.cos(t)
    y = r*np.sin(t)
    
    return np.array([x, y, z])

In [45]:
pairs['targets_pts'] = np.array([forward_kinematics(state) for state in pairs['targets']])
pairs['actuals_pts'] = np.array([forward_kinematics(state) for state in pairs['actuals']])

In [46]:
pairs['relatives'] = np.array([
    pairs['actuals_pts'][i] - pairs['targets_pts'][i]
    for i in range(len(pairs['actuals_pts']))
])

In [110]:
#diffs scatter
pts = pairs['relatives']
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim([-0.008, 0.008])
ax.set_ylim([-0.008, 0.008])
ax.set_zlim([-0.002, 0.014])
ax.scatter(pts[:,0], pts[:,1], pts[:,2])
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Diffs')
fig.show()

In [113]:
#mixed actual and target scatter
xs = np.r_[pairs['targets_pts'][:,0], pairs['actuals_pts'][:,0]]
ys = np.r_[pairs['targets_pts'][:,1], pairs['actuals_pts'][:,1]]
zs = np.r_[pairs['targets_pts'][:,2], pairs['actuals_pts'][:,2]]

n = zs.shape[0]/2
cs = np.r_[['r']*n, ['b']*n]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim([-0.18, 0.18])
ax.set_ylim([-0.18, 0.18])
ax.set_zlim([-0.02, 0.16])
ax.scatter(xs, ys, zs, c=cs)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Red = Targets. Blue = Actuals')
fig.show()

In [85]:
pairs['diffs_failures'] = np.take(pairs['relatives'], np.argwhere(pairs['successes'] == 0), axis=0)
pairs['diffs_successes'] = np.take(pairs['relatives'], np.argwhere(pairs['successes'] == 1), axis=0)

pairs['diffs_failures'] = pairs['diffs_failures'].reshape((pairs['diffs_failures'].shape[0], 3))
pairs['diffs_successes'] = pairs['diffs_successes'].reshape((pairs['diffs_successes'].shape[0], 3))

In [100]:
#mixed successes and failures scatter
xs = np.r_[pairs['diffs_failures'][:,0], pairs['diffs_successes'][:,0]]
ys = np.r_[pairs['diffs_failures'][:,1], pairs['diffs_successes'][:,1]]
zs = np.r_[pairs['diffs_failures'][:,2], pairs['diffs_successes'][:,2]]

n = zs.shape[0]/2
cs = np.r_[['r']*n, ['b']*n]

fig = plt.figure(figsize=(500,500))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim([-0.008, 0.008])
ax.set_ylim([-0.008, 0.008])
ax.set_zlim([-0.002, 0.014])
ax.scatter(xs, ys, zs, c=cs)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Diffs. Red = Failures. Blue = Successes')
fig.show()

In [115]:
diffs_mean = np.mean(pairs['relatives'], axis=0)
diffs_cov = np.cov(pairs['relatives'].T)
print 'mean'
print repr(diffs_mean)
print 'cov'
print repr(diffs_cov)

mean
array([  1.70398127e-03,   3.08850682e-05,   6.11979781e-04])
cov
array([[  5.24198858e-07,  -1.23597659e-07,   1.93426049e-07],
       [ -1.23597659e-07,   1.04596358e-06,   1.81981953e-08],
       [  1.93426049e-07,   1.81981953e-08,   4.16089676e-06]])
