In this notebook, we explore the ODMSemantic3D dataset, and how we can preprocess it for use in training a semantic segmentation model.

The following are some of the labels of the dataset from the GitHub Repo:


ground	- 2	

low_vegetation	- 3	

building -	6

human_made_object -	64	

In [None]:
import numpy as np

data = np.load("/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets/odm_data_waterbury-roads_2.npz")
points = data['pointclouds']   # shape: (num_samples, num_points, 3)
labels = data['labels']   # shape: (num_samples, num_points)

print(points)


Now we will plot one of the datasets from the .npz files that were produced from the dataset.py file

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Get the first sample
pc = points[0]  # shape: (N, 3)
lbl = labels[0]  # shape: (N,)

# Use entire dataset
pc_full = pc
lbl_full = lbl

# Calculate full ranges for X and Y
x_min, x_max = pc_full[:, 0].min(), pc_full[:, 0].max()
y_min, y_max = pc_full[:, 1].min(), pc_full[:, 1].max()

print(f"Total points: {len(pc_full):,}")
print(f"X range: [{x_min:.2f}, {x_max:.2f}]")
print(f"Y range: [{y_min:.2f}, {y_max:.2f}]")

# Create the plot
fig, ax = plt.subplots(figsize=(12, 10))

# Get unique labels and assign colors
unique_labels = np.unique(lbl_full)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
color_map = {label: colors[i] for i, label in enumerate(unique_labels)}

# Plot points colored by label (sample if too many points for performance)
# For very large datasets, we can sample points for visualization
max_points_to_plot = 10000000  # Adjust based on performance needs
if len(pc_full) > max_points_to_plot:
    # Randomly sample points
    indices = np.random.choice(len(pc_full), max_points_to_plot, replace=False)
    pc_plot = pc_full[indices]
    lbl_plot = lbl_full[indices]
    print(f"Sampling {max_points_to_plot:,} points for visualization (out of {len(pc_full):,} total)")
else:
    pc_plot = pc_full
    lbl_plot = lbl_full

# Plot points colored by their labels
for label in unique_labels:
    mask = lbl_plot == label
    if np.any(mask):
        label_points = pc_plot[mask]
        ax.scatter(label_points[:, 0], label_points[:, 1], 
                  c=[color_map[label]], s=10, alpha=0.6, 
                  label=f'Label {label}', edgecolors='none')

# Set axis limits to show entire dataset
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate')
ax.set_title('Point Cloud: Entire Dataset - Points Colored by Labels')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=16)
ax.grid(True, alpha=0.)

# Use auto aspect instead of equal to avoid dimension issues
ax.set_aspect('auto')

plt.tight_layout()
plt.show()


So looks to me like some houses are present next to those roads (as labelled)

Let's examine the other datasets next, and then we will also want to explore the local gradients to see if changes in Z will be picked up by the model - or if they are not steep/step changes.


Let's examine the remaining datasets for their characeristics and data cleanliness,then we will plot them


In [None]:
import numpy as np
from pathlib import Path

# Define the datasets directory
datasets_dir = Path("/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets")
npz_files = sorted(datasets_dir.glob("*.npz"))

results = []

for npz_file in npz_files:
    try:
        data = np.load(npz_file)
        
        # Load pointclouds and labels
        pointclouds = data['pointclouds'] if 'pointclouds' in data else data.get('points', None)
        labels = data.get('labels', None)
        
        if pointclouds is None:
            continue
        
        num_samples = pointclouds.shape[0]
        total_points = pointclouds.shape[0] * pointclouds.shape[1] if len(pointclouds.shape) > 1 else pointclouds.shape[0]
        
        # Check for NaNs and collect labels
        has_nans_points = False
        has_nans_labels = False
        all_labels = []
        
        for i in range(num_samples):
            pc = pointclouds[i]
            lbl = labels[i] if labels is not None else None
            
            if np.isnan(pc).any():
                has_nans_points = True
            
            if lbl is not None:
                if np.isnan(lbl).any():
                    has_nans_labels = True
                all_labels.extend(np.unique(lbl).tolist())
        
        unique_labels = sorted(set(all_labels)) if labels is not None else []
        
        results.append({
            'File': npz_file.name,
            'Samples': num_samples,
            'Shape': str(pointclouds.shape),
            'Labels': len(unique_labels),
            'Label Values': str(unique_labels),
            'NaN Points': '⚠️' if has_nans_points else '✓',
            'NaN Labels': '⚠️' if has_nans_labels else ('✓' if labels is not None else 'N/A'),
            'Total Points': total_points
        })
        
    except Exception as e:
        results.append({
            'File': npz_file.name,
            'Samples': 'ERROR',
            'Shape': 'ERROR',
            'Labels': 'ERROR',
            'Label Values': 'ERROR',
            'NaN Points': 'ERROR',
            'NaN Labels': 'ERROR',
            'Total Points': 'ERROR'
        })

