In [38]:
import os
import random
import torch
import numpy as np
from PIL import Image
from __future__ import division

In [None]:
class DataLoader:
    def __init__(self, dataset_dir, batch_size, img_height, img_width, num_source):
        """
        Initialize the DataLoader object.

        Args:
            dataset_dir: Root directory of the dataset.
            batch_size: Number of samples per batch.
            img_height: Height of the images.
            img_width: Width of the images.
            num_source: Number of source frames to load.
        """
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.num_source = num_source

    def load_train_batch(self, split='train'):
        """
        Load a batch of training instances.

        Args:
            split: Dataset split to use ('train', 'val', or 'test').

        Returns:
            tgt_images: Target images tensor of shape [batch_size, 3, H, W].
            src_image_stacks: Source image stacks tensor of shape [batch_size, num_source * 3, H, W].
            intrinsics: Camera intrinsics tensor of shape [batch_size, 3, 3].
        """
        # Load file paths for images and camera intrinsics
        file_list = self.format_file_list(self.dataset_dir, split)

        # Shuffle file list for random sampling
        combined = list(zip(file_list['image_file_list'], file_list['cam_file_list']))
        random.shuffle(combined)
        image_file_list, cam_file_list = zip(*combined)

        # Initialize storage for batches
        tgt_images = []
        src_image_stacks = []
        intrinsics = []

        # Load images and intrinsics for each batch
        for i in range(self.batch_size):
            tgt_image, src_image_stack = self.unpack_image_sequence(image_file_list[i])
            intrinsic = self.load_intrinsics(cam_file_list[i])
            tgt_images.append(tgt_image)
            src_image_stacks.append(src_image_stack)
            intrinsics.append(intrinsic)

        # Convert to PyTorch tensors
        tgt_images = torch.stack(tgt_images, dim=0)
        src_image_stacks = torch.stack(src_image_stacks, dim=0)
        intrinsics = torch.stack(intrinsics, dim=0)

        return tgt_images, src_image_stacks, intrinsics

    def format_file_list(self, data_root, split):
        """
        Format file paths into separate lists for images and camera intrinsics.

        Args:
            data_root: Root directory of the dataset.
            split: Dataset split ('train', 'val', 'test').

        Returns:
            dict: Dictionary containing image and camera intrinsic file lists.
        """
        with open(os.path.join(data_root, f'{split}.txt'), 'r') as f:
            frames = f.readlines()
        subfolders = [x.split(' ')[0] for x in frames]
        frame_ids = [x.split(' ')[1].strip() for x in frames]
        image_file_list = [os.path.join(data_root, subfolders[i], f"{frame_ids[i]}.jpg") for i in range(len(frames))]
        cam_file_list = [os.path.join(data_root, subfolders[i], f"{frame_ids[i]}_cam.txt") for i in range(len(frames))]
        
        return {'image_file_list': image_file_list, 'cam_file_list': cam_file_list}

    def unpack_image_sequence(self, image_file):
        """
        Unpack an image sequence into target and source images.

        Args:
            image_file: Path to the image sequence.

        Returns:
            tgt_image: Target image tensor of shape [3, H, W].
            src_image_stack: Source image stack tensor of shape [num_source * 3, H, W].
        """
        # Open the image file
        image_seq = Image.open(image_file)
        image_seq = np.array(image_seq)  # Convert to NumPy array

        # Split image sequence
        tgt_start_idx = self.img_width * (self.num_source // 2)
        tgt_image = image_seq[:, tgt_start_idx:tgt_start_idx + self.img_width, :]
        src_image_1 = image_seq[:, :tgt_start_idx, :]
        src_image_2 = image_seq[:, tgt_start_idx + self.img_width:, :]

        src_image_seq = np.concatenate([src_image_1, src_image_2], axis=1)
        src_image_stack = np.concatenate(
            [src_image_seq[:, i * self.img_width:(i + 1) * self.img_width, :] for i in range(self.num_source)], axis=2)

        # Transpose dimensions to [C, H, W] for PyTorch
        tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1).float() / 255.0
        src_image_stack = torch.from_numpy(src_image_stack).permute(2, 0, 1).float() / 255.0

        return tgt_image, src_image_stack

    def load_intrinsics(self, cam_file):
        """
        Load camera intrinsics from a file.

        Args:
            cam_file: Path to the camera intrinsic file.

        Returns:
            intrinsics: Camera intrinsics tensor of shape [3, 3].
        """
        with open(cam_file, 'r') as f:
            raw_cam_vec = [float(x) for x in f.read().strip().split(',')]
        intrinsics = np.array(raw_cam_vec).reshape(3, 3)
        
        return torch.from_numpy(intrinsics).float()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PoseNet(nn.Module):
    """
    Pose estimation network to predict 6-DoF poses for source images relative to the target.

    Args:
        input_channels: Number of input channels (target image + source image stack).
    Returns:
        pose_final: Predicted 6-DoF poses for source images relative to the target.
                    Shape: [batch_size, num_source, 6]
    """
    def __init__(self, input_channels):
        super(PoseNet, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.pose_pred = nn.Conv2d(256, 6, kernel_size=1, stride=1)

        self.num_source = None  # Will be set based on input dimensions

    def forward(self, tgt_image, src_image_stack):
        # Concatenate target and source images along the channel axis
        inputs = torch.cat((tgt_image, src_image_stack), dim=1)
        self.num_source = src_image_stack.shape[1] // (3 * tgt_image.shape[1])

        # Forward pass through convolutional layers
        x = F.relu(self.conv1(inputs))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))

        # Predict 6-DoF poses
        pose_pred = self.pose_pred(x)  # Shape: [batch_size, 6 * num_source, H, W]

        # Average spatial dimensions and reshape
        pose_avg = torch.mean(pose_pred, dim=(2, 3))  # Shape: [batch_size, 6 * num_source]
        pose_final = 0.01 * pose_avg.view(-1, self.num_source, 6)  # Shape: [batch_size, num_source, 6]

        return pose_final

