## Imports

In [1]:
import os
import sys
import argparse
import numpy as np
import csv

from params import ParamsKITTI, ParamsEuroc
from dataset import KITTIOdometry, EuRoCDataset

In [2]:
import cv2
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#%matplotlib inline

In [3]:
sys.executable

'/Users/David/miniconda3/envs/dev/bin/python'

In [40]:
a=np.array([[1,2,3,4,5,6],[2,4,6,8,10,12],[3,6,9,12,15,18]])
mask2=np.zeros((6,), dtype=int)
mask2[1] = 1
mask2[2] = 1
mask2[4] = 1

mask = np.array([False, True, True, False, True, False])

print("a:\n", a)
print("mask:\n", mask)
print("mask2:\n", mask2)
print("a[:, mask] (WORKS):\n", a[:, mask])
print("a[:, mask2]:\n", a[:, mask2.astype(bool)])

a:
 [[ 1  2  3  4  5  6]
 [ 2  4  6  8 10 12]
 [ 3  6  9 12 15 18]]
mask:
 [False  True  True False  True False]
mask2:
 [0 1 1 0 1 0]
a[:, mask] (WORKS):
 [[ 2  3  5]
 [ 4  6 10]
 [ 6  9 15]]
a[:, mask2]:
 [[ 2  3  5]
 [ 4  6 10]
 [ 6  9 15]]


## User-defined classes

In [None]:
class GroundTruthLoader:
    def __init__(self, path, dataset, start_idx, end_idx):
        self.traj = []
        start_ts = dataset.timestamps[start_idx] # timestamp of first frame to consider
        end_ts = dataset.timestamps[end_idx] # timestamp of last frame to consider

        has_first_row = False
        with open(path, newline='') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            next(csv_reader) # skip header
            for row in csv_reader:
                if float(row[0])/1e9 > end_ts:
                    break
                if float(row[0])/1e9 >= start_ts:
                    if not has_first_row:
                        ref = np.array([float(row[1]), float(row[2]), float(row[3])])
                        has_first_row = True
                    self.traj.append(np.array([float(row[1]), float(row[2]), float(row[3])]))
                    self.traj[-1] = self.traj[-1] - ref
                

In [None]:
FIRST_FRAME = 0
SECOND_FRAME = 1
DEFAULT = 2