# Print summary table
print(f"{'File':<45} {'Samples':<10} {'Shape':<25} {'Labels':<8} {'NaN P':<8} {'NaN L':<8} {'Total Points':<15}")
print("-" * 130)

for r in results:
    total_pts = f"{r['Total Points']:,}" if isinstance(r['Total Points'], int) else str(r['Total Points'])
    print(f"{r['File']:<45} {str(r['Samples']):<10} {r['Shape']:<25} {str(r['Labels']):<8} {r['NaN Points']:<8} {r['NaN Labels']:<8} {total_pts:<15}")

print("\n" + "=" * 130)
print("Label Values by File:")
print("=" * 130)
for r in results:
    if r['Label Values'] not in ['N/A', 'ERROR']:
        print(f"{r['File']}: {r['Label Values']}")


In [None]:
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from pathlib import Path
import numpy as np

# Define the datasets directory
datasets_dir = Path("/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets")
npz_files = sorted(datasets_dir.glob("*.npz"))

# Configuration for performance
max_points_for_gradient = 100000  # Lower for faster processing
max_points_to_plot = 100000  # Lower number of points for visualization
grid_resolution = 150  # Lower resolution for faster gradient computation
point_size = 4  # Point size for scatter plots (modify this to change point size)

# First pass: Collect all unique labels across all datasets
all_unique_labels = set()
for npz_file in npz_files:
    try:
        data = np.load(npz_file)
        labels = data.get('labels', None)
        if labels is not None:
            for sample_labels in labels:
                all_unique_labels.update(np.unique(sample_labels))
    except:
        continue

# Create consistent color map for all labels
all_unique_labels = sorted(all_unique_labels)
num_labels = len(all_unique_labels)
# Use tab20 colormap which has 20 distinct colors, extend with Set3 if needed
if num_labels <= 20:
    colors = plt.cm.tab20(np.linspace(0, 1, num_labels))
else:
    # For more than 20 labels, use Set3 which has more colors
    colors = plt.cm.Set3(np.linspace(0, 1, num_labels))
consistent_color_map = {label: colors[i] for i, label in enumerate(all_unique_labels)}

