In [151]:
import os
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from utils import *

import torch.nn.functional as F

import math

from distribution import AngularCentralGaussian, cholesky_wrapper
from pyro.distributions import MultivariateStudentT
from torch.utils.data import Dataset, DataLoader

from dataset import H36M
import torch
import numpy as np
import cv2

In [191]:
def cal_mpjpe_batch(points3d, points2d, R, t):
    """
    Args:
        points3d : (...,J,3)
        points2d : (...,V,J,2)
        R : (...,V,3,3)
        t: (...,V,3,1)
    Returns:
        weights : (...)
    """
    return torch.exp(
        -(
            homo_to_eulid(
                (R[...,None,:,:] @ points3d[...,None,:,:,None] + t[...,None,:,:]
                ).squeeze(-1)
            ) - points2d 
        ).norm(dim=-1).mean((-1,-2))
    )


class ProbabilisticTriangulation():
    def __init__(self, n_batch, n_view):
        self.n_batch = n_batch
        self.n_view = n_view
        self.expect_quan = torch.zeros(self.n_batch,self.n_view-1,4)
        self.tril_R = torch.eye(4,4)[None,None].expand(self.n_batch, self.n_view-1,-1,-1)
        self.mu_t = torch.zeros(self.n_batch,self.n_view-1,3)
        self.tril_t = torch.eye(3,3)[None,None].expand(self.n_batch,self.n_view-1,-1,-1)
        #  conv_quan (B,V,4,4)
        self.distrR = AngularCentralGaussian(self.tril_R)
        #  mu_t (B,V,3) conv_t (B,V,3,3)
        self.distrT = MultivariateStudentT(loc=self.mu_t,scale_tril=self.tril_t,df=3)

    def sample(self, size : torch.Size()):
        self.quan = self.distrR(size)
        self.t = self.distrT(size)
        # print(self.quan.shape, self.t.shape)

        sample_R = torch.cat([torch.eye(3)[None,None,None].expand(size[0],self.n_batch,-1,-1,-1) ,quaternion_to_matrix(self.quan)], dim = -3)
        sample_t = torch.cat([torch.zeros(size[0],self.n_batch,1,3) ,self.t] , dim = -2).unsqueeze(-1)
        return sample_R, sample_t

    def update_paramater_init(self, R,t):
        """
        Args:
            R : (B,V+1,3,3) -> (B,V,3,3)
            t : (B,V+1,3,1) -> (B,V,3,1)
        Returns:
            sample_quan : (M,B,V,4)
            sample_t : (M,B,V,3)
            weights: (M,B)
            M = 16
        """
        self.sample((15,))
        sample_quan = torch.cat([ matrix_to_quaternion(R[:,1:])[None], self.quan ],dim=0)
        sample_t = torch.cat([ t[None,:,1:].squeeze(-1), self.t], dim=0)
        self.quan = sample_quan
        self.t = sample_t
        weights = torch.tensor([1]+[0.1 for i in range(15)])[...,None].expand(-1,self.n_batch)
        self.update_paramater_with_weights(weights)

    def update_paramater_with_weights(self, weights):
        """
        Args:
            self.quan : (M,B,V,4)
            self.t : (M,B,V,3)
            weights : (M,B)
        Returns:
            conv_quan : (B,V,4,4)
            mu_t : (B,V,3)
            conv_t : (B,V,3,3)
        """
        
        # (B,V,4,M) @ (B,V,M,4) -> (B,V,4,4)
        conv_quan = (
            self.quan.permute(1,2,3,0) @ (self.quan * weights[...,None,None]).permute(1,2,0,3)
        ) / weights.sum(0)[...,None,None]
        self.tril_quan = cholesky_wrapper(conv_quan)

        self.mu_t = self.t.mean(0)

        centered_t = self.t - self.mu_t[None]
        # (B,V,3,M) @ (B,V,M,3) -> (B,V,3,3)
        conv_t = (
            centered_t.permute(1,2,3,0) @ (centered_t * weights[...,None,None]).permute(1,2,0,3)
        ) / weights.sum(0)[...,None,None]
        self.tril_t = cholesky_wrapper(conv_t)

        self.expect_quan = (self.quan * weights[...,None,None]).sum(0) / weights.sum(0)[...,None,None]
        self.distrR = AngularCentralGaussian(self.tril_quan)
        self.distrT = MultivariateStudentT(loc=self.mu_t,scale_tril=self.tril_t,df = 3)

    def getRt(self):
        """
        Returns:
            R : (B,V,3,3)
            t : (B,V,3,1)
        """
        R = torch.cat( [ torch.eye(3)[None,None].expand(self.n_batch,1,3,3), quaternion_to_matrix(self.expect_quan) ] ,dim=-3)
        t = torch.cat([torch.zeros(self.n_batch,1,3) ,self.mu_t] , dim = -2).unsqueeze(-1)
        return R,t




