# Pose detection for the ESA Pose Estimation Challenge

https://kelvins.esa.int/satellite-pose-estimation-challenge/problem/

In [None]:
import cv2
import json
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # 1: RTX, 0: Titan
import tensorflow as tf
from imgaug import augmenters as iaa
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D
from submission import SubmissionWriter
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Activation, AveragePooling2D, concatenate, Conv2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from tensorflow.keras import backend as K
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.utils import Sequence
from time import time
from pyrr import Quaternion
from PIL import Image

%load_ext autoreload
%autoreload 2

In [None]:
# Configure TensorFlow so that not all the GPU memory is allocated for this
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
K.set_session(tf.Session(config=config))

In [None]:
class Camera:

    """" Utility class for accessing camera parameters. """

    fx = 0.0176  # focal length[m]
    fy = 0.0176  # focal length[m]
    nu = 1920  # number of horizontal[pixels]
    nv = 1200  # number of vertical[pixels]
    ppx = 5.86e-6  # horizontal pixel pitch[m / pixel]
    ppy = ppx  # vertical pixel pitch[m / pixel]
    fpx = fx / ppx  # horizontal focal length[pixels]
    fpy = fy / ppy  # vertical focal length[pixels]
    k = [[fpx,   0, nu / 2],
         [0,   fpy, nv / 2],
         [0,     0,      1]]
    K = np.array(k)


def process_json_dataset(root_dir):
    with open(os.path.join(root_dir, 'train.json'), 'r') as f:
        train_images_labels = json.load(f)

    with open(os.path.join(root_dir, 'test.json'), 'r') as f:
        test_image_list = json.load(f)

    with open(os.path.join(root_dir, 'real_test.json'), 'r') as f:
        real_test_image_list = json.load(f)

    partitions = {'test': [], 'train': [], 'real_test': []}
    labels = {}

    for image_ann in train_images_labels:
        partitions['train'].append(image_ann['filename'])
        labels[image_ann['filename']] = {'q': image_ann['q_vbs2tango'], 'r': image_ann['r_Vo2To_vbs_true']}

    for image in test_image_list:
        partitions['test'].append(image['filename'])

    for image in real_test_image_list:
        partitions['real_test'].append(image['filename'])

    return partitions, labels


def quat2dcm(q):

    """ Computing direction cosine matrix from quaternion, adapted from PyNav. """

    # normalizing quaternion
    q = q/np.linalg.norm(q)

    q0 = q[0]
    q1 = q[1]
    q2 = q[2]
    q3 = q[3]

    dcm = np.zeros((3, 3))

    dcm[0, 0] = 2 * q0 ** 2 - 1 + 2 * q1 ** 2
    dcm[1, 1] = 2 * q0 ** 2 - 1 + 2 * q2 ** 2
    dcm[2, 2] = 2 * q0 ** 2 - 1 + 2 * q3 ** 2

    dcm[0, 1] = 2 * q1 * q2 + 2 * q0 * q3
    dcm[0, 2] = 2 * q1 * q3 - 2 * q0 * q2

    dcm[1, 0] = 2 * q1 * q2 - 2 * q0 * q3
    dcm[1, 2] = 2 * q2 * q3 + 2 * q0 * q1

    dcm[2, 0] = 2 * q1 * q3 + 2 * q0 * q2
    dcm[2, 1] = 2 * q2 * q3 - 2 * q0 * q1

    return dcm

def pointInTriangle(t, p):
    #https://stackoverflow.com/questions/2049582/how-to-determine-if-a-point-is-in-a-2d-triangle
    a = 0.5 *(-t[1][1]*t[2][0] + t[0][1]*(-t[1][0] + t[2][0]) + t[0][0]*(t[1][1] - t[2][1]) + t[1][0]*t[2][1]);
    s = 1/(2*a)*(t[0][1]*t[2][0] - t[0][0]*t[2][1] + (t[2][1] - t[0][1])*p[0] + (t[0][0] - t[2][0])*p[1]);
    u = 1/(2*a)*(t[0][0]*t[1][1] - t[0][1]*t[1][0] + (t[0][1] - t[1][1])*p[0] + (t[1][0] - t[0][0])*p[1]);
    return s > 0 and u > 0 and (1-s-u) > 0

