In [None]:
import os
import numpy as np
import h5py
from tqdm import tqdm
from plyfile import PlyData
from your_module import process_all_files, label_point_cloud  # Replace with actual module name

In [None]:
def load_ply_to_array(ply_path):
    ply = PlyData.read(ply_path)
    vertex_data = ply['vertex']
    return np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T

def load_npz_to_array(npz_path):
    data = np.load(npz_path)
    return data['points'], data['labels']

def main(data_root, output_h5_path, temp_output_folder='temp_outputs'):
    os.makedirs(temp_output_folder, exist_ok=True)
    
    labeled_data = []
    labeled_labels = []
    unlabeled_data = []
    unlabeled_labels = []

    for subject in tqdm(os.listdir(data_root), desc="Processing subjects"):
        subject_path = os.path.join(data_root, subject)
        if not os.path.isdir(subject_path):
            continue

        obj_file = None
        txt_file = None

        for file in os.listdir(subject_path):
            if file.endswith('.obj'):
                obj_file = os.path.join(subject_path, file)
            elif file.endswith('.txt'):
                txt_file = os.path.join(subject_path, file)

        if obj_file and txt_file:
            # Step 1: Generate unlabeled .ply
            process_all_files(subject_path, temp_output_folder, visualize=False)
            ply_path = os.path.join(temp_output_folder, f"{subject}_unlabeled.ply")
            if not os.path.exists(ply_path):
                print(f"Missing expected .ply file: {ply_path}")
                continue

            # Step 2: Label the point cloud
            npz_output_path = os.path.join(temp_output_folder, f"{subject}_labeled.npz")
            label_point_cloud(ply_path, txt_file, radius=0.006, output_ply=None, save_npz=True)

            if not os.path.exists(npz_output_path):
                print(f"Missing labeled .npz file: {npz_output_path}")
                continue

            # Step 3: Load and collect data
            unlabeled_points = load_ply_to_array(ply_path)
            labeled_points, labels = load_npz_to_array(npz_output_path)

            unlabeled_data.append(unlabeled_points)
            unlabeled_labels.append(np.zeros(unlabeled_points.shape[0], dtype=np.int32))  # Label all as 0

            labeled_data.append(labeled_points)
            labeled_labels.append(labels)  # Assume labels are already correct

    # Step 4: Stack and write to HDF5
    all_points = np.concatenate(unlabeled_data + labeled_data, axis=0)
    all_labels = np.concatenate(unlabeled_labels + labeled_labels, axis=0)

    print(f"Saving to HDF5: {output_h5_path}")
    with h5py.File(output_h5_path, 'w') as f:
        f.create_dataset('data', data=all_points.astype(np.float32))
        f.create_dataset('label', data=all_labels.astype(np.int64))

    print("Done.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate HDF5 dataset for PointNet++")
    parser.add_argument('--data_root', type=str, required=True, help="Root folder of subject folders")
    parser.add_argument('--output_h5', type=str, required=True, help="Output path for .h5 file")
    parser.add_argument('--temp_output', type=str, default='temp_outputs', help="Temporary folder for intermediate files")
    args = parser.parse_args()

    main(args.data_root, args.output_h5, args.temp_output)