In [None]:
def euler_to_matrix(vec_rot):
    """Converts Euler angles to rotation matrix
    Args:
        vec_rot: Euler angles in the order of rx, ry, rz -- [B, 3] torch.tensor
    Returns:
        A rotation matrix -- [B, 3, 3] torch.tensor
    """
    batch_size = vec_rot.shape[0]
    rx, ry, rz = vec_rot[:, 0], vec_rot[:, 1], vec_rot[:, 2]
    
    cos_rx, sin_rx = torch.cos(rx), torch.sin(rx)
    cos_ry, sin_ry = torch.cos(ry), torch.sin(ry)
    cos_rz, sin_rz = torch.cos(rz), torch.sin(rz)
    
    R_x = torch.stack([torch.ones(batch_size), torch.zeros(batch_size), torch.zeros(batch_size),
                       torch.zeros(batch_size), cos_rx, -sin_rx,
                       torch.zeros(batch_size), sin_rx, cos_rx], dim=1).view(batch_size, 3, 3)
    
    R_y = torch.stack([cos_ry, torch.zeros(batch_size), sin_ry,
                       torch.zeros(batch_size), torch.ones(batch_size), torch.zeros(batch_size),
                       -sin_ry, torch.zeros(batch_size), cos_ry], dim=1).view(batch_size, 3, 3)
    
    R_z = torch.stack([cos_rz, -sin_rz, torch.zeros(batch_size),
                       sin_rz, cos_rz, torch.zeros(batch_size),
                       torch.zeros(batch_size), torch.zeros(batch_size), torch.ones(batch_size)], dim=1).view(batch_size, 3, 3)
    
    rotation_matrix = torch.bmm(R_z, torch.bmm(R_y, R_x))
    
    return rotation_matrix

