<a href="https://colab.research.google.com/github/chiayu2002/test/blob/main/3dReconstruct.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import zipfile
import os

file_id = "1r172cIGZKBc3b7_b1-cscPnVFj8bl8HF"
zip_filename = "7SCENES.zip"
extract_dir = "./"

import gdown
gdown.download(f"https://drive.google.com/uc?id={file_id}", zip_filename, quiet=False)

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print(f"✅ 已解壓縮到：{extract_dir}")

Downloading...
From (original): https://drive.google.com/uc?id=1r172cIGZKBc3b7_b1-cscPnVFj8bl8HF
From (redirected): https://drive.google.com/uc?id=1r172cIGZKBc3b7_b1-cscPnVFj8bl8HF&confirm=t&uuid=d4d73112-c2d9-4223-bc70-a979b131582c
To: /content/7SCENES.zip
100%|██████████| 23.1G/23.1G [05:20<00:00, 71.9MB/s]


✅ 已解壓縮到：./


In [None]:
!pip install pillow-heif
!pip install open3d

import os
import os.path as osp
import numpy as np
import open3d as o3d
import cv2
from PIL import Image
from typing import List,Dict,Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

Collecting pillow-heif
  Downloading pillow_heif-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Downloading pillow_heif-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━[0m [32m6.0/7.8 MB[0m [31m179.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.8/7.8 MB[0m [31m176.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m108.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pillow-heif
Successfully installed pillow-heif-0.22.0
Collecting open3d
  Downloading open3d-0.19.0-cp311-cp311-manylinux_2_31_x86_64.whl.metadata (4.3 kB)
Collecting dash>=2.6.0 (from open3d)
  Downloading dash-3.0.4-py3-none-any.whl.meta

In [None]:
#generate ground truth(seq2ply)
INTRINSINC = (525, 525, 320, 240)  # fx, fy, cx, cy


def imread_cv2(path:str, options=cv2.IMREAD_COLOR):
    """Open an image or a depthmap with opencv-python."""
    if path.endswith((".exr", "EXR")):
        options = cv2.IMREAD_ANYDEPTH
    img = cv2.imread(path, options)
    if img is None:
        raise IOError(f"Could not load image={path} with {options=}")
    if img.ndim == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def depthmap_to_world_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
    """
    Projects a depth map into 3D world coordinates using camera intrinsics and optional pose.

    Args:
        depthmap (H x W): Depth values (in camera space).
        intrinsics (3 x 3): Camera intrinsic matrix.
        pose (optional, 4 x 4 or 4 x 3): Camera-to-world transformation.
        pseudo_focal (optional, H x W): Per-pixel focal length override.

    Returns:
        pts_world (H x W x 3): 3D point cloud in world coordinates.
        valid_mask (H x W): Boolean mask indicating valid (non-zero) depth pixels.
    """

    H, W = depthmap.shape
    camera_intrinsics = np.float32(camera_intrinsics)

    # Extract intrinsic parameters
    assert camera_intrinsics[0, 1] == 0.0 and camera_intrinsics[1, 0] == 0.0
    fu,fv = camera_intrinsics[0, 0], camera_intrinsics[1, 1]
    cu, cv = camera_intrinsics[0, 2], camera_intrinsics[1, 2]

    # Generate pixel coordinate grid
    u, v = np.meshgrid(np.arange(W), np.arange(H))  # u: cols, v: rows

    # Backproject depth to 3D camera coordinates
    z = depthmap
    x = (u - cu) * z / fu
    y = (v - cv) * z / fv
    pts_cam = np.stack((x, y, z), axis=-1).astype(np.float32)

    # Mark valid points (depth > 0)
    valid_mask = z > 0.0

    # Transform to world coordinates if pose is given
    if camera_pose is not None:
        R = camera_pose[:3, :3]
        t = camera_pose[:3, 3]
        pts_world = np.einsum("ik, vuk -> vui", R, pts_cam) + t
    else:
        pts_world = pts_cam

    return pts_world, valid_mask

class SevenSceneSequence:
    def __init__(
            self,
            seq_dir_path,
        ):
        self.seq_dir_path = seq_dir_path
        # Find all the filenames end with ".color.png"
        # and check if corresponding ".proj.png" and ".pose.txt" exists

        _color_files = [f for f in os.listdir(seq_dir_path) if f.endswith(".color.png")]
        frame_names = [f.rstrip(".color.png") for f in _color_files]

        self.valid_frame_names = []
        for name in frame_names:
            proj_path = osp.join(seq_dir_path, f"{name}.depth.proj.png")
            pose_path = osp.join(seq_dir_path, f"{name}.pose.txt")
            if osp.isfile(proj_path) and osp.isfile(pose_path):
                self.valid_frame_names.append(name)
        self.valid_frame_names = sorted(self.valid_frame_names)

        print(f"{len(self.valid_frame_names)} frames collected in {self.seq_dir_path}!!")
        print(f"{len(_color_files) - len(self.valid_frame_names)} rgb frames miss .proj.png or .pose.txt!!")


    def get_views(self,kf_every = 200)->List[Dict]:

        names = self.valid_frame_names[::kf_every] # select 1 out of every kf_every frames for reconstruction

        views = []
        """
        For each view(key frame), we compute the following metric
        """
        for idx,name in enumerate(names):
            view = dict()

            impath = osp.join(self.seq_dir_path, f"{name}.color.png")
            depthpath = osp.join(self.seq_dir_path, f"{name}.depth.proj.png")
            posepath = osp.join(self.seq_dir_path, f"{name}.pose.txt")
            view["name"] = f'{self.seq_dir_path}/{name}'

            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
            rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))

            width, height = Image.fromarray(rgb_image).size
            assert (width,height) == (640,480)
            view['img'] = (rgb_image / 255.0 ).astype(np.float32)# Normalize to 0 to 1 for open3d format
            view["true_shape"] = np.int32((height, width))

            depthmap[depthmap == 65535] = 0
            depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
            depthmap[depthmap > 10] = 0
            depthmap[depthmap < 1e-3] = 0
            assert np.isfinite(depthmap).all(), \
                f"NaN in depthmap for view {view['name']}"
            view['depthmap'] = depthmap

            camera_pose = np.loadtxt(posepath).astype(np.float32)
            fx, fy, cx, cy = INTRINSINC ### NOTE: This intrinsic does not match with that on internet
            intrinsics = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
            assert np.isfinite(camera_pose).all(), \
                f"NaN in camera pose for view {view['name']}"

            view['camera_pose'] = camera_pose
            view['camera_intrinsics'] = intrinsics

            # encode the image
            pts3d, valid_mask = depthmap_to_world_coordinates(**view)
            view["pts3d"] = pts3d
            view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
            view["img_mask"] = True

            # check all datatypes
            for key, val in view.items():
                res, err_msg = self._is_good_type(key, val)
                assert res, f"{err_msg} with {key}={val} for view {view['name']}"

            views.append(view)

        for view in views:
            height, width = view['true_shape']
            assert width >= height, ValueError("Width > Height")

        return views

    def _is_good_type(self,key, v):
        """returns (is_good, err_msg)"""
        if isinstance(v, (str, int, tuple)):
            return True, None
        if v.dtype not in (np.float32, bool, np.int32, np.int64, np.uint8):
            return False, f"bad {v.dtype=}"
        return True, None

def seq2ply(seq_dir_path, ply_path, kf_every = 1, crop_size = None, voxel_grid_size = None):
    """
    Converts a sequence of frames into a single 3D point cloud and saves it as a .ply file.

    Parameters:
        seq_dir_path (str): Path to the sequence directory. This directory should contain multiple
                            frame subdirectories or files, each including:
                                - .color.png: RGB image
                                - .proj.png: Projected depth or coordinate image
                                - .pose.txt: Camera pose matrix (usually 4x4)

        ply_path (str): Destination path for the output .ply point cloud file.
        kf_every (int): Selec key frame every "kf_every" frames for building points cloud

    Description:
        This function reads all frames in the given sequence directory, reconstructs 3D points using the color,
        projection, and pose data, merges them into a single point cloud, and writes the result to
        a .ply file.
    """
    # Step 1: Collect the necessary information of frames for reconstruction
    seq = SevenSceneSequence(seq_dir_path = seq_dir_path )
    views = seq.get_views(kf_every = kf_every)
    pts_gt_all, images_all,  masks_all = [], [], []

    # Step 2: Only believe the central information of the camera
    assert crop_size is None \
        or isinstance(crop_size, int), \
        "crop_size must be None or an integer"

    for _, view in enumerate(views):
        image = view["img"]  # W,H,3
        mask = view["valid_mask"]    # W,H
        pts_gt = view['pts3d'] # W,H,3

        # Center on the given window size
        if crop_size is not None:
            H, W = image.shape[:2]
            if crop_size > H or crop_size > W:
                print(f"Warning: Adjust crop_size({crop_size}) since it exceeds H({H}) or W({W})")
                crop_size = min(W,H)
            _shift = crop_size//2
            cx,cy = W // 2,H // 2
            l, t = cx - _shift, cy - _shift # left, top
            r, b = cx + _shift, cy + _shift # right, bottom

            image = image[t:b, l:r]
            mask = mask[t:b, l:r]
            pts_gt = pts_gt[t:b, l:r]

        #### Align predicted 3D points to the ground truth
        images_all.append( image[None, ...] )
        pts_gt_all.append( pts_gt[None, ...] )
        masks_all.append( mask[None, ...] )


    # Step 3: Build the 3D points map
    images_all = np.concatenate(images_all, axis=0)
    pts_gt_all = np.concatenate(pts_gt_all, axis=0)
    masks_all = np.concatenate(masks_all, axis=0)
    pts_gt_all_masked = pts_gt_all[masks_all > 0]
    images_all_masked = images_all[masks_all > 0]

    #save_params = {}
    #save_params["images_all"] = images_all
    #save_params["pts_gt_all"] = pts_gt_all
    #save_params["masks_all"] = masks_all
    #np.save(_path_,save_params,)

    pcd_gt = o3d.geometry.PointCloud()
    pcd_gt.points = o3d.utility.Vector3dVector(
        pts_gt_all_masked.reshape(-1, 3)
    )
    pcd_gt.colors = o3d.utility.Vector3dVector(
        images_all_masked.reshape(-1, 3)
    )
    print(f'Points Cloud has {len(pcd_gt.points)} points')
    if voxel_grid_size is not None:
        pcd_gt = pcd_gt.voxel_down_sample(voxel_size=voxel_grid_size)
        print(f'After downsample, Points Cloud has {len(pcd_gt.points)} points')

    o3d.io.write_point_cloud(ply_path, pcd_gt, )

In [None]:
def generate_ground_truth_ply(scenes_root, pointcloud_root, scene_list, split='train',
                               kf_every=20, voxel_grid_size=0.0075, enable=True):
    """
    產生 ground truth 點雲並儲存成 .ply 檔案。

    Args:
        scenes_root (str): 圖像的根目錄 (ex: '../7SCENES')
        pointcloud_root (str): 儲存 .ply 的資料夾 (ex: './train_truth')
        scene_list (List[str]): 需要處理的 scene 名稱列表
        split (str): 'train' 或 'test'
        kf_every (int): 每隔幾張 frame 選一次 keyframe
        voxel_grid_size (float): voxel downsample 的大小
        enable (bool): 若為 False，將完全跳過處理
    """
    if not enable:
        print(f"[INFO] Ground truth generation for split '{split}' is disabled.")
        return
    os.makedirs(pointcloud_root, exist_ok=True)

    for scene in scene_list:
        root_path = osp.join(scenes_root, scene)
        split_path = osp.join(root_path, split)
        split_txt = osp.join(root_path, f"{split.capitalize()}Split.txt")

        if not osp.isfile(split_txt):
            print(f"[WARNING] {split_txt} 不存在，跳過 scene {scene}")
            continue

        with open(split_txt, "r") as f:
            seq_names = [line.strip() for line in f.readlines()]

        for seq in seq_names:
            seq_num = int(seq.replace("sequence", ""))
            seq_dir = osp.join(split_path, f"seq-{seq_num:02d}")
            if not osp.isdir(seq_dir):
                print(f"[WARNING] 資料夾不存在：{seq_dir}，跳過")
                continue

            ply_path = osp.join(pointcloud_root, f"{scene}-seq-{seq_num}.ply")
            if osp.isfile(ply_path):
                print(f"[SKIP] 已存在：{ply_path}")
                continue

            print(f"[INFO] 正在處理：{scene} - seq-{seq_num:02d}")
            seq2ply(seq_dir, ply_path, kf_every=kf_every, voxel_grid_size=voxel_grid_size)

In [None]:
import torchvision.transforms as tvf
import PIL.Image
from pillow_heif import register_heif_opener
import re
from PIL import ExifTags
def exif_transpose(image: Image.Image, *, in_place: bool = False) -> Image.Image | None:
    """
    If an image has an EXIF Orientation tag, other than 1, transpose the image
    accordingly, and remove the orientation data.

    :param image: The image to transpose.
    :param in_place: Boolean. Keyword-only argument.
        If ``True``, the original image is modified in-place, and ``None`` is returned.
        If ``False`` (default), a new :py:class:`~PIL.Image.Image` object is returned
        with the transposition applied. If there is no transposition, a copy of the
        image will be returned.
    """
    image.load()
    image_exif = image.getexif()
    orientation = image_exif.get(ExifTags.Base.Orientation, 1)
    method = {
        2: Image.Transpose.FLIP_LEFT_RIGHT,
        3: Image.Transpose.ROTATE_180,
        4: Image.Transpose.FLIP_TOP_BOTTOM,
        5: Image.Transpose.TRANSPOSE,
        6: Image.Transpose.ROTATE_270,
        7: Image.Transpose.TRANSVERSE,
        8: Image.Transpose.ROTATE_90,
    }.get(orientation)
    if method is not None:
        if in_place:
            image.im = image.im.transpose(method)
            image._size = image.im.size
        else:
            transposed_image = image.transpose(method)
        exif_image = image if in_place else transposed_image

        exif = exif_image.getexif()
        if ExifTags.Base.Orientation in exif:
            del exif[ExifTags.Base.Orientation]
            if "exif" in exif_image.info:
                exif_image.info["exif"] = exif.tobytes()
            elif "Raw profile type exif" in exif_image.info:
                exif_image.info["Raw profile type exif"] = exif.tobytes().hex()
            for key in ("XML:com.adobe.xmp", "xmp"):
                if key in exif_image.info:
                    for pattern in (
                        r'tiff:Orientation="([0-9])"',
                        r"<tiff:Orientation>([0-9])</tiff:Orientation>",
                    ):
                        value = exif_image.info[key]
                        exif_image.info[key] = (
                            re.sub(pattern, "", value)
                            if isinstance(value, str)
                            else re.sub(pattern.encode(), b"", value)
                        )
        if not in_place:
            return transposed_image
    elif not in_place:
        return image.copy()
    return None
def _resize_pil_image(img, long_edge_size):
    S = max(img.size)
    if S > long_edge_size:
        interp = PIL.Image.LANCZOS
    elif S <= long_edge_size:
        interp = PIL.Image.BICUBIC
    new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
    return img.resize(new_size, interp)

def load_images(folder_or_list, size, square_ok=False, verbose=True, rotate_clockwise_90=False, crop_to_landscape=False):
    """open and convert all images in a list or folder to proper input format for DUSt3R"""
    ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    try:
        register_heif_opener()
        heif_support_enabled = True
    except ImportError:
        heif_support_enabled = False

    if isinstance(folder_or_list, str):
        if verbose:
            print(f">> Loading images from {folder_or_list}")
        root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))

    elif isinstance(folder_or_list, list):
        if verbose:
            print(f">> Loading a list of {len(folder_or_list)} images")
        root, folder_content = "", folder_or_list

    else:
        raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")

    supported_images_extensions = [".jpg", ".jpeg", ".png"]
    if heif_support_enabled:
        supported_images_extensions += [".heic", ".heif"]
    supported_images_extensions = tuple(supported_images_extensions)

    imgs = []
    for path in folder_content:
        if not path.lower().endswith(supported_images_extensions):
            continue
        img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
        if rotate_clockwise_90:
            img = img.rotate(-90, expand=True)
        if crop_to_landscape:
            # Crop to a landscape aspect ratio (e.g., 16:9)
            desired_aspect_ratio = 4 / 3
            width, height = img.size
            current_aspect_ratio = width / height

            if current_aspect_ratio > desired_aspect_ratio:
                # Wider than landscape: crop width
                new_width = int(height * desired_aspect_ratio)
                left = (width - new_width) // 2
                right = left + new_width
                top = 0
                bottom = height
            else:
                # Taller than landscape: crop height
                new_height = int(width / desired_aspect_ratio)
                top = (height - new_height) // 2
                bottom = top + new_height
                left = 0
                right = width

            img = img.crop((left, top, right, bottom))

        W1, H1 = img.size
        if size == 224:
            # resize short side to 224 (then crop)
            img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
        else:
            # resize long side to 512
            img = _resize_pil_image(img, size)
        W, H = img.size
        cx, cy = W // 2, H // 2
        if size == 224:
            half = min(cx, cy)
            img = img.crop((cx - half, cy - half, cx + half, cy + half))
        else:
            halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
            if not (square_ok) and W == H:
                halfh = 3 * halfw / 4
            img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))

        W2, H2 = img.size
        if verbose:
            print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
        imgs.append(
            dict(
                img=ImgNorm(img)[None],
                true_shape=np.int32([img.size[::-1]]),
                idx=len(imgs),
                instance=str(len(imgs)),
            )
        )

    assert imgs, "no images foud at " + root
    if verbose:
        print(f" (Found {len(imgs)} images)")
    return imgs