class Plane:
    def __init__(self, points):
        if len(points) != 3:
            raise ValueError("Plane always consists of three points")

        self.points = np.asarray(points)
        n = np.cross(points[1] - points[0], points[2] - points[0])
        self.normal = n / np.linalg.norm(n, 2)

    def intersect(self, v):
        ndotu = self.normal.dot(v)
        if abs(ndotu) < 1e-6:
            raise ValueError("Line is parallel to plane")

        return -self.points[0] - (self.normal.dot(-self.points[0]) / ndotu) * v + self.points[0]

    def intersects(self, v):
        # First calculate intersection point of vector and this plane:
        try:
            intersection = self.intersect(v)
        except ValueError:
            # The vector is parallel to the line, so always return false (could be completely on it or completely off)
            return False

        # We have a rotated 3D plane (i.e. z coordinates are level) and want to remove the
        # z coordinates while keeping relations (i.e. project 3D plane to xy-plane)
        # https://stackoverflow.com/questions/1023948/rotate-normal-vector-onto-axis-plane
        zAxisNew = self.normal
        xAxisOld = np.array([1,0,0])
        if np.array_equal(np.absolute(zAxisNew), xAxisOld):
            # the old x axis cannot be the same as the normal (the new z axis) since then the
            # coordinate system is perpendicular to the xy plane. Therefore change x and z then
            xAxisOld = np.array([0,0,1])
        yAxisOld = np.array([0,1,0])
        yAxisNew = np.cross(xAxisOld, zAxisNew)
        xAxisNew = np.cross(zAxisNew, yAxisNew)
        yAxisNew /= np.linalg.norm(yAxisNew, 2)
        xAxisNew /= np.linalg.norm(xAxisNew, 2)
        projected2dtriangle = np.asarray([[p.dot(xAxisNew), p.dot(yAxisNew)] for p in self.points])
        # Now we know the 2d projection of the points of the polygon. Also project the 3d intersection point to the same plane
        projected2dpoint = np.asarray([intersection.dot(xAxisNew), intersection.dot(yAxisNew)])
        return intersection, pointInTriangle(projected2dtriangle, projected2dpoint)

def getSatelliteModel():
    b = 0.6
    a = 0.75
    d = 0.8
    c = 0.32

    #     0         1
    #     +---a-----+
    #  d-/|   u    /|-c
    # 3 +---------+ | 2
    #   |w| y  z  |x|     (y: front, z: back)
    # 4 | +-------|-+ 5
    #   |/ v (0,0)|/-b
    # 7 +---------+ 6
    # reference points in satellite frame for drawing axes
    return np.array([
        [-a / 2,  d / 2, c], # 0
        [ a / 2,  d / 2, c], # 1
        [ a / 2, -d / 2, c], # 2
        [-a / 2, -d / 2, c], # 3
        [-a / 2,  b / 2, 0], # 4
        [ a / 2,  b / 2, 0], # 5
        [ a / 2, -b / 2, 0], # 6
        [-a / 2, -b / 2, 0]  # 7
    ]), np.array([
        [0, 1, 2], [0, 3, 2], # u
        [4, 5, 6], [4, 7, 6], # v
        [0, 3, 7], [0, 4, 7], # w
        [1, 2, 6], [1, 5, 6], # x
        [3, 2, 6], [3, 7, 6], # y
        [0, 1, 5], [0, 4, 5], # z
    ])