# Second pass: Visualize all datasets with consistent colors
for npz_file in npz_files:
    try:
        # Load the .npz file
        data = np.load(npz_file)
        
        # Load pointclouds and labels
        pointclouds = data['pointclouds'] if 'pointclouds' in data else data.get('points', None)
        labels = data.get('labels', None)
        
        if pointclouds is None:
            continue
        
        # Get the first sample
        pc_full = pointclouds[0]
        lbl_full = labels[0] if labels is not None else None
        
        # Sample points for gradient computation
        if len(pc_full) > max_points_for_gradient:
            indices = np.random.choice(len(pc_full), max_points_for_gradient, replace=False)
            pc_grad = pc_full[indices]
            lbl_grad = lbl_full[indices] if lbl_full is not None else None
        else:
            pc_grad = pc_full
            lbl_grad = lbl_full
        
        # Extract coordinates for gradient computation
        x = pc_grad[:, 0]
        y = pc_grad[:, 1]
        z = pc_grad[:, 2]
        
        # Calculate ranges
        x_min, x_max = x.min(), x.max()
        y_min, y_max = y.min(), y.max()
        
        # Create a regular grid for gradient computation
        xi = np.linspace(x.min(), x.max(), grid_resolution)
        yi = np.linspace(y.min(), y.max(), grid_resolution)
        xi_grid, yi_grid = np.meshgrid(xi, yi)
        
        # Interpolate Z values onto the grid
        zi_grid = griddata((x, y), z, (xi_grid, yi_grid), method='linear', fill_value=np.nan)
        
        # Compute gradient on the grid
        dz_dx = np.gradient(zi_grid, axis=1)
        dz_dy = np.gradient(zi_grid, axis=0)
        
        # Compute gradient magnitude (steepness)
        gradient_magnitude = np.sqrt(dz_dx**2 + dz_dy**2)
        
        # Interpolate gradient magnitude back to original point locations
        gradient_at_points = griddata(
            (xi_grid.flatten(), yi_grid.flatten()), 
            gradient_magnitude.flatten(), 
            (x, y), 
            method='linear', 
            fill_value=0
        )
        
        # Handle any NaN values
        gradient_at_points = np.nan_to_num(gradient_at_points, nan=0.0)
        
        # Sample points for visualization
        if len(pc_grad) > max_points_to_plot:
            plot_indices = np.random.choice(len(pc_grad), max_points_to_plot, replace=False)
            pc_plot = pc_grad[plot_indices]
            lbl_plot = lbl_grad[plot_indices] if lbl_grad is not None else None
            grad_plot = gradient_at_points[plot_indices]
        else:
            pc_plot = pc_grad
            lbl_plot = lbl_grad
            grad_plot = gradient_at_points
        
        # Create figure with two subplots side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
        
        # ===== LEFT PLOT: Points colored by labels =====
        if lbl_plot is not None:
            unique_labels = np.unique(lbl_plot)
            
            for label in unique_labels:
                mask = lbl_plot == label
                if np.any(mask):
                    label_points = pc_plot[mask]
                    # Use consistent color map
                    color = consistent_color_map.get(label, 'gray')
                    ax1.scatter(label_points[:, 0], label_points[:, 1], 
                               c=[color], s=point_size, alpha=0.6, 
                               edgecolors='none')
        else:
            # If no labels, just plot points in gray
            ax1.scatter(pc_plot[:, 0], pc_plot[:, 1], 
                       c='gray', s=point_size, alpha=0.6, edgecolors='none')
            ax1.text(0.5, 0.5, 'No labels available', 
                    transform=ax1.transAxes, ha='center', va='center', fontsize=14)
        
        ax1.set_xlim(x_min, x_max)
        ax1.set_ylim(y_min, y_max)
        ax1.set_xlabel('X Coordinate', fontsize=12)
        ax1.set_ylabel('Y Coordinate', fontsize=12)
        ax1.set_title(f'{npz_file.name}\nPoints Colored by Labels', fontsize=14)
        ax1.grid(True, alpha=0.3)
        ax1.set_aspect('auto')
        
        # ===== RIGHT PLOT: Points colored by gradient =====
        scatter = ax2.scatter(pc_plot[:, 0], pc_plot[:, 1], 
                             c=grad_plot, s=point_size, alpha=0.6, 
                             cmap='viridis', edgecolors='none')
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax2)
        cbar.set_label('Gradient Magnitude (Steepness)', fontsize=12)
        
        ax2.set_xlim(x_min, x_max)
        ax2.set_ylim(y_min, y_max)
        ax2.set_xlabel('X Coordinate', fontsize=12)
        ax2.set_ylabel('Y Coordinate', fontsize=12)
        ax2.set_title(f'{npz_file.name}\nPoints Colored by Local Gradient', fontsize=14)
        ax2.grid(True, alpha=0.3)
        ax2.set_aspect('auto')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"❌ ERROR processing {npz_file.name}: {e}")
        continue

# Create a large legend figure at the end
fig_legend = plt.figure(figsize=(12, 8))
ax_legend = fig_legend.add_subplot(111)
ax_legend.axis('off')

# Create legend entries for all labels
legend_elements = []
for label in all_unique_labels:
    color = consistent_color_map[label]
    legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', 
                                      markerfacecolor=color, markersize=20, 
                                      label=f'Label {label}'))

# Create the legend
legend = ax_legend.legend(handles=legend_elements, loc='center', 
                         ncol=min(4, len(all_unique_labels)), 
                         fontsize=24, frameon=True, 
                         title='Label Color Mapping', title_fontsize=28)
legend.get_frame().set_facecolor('white')
legend.get_frame().set_alpha(0.9)

plt.tight_layout()
plt.show()


Finally, let's do a quick examination of the height at all the points in the datasets so that we can see how this might affec the labels and therefore make decisions about how to model it

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np

# Define the datasets directory
datasets_dir = Path("/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets")
npz_files = sorted(datasets_dir.glob("*.npz"))

# Configuration
max_points_for_height = 100000  # Points for height visualization
point_size = 2  # Point size for scatter plots

# First pass: Collect all unique labels across all datasets for consistent colors
all_unique_labels = set()
for npz_file in npz_files:
    try:
        data = np.load(npz_file)
        labels = data.get('labels', None)
        if labels is not None:
            for sample_labels in labels:
                all_unique_labels.update(np.unique(sample_labels))
    except:
        continue

# Create consistent color map for all labels
all_unique_labels = sorted(all_unique_labels)
num_labels = len(all_unique_labels)
# Use tab20 colormap which has 20 distinct colors, extend with Set3 if needed
if num_labels <= 20:
    colors = plt.cm.tab20(np.linspace(0, 1, num_labels))
else:
    # For more than 20 labels, use Set3 which has more colors
    colors = plt.cm.Set3(np.linspace(0, 1, num_labels))