In [None]:
#generate train/test dataset
def get_keyframe_paths(seq_dir, kf_every=20):
    image_paths = []
    for fname in sorted(os.listdir(seq_dir)):
        if fname.endswith(".color.png"):
            frame_id = int(fname.replace("frame-", "").replace(".color.png", ""))
            if frame_id % kf_every == 0:
                image_paths.append(os.path.join(seq_dir, fname)
                )

    return image_paths
def load_depth_cv2(path, size, square_ok=False):
    depth = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
    depth[depth == 65535] = np.nan
    depth /= 1000.0
    depth[(depth < 1e-3) | (depth > 10.0)] = np.nan

    H1, W1 = depth.shape

    # Resize：與 load_images 對齊
    if size == 224:
        # resize short side to 224
        scale = round(size * max(W1 / H1, H1 / W1)) / max(W1, H1)
        new_size = (int(W1 * scale), int(H1 * scale))
    else:
        # resize long side to size
        scale = size / max(W1, H1)
        new_size = (int(W1 * scale), int(H1 * scale))

    depth = cv2.resize(depth, new_size, interpolation=cv2.INTER_NEAREST)

    # Center crop
    H, W = depth.shape
    cx, cy = W // 2, H // 2
    if size == 224:
        half = min(cx, cy)
        depth = depth[cy - half:cy + half, cx - half:cx + half]
    else:
        halfw = ((2 * cx) // 16) * 8
        halfh = ((2 * cy) // 16) * 8
        if not square_ok and W == H:
            halfh = int(3 * halfw / 4)
        depth = depth[cy - halfh:cy + halfh, cx - halfw:cx + halfw]


    mask = (~np.isnan(depth)).astype(np.float32)
    depth = np.nan_to_num(depth, nan=0.0)

    return depth, mask

class MultiViewPointCloudDataset(Dataset):
    def __init__(self, scenes_root, pointcloud_root, scene_list, kf_every=20, views_per_sample=5, size=384, split='train'):
        """
        scenes_root: 路径，如 '../7SCENES'
        pointcloud_root: 对应点云的路径，如 './test_truth'
        scene_list: ['chess', 'fire', ...]
        """
        assert split in ['train', 'test'], "split must be 'train' or 'test'"
        self.samples = []
        self.size = size
        self.views_per_sample = views_per_sample

        for scene in scene_list:
            scene_path = osp.join(scenes_root, scene)
            split_dir = osp.join(scene_path, split)
            split_txt = osp.join(scene_path, f'{split.capitalize()}Split.txt')

            with open(split_txt, "r") as f:
                seq_names = [line.strip() for line in f.readlines()]

            for seq in seq_names:
                seq_num = int(seq.replace("sequence", ""))
                seq_dir = osp.join(split_dir, f"seq-{seq_num:02d}")
                ply_path = osp.join(pointcloud_root, f"{scene}-seq-{seq_num}.ply")

                if not osp.isdir(seq_dir) :
                    continue

                image_paths = get_keyframe_paths(seq_dir, kf_every=kf_every)
                if len(image_paths) < views_per_sample:
                    continue

                self.samples.append({
                    "image_paths": image_paths[:-1],
                    "ply_path": ply_path
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_paths = sample["image_paths"]
        ply_path = sample["ply_path"]
        depth_paths = [path.replace("color.png", "depth.proj.png") for path in image_paths]
        depthmaps = []
        valid_masks = []
        for path in depth_paths:

            depth, mask = load_depth_cv2(path, self.size)
            depthmaps.append(torch.from_numpy(depth[None]))     # [1, H, W]
            valid_masks.append(torch.from_numpy(mask[None]))     # [1, H, W]

        depths = torch.stack(depthmaps)       # [V, 1, H, W]
        masks = torch.stack(valid_masks)      # [V, 1, H, W]

        images = load_images(image_paths, size=self.size, verbose=False)
        images = torch.stack([img_dict["img"].squeeze(0) for img_dict in images])
        gt_pcd = o3d.io.read_point_cloud(ply_path)
        gt_points = np.asarray(gt_pcd.points).astype(np.float32)  # (N, 3)

        return {
            "images": images,              # [V, 3, H, W] or list
            "depths": depths,              # [V, 1, H, W]
            "masks": masks,                # [V, 1, H, W]
            "target_pointcloud": torch.from_numpy(gt_points),  # [N, 3]
            "image_paths": image_paths
        }

In [None]:
from torchvision.models import resnet18
from torchvision.models.feature_extraction import create_feature_extractor
#model resnet+ace
class ACEHead(nn.Module):
    def __init__(self, in_channels=512, mid_channels=256):
        super().__init__()
        # dense block + skip block (parallel)
        self.skip = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1),
            nn.PReLU(),
            nn.Conv2d(mid_channels, mid_channels, 1),
            nn.PReLU(),
        )
        self.dense = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1),
            nn.PReLU(),
            nn.Conv2d(mid_channels, mid_channels, 1),
            nn.PReLU(),
        )

        self.eca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),                  # [B, C, 1, 1]
            nn.Conv2d(mid_channels * 2, 1, kernel_size=1),  # 用 Conv2d 而不是 Conv1d
            nn.Sigmoid()
        )

        self.output_layer = nn.Conv2d(mid_channels * 2, 4, 1)  # [x, y, z, w_hat]

    def forward(self, x):
        x1 = self.skip(x)
        x2 = self.dense(x)
        x_cat = torch.cat([x1, x2], dim=1)

        # channel attention
        attn = self.eca(x_cat).view(x_cat.shape[0], -1, 1, 1)
        x_attn = x_cat * attn

        return self.output_layer(x_attn)