def projectModel(q, r, plot=False):
    """
    Projecting points to image frame to draw axes
    # 1) Determine 8 vertice points
    # 2) Determine corresponding 8 surface planes
    # 3) Determine 8 vectors between camera and current 3d point (vertitice)
    # 4) Check if the 8 vectors intersect any of the 8 surface planes. If so, discard point
    """
    model_coordinates, cube_polygon_indices = getSatelliteModel()
    p_axes = np.ones((model_coordinates.shape[0], model_coordinates.shape[1] + 1))
    p_axes[:,:-1] = model_coordinates
    points_body = np.transpose(p_axes)

    # transformation to camera frame
    pose_mat = np.hstack((np.transpose(quat2dcm(q)), np.expand_dims(r, 1)))
    p_cam = np.dot(pose_mat, points_body)

    # Indices of points describing 3 point triangles of the cube
    # No point should intersect any of these triangles to be visible in the camera

    if plot:
        fig = plt.figure()
        ax = Axes3D(fig)

    points_camera_t = p_cam.transpose()
    points_camera_collision_indices = []
    for polygon_indices in cube_polygon_indices:
        points_polygon = points_camera_t[polygon_indices]
        plane = Plane(points_polygon)

        if plot:
            tri = Axes3D.art3d.Poly3DCollection([plane.points], alpha=0.2)
            tri.set_color([1,0,0])
            tri.set_edgecolor('k')
            ax.add_collection3d(tri)

        for i, p in enumerate(points_camera_t):
            intersection, intersects = plane.intersects(p)
            if(intersects):
                # The vector between camera origin and cube vertice intersects any of the 12 cube polygons.
                # There are two border cases to check:
                # 1) Sometimes an actual vertice intersects a neighboring polygon
                # 2) The vector between camera and point intersects a polygon that actually is behind the point
                dist_intersection = np.linalg.norm(intersection, 2)
                dist_point = np.linalg.norm(p, 2)
                if abs(dist_intersection - dist_point) > 0.01 and dist_intersection < dist_point and not i in points_camera_collision_indices:
                    points_camera_collision_indices.append(i)
                    if plot:
                        ax.scatter([intersection[0]], [intersection[1]], [intersection[2]])

    visible_points = np.ones(len(p_axes), dtype=bool)
    visible_points[points_camera_collision_indices] = False

    if plot:
        for p in points_camera_t[visible_points]:
            ax.plot([0, p[0]], [0, p[1]], [0, p[2]])

        #ax.set_xlim(-1, 1)
        #ax.set_ylim(-1, 1)
        #ax.set_zlim(5, 7)
        ax.autoscale()
        ax.set_xlabel('X axis')
        ax.set_ylabel('Y axis')
        ax.set_zlabel('Z axis')

        plt.show()

    p_cam = points_camera_t.transpose()

    # getting homogeneous coordinates
    points_camera_frame = p_cam / p_cam[2]
    # projection to image plane
    points_image_plane = Camera.K.dot(points_camera_frame)

    x, y = (points_image_plane[0], points_image_plane[1])
    return x, y, visible_points

def projectAxes(q, r):

    """ Projecting points to image frame to draw axes """

    # reference points in satellite frame for drawing axes
    p_axes = np.array([[0, 0, 0, 1],
                       [1, 0, 0, 1],
                       [0, 1, 0, 1],
                       [0, 0, 1, 1]])
    points_body = np.transpose(p_axes)

    # transformation to camera frame
    pose_mat = np.hstack((np.transpose(quat2dcm(q)), np.expand_dims(r, 1)))
    p_cam = np.dot(pose_mat, points_body)

    # getting homogeneous coordinates
    points_camera_frame = p_cam / p_cam[2]

    # projection to image plane
    points_image_plane = Camera.K.dot(points_camera_frame)

    x, y = (points_image_plane[0], points_image_plane[1])
    return x, y