consistent_color_map = {label: colors[i] for i, label in enumerate(all_unique_labels)}

for npz_file in npz_files:
    try:
        # Load the .npz file
        data = np.load(npz_file)
        
        # Load pointclouds and labels
        pointclouds = data['pointclouds'] if 'pointclouds' in data else data.get('points', None)
        labels = data.get('labels', None)
        
        if pointclouds is None:
            continue
        
        # Get the first sample
        pc_full = pointclouds[0]
        lbl_full = labels[0] if labels is not None else None
        
        # Sample points for visualization (100k points)
        if len(pc_full) > max_points_for_height:
            height_indices = np.random.choice(len(pc_full), max_points_for_height, replace=False)
            pc_plot = pc_full[height_indices]
            lbl_plot = lbl_full[height_indices] if lbl_full is not None else None
            z_plot = pc_plot[:, 2]
        else:
            pc_plot = pc_full
            lbl_plot = lbl_full
            z_plot = pc_full[:, 2]
        
        # Calculate ranges
        x_min, x_max = pc_plot[:, 0].min(), pc_plot[:, 0].max()
        y_min, y_max = pc_plot[:, 1].min(), pc_plot[:, 1].max()
        
        # Create figure with two subplots side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
        
        # ===== LEFT PLOT: Points colored by height (Z) =====
        scatter = ax1.scatter(pc_plot[:, 0], pc_plot[:, 1], 
                            c=z_plot, s=point_size, alpha=0.6, 
                            cmap='terrain', edgecolors='none')
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax1)
        cbar.set_label('Height (Z)', fontsize=14)
        
        ax1.set_xlim(x_min, x_max)
        ax1.set_ylim(y_min, y_max)
        ax1.set_xlabel('X Coordinate', fontsize=12)
        ax1.set_ylabel('Y Coordinate', fontsize=12)
        ax1.set_title(f'{npz_file.name}\nPoints Colored by Height (Z)', fontsize=14)
        ax1.grid(True, alpha=0.3)
        ax1.set_aspect('auto')
        
        # ===== RIGHT PLOT: Points colored by labels =====
        if lbl_plot is not None:
            unique_labels = np.unique(lbl_plot)
            
            for label in unique_labels:
                mask = lbl_plot == label
                if np.any(mask):
                    label_points = pc_plot[mask]
                    # Use consistent color map
                    color = consistent_color_map.get(label, 'gray')
                    ax2.scatter(label_points[:, 0], label_points[:, 1], 
                               c=[color], s=point_size, alpha=0.6, 
                               edgecolors='none')
        else:
            # If no labels, just plot points in gray
            ax2.scatter(pc_plot[:, 0], pc_plot[:, 1], 
                       c='gray', s=point_size, alpha=0.6, edgecolors='none')
            ax2.text(0.5, 0.5, 'No labels available', 
                    transform=ax2.transAxes, ha='center', va='center', fontsize=14)
        
        ax2.set_xlim(x_min, x_max)
        ax2.set_ylim(y_min, y_max)
        ax2.set_xlabel('X Coordinate', fontsize=12)
        ax2.set_ylabel('Y Coordinate', fontsize=12)
        ax2.set_title(f'{npz_file.name}\nPoints Colored by Labels', fontsize=14)
        ax2.grid(True, alpha=0.3)
        ax2.set_aspect('auto')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"❌ ERROR processing {npz_file.name}: {e}")
        continue


It looks like there is some noticeable correlation between heights and labels. As such, we can use pointnet++ as this implicitly accounts for changes in height.

Next, let's see how we might patch the pointclouds in our data preprocessing for the model. Let's test how quickly the patches can be compiled. If it takes a long time, then we will stick to small number of patches and that too only for one dataset, and we can then use that just to demonstrate for the exercise.

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from matplotlib.patches import Rectangle

# Define the datasets directory
datasets_dir = Path("/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets")
npz_files = sorted(datasets_dir.glob("*.npz"))

# Configuration
num_seeds = 4000
points_per_patch = 4096
eps_factor = 0.01  # Fraction of point cloud extent for eps

