# Curriculum based positional patching

In [1]:
import numpy as np 
import torch 
import torch.nn as nn
import torch.nn.functional as F

import sys 
sys.path.append('../')

from data_yours import ALS_50K

from torch_kmeans import KMeans

from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Preparation example data 

In [20]:
ds = ALS_50K(num_points=2048)
sample = ds[25542]

sample = torch.tensor(sample[0]).to(device).unsqueeze(0)

dl = DataLoader(ds, batch_size=8, shuffle=True)
batch = next(iter(dl))[0]

In [4]:
def get_cluster_points(pc, c_idx, k): 
    """Get cluster points, handling empty clusters"""
    clusters = [] 
    for i in range(k):  # Use k instead of c_idx.max() + 1
        cluster_mask = (c_idx == i)
        cluster_points = pc[cluster_mask]
        clusters.append(cluster_points)  # Even if empty, add empty tensor
    return clusters

In [5]:
def get_sample_clusters(pc, c_idx, c_centroids, k):
    clusters = get_cluster_points(pc, c_idx, k)

    clusters_dict = {
        i: { 
            'center': c_centroids[i, :],
            'points': clusters[i]
        } for i in range(c_centroids.shape[0]) 
    }

    return clusters_dict

In [8]:
# calculate euclidian distance for each centroid to its perpendicular point on the stem dummy 
def compute_xy_distance_to_stem_center(centroids, stem_center):
    """
    For each centroid, compute the distance in the XY plane to the stem center,
    but at the same Z as the centroid.

    centroids: (BS, C, 3) tensor of centroids 
    stem_center: (BS, 3) tensor of stem centers
    """
    # Use stem_center's x and y, centroid's z
    stem_xy = stem_center[:, :2].unsqueeze(1) # (BS, 1, 2)
    centroid_xy = centroids[:,:,:2]  # (BS, C, 2)
    # Distance in XY plane
    distances = torch.norm(centroid_xy - stem_xy, dim=2)
    return distances

In [9]:
def get_region_mask(z_height, xy_distance_to_stem, z_region='low', xy_region='high'):
    """
    Get region mask based on z height and xy distance regions.
    
    Args:
        z_height: tensor of z coordinates (BS, C)
        xy_distance_to_stem: tensor of xy distances to stem (BS, C)
        z_region: 'low', 'middle', or 'high'
        xy_region: 'low', 'middle', or 'high'
    
    Returns:
        region_mask: boolean tensor indicating selected region (BS, C)
    """
    # Handle batch dimensions
    z_min = z_height.min(dim=1, keepdim=True)[0]  # (BS, 1)
    z_max = z_height.max(dim=1, keepdim=True)[0]  # (BS, 1)
    z_range = z_max - z_min
    
    # Define z regions
    if z_region == 'low':
        z_threshold = z_min + 0.33 * z_range
        z_region_mask = (z_height < z_threshold)
    elif z_region == 'middle':
        z_low_threshold = z_min + 0.33 * z_range
        z_high_threshold = z_min + 0.67 * z_range
        z_region_mask = (z_height >= z_low_threshold) & (z_height <= z_high_threshold)
    elif z_region == 'high':
        z_threshold = z_min + 0.67 * z_range
        z_region_mask = (z_height > z_threshold)
    
    # Define xy regions - compute quantiles per batch
    quantile_33 = torch.quantile(xy_distance_to_stem, 0.33, dim=1, keepdim=True)  # (BS, 1)
    quantile_67 = torch.quantile(xy_distance_to_stem, 0.67, dim=1, keepdim=True)  # (BS, 1)
    
    if xy_region == 'low':
        xy_region_mask = (xy_distance_to_stem < quantile_33)
    elif xy_region == 'middle':
        xy_region_mask = (xy_distance_to_stem >= quantile_33) & (xy_distance_to_stem <= quantile_67)
    elif xy_region == 'high':
        xy_region_mask = (xy_distance_to_stem > quantile_67)
    
    # Combine masks
    region_mask = z_region_mask & xy_region_mask
    return region_mask

In [10]:
def select_clusters_with_region_mask(region_mask, centroids, num_clusters):
    """
    Select clusters with region mask, handling batches.
    
    Args:
        region_mask: (BS, C) boolean tensor
        num_clusters: int, number of clusters to select
        centroids: (BS, C, 3) tensor of centroids
    
    Returns:
        mask: (BS, C) boolean tensor indicating selected clusters
    """
    batch_size, num_total_clusters = region_mask.shape
    mask = torch.zeros_like(region_mask, dtype=torch.bool)
    
    for b in range(batch_size):
        region_indices = torch.where(region_mask[b])[0]
        
        if len(region_indices) >= num_clusters:
            selected = region_indices[:num_clusters]
        else:
            # Fill from outside region, sorted by distance to mean of region_indices
            outside_indices = torch.where(~region_mask[b])[0]
            
            # Compute mean of region centroids (or use batch mean if region_indices is empty)
            if len(region_indices) > 0:
                region_mean = centroids[b, region_indices].mean(dim=0, keepdim=True)
            else:
                region_mean = centroids[b].mean(dim=0, keepdim=True)
            
            # Compute distances from outside centroids to region_mean
            outside_centroids = centroids[b, outside_indices]
            dists = torch.norm(outside_centroids - region_mean, dim=1)
            sorted_outside = outside_indices[torch.argsort(dists)]
            needed = num_clusters - len(region_indices)
            selected = torch.cat([region_indices, sorted_outside[:needed]])
        
        mask[b, selected] = True
    
    return mask