def dof_vec_to_matrix(dof_vec):
    """Converts 6DoF parameters to transformation matrix
    Args:
        vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] torch.tensor
    Returns:
        A transformation matrix -- [B, 4, 4] torch.tensor
        R11 R12 R13 tx
        R21 R22 R23 ty
        R31 R32 R33 tz
        0   0   0   1
    """
    batch_size = dof_vec.shape[0]
    translation = dof_vec[:,:3]
    # Add a one at the end of translation
    ones = torch.ones(batch_size, 1)
    translation = torch.cat((translation, ones), dim=1)
    rot_vec = dof_vec[:, 3:]
    # print("rot_vec", rot_vec)
    rot_matrix = euler_to_matrix(rot_vec)
    # add zero at 4 row
    zeros = torch.zeros(batch_size, 1, 3)
    rot_matrix = torch.cat((rot_matrix, zeros), dim=1)
    transformation_matrix = torch.cat((rot_matrix, translation.unsqueeze(2)), dim=2)
    return transformation_matrix

def inverse_dof(dof_vec):
    """
    Computes the inverse of 6DoF parameters.
    
    Args:
        dof_vec: Tensor of shape [B, 6], representing 6DoF parameters (tx, ty, tz, rx, ry, rz).
    
    Returns:
        Inverted 6DoF parameters: Tensor of shape [B, 6].
    """
    # Negate both the translation and rotation parts
    translation_inv = -dof_vec[:, :3]
    rotation_inv = -dof_vec[:, 3:]
    return torch.cat((translation_inv, rotation_inv), dim=1)


In [None]:
def step_cloud(I_t, dof_vec):
    """
    Applies a 6DoF transformation to a point cloud.
    
    Args:
        I_t: Tensor of shape [B, N, 3], representing a batch of point clouds.
        dof_vec: Tensor of shape [B, 6], representing 6DoF parameters (tx, ty, tz, rx, ry, rz).
    
    Returns:
        I_t_1: Transformed point cloud, Tensor of shape [B, N, 3].
    """
    batch_size, num_points = I_t.shape[0], I_t.shape[1]
    
    # Step 1: Convert to homogeneous coordinates
    ones = torch.ones(batch_size, num_points, 1, device=I_t.device)  # [B, N, 1]
    I_t_augmented = torch.cat((I_t, ones), dim=2)  # [B, N, 4]
    
    # Step 2: Get the transformation matrix
    transf_mat = dof_vec_to_matrix(dof_vec)  # [B, 4, 4]
    
    # Step 3: Apply the transformation
    # Transpose the transformation matrix for compatibility
    transf_mat = transf_mat.transpose(1, 2)  # [B, 4, 4]
    I_t_1_homo = torch.bmm(I_t_augmented, transf_mat)  # [B, N, 4]
    
    # Step 4: Convert back to Cartesian coordinates
    I_t_1 = I_t_1_homo[:, :, :3]  # Drop the homogeneous coordinate
    
    return I_t_1

def photo_Loss(I, I_pred):
    pass
    #return photo_loss

In [74]:
def pixel_to_3d(points, intrins):
    """
    Converts pixel coordinates and depth to 3D coordinates.
    
    Args:
        points: Tensor of shape [B, N, 3], where B is the batch size, N is the number of points,
                and each point is represented as (u, v, w) where u and v are pixel coordinates and w is depth.
        intrins: List of intrinsic parameters of the camera.
    
    Returns:
        Tensor of shape [B, N, 3] representing the 3D coordinates.
    """
    fx = intrins[0][0]
    fy = intrins[1][1]
    cx = intrins[0][2]
    cy = intrins[1][2]
    
    u = points[:, :, 0]
    v = points[:, :, 1]
    w = points[:, :, 2]
    
    x = ((u - cx) * w) / fx
    y = ((v - cy) * w) / fy
    z = w
    
    return torch.stack((x, y, z), dim=2)

# Example usage
points_batch = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
intrins_example = [
    [956.9475, 0.0, 693.9767],
    [0.0, 952.2352, 238.6081],
    [0.0, 0.0, 1.0]
]

points_3d = pixel_to_3d(points_batch, intrins_example)
print(points_3d)

tensor([[[-2.1725, -0.7454,  3.0000],
         [-4.3261, -1.4720,  6.0000]],

        [[-6.4609, -2.1796,  9.0000],
         [-8.5770, -2.8683, 12.0000]]])