for npz_file in npz_files:
    try:
        print(f"\nProcessing: {npz_file.name}")
        
        # Load the .npz file
        data = np.load(npz_file)
        pointclouds = data['pointclouds'] if 'pointclouds' in data else data.get('points', None)
        labels = data.get('labels', None)
        
        if pointclouds is None:
            continue
        
        # Get the first sample
        points = pointclouds[0]
        lbl = labels[0] if labels is not None else None
        
        print(f"Total points: {len(points):,}")
        
        # Estimate eps
        point_ranges = points.max(axis=0) - points.min(axis=0)
        eps = np.mean(point_ranges) * eps_factor
        
        # Randomly select seed points
        if len(points) > num_seeds:
            seed_indices = np.random.choice(len(points), num_seeds, replace=False)
        else:
            seed_indices = np.arange(len(points))
            num_seeds = len(points)
        
        seed_points = points[seed_indices]
        
        # Extract patches for each seed
        patches = []
        patch_centers = []
        patch_bboxes = []  # Store bounding boxes (x_min, x_max, y_min, y_max)
        
        for i, seed_point in enumerate(seed_points):
            # Find points within eps distance
            distances = np.linalg.norm(points - seed_point, axis=1)
            nearby_mask = distances <= eps
            nearby_points = points[nearby_mask]
            
            if len(nearby_points) < points_per_patch:
                # Duplicate points if needed
                num_needed = points_per_patch - len(nearby_points)
                duplicate_indices = np.random.choice(len(nearby_points), num_needed, replace=True)
                patch_points = np.vstack([nearby_points, nearby_points[duplicate_indices]])
            elif len(nearby_points) > points_per_patch:
                # Sample points if too many
                sample_indices = np.random.choice(len(nearby_points), points_per_patch, replace=False)
                patch_points = nearby_points[sample_indices]
            else:
                patch_points = nearby_points
            
            # Calculate bounding box
            x_min, x_max = patch_points[:, 0].min(), patch_points[:, 0].max()
            y_min, y_max = patch_points[:, 1].min(), patch_points[:, 1].max()
            
            patches.append(patch_points)
            patch_centers.append(seed_point)
            patch_bboxes.append((x_min, x_max, y_min, y_max))
        
        print(f"Extracted {len(patches)} patches")
        
        # Sample original point cloud for background visualization
        max_bg_points = 50000
        if len(points) > max_bg_points:
            bg_indices = np.random.choice(len(points), max_bg_points, replace=False)
            pc_bg = points[bg_indices]
            lbl_bg = lbl[bg_indices] if lbl is not None else None
        else:
            pc_bg = points
            lbl_bg = lbl
        
        # Calculate ranges
        x_min_all, x_max_all = pc_bg[:, 0].min(), pc_bg[:, 0].max()
        y_min_all, y_max_all = pc_bg[:, 1].min(), pc_bg[:, 1].max()
        
        # Create figure
        fig, ax = plt.subplots(figsize=(16, 12))
        
        # Plot background point cloud (colored by labels if available)
        if lbl_bg is not None:
            # Create consistent color map
            unique_labels = sorted(np.unique(lbl_bg))
            colors = plt.cm.tab20(np.linspace(0, 1, len(unique_labels)))
            color_map = {label: colors[i] for i, label in enumerate(unique_labels)}
            
            for label in unique_labels:
                mask = lbl_bg == label
                if np.any(mask):
                    label_points = pc_bg[mask]
                    color = color_map[label]
                    ax.scatter(label_points[:, 0], label_points[:, 1], 
                             c=[color], s=0.3, alpha=0.2, edgecolors='none')
        else:
            ax.scatter(pc_bg[:, 0], pc_bg[:, 1], c='gray', s=0.3, alpha=0.2, edgecolors='none')
        
        # Plot seed points
        ax.scatter(seed_points[:, 0], seed_points[:, 1], 
                  c='red', s=30, alpha=0.8, marker='x', linewidths=2, 
                  label=f'Seed Points ({num_seeds})', zorder=5)
        
        # Draw bounding boxes for ALL patches
        for idx in range(len(patches)):
            x_min, x_max, y_min, y_max = patch_bboxes[idx]
            width = x_max - x_min
            height = y_max - y_min
            
            # Draw bounding box
            rect = Rectangle((x_min, y_min), width, height,
                           linewidth=0.5, edgecolor='blue', facecolor='none', 
                           alpha=0.4, linestyle='--', zorder=3)
            ax.add_patch(rect)
        
        ax.set_xlim(x_min_all, x_max_all)
        ax.set_ylim(y_min_all, y_max_all)
        ax.set_xlabel('X Coordinate', fontsize=12)
        ax.set_ylabel('Y Coordinate', fontsize=12)
        ax.set_title(f'{npz_file.name}\nAll {len(patches)} Patch Bounding Boxes Overlaid on Labels', fontsize=14)
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('auto')
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        bbox_areas = [(x_max - x_min) * (y_max - y_min) for x_min, x_max, y_min, y_max in patch_bboxes]
        print(f"Bounding box statistics:")
        print(f"  Mean area: {np.mean(bbox_areas):.2f}")
        print(f"  Min area: {np.min(bbox_areas):.2f}")
        print(f"  Max area: {np.max(bbox_areas):.2f}")
        print(f"  Std area: {np.std(bbox_areas):.2f}")
        
    except Exception as e:
        print(f"❌ ERROR processing {npz_file.name}: {e}")
        import traceback
        traceback.print_exc()
        continue