class SatellitePoseEstimationDataset:

    """ Class for dataset inspection: easily accessing single images, and corresponding ground truth pose data. """

    def __init__(self, root_dir='/datasets/speed_debug'):
        self.partitions, self.labels = process_json_dataset(root_dir)
        self.root_dir = root_dir

    def get_image(self, i=0, split='train'):

        """ Loading image as PIL image. """

        img_name = self.partitions[split][i]
        img_name = os.path.join(self.root_dir, 'images', split, img_name)
        image = Image.open(img_name).convert('RGB')
        return image

    def get_pose(self, i=0):

        """ Getting pose label for image. """

        img_id = self.partitions['train'][i]
        q, r = self.labels[img_id]['q'], self.labels[img_id]['r']
        return q, r

    def visualize(self, i, partition='train', ax=None):

        """ Visualizing image, with ground truth pose with axes projected to training image. """

        if ax is None:
            ax = plt.gca()
        img = self.get_image(i)
        ax.imshow(img)

        # no pose label for test
        if partition == 'train':
            q, r = self.get_pose(i)
            xa, ya = projectAxes(q, r)
            ax.arrow(xa[0], ya[0], xa[1] - xa[0], ya[1] - ya[0], head_width=30, color='r')
            ax.arrow(xa[0], ya[0], xa[2] - xa[0], ya[2] - ya[0], head_width=30, color='g')
            ax.arrow(xa[0], ya[0], xa[3] - xa[0], ya[3] - ya[0], head_width=30, color='b')

        return


## Load dataset via API

In [None]:
dataset_root_dir = './speed'
dataset = SatellitePoseEstimationDataset(root_dir=dataset_root_dir)

## Visualize a few images from the dataset

In [None]:
for i in range(0, 5):
    img = np.array(dataset.get_image(i))
    
    fig,ax = plt.subplots(1)
    ax.imshow(img)

    q, r = dataset.get_pose(i)
    xa, ya, visible = projectModel(q, r)
    for x, y, v in zip(xa, ya, visible):
        if v and x >= 0.0 and y >= 0.0 and x <= Camera.nu and y <= Camera.nv:
            ax.add_patch(Rectangle((x, y),3,3,linewidth=5,edgecolor='r',facecolor='none'))

    dataset.visualize(i, ax=ax)

    plt.show()


## Configuration

In [None]:
config = {
    'dim': (600, 960),
    'batch_size': 1,
    'label_size': 1.5,
    'shuffle': True,
    'output_scale': 8,
    'n_output_vertices': 8,
    'stages': 6,
    'augmentation': False
}

## Neural Network definition
### Option 1: Define model for training from scratch

In [None]:
# Use pre-trained VGG19 convolutional layer as first stage
pretrained_model = tf.keras.applications.vgg19.VGG19(
    weights="imagenet",
    include_top=False,
    input_shape=config['dim'] + (3,)
)

# Network architecture according to https://arxiv.org/abs/1602.00134
def create_stage(inp, filters):
    # Stage definition according to CPM Paper
    c = Conv2D(filters=64, kernel_size=7, strides=1, padding="same")(inp)
    c = BatchNormalization()(c)
    c = Activation("relu")(c)
    
    for i in range(4):
        c = Conv2D(filters=64, kernel_size=7, strides=1, padding="same")(c)
        c = BatchNormalization()(c)
        c = Activation("relu")(c)

    c = Conv2D(filters=64, kernel_size=1, strides=1, padding="same")(c)
    c = BatchNormalization()(c)
    c = Activation("relu")(c)
    return Conv2D(filters=filters, kernel_size=1, strides=1, padding="same")(c)
    
# Adding new trainable hidden and output layers to the model

inp = pretrained_model.input
inp_avg = AveragePooling2D(pool_size=9, strides=8, padding="same")(inp)

c = pretrained_model.layers[13].output
c = Conv2D(filters=128, kernel_size=3, strides=1, padding="same")(c)
c = BatchNormalization()(c)
c = Activation("relu")(c)