class ResNetBackbone(nn.Module):
    def __init__(self, in_channels=3, out_layer='layer4'):
        super().__init__()
        model = resnet18(pretrained=True)
        model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        model.layer4[0].conv1.stride = (1, 1)
        model.layer4[0].downsample[0].stride = (1, 1)

        return_nodes = {out_layer: "features"}
        self.backbone = create_feature_extractor(model, return_nodes=return_nodes)

    def forward(self, x):
        return self.backbone(x)["features"]  # [B, 512, H', W']

class RGBDPointPredictor(nn.Module):
    def __init__(self, output_size=(288, 384)):
        super().__init__()
        self.rgb_backbone = ResNetBackbone(in_channels=3)
        self.depth_backbone = ResNetBackbone(in_channels=1)
        self.fusion_conv = nn.Conv2d(512 * 2, 512, kernel_size=1)  # fuse channel-wise
        self.head = ACEHead(in_channels=512)
        self.output_size = output_size  # (H, W)

    def forward(self, rgb, depth):
        """
        rgb: [B, 3, H, W]  (e.g., [B, 3, 288, 384])
        depth: [B, 1, H, W]
        """
        rgb_feat = self.rgb_backbone(rgb)      # [B, 512, H', W']
        dpt_feat = self.depth_backbone(depth)  # [B, 512, H', W']
        feat = torch.cat([rgb_feat, dpt_feat], dim=1)  # [B, 1024, H', W']
        fused = self.fusion_conv(feat)                # [B, 512, H', W']
        output = self.head(fused)                     # [B, 4, H', W']

        output_upsampled = F.interpolate(
            output,
            size=self.output_size,
            mode="bilinear",
            align_corners=False
        )  # [B, 4, 288, 384]

        return output_upsampled


