In [19]:
from tqdm import tqdm
import os
from pathlib import Path
import numpy as np
import plotly.graph_objs as go
from plyfile import PlyData, PlyElement

# Global constant for grid step size
GRID_STEP = 1

def create_grid_cells(vertex_data_1, vertex_data_2):
    # Combine all points to find global min/max
    all_points = np.vstack((vertex_data_1, vertex_data_2))
    
    # Find min and max coordinates
    min_coords = np.min(all_points, axis=0)
    max_coords = np.max(all_points, axis=0)
    
    # Initialize lists for box vertices and indices
    boxes_vertices = []
    occupied_cells = {}  # Changed to dict to store z ranges
    
    # Find occupied cells and track z ranges
    for point in all_points:
        ix = int((point[0] - min_coords[0]) // GRID_STEP)
        iy = int((point[1] - min_coords[1]) // GRID_STEP)
        cell_key = (ix, iy)
        
        if cell_key not in occupied_cells:
            occupied_cells[cell_key] = {'min_z': point[2], 'max_z': point[2]}
        else:
            occupied_cells[cell_key]['min_z'] = min(occupied_cells[cell_key]['min_z'], point[2])
            occupied_cells[cell_key]['max_z'] = max(occupied_cells[cell_key]['max_z'], point[2])
    
    # Create vertices for occupied cells
    for (ix, iy), z_range in occupied_cells.items():
        x = min_coords[0] + ix * GRID_STEP
        y = min_coords[1] + iy * GRID_STEP
        z_min = z_range['min_z']
        z_max = z_range['max_z']
        
        # Create box vertices
        boxes_vertices.extend([
            # Bottom face
            [x, y, z_min], [x + GRID_STEP, y, z_min],
            [x + GRID_STEP, y, z_min], [x + GRID_STEP, y + GRID_STEP, z_min],
            [x + GRID_STEP, y + GRID_STEP, z_min], [x, y + GRID_STEP, z_min],
            [x, y + GRID_STEP, z_min], [x, y, z_min],
            # Top face
            [x, y, z_max], [x + GRID_STEP, y, z_max],
            [x + GRID_STEP, y, z_max], [x + GRID_STEP, y + GRID_STEP, z_max],
            [x + GRID_STEP, y + GRID_STEP, z_max], [x, y + GRID_STEP, z_max],
            [x, y + GRID_STEP, z_max], [x, y, z_max],
            # Vertical edges
            [x, y, z_min], [x, y, z_max],
            [x + GRID_STEP, y, z_min], [x + GRID_STEP, y, z_max],
            [x + GRID_STEP, y + GRID_STEP, z_min], [x + GRID_STEP, y + GRID_STEP, z_max],
            [x, y + GRID_STEP, z_min], [x, y + GRID_STEP, z_max],
        ])
    
    return np.array(boxes_vertices)

# Function to load and sample every 10,000th vertex from a .ply file
def load_sampled_ply(file_path, sample_step=10000):
    ply_data = PlyData.read(file_path)
    vertex_data = []
    for i, v in enumerate(ply_data['vertex'].data):
        if i % sample_step == 0:
            vertex_data.append((v['x'], v['y'], v['z']))
    return np.array(vertex_data)

# Load and sample the first .ply file
vertex_data_1 = load_sampled_ply('splats/trained_export_m7_1_8_adc.ply')

# Load and sample the second .ply file
vertex_data_2 = load_sampled_ply('splats/trained_export_m7_2_8_adc.ply')

# Calculate the range of each axis for aspect ratio
x_range = max(vertex_data_1[:, 0].max(), vertex_data_2[:, 0].max()) - min(vertex_data_1[:, 0].min(), vertex_data_2[:, 0].min())
y_range = max(vertex_data_1[:, 1].max(), vertex_data_2[:, 1].max()) - min(vertex_data_1[:, 1].min(), vertex_data_2[:, 1].min())
z_range = max(vertex_data_1[:, 2].max(), vertex_data_2[:, 2].max()) - min(vertex_data_1[:, 2].min(), vertex_data_2[:, 2].min())
aspect_ratio = dict(x=x_range, y=y_range, z=z_range)


grid_vertices = create_grid_cells(vertex_data_1, vertex_data_2)
# Add grid trace
grid_trace = go.Scatter3d(
    x=grid_vertices[:, 0],
    y=grid_vertices[:, 1],
    z=grid_vertices[:, 2],
    mode='lines',
    line=dict(color='gray', width=1),
    name='Grid',
    showlegend=True
)


# Create 3D scatter plots for each point cloud
trace1 = go.Scatter3d(
    x=vertex_data_1[:, 0],
    y=vertex_data_1[:, 1],
    z=vertex_data_1[:, 2],
    mode='markers',
    marker=dict(
        size=2,
        color='blue',
        opacity=0.5
    ),
    name='Point Cloud 1'
)

trace2 = go.Scatter3d(
    x=vertex_data_2[:, 0],
    y=vertex_data_2[:, 1],
    z=vertex_data_2[:, 2],
    mode='markers',
    marker=dict(
        size=2,
        color='red',
        opacity=0.5
    ),
    name='Point Cloud 2'
)

layout = go.Layout(
    scene=dict(
        xaxis=dict(title='X'),
        yaxis=dict(title='Y'),
        zaxis=dict(title='Z'),
        aspectratio=aspect_ratio
    ),
    title='Interactive 3D Point Cloud Visualization of Two Files'
)


def export_grid_cells(input_ply_path, grid_step=GRID_STEP):
    # Create output directory
    output_dir = f"grid-{grid_step}"
    os.makedirs(output_dir, exist_ok=True)
    
    # Read input PLY file
    print(f"Reading {input_ply_path}...")
    ply_data = PlyData.read(input_ply_path)
    vertex_data = ply_data['vertex'].data
    
    # Get base filename without extension
    base_name = Path(input_ply_path).stem
    
    # Create dictionary to store vertices by cell
    cell_vertices = {}
    
    # Process vertices and assign to cells
    print("Assigning vertices to cells...")
    for vertex in tqdm(vertex_data):
        # Calculate cell indices directly from global coordinates
        cell_x = int(vertex['x'] // grid_step)
        cell_y = int(vertex['y'] // grid_step)
        cell_key = (cell_x, cell_y)
        
        if cell_key not in cell_vertices:
            cell_vertices[cell_key] = []
        cell_vertices[cell_key].append(vertex)
    
    # Export each cell as a separate PLY file
    print("Exporting cell files...")
    for (cell_x, cell_y), vertices in tqdm(cell_vertices.items()):
        # Create vertex array for this cell
        vertex_array = np.array(vertices, dtype=vertex_data.dtype)
        
        # Create PLY element
        vertex_element = PlyElement.describe(vertex_array, 'vertex')
        
        # Create output filename
        output_filename = f"{base_name}_{grid_step}s_{cell_x}x_{cell_y}y.ply"
        output_path = os.path.join(output_dir, output_filename)
        
        # Save PLY file
        PlyData([vertex_element], text=False).write(output_path)
    
    print(f"Export complete! Files saved in {output_dir}/")
    print(f"Total cells created: {len(cell_vertices)}")
    

# Example usage:
export_grid_cells('splats/trained_export_m7_1_8_adc.ply')
# # Update figure with grid
# fig = go.Figure(data=[trace1, trace2, grid_trace], layout=layout)
# fig.show()


Reading splats/trained_export_m7_1_8_adc.ply...
Assigning vertices to cells...


100%|██████████| 17698442/17698442 [01:02<00:00, 284636.79it/s]


Exporting cell files...


100%|██████████| 437/437 [27:50<00:00,  3.82s/it]  

Export complete! Files saved in grid-1/
Total cells created: 437