pre = Conv2D(filters=128, kernel_size=3, strides=1, padding="same")(c)
pre = BatchNormalization()(pre)
pre = Activation("relu")(pre)

s1 = Conv2D(filters=128, kernel_size=1, strides=1, padding="same")(pre)
stages = [Conv2D(filters=config['n_output_vertices'], kernel_size=1, strides=1, padding="same")(s1)]

for i in range(5):
    stages.append(create_stage(concatenate([pre, stages[-1], inp_avg], axis=3), config['n_output_vertices']))

model = tf.keras.models.Model(inputs=inp, outputs=stages)
model.compile(loss="mse", optimizer=Adam(lr=1e-3))

model.summary()


### Option 2: Load trained model for inference

In [None]:
model = load_model("model.h5")
model.summary()

## Data generator definition and declaration

In [None]:
class KerasDataGenerator(Sequence):

    """ DataGenerator for Keras to be used with fit_generator (https://keras.io/models/sequential/#fit_generator)"""

    def __init__(self,
                 preprocessor,
                 label_list,
                 speed_root,
                 label_size,
                 batch_size,
                 dim,
                 shuffle=True,
                 output_scale=8,
                 n_output_vertices=8,
                 stages=6,
                augmentation=True):

        # loading dataset
        self.image_root = speed_root

        # Initialization
        self.preprocessor = preprocessor
        self.dim = dim
        self.batch_size = batch_size
        self.labels = {label['filename']: {'q': label['q_vbs2tango'], 'r': label['r_Vo2To_vbs_true']} for label in label_list}
        self.list_IDs = [label['filename'] for label in label_list]
        self.shuffle = shuffle
        self.label_size = label_size
        self.indexes = None
        self.output_scale = output_scale
        self.n_output_vertices = n_output_vertices
        self.stages = stages
        self.augmentation = augmentation
        self.on_epoch_end()

    def __len__(self):

        """ Denotes the number of batches per epoch. """

        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):

        """ Generate one batch of data """

        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        return self.__data_generation(list_IDs_temp)

    def on_epoch_end(self):

        """ Updates indexes after each epoch """

        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def drawBlob(self, img, pos, sigma=3):
        # https://github.com/NVlabs/Deep_Object_Pose/blob/master/src/training/train.py#L851
        w = int(sigma*3)
        if pos[0]-w>=0 and pos[0]+w<img.shape[0] and pos[1]-w>=0 and pos[1]+w<img.shape[1]:
            for i in range(int(pos[0])-w, int(pos[0])+w):
                for j in range(int(pos[1])-w, int(pos[1])+w):
                    img[i,j] = np.exp(-(((i - pos[0])**2 + (j - pos[1])**2)/(2*(sigma**2))))

    def __data_generation(self, list_IDs_temp):

        """ Generates data containing batch_size samples """

        # Initialization
        imgs = np.empty((self.batch_size, *self.dim, 3))
        masks = np.zeros(
            (self.batch_size,
             int(self.dim[0] / self.output_scale),
             int(self.dim[1] / self.output_scale),
             self.n_output_vertices
            ), dtype=np.float)

        seq = iaa.SomeOf((0, 3),
            [
                iaa.OneOf([
                    iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0
                    iaa.AverageBlur(k=(2, 7)), # blur image using local means with kernel sizes between 2 and 7
                    #iaa.MedianBlur(k=(3, 11)), # blur image using local medians with kernel sizes between 2 and 7
                ]),
                iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # sharpen images
                # search either for all edges or for directed edges,
                # blend the result with the original image using a blobby mask
                iaa.SimplexNoiseAlpha(iaa.OneOf([
                    iaa.EdgeDetect(alpha=(0.5, 1.0)),
                    iaa.DirectedEdgeDetect(alpha=(0.5, 1.0), direction=(0.0, 1.0)),
                ])),
                iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255)), # add gaussian noise to images
                iaa.OneOf([
                    iaa.Dropout((0.01, 0.1)), # randomly remove up to 10% of the pixels
                    iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.05)),
                ]),
                iaa.Add((-10, 10)), # change brightness of images (by -10 to 10 of original value)
                iaa.OneOf([
                    iaa.Multiply((0.5, 1.5)),
                    iaa.FrequencyNoiseAlpha(
                        exponent=(-4, 0),
                        first=iaa.Multiply((0.5, 1.5)),
                        second=iaa.ContrastNormalization((0.5, 2.0))
                    )
                ]),
                iaa.ContrastNormalization((0.5, 2.0)), # improve or worsen the contrast
            ],
            random_order=True
        ).to_deterministic()

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            img_path = os.path.join(self.image_root, ID)
            img = keras_image.load_img(img_path, target_size=self.dim) #, color_mode = "grayscale")
            img = keras_image.img_to_array(img)
            if self.augmentation:
                img = seq.augment_image(img)
            imgs[i] = self.preprocessor(img)

            q, r = self.labels[ID]['q'], self.labels[ID]['r']
            xa, ya, visibles = projectModel(q, r)
            for j, (x, y, visible) in enumerate(zip(xa, ya, visibles)):
                if j >= self.n_output_vertices:
                    break
                
                x /= (Camera.nu * self.output_scale)
                y /= (Camera.nv * self.output_scale)
                if visible and x >= 0.0 and y >= 0.0 and x <= 1.0 and y <= 1.0:
                    x_s, y_s = int(x * self.dim[1]), int(y * self.dim[0])
                    self.drawBlob(masks[i][...,j], (y_s, x_s), self.label_size)

        return imgs, [masks*255 for i in range(self.stages)]