In [76]:
def _3d_to_pixel(points_3d, intrins):
    """
    Converts 3D coordinates to pixel coordinates and depth.
    
    Args:
        points_3d: Tensor of shape [B, N, 3], where B is the batch size, N is the number of points,
                   and each point is represented as (x, y, z) where x, y, and z are 3D coordinates.
        intrins: List of intrinsic parameters of the camera.
    
    Returns:
        Tensor of shape [B, N, 3] representing the pixel coordinates and depth.
    """
    fx = intrins[0][0]
    fy = intrins[1][1]
    cx = intrins[0][2]
    cy = intrins[1][2]
    
    x = points_3d[:, :, 0]
    y = points_3d[:, :, 1]
    z = points_3d[:, :, 2]
    
    u = (x * fx / z) + cx
    v = (y * fy / z) + cy
    w = z
    
    return torch.stack((u, v, w), dim=2)

# Example usage
points_3d_example = points_3d
intrins_example = [
    [956.9475, 0.0, 693.9767],
    [0.0, 952.2352, 238.6081],
    [0.0, 0.0, 1.0]
]

pixels = _3d_to_pixel(points_3d_example, intrins_example)
print(pixels)

tensor([[[ 0.9999,  2.0000,  3.0000],
         [ 3.9999,  5.0000,  6.0000]],

        [[ 7.0000,  8.0000,  9.0000],
         [ 9.9999, 11.0000, 12.0000]]])


In [None]:
# unit test dimensions
torch.manual_seed(42)
batch = 1
num_points = 2
I_t_example = torch.randint(0, 10, (batch, num_points, 3))
dof_vec_example = torch.randint(0, 10, (batch, 6))

print('I_t_example\n', I_t_example)
print('I_t_example.shape', I_t_example.shape)
print("dof_vec_example\n", dof_vec_example)
print('dof_vec_example.shape', dof_vec_example.shape)

step_cloud(I_t_example, dof_vec_example)

I_t_example
 tensor([[[2, 7, 6],
         [4, 6, 5]]])
I_t_example.shape torch.Size([1, 2, 3])
dof_vec_example
 tensor([[0, 4, 0, 3, 8, 4]])
dof_vec_example.shape torch.Size([1, 6])


tensor([[[-2.4927, 13.0113, -1.2582],
         [-1.9954, 11.8566, -3.3604]]])

In [69]:
# unit test, invertible?
# Example Input
I_t = torch.rand(2, 5, 3)  # Original Point Cloud
dof_vec = torch.tensor([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0]], dtype=torch.float32)  # Translation only

# Step 1: Transform the point cloud
transformed_cloud = step_cloud(I_t, dof_vec)

# Step 2: Compute the inverse 6DoF
inverse_dof_vec = inverse_dof(dof_vec)

# Step 3: Apply the inverse transformation
recovered_cloud = step_cloud(transformed_cloud, inverse_dof_vec)

# Check if input matches recovered cloud
print("Original Point Cloud:\n", I_t)
print("Transformed Point Cloud:\n", transformed_cloud)
print("Recovered Point Cloud:\n", recovered_cloud)

assert torch.allclose(I_t, recovered_cloud, atol=1e-6), "The original and recovered point clouds do not match!"

Original Point Cloud:
 tensor([[[0.4340, 0.1371, 0.5117],
         [0.1585, 0.0758, 0.2247],
         [0.0624, 0.1816, 0.9998],
         [0.5944, 0.6541, 0.0337],
         [0.1716, 0.3336, 0.5782]],

        [[0.0600, 0.2846, 0.2007],
         [0.5014, 0.3139, 0.4654],
         [0.1612, 0.1568, 0.2083],
         [0.3289, 0.1054, 0.9192],
         [0.4008, 0.9302, 0.6558]]])
Transformed Point Cloud:
 tensor([[[1.4340, 0.1371, 0.5117],
         [1.1585, 0.0758, 0.2247],
         [1.0624, 0.1816, 0.9998],
         [1.5944, 0.6541, 0.0337],
         [1.1716, 0.3336, 0.5782]],

        [[0.0600, 1.2846, 0.2007],
         [0.5014, 1.3139, 0.4654],
         [0.1612, 1.1568, 0.2083],
         [0.3289, 1.1054, 0.9192],
         [0.4008, 1.9302, 0.6558]]])