class VO:
    def __init__(self, path, cam, start_idx=0):
        self.stage = FIRST_FRAME
        self.curr_idx = start_idx
        self.num_processed = 0
        
        # dataset-dependent params
        self.params = ParamsEuroc()
        self.dataset = EuRoCDataset(path)
        
        self.detector = cv2.ORB_create(nfeatures=200, scaleFactor=1.2, nlevels=1, edgeThreshold=31)
        self.ffdetector = cv2.FastFeatureDetector_create(threshold=25, nonmaxSuppression=True)
        self.extractor = cv2.xfeatures2d.BriefDescriptorExtractor_create(bytes=32, use_orientation=False)
        self.bf_matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
        
        # kpts and descriptors of all frames seen so far
        self.kpts = []
        self.des = []
        self.matches = []
        
        # params for Shi-Tomasi corner detection
        self.detector_params = dict(maxCorners = 150,
                              qualityLevel = 0.3,
                              minDistance = 7,
                              blockSize = 7)

        # tracker params
        self.tracker_params = dict(winSize = (21, 21),
                                   criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01))
        
        # hash table to find 3d points across frames
        #self.pts_3d = {}
        self.pts_3d = []
        self.good_idxs = []
        
        # relevant images seen so far
        self.prev_img = self.dataset.left[start_idx]
        self.curr_img = self.dataset.left[start_idx]
        
        # camera model
        self.f = (cam.fx + cam.fy) / 2 # avg of both focal lengths
        self.pp = (cam.cx, cam.cy)
        self.K = np.append(cam.intrinsic_matrix, np.array([[0, 0, 0]]).T, axis=1) # 3x4 ndarray
        
        # trajectory
        self.poses = []
        
        self.viz = True
        self.tracks = []
            
    def update(self):
        #print("INFO: in update()")
        self.timestamp = self.dataset.timestamps[self.curr_idx]
        self.prev_img = self.curr_img
        self.curr_img = self.dataset.left[self.curr_idx]
        
        if self.stage == FIRST_FRAME:
            self.process_first_frame()
        elif self.stage == SECOND_FRAME:
            self.process_second_frame()
        else:
            self.process_default()
        
        if self.viz:
            self.draw()
            
        self.curr_idx = self.curr_idx + 1
        self.num_processed = self.num_processed + 1
    
    def draw(self):
        vis = self.curr_img.copy()
        if len(self.tracks) > 0:
            print("THERE ARE TRACKS")
            new_tracks = []
            for tr, (x, y), good_flag in zip(self.tracks, self.kpts[-1].reshape(-1, 2), self.good):
                if not good_flag:
                    continue
                tr.append((x, y))
                new_tracks.append(tr)
                cv2.circle(vis, (x, y), 2, (0, 255, 0), -1)
            self.tracks = new_tracks
            cv2.polylines(vis, [np.int32(tr) for tr in self.tracks], False, (0, 255, 0))
        else:
            for x, y in [np.int32(tr[-1]) for tr in self.tracks]:
                cv2.circle(vis, (x, y), 5, 0, -1)
            if self.kpts[-1] is not None:
                for x, y in np.float32(self.kpts[-1]).reshape(-1, 2):
                    self.tracks.append([(x, y)])
                    print(x, y)
        plt.imshow(vis, cmap='gray')
        plt.show()
        
    def track_features(self, kpts_prev=None):
        if kpts_prev is None:
            kpts_prev = self.kpts[-1]
        print("in track_features... in:", len(self.kpts[-1]))
        #kpts_prev = self.kpts[-1]
        kpts, st, err = cv2.calcOpticalFlowPyrLK(self.prev_img, self.curr_img, kpts_prev, None, **(self.tracker_params))
        #status = st.reshape(st.shape[0])
        #self.good_trks = st == 1
        #self.kpts[-1] = kpts_prev[st == 1]
        #kpts = kpts[st == 1]
        #print("in track_features... out:", len(kpts))
        kpts_r, _st, _err = cv2.calcOpticalFlowPyrLK(self.curr_img, self.prev_img, kpts, None, **(self.tracker_params))
        d = abs(kpts - kpts_r).reshape(-1, 2).max(-1)
        self.good = d < 1
        return kpts

    def process_first_frame(self):
        print("\nINFO: in process_first_frame()")
        #kpts, des = self.detector.detectAndCompute(self.curr_img, None)
        #kpts = self.ffdetector.detect(self.curr_img)
        kpts = cv2.goodFeaturesToTrack(self.curr_img, mask = None, **self.detector_params)
        print(len(kpts))
        R = np.array([[1.0, 0, 0],
                      [0, 1.0, 0],
                      [0, 0, 1.0]]) # rotation matrix
        t = np.array([0, 0, 0]) # translation vector
        self.poses.append((R, t))
        self.kpts.append(np.asarray(kpts))
        #self.des.append(des)
        self.stage = SECOND_FRAME

    def process_second_frame(self):
        print("\nINFO: in process_second_frame()")

        '''
        DEPRECATED:
        
        kpts, des = self.detector.detectAndCompute(self.curr_img, None)
                
        # match
        matches = self.bf_matcher.knnMatch(self.des[-1], des, k=2)
        pts1 = []
        pts2 = []
        good = []
        idxs = []
        for m, n in matches:
            if m.distance < 0.5 * n.distance:
                pts1.append(self.kpts[-1][m.queryIdx].pt)
                pts2.append(kpts[m.trainIdx].pt)
                good.append([m])
                idxs.append(m.trainIdx)
        print("{} good matches out of {} total keypoints in frame {}"
              .format(len(pts2), len(kpts), self.num_processed + 1))
        '''

        kpts = self.track_features()
        
        # extract relative pose
        E, mask = cv2.findEssentialMat(np.array(self.kpts[-1]), np.array(kpts), focal=self.f, pp=self.pp, method=cv2.RANSAC, prob=0.99, threshold=0.5)
        _, R, t, _ = cv2.recoverPose(E, np.array(self.kpts[-1]), np.array(kpts), focal=self.f, pp=self.pp)
        #print("Essential matrix:", E)
        
        # triangulate points from two views (current and previous)
        pts_3d, idxs = self.triangulate_points(R, t, self.kpts[-1], kpts) # np.array, 3xN
        #print("{} points triangulated".format(pts_3d.shape[1]), pts_3d)
        
        # insert elements into hash table
        #self.populate_idx_to_pts3d(idxs, pts_3d.T)
        
        # compute absolute pose
        R = R.dot(self.poses[-1][0])
        t = R.dot(self.poses[-1][1]) + t[:,0]
        
        # bookkeep
        #print(pts_3d[:,:5])
        #print(pts2[:5])
        #print("rotation:", R)
        print("translation:", t)
        
        self.good_idxs.append(np.asarray(idxs))
        self.pts_3d.append(pts_3d)
        
        self.poses.append((R, t))
        self.kpts.append(np.array(kpts))
        #self.des.append(des)
        #self.matches.append(good)
        self.stage = DEFAULT
    
    def process_default(self):
        print("\nINFO: in process_default()")
        '''
        # detect and extract TODO: change to use KLT tracking
        kpts, des = self.detector.detectAndCompute(self.curr_img, None)
        
        # match
        matches = self.bf_matcher.knnMatch(self.des[-1], des, k=2)
        pts1 = []
        pts2 = []
        good = []
        pts_3d = []
        pts_2d = []
        idxs = []
        for m, n in matches:
            if m.distance < 0.5 * n.distance:
                pts1.append(self.kpts[-1][m.queryIdx].pt)
                pts2.append(kpts[m.trainIdx].pt)
                good.append([m])
                idxs.append(m.trainIdx)
                if m.queryIdx in self.pts_3d:
                    pts_2d.append(self.kpts[-1][m.queryIdx].pt)
                    pts_3d.append(self.pts_3d[m.queryIdx])
        print("{} good matches out of {} total keypoints in frame {}"
              .format(len(pts2), len(kpts), self.num_processed + 1))        

        '''
        # track from previous to current frame
        #kpts = self.track_features(self.kpts[-1][self.good_idxs[-1]])
        kpts = self.track_features()
        
        '''
        print(self.pts_3d[-1])
        _, rot, t, inliers = cv2.solvePnPRansac(np.asarray(self.pts_3d[-1][self.good_trks]), np.asarray(kpts), 
                                                    self.K[:,:3], None, None, None, False, 50, 2.0, 0.9, None)
        
        
        # make sure we can proceed
        if inliers is None or len(inliers) < 5:
            print("ERROR -- failed to solve PnP... skipping frame")
            return
        
        print("num total pts:", len(pts_3d), "num inliers:", len(inliers))
        R = cv2.Rodrigues(rot)[0]
        
        # triangulate points from two views (current and previous)
        pts_3d, idxs = self.triangulate_points(R, t, self.kpts[-1], kpts) # np.array, 3xN
        print("{} points triangulated".format(pts_3d.shape[1]))
        
        # insert elements into hash table
        self.populate_idx_to_pts3d(idxs, pts_3d.T)
        
        # compute absolute pose
        R = R.dot(self.poses[-1][0])
        t = R.dot(self.poses[-1][1]) + t[:,0]
        '''
        
        # bookkeep
        #print("rotation:", R)
        #print("translation:", t)
        #self.poses.append((R, t))
        self.kpts.append(kpts)
        #self.des.append(des)
        #self.matches.append(good)
        
    '''
    def populate_idx_to_pts3d(self, idxs, pts_3d):
        if (pts_3d.shape[0] == 3):
            pts_3d = pts_3d.T
        if (pts_3d.shape[0] != len(idxs)):
            print("FATAL -- {} vs {}".format(pts_3d.shape[0], len(idxs))) # TODO exit here
        
        self.pts_3d.clear()
        for idx, p in zip(idxs, pts_3d):
            if idx in self.pts_3d:
                print("duplicated idx!!!")
            self.pts_3d[idx] = p
        print("{} elements in hash table".format(len(self.pts_3d)))
    ''' 
    def triangulate_points(self, R, t, kpts1, kpts2):
        P_1 = self.K.dot(np.linalg.inv(self.T_from_Rt(R, t)))
        P_2 = self.K # assume camera 2 is at origin
        
        pts_hom = cv2.triangulatePoints(P_1, P_2, np.asarray(kpts1).T, np.asarray(kpts2).T) # in homogeneous coords
        pts = pts_hom / np.tile(pts_hom[-1, :], (4, 1)) # 4xN
        good_idxs = (pts[3,:] > 0) & (np.abs(pts[2, :]) > 0.01)
        #print(np.array(pts.T[good_idxs]).T.shape)
        #print(len(good_idxs), "good indices")
        #print(np.array(pts[:3,good_idxs]).shape)
        return np.array(pts[:3, good_idxs]), good_idxs[0] # 3xM, where M = len(good_idxs)
        
    def draw_matches(self):
        plt.ion() # interactive
        if self.stage == DEFAULT:
            img_matches = cv2.drawMatchesKnn(self.prev_image, self.prev_kpts, self.curr_image,
                                             self.curr_kpts, self.matches[-1], None, 
                                             flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
            plt.figure(figsize=(12,8), dpi=100)
            plt.imshow(img_matches)
            plt.show()
            input("press any key")
    
    def T_from_Rt(self, R, t):
        t = t.reshape((3, 1))
        R = R.reshape((3, 3))
        return np.append(np.append(R, t, axis=1), np.array([[0,0,0,1]]), axis=0)
    
    def get_image(self, idx):
        idx = max(0, idx)
        return self.dataset.left[idx]
    
    @property
    def curr_image(self):
        return self.curr_img
    
    @property
    def curr_kpts(self):
        if len(self.kpts) > 0:
            return self.kpts[-1]
        else:
            return []
        
    @property
    def prev_image(self):
        return self.prev_img
    
    @property
    def prev_kpts(self):
        if len(self.kpts) > 1:
            print("len(self.kpts):", len(self.kpts))
            return self.kpts[-2]
        else:
            return []

## Main

In [None]:
# create ground truth loader object
ground_truth_path = os.path.join(path, 'mav0/state_groundtruth_estimate0/data.csv')
gtloader = GroundTruthLoader(ground_truth_path, dataset, start_idx, start_idx+frames_to_process)

In [None]:
# create VO object
path = '/Users/David/Downloads'
start_idx = 400
frames_to_process = 1000
dataset = EuRoCDataset(path)
vo = VO(path, dataset.left_cam, start_idx)

In [None]:
# testing
idx = 120
plt.ion()
for i in range(390, 420):
    plt.figure()
    plt.imshow(vo.get_image(i), cmap='gray')
    plt.show()
    input("press any key")
    #print(vo.dataset.timestamps[i])

In [None]:
%matplotlib inline

# main loop
for i in range(1, 10):#frames_to_process + 1):
    print("INFO: images", i-1, "-", i)
    vo.update()
    #vo.draw_matches()

In [None]:
img = vo.curr_img
cv2.circle(img, (200,300), 20, (0, 255, 0), 2)
cv2.imshow('something', img)

## Visualize results

In [None]:
%matplotlib notebook 
# why again??

plot_ground_truth = False

fig = plt.figure()
ax = fig.gca(projection='3d')

xs = []
ys = []
zs = []
for r, t in vo.poses:
    xs.append(t[0])
    ys.append(t[1])
    zs.append(t[2])

xs_t = []
ys_t = []
zs_t = []
for pt in gtloader.traj:
    xs_t.append(pt[0])
    ys_t.append(pt[1])
    zs_t.append(pt[2])
    
ax.plot(xs, ys, zs, label='Estimated 3D trajectory')
if plot_ground_truth:
    ax.plot(xs_t, ys_t, zs_t, label='Ground truth')
ax.legend()
plt.show()