In [None]:
from scipy.spatial import cKDTree
def chamfer_distance(p1, p2):
    """
    Chamfer Distance between two point clouds without batch dim.

    Args:
        p1: Tensor (P1, D)
        p2: Tensor (P2, D)

    Returns:
        scalar loss
    """
    diff = p1.unsqueeze(1) - p2.unsqueeze(0)   # (P1, P2, D)
    dist = torch.sum(diff ** 2, dim=-1)        # (P1, P2)

    min_dist_p1, _ = torch.min(dist, dim=1)    # (P1,)
    min_dist_p2, _ = torch.min(dist, dim=0)    # (P2,)

    loss = min_dist_p1.mean() + min_dist_p2.mean()
    return loss
def point_cloud_accuracy(pred_points: np.ndarray, gt_points: np.ndarray) -> float:
    """
    Median distance from each predicted point to its nearest ground truth point.
    """
    tree = cKDTree(gt_points)
    distances, _ = tree.query(pred_points, k=1)
    return np.median(distances)

def point_cloud_completeness(pred_points: np.ndarray, gt_points: np.ndarray) -> float:
    """
    Median distance from each ground-truth point to its nearest predicted point.
    """
    tree = cKDTree(pred_points)
    distances, _ = tree.query(gt_points, k=1)
    return np.median(distances)

