In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

In [2]:
MAX_DEPTH = 10.

In [12]:
class ProblemInstance:
    
    @staticmethod
    def rand_rot_mtx():
        roll = torch.randint(high=90, size=(1,))
        yaw = torch.randint(high=90, size=(1,))
        pitch = torch.randint(high=90, size=(1,))

        tensor_0 = torch.zeros(1)
        tensor_1 = torch.ones(1)

        RX = torch.stack([
                    torch.stack([tensor_1, tensor_0, tensor_0]),
                    torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),
                    torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)

        RX = torch.eye(3)

        print("Pitch is: ", pitch)
        RY = torch.stack([
                        torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
                        torch.stack([tensor_0, tensor_1, tensor_0]),
                        torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)

        RZ = torch.stack([
                        torch.stack([torch.cos(yaw), -torch.sin(yaw), tensor_0]),
                        torch.stack([torch.sin(yaw), torch.cos(yaw), tensor_0]),
                        torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3,3)

        RZ = torch.eye(3)

        R = torch.mm(RZ, RY)
        R = torch.mm(R, RX)

        return R
    
    @staticmethod
    def generate_example():
        R = ProblemInstance.rand_rot_mtx()
        t = torch.rand(3, )

        x1, x2 = torch.rand(3,), torch.rand(3, )
        x1[2], x2[2] = 1., 1.

        rdepth = torch.rand(1) * MAX_DEPTH

        X1, X2 = R.T @ (x1 * rdepth - t), R.T @ (x2 * rdepth - t)

        return x1, x2, X1, X2, R, t

In [13]:
ProblemInstance.generate_example()

Pitch is:  tensor([73])


(tensor([0.8180, 0.9137, 1.0000]),
 tensor([0.7823, 0.1215, 1.0000]),
 tensor([ 0.4048,  6.8728, -9.1699]),
 tensor([ 0.6104,  0.6750, -8.9810]),
 tensor([[-0.7362,  0.0000, -0.6768],
         [ 0.0000,  1.0000,  0.0000],
         [ 0.6768,  0.0000, -0.7362]]),
 tensor([0.4916, 0.2753, 0.7986]))

In [131]:
class Up2P:
    
    def __init__(self):
        self.dtype = torch.float64
        # TODO: add angles to potentially optimize for
        # or use them in a solver itself so not to prerotate the scene
    
    # x: [2,3]
    # X: [2,3]
    def __call__(self, x, X):
        assert x.shape == (2, 3)
        assert X.shape == (2, 3)
        # [4, 4]
        # should be transposed as in Eigen order of indexation is a different one
        A = torch.tensor([[-x[0, 2], 0, x[0, 0], X[0, 0] * x[0, 2] - X[0, 2] * x[0, 0]],
                          [0, -x[0, 2], x[0, 1], -X[0, 1] * x[0, 2] - X[0, 2] * x[0, 1]],
                          [-x[1, 2], 0, x[1, 0], X[1, 0] * x[1, 2] - X[1, 2] * x[1, 0]],
                          [0, -x[1, 2], x[1, 1], -X[1, 1] * x[1, 2] - X[1, 2] * x[1, 1]]],
                        dtype=self.dtype)
        # [4, 2]                  
        b = torch.cat([torch.tensor([
                                -2 * X[0, 0] * x[0, 0] - 2 * X[0, 2] * x[0, 2],
                                X[0, 2] * x[0, 0] - X[0, 0] * x[0, 2],
                                -2 * X[0, 0] * x[0, 1],
                                X[0, 2] * x[0, 1] - X[0, 1] * x[0, 2]
                           ], dtype=self.dtype),
                       torch.tensor([
                               -2 * X[1, 0] * x[1, 0] - 2 * X[1, 2] * x[1, 2],
                               X[1, 2] * x[1, 0] - X[1, 0] * x[1, 2],
                               -2 * X[1, 0] * x[1, 1],
                               X[1, 2] * x[1, 1] - X[1, 1] * x[1, 2]
                           ], dtype=self.dtype)],
                      dim=-1).reshape((4, 2))
       
        assert A.shape == (4, 4) and b.shape == (4, 2)
        
        b = A.inverse() @ b
        sols = self.solve_quadratic_real(1., b[3, 0], b[3, 1])
        if sols is None:
            return []
        
        res = []
        for q in sols:
            q2 = q ** 2
            inv_norm = 1 / (1 + q2)
            cq = (1 - q2) * inv_norm
            sq = 2 * q * inv_norm
            
            R = torch.eye(3)
            R[0, 0] = cq
            R[0, 2] = sq
            R[2, 0] = -sq
            R[2, 2] = cq

            t = b[:3, 0] * q + b[:3, 1]
            t *= -inv_norm
            res.append((R, t))
        
        return res
        
    def solve_quadratic_real(self, a, b, c):
        b2m4ac = b * b - 4 * a * c
        if b2m4ac < 0:
            return None
        
        sq = torch.sqrt(b2m4ac)
        roots = []
        if b > 0:
            roots.append((2*c) / (-b - sq))
        else:
            roots.append((2*c) / (-b + sq))
            
        roots.append(c / (a * roots[0]))
        
        return roots
    
    # [TODO] measure some more meaningful metrics
    @staticmethod
    def validate_sol(R, t, Rgt, tgt):
        return (R - Rgt).norm(), (t - tgt).norm()
        
        
solver = Up2P()
# solver(None, None)

In [132]:
solver(torch.stack([x1, x2]), torch.stack([X1, X2]))

[(tensor([[-0.6669,  0.0000,  0.7451],
          [ 0.0000,  1.0000,  0.0000],
          [-0.7451,  0.0000, -0.6669]]),
  tensor([0.1421, 0.6658, 0.8468], dtype=torch.float64)),
 (tensor([[-0.9117,  0.0000, -0.4108],
          [ 0.0000,  1.0000,  0.0000],
          [ 0.4108,  0.0000, -0.9117]]),
  tensor([-2.4490,  0.1652,  2.4131], dtype=torch.float64))]

In [133]:
x1, x2, X1, X2, Rg, tg = ProblemInstance.generate_example()

Pitch is:  tensor([69])


In [134]:
x1, x2, X1, X2, R, t

(tensor([0.5298, 0.1372, 1.0000]),
 tensor([0.1857, 0.0610, 1.0000]),
 tensor([3.5739, 0.7794, 5.9480]),
 tensor([1.3741, 0.2890, 6.2022]),
 tensor([[-0.9117,  0.0000, -0.4108],
         [ 0.0000,  1.0000,  0.0000],
         [ 0.4108,  0.0000, -0.9117]]),
 tensor([-2.4490,  0.1652,  2.4131], dtype=torch.float64))

In [135]:
print(Rg, tg)
for R, t in solver(torch.stack([x1, x2]), torch.stack([X1, X2])):
    re, te = Up2P.validate_sol(R, t, Rg, tg)
    print(f"--------- R error: {re.float()}, t error: {te.float()} ----------")
    print(R,  t)
    print()

tensor([[ 0.9934,  0.0000, -0.1148],
        [ 0.0000,  1.0000,  0.0000],
        [ 0.1148,  0.0000,  0.9934]]) tensor([0.5413, 0.1033, 0.1147])
--------- R error: 1.589472731211572e-06, t error: 8.036430699576158e-06 ----------
tensor([[ 0.9934,  0.0000, -0.1148],
        [ 0.0000,  1.0000,  0.0000],
        [ 0.1148,  0.0000,  0.9934]]) tensor([0.5413, 0.1033, 0.1147], dtype=torch.float64)

--------- R error: 0.25262829661369324, t error: 1.2381080389022827 ----------
tensor([[ 0.9571,  0.0000, -0.2897],
        [ 0.0000,  1.0000,  0.0000],
        [ 0.2897,  0.0000,  0.9571]]) tensor([ 1.5444,  0.0601, -0.6098], dtype=torch.float64)

