In [1]:
import numpy as np
import argparse
from pruning import prune_3dgs
from utils import load_pointcept_data, load_3dgs_data, update_3dgs_attributes, remove_duplicates, remove_outliers_3dgs
from tqdm import tqdm
import os
import open3d as o3d
from plyfile import PlyData, PlyElement

def save_ply(points, colors, labels, output_path):
    """
    Point Cloud를 PLY 파일로 저장.
    
    Args:
        points (np.ndarray): 점 좌표 (N, 3).
        colors (np.ndarray): 점 색상 (N, 3), 0~255 범위.
        labels (np.ndarray): 점 라벨 (N,).
        output_path (str): 저장할 PLY 파일 경로.
    """
    # PLY 파일 형식으로 데이터 준비
    vertex = np.array(
        [(p[0], p[1], p[2], c[0], c[1], c[2], l) for p, c, l in zip(points, colors, labels)],
        dtype=[
            ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
            ('label', 'i4')
        ]
    )
    el = PlyElement.describe(vertex, 'vertex')
    PlyData([el], text=True).write(output_path)
    print(f"Saved PLY file to {output_path}")

def merge_pointcept_with_3dgs(pointcept_dir, path_3dgs, output_dir, prune_methods=None, prune_params=None):
    """
    Pointcept Point Cloud와 3DGS Point Cloud를 병합하고 PLY 파일로 저장.
    
    Args:
        pointcept_dir (str): Pointcept 데이터 디렉토리 (예: scannet/train/scene0000_00).
        path_3dgs (str): 3DGS Point Cloud 경로.
        output_dir (str): 출력 디렉토리 (예: scannet_merged/train/scene0000_00).
        prune_methods (dict): 적용할 pruning 방법 및 ratio.
        prune_params (dict): pruning 하이퍼파라미터.
    """
    # 기본값 설정
    if prune_methods is None:
        prune_methods = {
            'scale': False,
            'scale_ratio': 0.0,
            'opacity': False,
            'opacity_ratio': 0.0,
            'distance': True,
            'sor': True
        }
    if prune_params is None:
        prune_params = {
            'distance_threshold': 0.05,
            'sor_nb_neighbors': 20,
            'sor_std_ratio': 2.0
        }

    # 1. Pointcept 데이터 로드 (.npy 파일에서)
    pointcept_data = load_pointcept_data(pointcept_dir)
    points_pointcept = pointcept_data['coord']
    colors_pointcept = pointcept_data['color']
    labels_pointcept = pointcept_data['segment20']
    labels200_pointcept = pointcept_data['segment200']
    instances_pointcept = pointcept_data['instance']

    # 2. 3DGS 데이터 로드
    points_3dgs, normals_3dgs, vertex_data_3dgs = load_3dgs_data(path_3dgs)

    # 3. 3DGS 점 pruning
    points_3dgs, normals_3dgs, vertex_data_3dgs = prune_3dgs(
        vertex_data_3dgs, points_3dgs, normals_3dgs, points_pointcept, prune_methods, prune_params
    )

    # 4. 3DGS 점의 색상과 라벨을 근처 Pointcept 점에서 복사
    colors_3dgs, labels_3dgs, labels200_3dgs, instances_3dgs = update_3dgs_attributes(
        points_3dgs, points_pointcept, colors_pointcept, labels_pointcept, labels200_pointcept, instances_pointcept
    )

    # 5. 중복 점 제거 (3DGS 점 제거)
    points_3dgs, normals_3dgs, labels_3dgs, labels200_3dgs, instances_3dgs, colors_3dgs = remove_duplicates(
        points_3dgs, points_pointcept, normals_3dgs, labels_3dgs, labels200_3dgs, instances_3dgs, colors_3dgs
    )

    # 6. 3DGS 점에 Outlier 제거 적용
    points_3dgs, colors_3dgs, normals_3dgs, labels_3dgs, labels200_3dgs, instances_3dgs = remove_outliers_3dgs(
        points_3dgs, colors_3dgs, normals_3dgs, labels_3dgs, labels200_3dgs, instances_3dgs
    )

    # 7. 병합
    points_merged = np.vstack((points_pointcept, points_3dgs))
    colors_merged = np.vstack((colors_pointcept, colors_3dgs))
    labels_merged = np.concatenate((labels_pointcept, labels_3dgs))
    print(f"Final merged points: {len(points_merged)} (Pointcept: {len(points_pointcept)}, 3DGS: {len(points_3dgs)})")

    # 8. PLY 파일로 저장
    os.makedirs(output_dir, exist_ok=True)
    output_ply_path = os.path.join(output_dir, "merged_point_cloud.ply")
    save_ply(points_merged, colors_merged, labels_merged, output_ply_path)