In [None]:
# train/test dependency
def random_sampling(points, num_samples=2048):
    if points.shape[0] > num_samples:
        idx = torch.randperm(points.shape[0])[:num_samples]
        return points[idx]
    else:
        return points
def extract_scene_and_seq(path):
    # path: './7SCENES/stairs/train/seq-06/frame-000000.color.png'
    parts = path.split(os.sep)
    scene = parts[-4]               # 'stairs'
    sequence_id = int(parts[-2].split('-')[1])  # '06' → 6
    return scene, sequence_id
def evaluate(model, loader, device, desc="Evaluation", save_results=False, save_dir="./test"):
  model.eval()
  total_loss = 0
  total_acc = 0
  total_comp = 0
  with torch.no_grad():
      for data in tqdm(loader, desc=desc, leave=False):
          W = data["images"].shape[-1]
          H = data["images"].shape[-2]
          images = data["images"].view(-1, 3, H, W).to(device)
          depths = data["depths"].view(-1, 1, H, W).to(device)
          target_pcd = data["target_pointcloud"].to(device)

          pred = model(images, depths)
          xyz = pred[:, :3].permute(0, 2, 3, 1).reshape(-1, 3)
          xyz_sampled = random_sampling(xyz, 8192)
          target_sampled = random_sampling(target_pcd[0], 8192)

          loss = chamfer_distance(xyz_sampled, target_sampled)
          acc = point_cloud_accuracy(xyz_sampled.detach().cpu().numpy(), target_sampled.detach().cpu().numpy())
          comp = point_cloud_completeness(xyz_sampled.detach().cpu().numpy(), target_sampled.detach().cpu().numpy())

          total_loss += loss.item()
          total_acc += acc
          total_comp += comp
          # Save prediction as .ply file
          if save_results:
              scene, sequence_id = extract_scene_and_seq(data["image_paths"][0][0])
              save_path = os.path.join(save_dir, f"{scene}-seq-{sequence_id:02d}.ply")


              pcd = o3d.geometry.PointCloud()
              pcd.points = o3d.utility.Vector3dVector(xyz_sampled.detach().cpu().numpy())
              o3d.io.write_point_cloud(save_path, pcd)

  num_batches = len(loader)
  return total_loss / num_batches, total_acc / num_batches, total_comp / num_batches

