In [3]:
from utils.lie_group_helper import make_c2w

import torch

In [28]:
tmp = make_c2w(torch.tensor([1.57, 0, 0], dtype=torch.float32), torch.tensor([0, 0, 0], dtype=torch.float32))
tmp

tensor([[ 1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  7.9626e-04, -1.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  1.0000e+00,  7.9626e-04,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]])

array([0.43980232, 0.43980232, 0.43980232, 0.64785936])

In [8]:
import numpy as np
import torch
from scipy.spatial.transform import Rotation as RotLib


def SO3_to_quat(R):
    """
    :param R:  (N, 3, 3) or (3, 3) np
    :return:   (N, 4, ) or (4, ) np
    """
    x = RotLib.from_matrix(R)
    quat = x.as_quat()
    return quat


def quat_to_SO3(quat):
    """
    :param quat:    (N, 4, ) or (4, ) np
    :return:        (N, 3, 3) or (3, 3) np
    """
    x = RotLib.from_quat(quat)
    R = x.as_matrix()
    return R


def convert3x4_4x4(input):
    """
    :param input:  (N, 3, 4) or (3, 4) torch or np
    :return:       (N, 4, 4) or (4, 4) torch or np
    """
    if torch.is_tensor(input):
        if len(input.shape) == 3:
            output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1)  # (N, 4, 4)
            output[:, 3, 3] = 1.0
        else:
            output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0)  # (4, 4)
    else:
        if len(input.shape) == 3:
            output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1)  # (N, 4, 4)
            output[:, 3, 3] = 1.0
        else:
            output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0)  # (4, 4)
            output[3, 3] = 1.0
    return output


def vec2skew(v):
    """
    :param v:  (3, ) torch tensor
    :return:   (3, 3)
    """
    zero = torch.zeros(1, dtype=torch.float32, device=v.device)
    skew_v0 = torch.cat([ zero,    -v[2:3],   v[1:2]])  # (3, 1)
    skew_v1 = torch.cat([ v[2:3],   zero,    -v[0:1]])
    skew_v2 = torch.cat([-v[1:2],   v[0:1],   zero])
    skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0)  # (3, 3)
    return skew_v  # (3, 3)


def Exp(r):
    """so(3) vector to SO(3) matrix
    :param r: (3, ) axis-angle, torch tensor
    :return:  (3, 3)
    """
    skew_r = vec2skew(r)  # (3, 3)
    norm_r = r.norm() + 1e-15
    eye = torch.eye(3, dtype=torch.float32, device=r.device)
    R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
    return R


def make_c2w(r, t):
    """
    :param r:  (3, ) axis-angle             torch tensor
    :param t:  (3, ) translation vector     torch tensor
    :return:   (4, 4)
    """
    R = Exp(r)  # (3, 3)
    c2w = torch.cat([R, t.unsqueeze(1)], dim=1)  # (3, 4)
    c2w = convert3x4_4x4(c2w)  # (4, 4)
    return c2w


In [9]:
a = torch.FloatTensor([1, 1, 1])

In [21]:
a = torch.FloatTensor([1, 1, 1])
r = a

skew_r = vec2skew(r)  # (3, 3)
print(skew_r)
norm_r = r.norm() + 1e-15
print(norm_r)
eye = torch.eye(3, dtype=torch.float32, device=r.device)
R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)

tensor([[ 0., -1.,  1.],
        [ 1.,  0., -1.],
        [-1.,  1.,  0.]])
tensor(1.7321)


In [22]:
def restore_rt(c2w):
    """
    Restores the rotation vector (axis-angle) and translation vector from the c2w matrix
    :param c2w: (4, 4) torch tensor
    :return:    (3, ) rotation vector, (3, ) translation vector
    """
    R = c2w[:3, :3]  # (3, 3)
    t = c2w[:3, 3]   # (3, )

    # Convert rotation matrix to axis-angle
    rotation_matrix = R.cpu().numpy()  # Convert to numpy for scipy
    rot = RotLib.from_matrix(rotation_matrix)
    r = torch.tensor(rot.as_rotvec(), dtype=torch.float32, device=c2w.device)  # Convert back to torch tensor

    return r, t

In [1]:
from models.intrinsics import LearnFocal
model = LearnFocal(H=512, 
                        W=384, 
                        n_images=88, 
                        init_focal=[0 for _ in range(88)])

In [2]:
import torch
ckpt = torch.load("/home/diya/Public/Image2Smiles/KMolOCR_DL_Server/3D-CV-Project/dustnerf/logs/nerfmm/lego/lr_0.001_gpu0_seed_17_resize_1_Nsam_128_Ntr_img_-1_freq_10__240607_2013/latest_focal.pth")


In [3]:
model.load_state_dict(ckpt['model_state_dict'])

<All keys matched successfully>

In [4]:
ckpt

{'epoch': 600,
 'model_state_dict': OrderedDict([('weight_global',
               tensor(0.7835, device='cuda:0')),
              ('bias_global', tensor(-101.8961, device='cuda:0')),
              ('bias_local',
               tensor([ 2.0973e-01, -1.1642e-01,  1.9123e+00,  2.1880e+00, -4.6085e-01,
                       -3.8331e+00,  4.3606e-01, -2.6331e+00, -1.7671e+00,  2.6570e+00,
                       -3.7215e+00, -2.4732e+00,  2.7677e+00,  1.1785e+00,  1.2484e+00,
                       -1.5905e+00, -2.1595e+00,  2.7557e+00,  2.9114e-01,  7.3540e-02,
                        4.9491e+00, -4.6179e+00, -1.8507e+00, -8.0863e-01,  9.9891e-01,
                        2.6053e-01,  3.1926e+00,  2.0520e+00, -9.4284e+00, -3.3916e+00,
                       -3.7549e+00, -2.8690e+00, -4.4767e+00, -2.6130e+00, -1.2340e+01,
                        1.3400e+00,  1.5014e+00,  5.1290e-03, -4.9794e+00,  7.5949e+00,
                        1.0287e+00,  4.2055e+00, -1.2009e+00,  2.8225e+00,  1.1607e+

In [5]:
model.get_all_focals()

tensor([[-101.8521, -101.8825,  -98.2392,  -97.1087, -101.6837,  -87.2032,
         -101.7059,  -94.9630,  -98.7733,  -94.8365,  -88.0468,  -95.7793,
          -94.2360, -100.5072, -100.3375,  -99.3664,  -97.2326,  -94.3021,
         -101.8113, -101.8907,  -77.4023,  -80.5711,  -98.4708, -101.2422,
         -100.8982, -101.8282,  -91.7034,  -97.6853,  -13.0012,  -90.3933,
          -87.7967,  -93.6646,  -81.8550,  -95.0682,   50.3814, -100.1005,
          -99.6418, -101.8960,  -77.1012,  -44.2137, -100.8379,  -84.2094,
         -100.4538,  -93.9297, -100.5489,  -85.6869,  -93.4529,  -88.9712,
          -97.4887, -101.8884,  -91.1981, -101.1854, -101.1605, -101.7142,
         -101.5274,  -89.2821, -101.8357,  -96.8990,  -73.4749, -101.7855,
         -101.5877,  -97.3508,  -99.8376, -101.3184, -101.7229,  -97.5653,
         -101.5519,  -96.9606,  -64.8330,  -99.9271, -101.8592,  -93.5854,
         -101.7880,  -99.0758,  -98.8656,  -16.9338,  -83.2017,  -95.6560,
          -82.8239,  -11.