In [None]:
# Load labels for training
with open(os.path.join(dataset_root_dir, 'train.json'), 'r') as f:
    label_list = json.load(f)

# Split training validation 80/20
train_labels = label_list[:int(len(label_list)*.8)]
validation_labels = label_list[int(len(label_list)*.8):]

# Data generators for training and validation
training_generator = KerasDataGenerator(preprocess_input, train_labels, dataset_root_dir + "/images/train", **config)
validation_generator = KerasDataGenerator(preprocess_input, validation_labels, dataset_root_dir + "/images/train", **config)

A data generator can also be created for the real data, but this is currently not used.

In [None]:
with open(os.path.join(dataset_root_dir, 'real.json'), 'r') as f:
    real_label_list = json.load(f)

real_generator = KerasDataGenerator(preprocess_input, real_label_list, "./speed/images/real", **config)

Verify the data generator by printing the scaled image and the corresponding keypoint heatmap (all n output heatmaps for each of the vertices is merged to one heatmap)

In [None]:
for imgs, masks in training_generator:
    print(imgs.shape, len(masks), masks[0].shape)
    for img, mask in zip(imgs, masks[0]):        
        plt.figure(figsize=(20, 20))

        plt.subplot(121)
        plt.imshow(img[...,0], cmap='gray')

        plt.subplot(122)
        m = np.zeros(mask.shape[:2], dtype=np.float)
        for i in range(config['n_output_vertices']):
            m += mask[...,i]
        plt.imshow(m, cmap='gray')
        plt.show()


## Training

In [None]:
checkpoint = ModelCheckpoint("model_.h5",save_best_only=True, verbose=1, monitor="val_loss")
reduce = ReduceLROnPlateau(factor=0.5, patience=3, monitor='val_loss')
earlyStopping = EarlyStopping(patience=6, verbose=1,monitor="val_loss")
tensorboard = TensorBoard(log_dir='logs/{}'.format(time()))
history = model.fit_generator(
    generator=training_generator,
    validation_data=validation_generator,
    use_multiprocessing=False,
    workers=8,
    callbacks=[earlyStopping, checkpoint, tensorboard],
    epochs=10000
)