It takes a while, so below we will explore how to patch a small number with circles, we will use a similar method for our training run

In [None]:
# Using the ball query function from dataset.py for PyTorch PointNet++ training
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from pointcloud_segmentation.dataset import extract_patches_ball_query_from_npz, BallQueryPatchDataset
from torch.utils.data import DataLoader

# Extract patches using ball query from the dataset
npz_path = "/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets/odm_data_waterbury-roads_2.npz"

print("Extracting patches using ball query...")
patches, patch_centers, patch_labels = extract_patches_ball_query_from_npz(
    npz_path=npz_path,
    num_patches=50,  # Number of patches to extract
    points_per_patch=4096,  # Exact number of points per patch
    radius_percent=0.02,  # 2% of max(x_range, y_range)
    sample_idx=0,
    random_seed=42,
    device="cpu"
)

print(f"\nExtracted {len(patches)} patches")
print(f"Patch shape: {patches.shape}")  # Should be (num_patches, 4096, 3)
print(f"All patches have exactly 4096 points: ✓")

if patch_labels is not None:
    print(f"Patch labels shape: {patch_labels.shape}")  # Should be (num_patches, 4096)
    print(f"Unique labels in patches: {np.unique(patch_labels)}")
else:
    print("No labels available")

# Create a PyTorch Dataset from the patches
dataset = BallQueryPatchDataset(patches, patch_labels)

# Create a DataLoader for training
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Example: Get a batch
batch_patches, batch_labels = next(iter(dataloader))
print(f"\nBatch shape from DataLoader: {batch_patches.shape}")  # Should be (batch_size, 4096, 3)
if batch_labels is not None:
    print(f"Batch labels shape: {batch_labels.shape}")  # Should be (batch_size, 4096)

# Visualize a few patches
num_patches_to_show = 5
fig, axes = plt.subplots(1, num_patches_to_show, figsize=(20, 4))

# Load original point cloud for background visualization
data = np.load(npz_path)
points = data['pointclouds'][0]
labels = data['labels'][0] if 'labels' in data else None

# Sample background points for context
bg_sample_size = 10000
bg_indices = np.random.choice(len(points), min(bg_sample_size, len(points)), replace=False)
bg_points = points[bg_indices]

# Calculate radius for visualization
x_range = points[:, 0].max() - points[:, 0].min()
y_range = points[:, 1].max() - points[:, 1].min()
max_range = max(x_range, y_range)
radius = max_range * 0.02

for idx, ax in enumerate(axes):
    if idx < len(patches):
        patch = patches[idx]
        center = patch_centers[idx]
        
        # Plot background points
        ax.scatter(bg_points[:, 0], bg_points[:, 1], 
                  c='lightgray', s=0.5, alpha=0.3, edgecolors='none')
        
        # Plot patch points
        ax.scatter(patch[:, 0], patch[:, 1], 
                  c='blue', s=2, alpha=0.6, edgecolors='none', label='Patch points')
        
        # Plot query center
        ax.scatter(center[0], center[1], 
                  c='red', s=100, marker='x', linewidths=3, 
                  label='Query center', zorder=10)
        
        # Draw radius circle
        circle = Circle((center[0], center[1]), radius, 
                       fill=False, edgecolor='red', linestyle='--', 
                       linewidth=2, alpha=0.7, label=f'Radius={radius:.1f}')
        ax.add_patch(circle)
        
        # Set limits to show patch area
        patch_x_range = patch[:, 0].max() - patch[:, 0].min()
        patch_y_range = patch[:, 1].max() - patch[:, 1].min()
        margin = max(patch_x_range, patch_y_range) * 0.2
        
        ax.set_xlim(center[0] - radius - margin, center[0] + radius + margin)
        ax.set_ylim(center[1] - radius - margin, center[1] + radius + margin)
        
        ax.set_xlabel('X Coordinate', fontsize=10)
        ax.set_ylabel('Y Coordinate', fontsize=10)
        ax.set_title(f'Patch {idx+1}\n({len(patch)} points)', fontsize=11)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')
        ax.legend(fontsize=8)

plt.tight_layout()
plt.show()

# Overall visualization showing all patches
fig, ax = plt.subplots(figsize=(16, 12))