In [None]:
# train.py
from torch.utils.data import random_split
from tqdm import tqdm
import argparse
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train(device):
    dataset = MultiViewPointCloudDataset(
        scenes_root="./7SCENES",
        pointcloud_root="./train_truth",
        scene_list = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'],
        kf_every=10,
        views_per_sample=10, #total choose img number
        size=384,
        split='train'
    )

    val_size = int(0.3 * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

    model = RGBDPointPredictor().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    save_dir = "./checkpoints"
    os.makedirs(save_dir, exist_ok=True)

    num_epochs = 60
    patience = 5  # early stop patience
    best_val_loss = float("inf")
    best_epoch = 0
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss, total_acc, total_comp = 0, 0, 0
        progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for data in progress:
            W, H = data["images"].shape[-1], data["images"].shape[-2]
            images = data["images"].view(-1, 3, H, W).to(device)
            depths = data["depths"].view(-1, 1, H, W).to(device)
            target_pcd = data["target_pointcloud"].to(device)

            pred = model(images, depths)
            xyz = pred[:, :3].permute(0, 2, 3, 1).reshape(-1, 3)
            xyz_sampled = random_sampling(xyz, 2048)
            target_sampled = random_sampling(target_pcd[0], 8192)

            loss = chamfer_distance(xyz_sampled, target_sampled)
            acc = point_cloud_accuracy(xyz_sampled.detach().cpu().numpy(), target_sampled.detach().cpu().numpy())
            comp = point_cloud_completeness(xyz_sampled.detach().cpu().numpy(), target_sampled.detach().cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_acc += acc
            total_comp += comp
            progress.set_postfix(loss=loss.item(), acc=acc, comp=comp)

        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        avg_comp = total_comp / len(train_loader)
        val_loss, val_acc, val_comp = evaluate(model, val_loader, device, desc="Validation")
        scheduler.step(val_loss)


        print(f"[Epoch {epoch+1}] Train Loss: {avg_loss:.4f} | Acc: {avg_acc:.4f} | Comp: {avg_comp:.4f}")
        print(f"               Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Comp: {val_comp:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch + 1
            patience_counter = 0
            torch.save(model.state_dict(), f"{save_dir}/best_model.pth")
            print(f"Saved new best model at epoch {epoch+1}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}. Best model from epoch {best_epoch}")
            break


In [None]:
#test and create folder
def test(device):
    test_dataset = MultiViewPointCloudDataset(
        scenes_root="./7SCENES",
        pointcloud_root="./test_truth",
        scene_list = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'],
        kf_every=20,
        views_per_sample=30,
        size=384,
        split='test'
    )
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    model = RGBDPointPredictor().to(device)
    model.load_state_dict(torch.load("./checkpoints/best_model.pth"))
    test_loss, test_acc, test_comp = evaluate(
        model, test_loader, device,
        desc="Test",
        save_results=True,
        save_dir="./test"
    )

    print(f"[Test] Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | Comp: {test_comp:.4f}")

In [None]:
!mkdir -p test_truth
!mkdir -p train_truth
!mkdir -p test

In [None]:
scene_list = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
generate_ground_truth_ply(
    scenes_root="./7SCENES",
    pointcloud_root="./train_truth",
    scene_list=scene_list,
    split="train",
    enable=True
)

# test split GT
generate_ground_truth_ply(
    scenes_root="./7SCENES",
    pointcloud_root="./test_truth",
    scene_list=scene_list,
    split="test",
    enable=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train(device)

Epoch 1/60: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, acc=0.284, comp=0.662, loss=1.09]


[Epoch 1] Train Loss: 3.1453 | Acc: 0.4537 | Comp: 1.3342
               Val   Loss: 0.7786 | Acc: 0.3110 | Comp: 0.5242
Saved new best model at epoch 1


Epoch 2/60: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, acc=0.259, comp=0.343, loss=0.504]


[Epoch 2] Train Loss: 0.8330 | Acc: 0.3184 | Comp: 0.5058
               Val   Loss: 0.3461 | Acc: 0.2910 | Comp: 0.2484
Saved new best model at epoch 2


Epoch 3/60: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, acc=0.189, comp=0.175, loss=0.173]


[Epoch 3] Train Loss: 0.3704 | Acc: 0.2749 | Comp: 0.2495
               Val   Loss: 0.4578 | Acc: 0.2711 | Comp: 0.2881
No improvement. Patience: 1/5


Epoch 4/60: 100%|██████████| 20/20 [00:16<00:00,  1.24it/s, acc=0.143, comp=0.214, loss=0.244]


[Epoch 4] Train Loss: 0.2781 | Acc: 0.2441 | Comp: 0.2388
               Val   Loss: 0.2495 | Acc: 0.1829 | Comp: 0.2129
Saved new best model at epoch 4


Epoch 5/60: 100%|██████████| 20/20 [00:15<00:00,  1.29it/s, acc=0.181, comp=0.233, loss=0.16]


[Epoch 5] Train Loss: 0.2125 | Acc: 0.2078 | Comp: 0.2290
               Val   Loss: 0.2842 | Acc: 0.1725 | Comp: 0.2385
No improvement. Patience: 1/5


Epoch 6/60: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, acc=0.217, comp=0.187, loss=0.172]