In [None]:
def points_to_pose(points_3d, points_2d, camera_matrix=Camera.K):
    def convert_rvec_to_quaternion(rvec):
        '''Convert rvec (which is log quaternion) to quaternion'''
        theta = np.sqrt(rvec[0] * rvec[0] + rvec[1] * rvec[1] + rvec[2] * rvec[2])  # in radians
        raxis = [rvec[0] / theta, rvec[1] / theta, rvec[2] / theta]

        # pyrr's Quaternion (order is XYZW), https://pyrr.readthedocs.io/en/latest/oo_api_quaternion.html
        return np.roll(Quaternion.from_axis_rotation(raxis, theta), 1) # change order to wxyz

    (success, rotation_vector, translation_vector, outliners) = cv2.solvePnPRansac(
        points_3d, points_2d, camera_matrix, None, iterationsCount=250, reprojectionError=50)
    #(success, rotation_vector, translation_vector) = cv2.solvePnP(points_3d, points_2d, camera_matrix, None)
    #print(outliners)
    if success:
        location = list(translation_vector[...,0])
        quaternion = convert_rvec_to_quaternion(rotation_vector)

        projected_points, _ = cv2.projectPoints(points_3d, rotation_vector, translation_vector, camera_matrix, None)
        projected_points = np.squeeze(projected_points)

        # If the location.Z is negative or object is behind the camera then flip both location and rotation
        x, y, z = location
        if z < 0:
            print("neg")
            # Get the opposite location
            location = [-x, -y, -z]

            # Change the rotation by 180 degree
            rotate_angle = np.pi
            rotate_quaternion = Quaternion.from_axis_rotation(location, rotate_angle)
            quaternion = rotate_quaternion.cross(quaternion)

        return quaternion, location, projected_points
    return [], [], []

def extract_maxima(belief, area=5):
    b = belief.copy()
    maxima = []
    maxima_vals = []
    for i in range(3):
        pmax = np.unravel_index(b.argmax(), b.shape)
        pmax_val = b[pmax[0]][pmax[1]]
        if i > 0 and pmax_val < maxima_vals[-1] * 0.5 or pmax_val < 30.0:
        #if pmax_val < 100.0:
            break
        maxima_vals.append(pmax_val)
        b[max(0,pmax[0]-area):min(b.shape[0], pmax[0]+area):, max(0,pmax[1]-area):min(b.shape[1], pmax[1]+area):] = 0
        maxima.append(pmax)
    return np.asarray(maxima)

def predict_pose(model, img, mask=None, debug=True):
    model_points_all, _ = getSatelliteModel()

    pred = model.predict(np.asarray([img]))

    if debug:
        plt.figure(figsize=(20, 40))

    points_3d = []
    points_2d = []
    for i, p3d in enumerate(model_points_all):
        p = pred[-1][0][...,i]
        if debug:
            plt.subplot(8,3,i * 3 + 1)
            plt.imshow((img + 128).astype(np.uint8), cmap='gray')
            ax = plt.subplot(8,3,i * 3 + 2)
            ax.imshow(p)
        maxima = extract_maxima(p)
        if debug and mask is not None:
            axmask = plt.subplot(8,3,i * 3 + 3)
            axmask.imshow(mask[...,i], vmin=0, vmax=255)
        for m in maxima:
            points_3d.append(p3d)
            points_2d.append([m[1]*8*2, m[0]*8*2])
            if debug:
                ax.add_patch(Rectangle((m[1], m[0]),1,1,linewidth=5,edgecolor='r',facecolor='none'))
                if mask is not None:
                    axmask.add_patch(Rectangle((m[1], m[0]),1,1,linewidth=5,edgecolor='r',facecolor='none'))
            #break
    points_3d = np.array(points_3d, dtype=np.float32)
    points_2d = np.array(points_2d, dtype=np.float32)
    if len(points_2d) < 4:
        return np.array([0]*4), np.array([0]*3)
    
    quat_res, trans_res, points = points_to_pose(points_3d, points_2d, Camera.K)
    if debug:
        plt.tight_layout()
        plt.show()
        
        print("orig", points_2d)
        print("model", points_3d)
        print("proj", points)

        aximgproj = plt.subplot(111)
        aximgproj.imshow(np.clip(img + 128, 0, 255).astype(np.uint8), cmap='gray')
        for p in points:
            aximgproj.add_patch(Rectangle((p[0]/2, p[1]/2),3,3,linewidth=2,edgecolor='r',facecolor='none'))
        
        xa, ya = projectAxes(q, r)
        aximgproj.arrow(xa[0], ya[0], xa[1] - xa[0], ya[1] - ya[0], head_width=30, color='r')
        aximgproj.arrow(xa[0], ya[0], xa[2] - xa[0], ya[2] - ya[0], head_width=30, color='g')
        aximgproj.arrow(xa[0], ya[0], xa[3] - xa[0], ya[3] - ya[0], head_width=30, color='b')

        plt.show()
    return quat_res, trans_res
    #print("res", quat_res, trans_res)