# Plot all background points
if labels is not None:
    unique_labels = np.unique(labels[bg_indices])
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_labels)))
    color_map = {label: colors[i] for i, label in enumerate(unique_labels)}
    
    for label in unique_labels:
        mask = labels[bg_indices] == label
        if np.any(mask):
            label_points = bg_points[mask]
            ax.scatter(label_points[:, 0], label_points[:, 1], 
                      c=[color_map[label]], s=0.5, alpha=0.2, edgecolors='none')
else:
    ax.scatter(bg_points[:, 0], bg_points[:, 1], 
              c='lightgray', s=0.5, alpha=0.2, edgecolors='none')

# Plot all query centers
centers_array = np.array(patch_centers)
ax.scatter(centers_array[:, 0], centers_array[:, 1], 
          c='red', s=50, marker='x', linewidths=2, 
          label=f'Query Centers ({len(patch_centers)})', zorder=10)

# Draw radius circles for all patches
for center in patch_centers:
    circle = Circle((center[0], center[1]), radius, 
                   fill=False, edgecolor='blue', linestyle='--', 
                   linewidth=1, alpha=0.4)
    ax.add_patch(circle)

# Set limits
x_min, x_max = points[:, 0].min(), points[:, 0].max()
y_min, y_max = points[:, 1].min(), points[:, 1].max()
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

ax.set_xlabel('X Coordinate', fontsize=12)
ax.set_ylabel('Y Coordinate', fontsize=12)
ax.set_title(f'Ball Query Patching: {len(patches)} Patches with Radius={radius:.2f}', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('auto')

plt.tight_layout()
plt.show()

# Print statistics
print(f"\nPatch Statistics:")
print(f"  All patches have exactly 4096 points: ✓")
print(f"  Verified: {len(patches)} patches × 4096 points = {len(patches) * 4096:,} total patch points")
print(f"\nData format ready for PyTorch PointNet++ training:")
print(f"  - Patches: shape {patches.shape} -> (num_patches, 4096, 3)")
print(f"  - Labels: shape {patch_labels.shape if patch_labels is not None else 'None'} -> (num_patches, 4096)")
print(f"  - Can be used with DataLoader for batch training")


In [None]:
# Save and load patches for PointNet++ training
import importlib
import sys

# Reload the module to get the latest functions (in case kernel was already running)
if 'pointcloud_segmentation.dataset' in sys.modules:
    importlib.reload(sys.modules['pointcloud_segmentation.dataset'])

from pointcloud_segmentation.dataset import (
    extract_and_save_patches_ball_query,
    load_patches_ball_query,
    save_patches_ball_query,
    BallQueryPatchDataset
)
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

# Option 1: Extract and save patches in one step
npz_path = "/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/datasets/odm_data_waterbury-roads_2.npz"
output_patches_path = "/Users/hussainabass/Documents/pointcloud-segmentation-stub/data/ODMSemantic3D/patches_waterbury_roads_2.npz"

print("Extracting and saving patches...")
print("Note: Using 100 patches to avoid memory issues. Increase num_patches if needed.")
try:
    extract_and_save_patches_ball_query(
        npz_path=npz_path,
        output_path=output_patches_path,
        num_patches=100,  # Number of patches to extract (reduced to avoid memory issues)
        points_per_patch=4096,  # Exact number of points per patch
        radius_percent=0.02,  # 2% of max(x_range, y_range)
        sample_idx=0,
        random_seed=42,
        device="cpu"
    )
except Exception as e:
    print(f"Error during extraction: {e}")
    print("Try reducing num_patches or using a smaller point cloud sample.")
    raise

print("\n" + "="*60)
print("Patches saved! Now loading them back...")
print("="*60)

# Option 2: Load patches from saved file
patches, patch_labels, patch_centers, metadata = load_patches_ball_query(output_patches_path)

print(f"\nLoaded patches:")
print(f"  - Shape: {patches.shape}")
print(f"  - Labels shape: {patch_labels.shape if patch_labels is not None else 'None'}")
print(f"  - Centers shape: {patch_centers.shape if patch_centers is not None else 'None'}")
if metadata:
    print(f"  - Metadata: {metadata}")

# Create PyTorch Dataset from loaded patches
dataset = BallQueryPatchDataset(patches_path=output_patches_path)

# Or create from arrays directly
# dataset = BallQueryPatchDataset(patches=patches, patch_labels=patch_labels)

# Create DataLoader for training
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Example: Get a batch
batch_patches, batch_labels = next(iter(dataloader))
print(f"\nBatch from DataLoader:")
print(f"  - Patches shape: {batch_patches.shape}")  # (batch_size, 4096, 3)
print(f"  - Labels shape: {batch_labels.shape if batch_labels is not None else 'None'}")  # (batch_size, 4096)

# Visualization
print("\n" + "="*60)
print("Visualizing patches...")
print("="*60)

# Load original point cloud for background
data = np.load(npz_path)
points = data['pointclouds'][0]
labels = data['labels'][0] if 'labels' in data else None

# Calculate radius for visualization
x_range = points[:, 0].max() - points[:, 0].min()
y_range = points[:, 1].max() - points[:, 1].min()
max_range = max(x_range, y_range)
radius = max_range * 0.02

# Visualize a few patches
num_patches_to_show = 5
fig, axes = plt.subplots(1, num_patches_to_show, figsize=(20, 4))

# Sample background points for context
bg_sample_size = 10000
bg_indices = np.random.choice(len(points), min(bg_sample_size, len(points)), replace=False)
bg_points = points[bg_indices]

for idx, ax in enumerate(axes):
    if idx < len(patches):
        patch = patches[idx]
        center = patch_centers[idx] if patch_centers is not None else patch.mean(axis=0)
        
        # Plot background points
        ax.scatter(bg_points[:, 0], bg_points[:, 1], 
                  c='lightgray', s=0.5, alpha=0.3, edgecolors='none')
        
        # Plot patch points
        ax.scatter(patch[:, 0], patch[:, 1], 
                  c='blue', s=2, alpha=0.6, edgecolors='none', label='Patch points')
        
        # Plot query center
        ax.scatter(center[0], center[1], 
                  c='red', s=100, marker='x', linewidths=3, 
                  label='Query center', zorder=10)
        
        # Draw radius circle
        circle = Circle((center[0], center[1]), radius, 
                       fill=False, edgecolor='red', linestyle='--', 
                       linewidth=2, alpha=0.7, label=f'Radius={radius:.1f}')
        ax.add_patch(circle)
        
        # Set limits to show patch area
        patch_x_range = patch[:, 0].max() - patch[:, 0].min()
        patch_y_range = patch[:, 1].max() - patch[:, 1].min()
        margin = max(patch_x_range, patch_y_range) * 0.2
        
        ax.set_xlim(center[0] - radius - margin, center[0] + radius + margin)
        ax.set_ylim(center[1] - radius - margin, center[1] + radius + margin)
        
        ax.set_xlabel('X Coordinate', fontsize=10)
        ax.set_ylabel('Y Coordinate', fontsize=10)
        ax.set_title(f'Patch {idx+1}\n({len(patch)} points)', fontsize=11)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')
        ax.legend(fontsize=8)

plt.tight_layout()
plt.show()

# Overall visualization showing all patches
fig, ax = plt.subplots(figsize=(16, 12))

# Plot all background points
if labels is not None:
    unique_labels = np.unique(labels[bg_indices])
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_labels)))
    color_map = {label: colors[i] for i, label in enumerate(unique_labels)}
    
    for label in unique_labels:
        mask = labels[bg_indices] == label
        if np.any(mask):
            label_points = bg_points[mask]
            ax.scatter(label_points[:, 0], label_points[:, 1], 
                      c=[color_map[label]], s=0.5, alpha=0.2, edgecolors='none')