[Epoch 6] Train Loss: 0.1743 | Acc: 0.1871 | Comp: 0.2039
               Val   Loss: 0.3049 | Acc: 0.1604 | Comp: 0.2782
No improvement. Patience: 2/5


Epoch 7/60: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, acc=0.414, comp=0.296, loss=0.383]


[Epoch 7] Train Loss: 0.1588 | Acc: 0.1778 | Comp: 0.1981
               Val   Loss: 0.2575 | Acc: 0.1614 | Comp: 0.1956
No improvement. Patience: 3/5


Epoch 8/60: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, acc=0.114, comp=0.21, loss=0.128]


[Epoch 8] Train Loss: 0.1361 | Acc: 0.1643 | Comp: 0.1956
               Val   Loss: 0.2408 | Acc: 0.1571 | Comp: 0.2076
Saved new best model at epoch 8


Epoch 9/60: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, acc=0.544, comp=0.178, loss=0.387]


[Epoch 9] Train Loss: 0.1273 | Acc: 0.1563 | Comp: 0.1857
               Val   Loss: 0.2376 | Acc: 0.1530 | Comp: 0.2342
Saved new best model at epoch 9


Epoch 10/60: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, acc=0.115, comp=0.193, loss=0.106]


[Epoch 10] Train Loss: 0.1235 | Acc: 0.1566 | Comp: 0.1851
               Val   Loss: 0.2386 | Acc: 0.1543 | Comp: 0.2226
