In [38]:
import torch
import os
import random
import tensorflow as tf
from __future__ import division
import tf_slim as slim

In [None]:
# Reference code to load the data - Need to fix the paths to the data

class DataLoader(object):
    def __init__(self, 
                 dataset_dir=None, 
                 batch_size=None, 
                 img_height=None, 
                 img_width=None, 
                 num_source=None):
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.num_source = num_source

    def load_train_batch(self):
        """Load a batch of training instances."""

        # Random seed generated to randomize data loading for each epoch
        seed = random.randint(0, 2**31 - 1)

        # Reads train.txt file from dataset directory
        # Lists the paths of images and corresponding camera intrinsic files
        # Output = image_file_list, cam_file_list
        file_list = self.format_file_list(self.dataset_dir, 'train')

        # TensorFlow queues created for image and intrinsic files
        # string_input_producer manages input queues that provide file paths one at a time for processing
        image_paths_queue = tf.train.string_input_producer(
            file_list['image_file_list'], 
            seed=seed, 
            shuffle=True)
        cam_paths_queue = tf.train.string_input_producer(
            file_list['cam_file_list'], 
            seed=seed, 
            shuffle=True)
        self.steps_per_epoch = int(
            len(file_list['image_file_list']) // self.batch_size)

        # WholeFileReader reads raw file contents from image_paths_queue
        img_reader = tf.WholeFileReader()
        _, image_contents = img_reader.read(image_paths_queue)
        # Decode JPEG-encoded image sequences into tensors
        image_seq = tf.image.decode_jpeg(image_contents)

        # Split sequences into target and source images
        #   Target image - center frame in sequence
        #   Source image stack - frames before and after the target frame, concat along channel axis
        tgt_image, src_image_stack = self.unpack_image_sequence(
            image_seq, self.img_height, self.img_width, self.num_source)

        # Load camera intrinsics
        # TextLineReader reads corresponding camera intrinsic files - reshaped into 3x3 intrinsic matrix
        cam_reader = tf.TextLineReader()
        _, raw_cam_contents = cam_reader.read(cam_paths_queue)
        rec_def = [[1.] for _ in range(9)]
        raw_cam_vec = tf.decode_csv(raw_cam_contents, record_defaults=rec_def)
        raw_cam_vec = tf.stack(raw_cam_vec)
        intrinsics = tf.reshape(raw_cam_vec, [3, 3])

        # Form training batches - target images, source stacks, intrinsics
        src_image_stack, tgt_image, intrinsics = tf.train.batch(
            [src_image_stack, tgt_image, intrinsics], 
            batch_size=self.batch_size)

        return tgt_image, src_image_stack, intrinsics

    # Construct 3x3 camera intrinsics matrix for each batch
    def make_intrinsics_matrix(self, fx, fy, cx, cy):
        # Uses focal and principal points
        batch_size = fx.get_shape().as_list()[0]
        zeros = tf.zeros_like(fx)
        r1 = tf.stack([fx, zeros, cx], axis=1)
        r2 = tf.stack([zeros, fy, cy], axis=1)
        r3 = tf.constant([0., 0., 1.], shape=[1, 3])
        r3 = tf.tile(r3, [batch_size, 1])
        intrinsics = tf.stack([r1, r2, r3], axis=1)
        return intrinsics

    # Reads list of file paths and formats them into seaparate lists for image files and camera intrinsic files
    def format_file_list(self, data_root, split):
        with open(os.path.join(data_root, f'{split}.txt'), 'r') as f:
            frames = f.readlines()
        subfolders = [x.split(' ')[0] for x in frames]
        frame_ids = [x.split(' ')[1].strip() for x in frames]
        image_file_list = [os.path.join(data_root, subfolders[i], 
                                        f"{frame_ids[i]}.jpg") for i in range(len(frames))]
        cam_file_list = [os.path.join(data_root, subfolders[i], 
                                      f"{frame_ids[i]}_cam.txt") for i in range(len(frames))]
        return {'image_file_list': image_file_list, 'cam_file_list': cam_file_list}

    # Split single input into target image and source image stack
    def unpack_image_sequence(self, image_seq, img_height, img_width, num_source):
        # Assuming the center image is the target frame
        tgt_start_idx = int(img_width * (num_source // 2))
        tgt_image = tf.slice(image_seq, 
                             [0, tgt_start_idx, 0], 
                             [-1, img_width, -1])

        # Source frames before the target frame
        src_image_1 = tf.slice(image_seq, 
                               [0, 0, 0], 
                               [-1, int(img_width * (num_source // 2)), -1])

        # Source frames after the target frame
        src_image_2 = tf.slice(image_seq, 
                               [0, int(tgt_start_idx + img_width), 0], 
                               [-1, int(img_width * (num_source // 2)), -1])

        src_image_seq = tf.concat([src_image_1, src_image_2], axis=1)

        # Stack source frames along the color channels (i.e., [H, W, N*3])
        src_image_stack = tf.concat([tf.slice(src_image_seq, 
                                    [0, i * img_width, 0], 
                                    [-1, img_width, -1]) 
                                    for i in range(num_source)], axis=2)
        src_image_stack.set_shape([img_height, img_width, num_source * 3])
        tgt_image.set_shape([img_height, img_width, 3])

        return tgt_image, src_image_stack


In [None]:
def pose_exp_net(tgt_image, src_image_stack, is_training=True):
    """
    Pose estimation network to predict 6-DoF poses for source images relative to the target.
    Args:
        tgt_image: Target image (RGB).
        src_image_stack: Stack of source images.
        is_training: Training mode.
    Returns:
        pose_final: Predicted 6-DoF poses for source images relative to the target.
                   Shape: [batch_size, num_source, 6]
    """

    # Concatenate target image and source images along the channel axis
    inputs = tf.concat([tgt_image, src_image_stack], axis=3)
    num_source = int(src_image_stack.get_shape()[3].value // 3)

    # Define namespace of the network for debugging
    with tf.variable_scope('pose_net'):
        with slim.arg_scope([slim.conv2d],
                            normalizer_fn=None,
                            weights_regularizer=slim.l2_regularizer(0.05),
                            activation_fn=tf.nn.relu):
            # Shared convolutional layers for feature extraction
            cnv1 = slim.conv2d(inputs, 16, [7, 7], stride=2, scope='cnv1')
            cnv2 = slim.conv2d(cnv1, 32, [5, 5], stride=2, scope='cnv2')
            cnv3 = slim.conv2d(cnv2, 64, [3, 3], stride=2, scope='cnv3')
            cnv4 = slim.conv2d(cnv3, 128, [3, 3], stride=2, scope='cnv4')
            cnv5 = slim.conv2d(cnv4, 256, [3, 3], stride=2, scope='cnv5')
            cnv6 = slim.conv2d(cnv5, 256, [3, 3], stride=2, scope='cnv6')
            cnv7 = slim.conv2d(cnv6, 256, [3, 3], stride=2, scope='cnv7')

            # Predict 6-DoF poses (translation + rotation)
            pose_pred = slim.conv2d(cnv7, 6 * num_source, [1, 1], scope='pred',
                                    stride=1, normalizer_fn=None, activation_fn=None)

            # Average spatial dimensions and scale pose predictions
            pose_avg = tf.reduce_mean(pose_pred, [1, 2])  # Average spatial dimensions
            pose_final = 0.01 * tf.reshape(pose_avg, [-1, num_source, 6])  # Final 6-DoF poses

            return pose_final

In [66]:

def euler_to_matrix(vec_rot):
    """Converts Euler angles to rotation matrix
    Args:
        vec_rot: Euler angles in the order of rx, ry, rz -- [B, 3] torch.tensor
    Returns:
        A rotation matrix -- [B, 3, 3] torch.tensor
    """
    batch_size = vec_rot.shape[0]
    rx, ry, rz = vec_rot[:, 0], vec_rot[:, 1], vec_rot[:, 2]
    
    cos_rx, sin_rx = torch.cos(rx), torch.sin(rx)
    cos_ry, sin_ry = torch.cos(ry), torch.sin(ry)
    cos_rz, sin_rz = torch.cos(rz), torch.sin(rz)
    
    R_x = torch.stack([torch.ones(batch_size), torch.zeros(batch_size), torch.zeros(batch_size),
                       torch.zeros(batch_size), cos_rx, -sin_rx,
                       torch.zeros(batch_size), sin_rx, cos_rx], dim=1).view(batch_size, 3, 3)
    
    R_y = torch.stack([cos_ry, torch.zeros(batch_size), sin_ry,
                       torch.zeros(batch_size), torch.ones(batch_size), torch.zeros(batch_size),
                       -sin_ry, torch.zeros(batch_size), cos_ry], dim=1).view(batch_size, 3, 3)
    
    R_z = torch.stack([cos_rz, -sin_rz, torch.zeros(batch_size),
                       sin_rz, cos_rz, torch.zeros(batch_size),
                       torch.zeros(batch_size), torch.zeros(batch_size), torch.ones(batch_size)], dim=1).view(batch_size, 3, 3)
    
    rotation_matrix = torch.bmm(R_z, torch.bmm(R_y, R_x))
    
    return rotation_matrix

def dof_vec_to_matrix(dof_vec):
    """Converts 6DoF parameters to transformation matrix
    Args:
        vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] torch.tensor
    Returns:
        A transformation matrix -- [B, 4, 4] torch.tensor
        R11 R12 R13 tx
        R21 R22 R23 ty
        R31 R32 R33 tz
        0   0   0   1
    """
    batch_size = dof_vec.shape[0]
    translation = dof_vec[:,:3]
    # Add a one at the end of translation
    ones = torch.ones(batch_size, 1)
    translation = torch.cat((translation, ones), dim=1)
    rot_vec = dof_vec[:, 3:]
    # print("rot_vec", rot_vec)
    rot_matrix = euler_to_matrix(rot_vec)
    # add zero at 4 row
    zeros = torch.zeros(batch_size, 1, 3)
    rot_matrix = torch.cat((rot_matrix, zeros), dim=1)
    transformation_matrix = torch.cat((rot_matrix, translation.unsqueeze(2)), dim=2)
    return transformation_matrix

def inverse_dof(dof_vec):
    """
    Computes the inverse of 6DoF parameters.
    
    Args:
        dof_vec: Tensor of shape [B, 6], representing 6DoF parameters (tx, ty, tz, rx, ry, rz).
    
    Returns:
        Inverted 6DoF parameters: Tensor of shape [B, 6].
    """
    # Negate both the translation and rotation parts
    translation_inv = -dof_vec[:, :3]
    rotation_inv = -dof_vec[:, 3:]
    return torch.cat((translation_inv, rotation_inv), dim=1)


In [None]:
def step_cloud(I_t, dof_vec):
    """
    Applies a 6DoF transformation to a point cloud.
    
    Args:
        I_t: Tensor of shape [B, N, 3], representing a batch of point clouds.
        dof_vec: Tensor of shape [B, 6], representing 6DoF parameters (tx, ty, tz, rx, ry, rz).
    
    Returns:
        I_t_1: Transformed point cloud, Tensor of shape [B, N, 3].
    """
    batch_size, num_points = I_t.shape[0], I_t.shape[1]
    
    # Step 1: Convert to homogeneous coordinates
    ones = torch.ones(batch_size, num_points, 1, device=I_t.device)  # [B, N, 1]
    I_t_augmented = torch.cat((I_t, ones), dim=2)  # [B, N, 4]
    
    # Step 2: Get the transformation matrix
    transf_mat = dof_vec_to_matrix(dof_vec)  # [B, 4, 4]
    
    # Step 3: Apply the transformation
    # Transpose the transformation matrix for compatibility
    transf_mat = transf_mat.transpose(1, 2)  # [B, 4, 4]
    I_t_1_homo = torch.bmm(I_t_augmented, transf_mat)  # [B, N, 4]
    
    # Step 4: Convert back to Cartesian coordinates
    I_t_1 = I_t_1_homo[:, :, :3]  # Drop the homogeneous coordinate
    
    return I_t_1

def photo_Loss(I, I_pred):
    pass
    #return photo_loss

In [74]:
def pixel_to_3d(points, intrins):
    """
    Converts pixel coordinates and depth to 3D coordinates.
    
    Args:
        points: Tensor of shape [B, N, 3], where B is the batch size, N is the number of points,
                and each point is represented as (u, v, w) where u and v are pixel coordinates and w is depth.
        intrins: List of intrinsic parameters of the camera.
    
    Returns:
        Tensor of shape [B, N, 3] representing the 3D coordinates.
    """
    fx = intrins[0][0]
    fy = intrins[1][1]
    cx = intrins[0][2]
    cy = intrins[1][2]
    
    u = points[:, :, 0]
    v = points[:, :, 1]
    w = points[:, :, 2]
    
    x = ((u - cx) * w) / fx
    y = ((v - cy) * w) / fy
    z = w
    
    return torch.stack((x, y, z), dim=2)

# Example usage
points_batch = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
intrins_example = [
    [956.9475, 0.0, 693.9767],
    [0.0, 952.2352, 238.6081],
    [0.0, 0.0, 1.0]
]

points_3d = pixel_to_3d(points_batch, intrins_example)
print(points_3d)

tensor([[[-2.1725, -0.7454,  3.0000],
         [-4.3261, -1.4720,  6.0000]],

        [[-6.4609, -2.1796,  9.0000],
         [-8.5770, -2.8683, 12.0000]]])


In [76]:
def _3d_to_pixel(points_3d, intrins):
    """
    Converts 3D coordinates to pixel coordinates and depth.
    
    Args:
        points_3d: Tensor of shape [B, N, 3], where B is the batch size, N is the number of points,
                   and each point is represented as (x, y, z) where x, y, and z are 3D coordinates.
        intrins: List of intrinsic parameters of the camera.
    
    Returns:
        Tensor of shape [B, N, 3] representing the pixel coordinates and depth.
    """
    fx = intrins[0][0]
    fy = intrins[1][1]
    cx = intrins[0][2]
    cy = intrins[1][2]
    
    x = points_3d[:, :, 0]
    y = points_3d[:, :, 1]
    z = points_3d[:, :, 2]
    
    u = (x * fx / z) + cx
    v = (y * fy / z) + cy
    w = z
    
    return torch.stack((u, v, w), dim=2)

# Example usage
points_3d_example = points_3d
intrins_example = [
    [956.9475, 0.0, 693.9767],
    [0.0, 952.2352, 238.6081],
    [0.0, 0.0, 1.0]
]

pixels = _3d_to_pixel(points_3d_example, intrins_example)
print(pixels)

tensor([[[ 0.9999,  2.0000,  3.0000],
         [ 3.9999,  5.0000,  6.0000]],

        [[ 7.0000,  8.0000,  9.0000],
         [ 9.9999, 11.0000, 12.0000]]])


In [None]:
# unit test dimensions
torch.manual_seed(42)
batch = 1
num_points = 2
I_t_example = torch.randint(0, 10, (batch, num_points, 3))
dof_vec_example = torch.randint(0, 10, (batch, 6))

print('I_t_example\n', I_t_example)
print('I_t_example.shape', I_t_example.shape)
print("dof_vec_example\n", dof_vec_example)
print('dof_vec_example.shape', dof_vec_example.shape)

step_cloud(I_t_example, dof_vec_example)

I_t_example
 tensor([[[2, 7, 6],
         [4, 6, 5]]])
I_t_example.shape torch.Size([1, 2, 3])
dof_vec_example
 tensor([[0, 4, 0, 3, 8, 4]])
dof_vec_example.shape torch.Size([1, 6])


tensor([[[-2.4927, 13.0113, -1.2582],
         [-1.9954, 11.8566, -3.3604]]])

In [69]:
# unit test, invertible?
# Example Input
I_t = torch.rand(2, 5, 3)  # Original Point Cloud
dof_vec = torch.tensor([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0]], dtype=torch.float32)  # Translation only

# Step 1: Transform the point cloud
transformed_cloud = step_cloud(I_t, dof_vec)

# Step 2: Compute the inverse 6DoF
inverse_dof_vec = inverse_dof(dof_vec)

# Step 3: Apply the inverse transformation
recovered_cloud = step_cloud(transformed_cloud, inverse_dof_vec)

# Check if input matches recovered cloud
print("Original Point Cloud:\n", I_t)
print("Transformed Point Cloud:\n", transformed_cloud)
print("Recovered Point Cloud:\n", recovered_cloud)

assert torch.allclose(I_t, recovered_cloud, atol=1e-6), "The original and recovered point clouds do not match!"

Original Point Cloud:
 tensor([[[0.4340, 0.1371, 0.5117],
         [0.1585, 0.0758, 0.2247],
         [0.0624, 0.1816, 0.9998],
         [0.5944, 0.6541, 0.0337],
         [0.1716, 0.3336, 0.5782]],

        [[0.0600, 0.2846, 0.2007],
         [0.5014, 0.3139, 0.4654],
         [0.1612, 0.1568, 0.2083],
         [0.3289, 0.1054, 0.9192],
         [0.4008, 0.9302, 0.6558]]])
Transformed Point Cloud:
 tensor([[[1.4340, 0.1371, 0.5117],
         [1.1585, 0.0758, 0.2247],
         [1.0624, 0.1816, 0.9998],
         [1.5944, 0.6541, 0.0337],
         [1.1716, 0.3336, 0.5782]],

        [[0.0600, 1.2846, 0.2007],
         [0.5014, 1.3139, 0.4654],
         [0.1612, 1.1568, 0.2083],
         [0.3289, 1.1054, 0.9192],
         [0.4008, 1.9302, 0.6558]]])
Recovered Point Cloud:
 tensor([[[0.4340, 0.1371, 0.5117],
         [0.1585, 0.0758, 0.2247],
         [0.0624, 0.1816, 0.9998],
         [0.5944, 0.6541, 0.0337],
         [0.1716, 0.3336, 0.5782]],

        [[0.0600, 0.2846, 0.2007],
       

In [None]:
batch = 2

example_input = torch.randint(0, 10, (batch, 6))

print(example_input)
print(dof_vec_to_matrix(example_input))

tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4]])
tensor([[[ 0.2724, -0.5668,  0.7775,  2.0000],
         [-0.9207, -0.3882,  0.0395,  7.0000],
         [ 0.2794, -0.7267, -0.6276,  6.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.0951, -0.8405,  0.5334,  0.0000],
         [ 0.1101,  0.5414,  0.8335,  4.0000],
         [-0.9894, -0.0205,  0.1440,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]])