In [192]:
a = ProbabilisticTriangulation(1,4)
# R,t = a.sample((16,))
# weights = cal_mpjpe_batch(pose_3d[None], pose_2d[None], R, t)
a.update_paramater_init(Rgt,tgt)
R,t = a.getRt()

In [194]:
R,Rgt

(tensor([[[[ 1.0000,  0.0000,  0.0000],
           [ 0.0000,  1.0000,  0.0000],
           [ 0.0000,  0.0000,  1.0000]],
 
          [[-0.4481, -0.0861, -0.8898],
           [-0.0860,  0.9949, -0.0530],
           [ 0.8898,  0.0528, -0.4532]],
 
          [[ 0.8750,  0.3302,  0.3541],
           [-0.1889,  0.9063, -0.3781],
           [-0.4458,  0.2640,  0.8553]],
 
          [[-0.9555,  0.2104,  0.2069],
           [ 0.2730,  0.8965,  0.3490],
           [-0.1120,  0.3900, -0.9140]]]]),
 tensor([[[[ 1.0000,  0.0000,  0.0000],
           [ 0.0000,  1.0000,  0.0000],
           [ 0.0000,  0.0000,  1.0000]],
 
          [[-0.6687,  0.1179, -0.7341],
           [-0.1866,  0.9291,  0.3192],
           [ 0.7197,  0.3505, -0.5993]],
 
          [[ 0.6768, -0.0671,  0.7331],
           [ 0.0947,  0.9955,  0.0037],
           [-0.7301,  0.0669,  0.6801]],
 
          [[-0.9931, -0.0928, -0.0716],
           [-0.1147,  0.8949,  0.4312],
           [ 0.0241,  0.4365, -0.8994]]]]))