else:
    ax.scatter(bg_points[:, 0], bg_points[:, 1], 
              c='lightgray', s=0.5, alpha=0.2, edgecolors='none')

# Plot all query centers
if patch_centers is not None:
    centers_array = np.array(patch_centers)
    ax.scatter(centers_array[:, 0], centers_array[:, 1], 
              c='red', s=50, marker='x', linewidths=2, 
              label=f'Query Centers ({len(patch_centers)})', zorder=10)
    
    # Draw radius circles for all patches
    for center in patch_centers:
        circle = Circle((center[0], center[1]), radius, 
                       fill=False, edgecolor='blue', linestyle='--', 
                       linewidth=1, alpha=0.4)
        ax.add_patch(circle)

# Set limits
x_min, x_max = points[:, 0].min(), points[:, 0].max()
y_min, y_max = points[:, 1].min(), points[:, 1].max()
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

ax.set_xlabel('X Coordinate', fontsize=12)
ax.set_ylabel('Y Coordinate', fontsize=12)
ax.set_title(f'Ball Query Patching: {len(patches)} Patches with Radius={radius:.2f}', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('auto')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("Ready for PointNet++ training!")
print("="*60)
print(f"\nYou can now use the DataLoader in your training loop:")
print(f"  for batch_patches, batch_labels in dataloader:")
print(f"      # batch_patches: (B, 4096, 3)")
print(f"      # batch_labels: (B, 4096)")
print(f"      # Train your model...")