No improvement. Patience: 1/5


Epoch 11/60: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, acc=0.0982, comp=0.185, loss=0.099]


[Epoch 11] Train Loss: 0.1137 | Acc: 0.1479 | Comp: 0.1781
               Val   Loss: 0.2769 | Acc: 0.1552 | Comp: 0.2588
No improvement. Patience: 2/5


Epoch 12/60: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, acc=0.131, comp=0.16, loss=0.108]


[Epoch 12] Train Loss: 0.1084 | Acc: 0.1455 | Comp: 0.1767
               Val   Loss: 0.2772 | Acc: 0.1571 | Comp: 0.2374
No improvement. Patience: 3/5


Epoch 13/60: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s, acc=0.103, comp=0.196, loss=0.0818]


[Epoch 13] Train Loss: 0.1031 | Acc: 0.1411 | Comp: 0.1759
               Val   Loss: 0.2453 | Acc: 0.1490 | Comp: 0.2254
No improvement. Patience: 4/5


Epoch 14/60: 100%|██████████| 20/20 [00:15<00:00,  1.29it/s, acc=0.116, comp=0.128, loss=0.0584]
                                                         

[Epoch 14] Train Loss: 0.1004 | Acc: 0.1415 | Comp: 0.1704
               Val   Loss: 0.2567 | Acc: 0.1509 | Comp: 0.2361
No improvement. Patience: 5/5
Early stopping at epoch 14. Best model from epoch 9




In [None]:
test(device) #test and save

                                                     

[Test] Loss: 0.8222 | Acc: 0.5178 | Comp: 0.1784