Recovered Point Cloud:
 tensor([[[0.4340, 0.1371, 0.5117],
         [0.1585, 0.0758, 0.2247],
         [0.0624, 0.1816, 0.9998],
         [0.5944, 0.6541, 0.0337],
         [0.1716, 0.3336, 0.5782]],

        [[0.0600, 0.2846, 0.2007],
       

In [None]:
batch = 2

example_input = torch.randint(0, 10, (batch, 6))

print(example_input)
print(dof_vec_to_matrix(example_input))

tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4]])
tensor([[[ 0.2724, -0.5668,  0.7775,  2.0000],
         [-0.9207, -0.3882,  0.0395,  7.0000],
         [ 0.2794, -0.7267, -0.6276,  6.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.0951, -0.8405,  0.5334,  0.0000],
         [ 0.1101,  0.5414,  0.8335,  4.0000],
         [-0.9894, -0.0205,  0.1440,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]])


In [None]:
def projective_inverse_warp(src_image, depth, pose, intrinsics):
    """
    Warps the source image to the target frame using depth and pose.

    Args:
        src_image: Source image tensor (shape: [B, H, W, 3]).
        depth: Depth map for the target view (shape: [B, H, W]).
        pose: 6-DoF pose parameters (shape: [B, 6]).
        intrinsics: Camera intrinsics matrix (shape: [B, 3, 3]).

    Returns:
        warped_image: Source image warped to the target frame (shape: [B, H, W, 3]).
    """
    batch_size, img_height, img_width, _ = src_image.shape

    # Step 1: Create pixel grid
    u, v = torch.meshgrid(torch.arange(0, img_width, device=src_image.device),
                          torch.arange(0, img_height, device=src_image.device))
    u = u.flatten().float()
    v = v.flatten().float()
    pixel_coords = torch.stack([u, v, torch.ones_like(u)], dim=1)  # [HW, 3]
    pixel_coords = pixel_coords.unsqueeze(0).expand(batch_size, -1, -1)  # [B, HW, 3]

    # Step 2: Backproject pixels to 3D space
    cam_coords = pixel_to_3d(pixel_coords, intrinsics)  # [B, HW, 3]
    cam_coords = cam_coords * depth.view(batch_size, -1, 1)  # Scale by depth

    # Step 3: Apply 6-DoF transformation
    cam_coords_transformed = step_cloud(cam_coords, pose)  # [B, HW, 3]

    # Step 4: Reproject to 2D space
    pixel_coords_proj = _3d_to_pixel(cam_coords_transformed, intrinsics)  # [B, HW, 3]
    u_proj = pixel_coords_proj[:, :, 0].view(batch_size, img_height, img_width)
    v_proj = pixel_coords_proj[:, :, 1].view(batch_size, img_height, img_width)

    # Step 5: Sample from source image
    grid = torch.stack([u_proj / img_width * 2 - 1, v_proj / img_height * 2 - 1], dim=-1)  # [B, H, W, 2]
    warped_image = torch.nn.functional.grid_sample(src_image, grid, align_corners=False)

    return warped_image

In [None]:
def compute_smoothness_loss(pred_depth, image):
    """
    Computes edge-aware smoothness loss for the predicted depth map.

    Args:
        pred_depth: Predicted depth map (Tensor [B, H, W]).
        image: Corresponding RGB image for edge awareness (Tensor [B, C, H, W]).

    Returns:
        smoothness_loss: Edge-aware smoothness loss (Tensor).
    """
    # Ensure pred_depth has shape [B, 1, H, W] for consistency
    if pred_depth.dim() == 3:
        pred_depth = pred_depth.unsqueeze(1)  # [B, 1, H, W]
    
    # Convert RGB image to grayscale by taking the mean across channels
    grayscale = torch.mean(image, dim=1, keepdim=True)  # [B, 1, H, W]

    # Compute gradients of depth map
    depth_gradient_x = torch.abs(pred_depth[:, :, :, 1:] - pred_depth[:, :, :, :-1])  # [B, 1, H, W-1]
    depth_gradient_y = torch.abs(pred_depth[:, :, 1:, :] - pred_depth[:, :, :-1, :])  # [B, 1, H-1, W]

    # Compute gradients of image
    image_gradient_x = torch.abs(grayscale[:, :, :, 1:] - grayscale[:, :, :, :-1])  # [B, 1, H, W-1]
    image_gradient_y = torch.abs(grayscale[:, :, 1:, :] - grayscale[:, :, :-1, :])  # [B, 1, H-1, W]

    # Weight depth gradients with image gradients
    # Exponential weighting: edges in image lead to less smoothing
    weighted_smoothness_x = depth_gradient_x * torch.exp(-image_gradient_x)
    weighted_smoothness_y = depth_gradient_y * torch.exp(-image_gradient_y)

    # Compute mean loss
    smoothness_loss = torch.mean(weighted_smoothness_x) + torch.mean(weighted_smoothness_y)
    
    return smoothness_loss

In [None]:
# Initialize SSIM metric
import torch
from torchmetrics import StructuralSimilarityIndexMeasure

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # choose device to run on depending on availability
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

def compute_loss(pred_depth, pred_poses, tgt_image, src_image_stack, intrinsics, opt):
    """
    Computes photometric loss, smoothness loss, and total loss.

    Args:
        pred_depth: List of predicted depth maps for different scales (List of tensors [B, H, W]).
        pred_poses: Predicted 6-DoF poses for source frames (Tensor [B, num_source, 6]).
        tgt_image: Target image tensor (shape: [B, 3, H, W]).
        src_image_stack: Source image stack (shape: [B, 3*num_source, H, W]).
        intrinsics: Camera intrinsics matrix (shape: [B, 3, 3]).
        opt: Options object containing hyperparameters.

    Returns:
        total_loss: Combined loss (Tensor).
        photometric_loss: Photometric loss (Tensor).
        smoothness_loss: Smoothness loss (Tensor).
    """
    # Normalize images to [0, 1] if necessary
    if tgt_image.max() > 1.0:
        tgt_image = tgt_image.float() / 255.0
    if src_image_stack.max() > 1.0:
        src_image_stack = src_image_stack.float() / 255.0

    # Ensure all tensors are on the correct device
    tgt_image = tgt_image.to(device)
    src_image_stack = src_image_stack.to(device)
    intrinsics = intrinsics.to(device)
    pred_poses = pred_poses.to(device)
    
    photometric_loss = 0.0
    smoothness_loss = 0.0

    for s in range(len(pred_depth)):
        curr_depth = pred_depth[s]  # [B, H, W]

        # Resize images for the current scale
        scale_factor = 1 / (2 ** s)
        curr_tgt_image = F.interpolate(
            tgt_image, scale_factor=scale_factor, mode='bilinear', align_corners=False)  # [B, 3, H', W']
        curr_src_image_stack = F.interpolate(
            src_image_stack, scale_factor=scale_factor, mode='bilinear', align_corners=False)  # [B, 3*num_source, H', W']

        for i in range(opt.num_source):
            # Extract the current source image
            src_image = curr_src_image_stack[:, i*3:(i+1)*3, :, :]  # [B, 3, H', W']

            # Warp the source image to the target frame
            warped_image = projective_inverse_warp(src_image, curr_depth, pred_poses[:, i, :], intrinsics)  # [B, 3, H', W']

            # Compute photometric loss (L1 + SSIM)
            l1_loss = F.l1_loss(warped_image, curr_tgt_image, reduction='mean')  # Scalar
            ssim_loss = 1 - ssim_metric(warped_image, curr_tgt_image)  # [B]
            ssim_loss = ssim_loss.mean()  # Scalar
            photometric_loss += (0.85 * l1_loss + 0.15 * ssim_loss)  # Scalar

        # Compute smoothness loss for the current scale
        smoothness_loss += compute_smoothness_loss(curr_depth, curr_tgt_image)  # Scalar

    # Combine photometric and smoothness loss
    total_loss = photometric_loss + opt.smooth_weight * smoothness_loss  # Scalar

    return total_loss, photometric_loss, smoothness_loss