In [None]:
from scipy.linalg import norm
from scipy.spatial.transform import Rotation as R
from sklearn.neighbors import NearestNeighbors
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import os

In [None]:
pose_files = glob.glob(os.path.join('data/pose', '*.txt'))
pose_files.sort()

In [None]:
def get_train_val_split(pose_files, skip, max_index=len(pose_files), stride=1):
    pose_files = pose_files[0:max_index:stride]
    train_filenames = [pose_files[i] for i in range(len(pose_files)) if (i % skip) != 0]
    val_filenames = [pose_files[i] for i in range(len(pose_files)) if (i % skip) == 0]
    
    return train_filenames, val_filenames

In [None]:
def load_poses(pose_files):
    rots = []
    ts = []
    for file in pose_files:
        with open(file) as csv_file:
            data = pd.read_csv(csv_file, delimiter=' ', index_col=None, header=None)
            
            rot = np.array(data.values[0:3,0:3])
            t = np.array(data.values[0:3,-1])
            rots.append(rot)
            ts.append(t)
    return rots, ts

In [None]:
def get_nn_indices(neighbors, items):
    neigh = NearestNeighbors(n_neighbors=1)
    neigh.fit(neighbors)

    neigh_dist, neigh_index = neigh.kneighbors(items)
    
    return neigh_index

In [None]:
def get_rotvecs_from_matrices(rots):
    rotvecs = []
    for rot in rots:
        rotvecs.append(R.from_matrix(rot).as_rotvec())
    
    return rotvecs

In [None]:
def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'::

            >>> angle_between((1, 0, 0), (0, 1, 0))
            1.5707963267948966
            >>> angle_between((1, 0, 0), (1, 0, 0))
            0.0
            >>> angle_between((1, 0, 0), (-1, 0, 0))
            3.141592653589793
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))


#angles = []
#for i, val_rot in enumerate(val_rots):
#    angle = get_angle(val_rot, train_rots[nn_indices[i,0]])
#    angles.append(np.rad2deg(angle))
#
#print(angles)

In [None]:
# Load train and validation poses
def get_val_nn_train_angles(train_rots, val_rots, unit='deg'):
    # Transform to axis representation
    train_rotvecs = get_rotvecs_from_matrices(train_rots)
    val_rotvecs = get_rotvecs_from_matrices(val_rots)

    # Find nearest neighbors by angle
    nn_indices = get_nn_indices(train_rotvecs, val_rotvecs)

    # Get angles to nearest neighbords in degrees
    angles = []
    for i, val_rotvec in enumerate(val_rotvecs):
        angle = angle_between(val_rotvec, train_rotvecs[nn_indices[i,0]])
        if unit == 'deg':
            angle = np.rad2deg(angle)
        angles.append(angle)
    
    return angles

In [None]:
# Display results
tv_files = [None]*4
tv_files[0] = get_train_val_split(pose_files, 6, 2700, 1)
tv_files[1] = get_train_val_split(pose_files, 6, 2700, 2)
tv_files[2] = get_train_val_split(pose_files, 6, 2700, 4)
tv_files[3] = get_train_val_split(pose_files, 6, 2700, 8)

title = 'Train:Val = {}:{} samples \n (min, max, mean) = ({:.1f}, {:.1f}, {:.1f})'

plt.figure(figsize=(15,4))
for i in range(len(tv_files)):
    train_files, val_files = tv_files[i]
    train_rots, _ = load_poses(train_files)
    val_rots, _ = load_poses(val_files)
    angles = get_val_nn_train_angles(train_rots, val_rots, unit='deg')
    
    #print('Min:', np.min(angles))
    #print('Max:', np.max(angles))
    #print('Mean:', np.mean(angles))
    
    plt.subplot(1, len(tv_files), i+1)
    plt.title(title.format(len(train_files), len(val_files), np.min(angles), np.max(angles), np.mean(angles)))
    num_bins = 20
    plt.hist(angles, bins=num_bins)
    plt.xlabel('degrees')
    plt.ylabel('count')
plt.show()