In [None]:
import numpy as np
from numpy import linalg as LA
from scipy.io import loadmat
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib as mpl
from matplotlib.pyplot import cm
import cv2
from tqdm import trange
import computer_vision as cv




def compute_reprojection_error_wrt_T(Pi, X, x):

    x_proj = cv.dehomogenize(Pi @ X)
    res = np.array([x[0] - x_proj[0], x[1] - x_proj[1]])
    res = np.array(res).reshape(-1)
    reproj_err = LA.norm(res)**2

    return reproj_err, res

def compute_total_reprojection_error_wrt_T(P_arr, X_idx_arr, x_arr, inliers_arr, verbose=False):

    n_cameras = P_arr.shape[0]
    reproj_err_tot = []
    res_tot = []

    for i in range(n_cameras):

        Pi = P_arr[i]

        inliers = inliers_arr[i]
        X_idx = X_idx_arr[i]

        X = X[:,X_idx][:,inliers]
        x =  x_arr[i][:,inliers]

        reproj_err, res = compute_reprojection_error_wrt_T(Pi, X, x)
        reproj_err_tot.append(reproj_err)
        res_tot.append(res)
    
    res_tot = np.concatenate(res_tot, 0)

    if verbose:
        print('\nTotal reprojection error:', round(np.sum(reproj_err_tot), 2))
        print('Median reprojection error:', round(np.median(reproj_err_tot), 2))
        print('Avg. reprojection error:', round(np.mean(reproj_err_tot), 2))

    return reproj_err_tot, res_tot

def compute_jacobian_of_residual_wrt_T(Pi, Xj):
    jac1 = (Pi[0,:] @ Xj) / (Pi[-1,:] @ Xj)**2 - (1 / (Pi[-1,:] @ Xj))
    jac2 = (Pi[1,:] @ Xj) / (Pi[-1,:] @ Xj)**2 - (1 / (Pi[-1,:] @ Xj))
    jac = np.row_stack((jac1, jac2))
    return jac

def linearize_reprojection_error_wrt_T(Pi, X, x):

    _, res = compute_reprojection_error_wrt_T(Pi, X, x)

    jac = []
    for j in range(X.shape[1]):
        jac_i = compute_jacobian_of_residual_wrt_T(Pi, X[j])
        jac.append(jac_i)
    jac = np.concatenate(jac, 0)
    
    return res, jac

def compute_update(res, jac, mu):
    I = np.eye(jac.shape[1])
    delta = -LA.inv(jac.T @ jac + mu*I) @ jac.T @ res
    return delta

def optimize_T(P_arr, X_init, X_idx_arr, x_arr, inliers_arr, mu_init, n_its, verbose=False):

    n_cameras = P_arr.shape[0]
    steps = []

    for i in trange(n_cameras):

        Pi = P_arr[i]
        Ri = Pi[:,:-1]
        Ti = Pi[:,-1]

        inliers = inliers_arr[i]
        X_idx = X_idx_arr[i]

        X = X_init[:,X_idx][:,inliers]
        x = x_arr[i][:,inliers]

        converged = False
        mu = mu_init
        step = 0
    
        while (step <= n_its) and converged is not True:
            step += 1
            
            res, jac = linearize_reprojection_error_wrt_T(Pi, X, x)
            print('res.shape, jac.shape', res.shape, jac.shape, jac[:,None])

            delta_T = compute_update(res, jac[:,None], mu)
            Ti_opt = Ti + delta_T
            Pi_opt = np.column_stack((Ri, Ti_opt))
            
            reproj_err, _ = compute_reprojection_error_wrt_T(Pi, X, x)
            reproj_err_opt, _ = compute_reprojection_error_wrt_T(Pi_opt, X, x)

            if np.isclose(reproj_err_opt, reproj_err):
                converged = True
            elif reproj_err_opt < reproj_err:
                Ti = Ti_opt.copy()
                mu /= 10
            else:
                mu *= 10
        
        P_arr[i,:,-1] = Ti
        steps.append(step)

    if verbose:
        print('\nAvg its:', np.mean(steps))
        print('Max its:', np.max(steps))
        print('Min its:', np.min(steps))

    return P_arr






# import numpy as np
# from scipy.io import loadmat
# import scipy
# from mpl_toolkits import mplot3d
# import matplotlib.pyplot as plt
# import matplotlib.image as mpimg
# from matplotlib.pyplot import cm
# import matplotlib as mpl
# import cv2
# from tqdm import trange
# from scipy.spatial.transform import Rotation
# import sys



# from scipy.io import loadmat
# import matplotlib.image as mpimg
# import matplotlib as mpl
# import time
# from get_dataset_info import *
# import cv2