In [6]:
class CalibrationBatch():
    def __init__(self, points2d, confi2d):
        """
        points2d : (B,V,J,2)
        confi2d : (B,V,J)
        points3d : (B,J,3)
        confi3d : (B,J)
        R : (B,V,3,3)
        t : (B,V,3,1)
        isdistribution : bool
        """
        self.n_batch,self.n_view,self.n_joint = points2d.shape[:3]
        self.points2d = points2d
        self.confi2d = confi2d
        self.points3d = torch.zeros((self.n_batch,self.n_joint,3))
        self.confi3d = torch.zeros((self.n_batch,self.n_joint))
        self.R = torch.zeros((self.n_batch,self.n_view,3,3))
        self.t = torch.zeros((self.n_batch,self.n_view,3,1))
        self.prob_tri = ProbabilisticTriangulation(self.n_batch, self.n_view)


    def weighted_triangulation(self, points2d, confi2d, R ,t):
        """
        Args:
            points2d : (V',J,2)
            confi2d : (V',J)
            R : (V',3,3)
            t : (V',3,1)
        Returns:
            points3d : (J,3)
            confi3d : (J)
        """
        n_view_filter= points2d.shape[0]
        points3d = torch.zeros((self.n_joint, 3))
        confi3d = torch.zeros((self.n_joint))
        # print(points2d.shape,confi2d.shape,R.shape,t.shape)
        for j in range(self.n_joint):
            A = []
            for i in range(n_view_filter):
                if confi2d[i,j] > 0.5:
                    P = torch.cat([R[i],t[i]],dim=1)
                    P3T = P[2]
                    A.append(confi2d[i,j] * (points2d[i,j,0]*P3T - P[0]))
                    A.append(confi2d[i,j] * (points2d[i,j,1]*P3T - P[1]))
            A = torch.stack(A)
            # print(A.shape)
            if A.shape[0] >= 4:
                u, s, vh = torch.linalg.svd(A)
                error = s[-1]
                X = vh[len(s) - 1]
                points3d[j,:] = X[:3] / X[3]
                confi3d[j] = np.exp(-torch.abs(error))
            else:
                points3d[:,j] = torch.tensor([0.0,0.0,0.0])
                confi3d[j] = 0

        return points3d, confi3d

    def weighted_triangulation_sample(self, points2d, confi2d, R ,t):
        """
        Args:
            points2d : (B,V',J,2)
            confi2d : (B,V',J)
            R : (M,B, V',3,3)
            t : (M,B, V',3,1)
        Returns:
            sample_points3d : (M,B,J,3)
            sample_confi3d : (M,B,J)
        """
        n_sample = R.shape[0]
        sample_points3d = torch.zeros((n_sample,self.n_batch,self.n_joint,3))
        sample_confi3d = torch.zeros((n_sample,self.n_batch,self.n_joint))
        for i in range(n_sample):
            for j in range(self.n_batch):
                sample_points3d[i,j], sample_confi3d = self.weighted_triangulation(
                    points2d[j], confi2d[j], R[i,j], t[i,j]
                )
        return sample_points3d, sample_confi3d

    def pnp(self,batch_id):
        for i in range(self.n_view):
            mask = torch.logical_and(self.confi2d[batch_id,i]>0.8,self.confi3d[batch_id]>0.8)
            p2d = self.points2d[batch_id,i,mask].numpy()
            p3d = self.points3d[batch_id,mask].numpy()
            ret, rvec, tvec = cv2.solvePnP(p3d, p2d, np.eye(3), np.zeros(5))
            R, _ = cv2.Rodrigues(rvec)
            self.R[batch_id,i] = torch.tensor(R)
            self.t[batch_id,i] = torch.tensor(tvec)


    def eight_point(self):
        for batch_id in range(self.n_batch):
            mask = torch.logical_and(self.confi2d[batch_id,0]>0.8, self.confi2d[batch_id,1]>0.8)
            
            p0 = self.points2d[batch_id,0,mask].numpy()
            p1 = self.points2d[batch_id,1,mask].numpy()
            # p0,p1 (N,2)
            E, mask = cv2.findEssentialMat(p0, p1, focal=1.0, pp=(0., 0.),
                                            method=cv2.RANSAC, prob=0.999, threshold=0.0003)
            p0_inliers = p0[mask.ravel() == 1]
            p1_inliers = p0[mask.ravel() == 1]
            point, R, t,mask  = cv2.recoverPose(E, p0_inliers, p1_inliers)
            self.R[batch_id,0],self.t[batch_id,0] = torch.eye(3), torch.zeros((3,1))
            self.R[batch_id,1],self.t[batch_id,1] = torch.tensor(R),torch.tensor(t)

            print(self.R[batch_id,0],self.t[batch_id,0])

            self.points3d[batch_id], self.confi3d[batch_id] = self.weighted_triangulation(
                self.points2d[batch_id,:2],self.confi2d[batch_id,:2],self.R[batch_id,:2],self.t[batch_id,:2]
            )
            
            self.pnp(batch_id)
            
            # print(self.R[batch_id,0],self.t[batch_id,0])
            # print(self.mpjpe(2))
            # print(self.confi3d[batch_id])

            self.points3d[batch_id], self.confi3d[batch_id] = self.weighted_triangulation(
                self.points2d[batch_id],self.confi2d[batch_id],self.R[batch_id],self.t[batch_id]
            )
            # print(self.confi3d[batch_id])
            # print(self.mpjpe(self.n_view))

    def monte_carlo(self):
        self.eight_point()
        self.prob_tri.update_paramater_init(self.R,self.t)
        for i in range(16):
            sample_R, sample_t = self.prob_tri.sample()
            sample_points3d, sample_confi3d = self.weighted_triangulation_sample(self.points2d, self.confi2d, sample_R, sample_t)
            weights = cal_mpjpe_batch(sample_points3d, self.points2d[None],  sample_R, sample_t)
            self.prob_tri.update_paramater_with_weights(weights)
            self.R, self.t = self.prob_tri.getRt()

    def mpjpe(self, n_view_filter):
        return (homo_to_eulid((self.R[...,:n_view_filter,None,:,:] @ self.points3d[...,None,:,:,None] + self.t[...,:n_view_filter,None,:,:]).squeeze(-1)) - self.points2d[:,:n_view_filter] ).mean()
    


# calibr = CalibrationBatch(pose_2d,confi)
# calibr.eight_point()
# calibr.mpjpe(2)
# calibr.confi2d

In [182]:
h36m = H36M()
h36mloader = DataLoader(h36m, batch_size = 1, shuffle = True)
for step, (pose_3d, pose_2d, confi, Rgt, tgt) in enumerate(h36mloader):
    print(pose_2d.shape,pose_3d.shape, confi.shape,Rgt.shape,tgt.shape)
    break

torch.Size([1, 4, 17, 2]) torch.Size([1, 17, 3]) torch.Size([1, 4, 17]) torch.Size([1, 4, 3, 3]) torch.Size([1, 4, 3, 1])


In [7]:
calibr = CalibrationBatch(pose_2d,confi)
calibr.eight_point()
calibr.mpjpe(2)
calibr.confi2d

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]]) tensor([[0.],
        [0.],
        [0.]])
tensor([[ 1.0000e+00,  4.3324e-07, -3.1816e-07],
        [-4.3324e-07,  1.0000e+00, -1.0566e-06],
        [ 3.1816e-07,  1.0566e-06,  1.0000e+00]]) tensor([[1.7213e-07],
        [5.0632e-07],
        [3.7451e-08]])
tensor(-6.9096e-09)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor(1.9859e-10)


tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],
       dtype=torch.float64)

In [20]:
quaternions.shape

torch.Size([2, 2, 4])