## Perform a orientation determination on the validation data set

In [None]:
for imgs, masks in validation_generator:
    for img, mask in zip(imgs, masks[0]):        
        try:
            quat_res, trans_res = predict_pose(model, img, mask, True)
        except cv2.error as e:
            continue
        print("res", quat_res, trans_res)
        print("="*30)


## Perform orientation determination and save results

In [None]:
def evaluate(model, dataset, append_submission, dataset_root):

    """ Running evaluation on test set, appending results to a submission. """

    with open(os.path.join(dataset_root, dataset + '.json'), 'r') as f:
        image_list = json.load(f)

    print('Running evaluation on {} set...'.format(dataset))

    err1 = 0
    err2 = 0
    err3 = 0
    err4 = 0
    for i, img in enumerate(image_list):
        print("index", i)
        img_path = os.path.join(dataset_root, 'images', dataset, img['filename'])
        
        img_raw = keras_image.load_img(img_path, target_size=(600, 960)) #, color_mode = "grayscale")
        img_raw = keras_image.img_to_array(img_raw)
        img_proc = preprocess_input(img_raw)
        
        try:
            quat_res, trans_res = predict_pose(model, img_proc, debug=True)
            print(quat_res, trans_res)
            
            img_orig = keras_image.load_img(img_path)
            plt.imshow(img_orig)
            xa, ya = projectAxes(np.array(quat_res), np.array(trans_res))
            plt.arrow(xa[0], ya[0], xa[1] - xa[0], ya[1] - ya[0], head_width=30, color='r')
            plt.arrow(xa[0], ya[0], xa[2] - xa[0], ya[2] - ya[0], head_width=30, color='g')
            plt.arrow(xa[0], ya[0], xa[3] - xa[0], ya[3] - ya[0], head_width=30, color='b')
            plt.show()
            
            if len(quat_res) == 0 or len(trans_res) == 0:
                append_submission(img['filename'], [0]*4, [0,0,3])
                err1 += 1
            elif (np.array(trans_res) == 0).all():
                append_submission(img['filename'], [0]*4, [0,0,10])
                err4 += 1
            elif trans_res[2] > 60:
                append_submission(img['filename'], [0]*4, [0,0,40])
                err2 += 1
            else:
                append_submission(img['filename'], quat_res, trans_res)
        except cv2.error as e:
            append_submission(img['filename'], [0]*4, [0,0,10])
            err3 += 1
        print("="*30)

    print("Err amount", err1, err2, err3, err4)
            
submission = SubmissionWriter()
evaluate(model, 'test', submission.append_test, dataset_root_dir)
evaluate(model, 'real_test', submission.append_real_test, dataset_root_dir)
submission.export(suffix='submission')