def process_single_scene(scene, input_root, output_root, path_3dgs_root, split, prune_methods, prune_params):
    """
    단일 scene을 처리하는 함수.
    
    Args:
        scene (str): 처리할 scene 이름.
        input_root (str): 입력 Pointcept 데이터의 루트 디렉토리.
        output_root (str): 출력 디렉토리.
        path_3dgs_root (str): 3DGS 데이터의 루트 디렉토리.
        split (str): 'train' 또는 'val'.
        prune_methods (dict): 적용할 pruning 방법 및 ratio.
        prune_params (dict): pruning 하이퍼파라미터.
    """
    try:
        # 입력 경로
        pointcept_dir = os.path.join(input_root, split, scene)
        path_3dgs = os.path.join(path_3dgs_root, scene, "point_cloud.ply")
        # 출력 경로
        output_dir = os.path.join(output_root, split, scene)

        if not os.path.exists(os.path.join(pointcept_dir, "coord.npy")):
            print(f"Pointcept data not found: {pointcept_dir}")
            return
        if not os.path.exists(path_3dgs):
            print(f"3DGS PLY file not found: {path_3dgs}")
            return

        # 병합 및 PLY 파일로 저장
        merge_pointcept_with_3dgs(pointcept_dir, path_3dgs, output_dir, prune_methods, prune_params)
    except Exception as e:
        print(f"Error processing scene {scene}: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process a single ScanNet scene with 3DGS merging and save as PLY for visualization.")
    parser.add_argument(
        "--scene",
        default="scene0010_00",
        help="Scene to process (e.g., scene0010_00)",
    )
    parser.add_argument(
        "--split",
        default="train",
        choices=["train", "val"],
        help="Split to process (train or val, default: train)",
    )
    parser.add_argument(
        "--input_root",
        default="/home/knuvi/Desktop/song/Pointcept/data/scannet",
        help="Path to the ScanNet dataset root (default: /home/knuvi/Desktop/song/Pointcept/data/scannet)",
    )
    parser.add_argument(
        "--output_root",
        default="/home/knuvi/Desktop/song/Pointcept/data/scannet_merged",
        help="Output path for processed data (default: /home/knuvi/Desktop/song/Pointcept/data/scannet_merged)",
    )
    parser.add_argument(
        "--path_3dgs_root",
        default="/home/knuvi/Desktop/song/data/3dgs_scans/3dgs_output",
        help="Path to the 3DGS dataset (default: /home/knuvi/Desktop/song/data/3dgs_scans/3dgs_output)",
    )
    parser.add_argument(
        "--scale_ratio",
        default=0.3,
        type=float,
        help="Ratio of points to prune based on scale (top X%). If > 0, scale pruning is enabled.",
    )
    parser.add_argument(
        "--opacity_ratio",
        default=0.5,
        type=float,
        help="Ratio of points to prune based on opacity (bottom X%). If > 0, opacity pruning is enabled.",
    )
    parser.add_argument(
        "--enable_distance",
        action="store_true",
        help="Enable distance-based pruning",
    )
    parser.add_argument(
        "--enable_sor",
        action="store_true",
        help="Enable SOR-based pruning",
    )
    args = parser.parse_args()

    # Pruning 방법 설정
    prune_methods = {
        'scale': args.scale_ratio > 0,  # scale_ratio가 0보다 크면 scale pruning 활성화
        'scale_ratio': args.scale_ratio,
        'opacity': args.opacity_ratio > 0,  # opacity_ratio가 0보다 크면 opacity pruning 활성화
        'opacity_ratio': args.opacity_ratio,
        'distance': args.enable_distance,
        'sor': args.enable_sor
    }

    # Pruning 하이퍼파라미터 설정
    prune_params = {
        'distance_threshold': 0.05,
        'sor_nb_neighbors': 20,
        'sor_std_ratio': 2.0
    }

    # 단일 Scene 처리
    process_single_scene(
        args.scene,
        args.input_root,
        args.output_root,
        args.path_3dgs_root,
        args.split,
        prune_methods,
        prune_params
    )

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Initial Pointcept points: 237360
Initial 3DGS points: 314733
Scale (log) distribution: min=-5.604605, max=0.279888, mean=-3.512323
Scale (log) percentiles: 10th=-4.173076, 90th=-2.831165
Exp(Scale) distribution: min=0.003681, max=1.322982, mean=0.035132
Exp(Scale) percentiles: 10th=0.015405, 50th=0.029063, 90th=0.058944
Opacity (logit) distribution: min=-6.829210, max=10.873964, mean=-1.618875
Opacity (logit) percentiles: 10th=-3.813301, 90th=1.171892
Sigmoid(Opacity) distribution: min=0.001081, max=0.999981, mean=0.250290
Sigmoid(Opacity) percentiles: 10th=0.021598, 50th=0.128078, 90th=0.763487
Dynamic Scale Threshold (top 40.0%): 0.032801
Inside prune_3dgs_by_scale - Exp(Scale) distribution: min=0.003681, max=1.322982, mean=0.035132
Inside prune_3dgs_by_scale - Exp(Scale) percentiles: 10th=0.015405, 50th=0.029063, 90t