In [12]:
def get_cluster_tensor(points_with_clusters, mask, num_points=16):
    """
    Returns a tensor of shape [BS, Num_clusters, num_points, 3] with up/downsampling.
    - points_with_clusters: [BS, N, D] (last column is cluster index)
    - mask: [BS, C] boolean (clusters to select)
    - num_points: int, number of points per cluster
    """
    BS, N, D = points_with_clusters.shape
    C = mask.shape[1]
    cluster_labels = points_with_clusters[..., -1].long()  # [BS, N]
    selected_cluster_indices = [torch.where(mask[b])[0] for b in range(BS)]
    num_selected = max(len(idx) for idx in selected_cluster_indices)
    # Prepare output tensor
    out = torch.zeros(BS, num_selected, num_points, D-1, device=points_with_clusters.device)
    for b in range(BS):
        idxs = selected_cluster_indices[b]
        for i, cluster_id in enumerate(idxs):
            pts = points_with_clusters[b][cluster_labels[b] == cluster_id][:, :D-1]  # exclude cluster index
            n = pts.shape[0]
            if n > num_points:
                sel = torch.randperm(n)[:num_points]
                out[b, i] = pts[sel]
            elif n < num_points:
                # Upsample with replacement
                sel = torch.cat([torch.arange(n), torch.randint(0, n, (num_points-n,))])
                out[b, i] = pts[sel]
    return out  # shape: [BS, Num_clusters, num_points, 3]

In [13]:
def get_centroids_from_mask(mask, centroids): 
    BS, C = mask.shape
    masked_centroids = torch.where(mask)[1].reshape(BS, -1) # get indices 
    return masked_centroids

In [18]:
def region_based_patching(
    batch, 
    patch_size=16, 
    mask_ratio=0.1, 
    z_region='low', 
    xy_region='high', 
    device=None
):
    if device is None:
        device = batch.device

    num_clusters = batch.shape[1] // patch_size 
    num_masked = int(num_clusters * mask_ratio)  # Number of clusters to mask
    num_vis = num_clusters - num_masked  # Number of visible clusters
    
    # KMeans clustering
    kmeans = KMeans(init_method='k-means++', n_clusters=num_clusters, device=device)
    clusters = kmeans(batch)
    
    c_idx, c_centroids = clusters.labels, clusters.centers
    
    # Get points with clusters
    points_with_clusters = torch.cat([batch, c_idx.unsqueeze(-1).float()], dim=-1)
    
    # Compute distances and region masks
    stem_centers = batch[torch.arange(batch.shape[0]), batch[:, :, -1].argmax(dim=1)]
    xy_distance_to_stem = compute_xy_distance_to_stem_center(c_centroids, stem_centers)
    c_heights = c_centroids[:, :, 2]
    
    region_mask = get_region_mask(c_heights, xy_distance_to_stem, z_region=z_region, xy_region=xy_region)
    
    mask = select_clusters_with_region_mask(region_mask, c_centroids, num_masked)
    
    mask_pos = get_cluster_tensor(points_with_clusters, mask, num_points=patch_size)
    
    vis_pos = get_cluster_tensor(points_with_clusters, ~mask, num_points=patch_size)
    
    masked_centroids_idx = get_centroids_from_mask(mask, c_centroids)
    vis_centroids_idx = get_centroids_from_mask(~mask, c_centroids)
    
    mask_center_pos = c_centroids[torch.arange(c_centroids.shape[0])[:, None], masked_centroids_idx]
    vis_center_pos = c_centroids[torch.arange(c_centroids.shape[0])[:, None], vis_centroids_idx]

    patch_idx = torch.arange(num_clusters).unsqueeze(0).repeat(3, 1)
    mask_idx = patch_idx[:, :num_masked]
    vis_idx = patch_idx[:, num_masked:]

    shuffle_idx = torch.cat([mask_idx, vis_idx], dim=1)
    
    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_idx, vis_idx, shuffle_idx


In [21]:
x = region_based_patching(batch, patch_size=16, mask_ratio=0.1, z_region='middle', xy_region='middle', device=device)

Full batch converged at iteration 15/100 with center shifts = tensor([0.0000e+00, 1.0117e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.2254e-05]).
