In [1]:
import numpy as np 
import torch 
import torch.nn as nn
import torch.nn.functional as F

import sys 
sys.path.append('../')

from data_yours import ALS_50K

In [2]:
from torch_kmeans import KMeans, SoftKMeans, ConstrainedKMeans

In [3]:
ds = ALS_50K()
len(ds)

48180

In [5]:
# get sample 
ds = ALS_50K(num_points=2048)
sample = ds[25542]

In [None]:
def compute_threshold(rdpercents): 
    return torch.quantile(rdpercents, 0.75)

In [4]:
# get xyz 
sample = sample[0]
sample.shape

(2048, 3)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
sample = torch.tensor(sample, dtype=torch.float32).to(device)  # add batch dimension

## For a single point cloud

## kNN Clustering function 

In [7]:
from kmeans_pytorch import kmeans

In [8]:
def get_tree_center_xy(pc): 
    return torch.mean(pc[:,:2], dim=0)

In [4]:
def get_clusters(pc, c_idx):
    """
    Splits the point cloud (pc) into clusters according to c_idx.
    Returns a list of tensors, one per cluster.
    """
    return [pc[c_idx == i] for i in range(c_idx.max() + 1)]

In [None]:
def initialize_clusters(pc, device, k=100): 
    c_idx, c_centroids = kmeans(
        X=pc,
        num_clusters=k,
        distance='euclidean',
        device=device
    )
    clusters = get_clusters(pc, c_idx)

    clusters_dict = {
        i: {
            'centroid': c_centroids[i].to(device),
            'points': clusters[i].to(device)
        }
        for i in range(len(c_centroids))
    }
    return clusters_dict

In [12]:
def compute_pca(points):
    mean = points.mean(dim=0)
    X = points - mean
    eigvals, eigvecs = torch.linalg.eigh(torch.cov(X.T))
    idx = torch.argsort(eigvals, descending=True)
    eigvecs = eigvecs[:, idx]
    direction = eigvecs[:, 0]  # 1st principal component
    normal = eigvecs[:, 1]     # 2nd principal component
    return mean, direction, normal
#compute_pca(clusters[0])

In [13]:
def cluster_residuals(cluster):


    vec = cluster['points'] - cluster['centroid'] 
    normal = cluster['normal'] / torch.norm(cluster['normal']) 
    distances = torch.abs(torch.matmul(vec, normal))

    return torch.sum(distances)

In [14]:
def merge_clusters(c1, c2):
    merged = torch.cat((c1['points'], c2['points']))
    return merged

In [15]:
def merged_residual_diff_percent(c1, c2):
    r_prior = c1['residual'] + c2['residual']
    merged = {
        'points': merge_clusters(c1, c2),
    }

    mean, direction, normal = compute_pca(merged['points'])
    merged['centroid'] = mean
    merged['direction'] = direction
    merged['normal'] = normal
    r_post = cluster_residuals(merged)
    return (r_post - r_prior)/ r_prior * 100

In [16]:
def distance_pairs(centroids): 
    dmatrix = torch.cdist(centroids, centroids) 
    dmatrix = dmatrix.fill_diagonal_(float('inf'))
    tril_indices = torch.tril_indices(dmatrix.size(0), dmatrix.size(1), offset=-1)
    tril_indices = tril_indices.to(centroids.device)
    distances = dmatrix[tril_indices[0], tril_indices[1]]
    sorted_indices = distances.argsort()
    sorted_indices = sorted_indices.to(centroids.device)
    sorted_pairs = torch.stack((tril_indices[0][sorted_indices], tril_indices[1][sorted_indices]), dim=1)
    sorted_distances = distances[sorted_indices]
    return sorted_pairs, sorted_distances

#distance_pairs(cl_centroids)

In [17]:
def pairs_angle(c1, c2): 
    angle = torch.acos(torch.dot(c1['normal'], c2['normal']) / (torch.norm(c1['normal']) * torch.norm(c2['normal'])))
    return angle

In [18]:
def sort_pairs_distance_angle(pairs, distances, angles):
    # Sort by distance + small weight on angle
    sort_metric = distances + 1e-2 * angles
    sorted_indices = torch.argsort(sort_metric)
    sorted_pairs = pairs[sorted_indices]
    return sorted_pairs

In [19]:
def compute_threshold(rdpercents): 
    return torch.quantile(rdpercents, 0.75)

In [40]:
def update_clusters(clusters, idx1, idx2):
    # Merge two clusters and update the dictionary
    c1 = clusters[idx1]
    c2 = clusters[idx2]

    merged_points = merge_clusters(c1, c2)
    mean, direction, normal = compute_pca(merged_points)

    new_cluster = {
        'centroid': mean,
        'direction': direction,
        'normal': normal,
        'points': merged_points,
        'residual': cluster_residuals({'points': merged_points, 'centroid': mean, 'normal': normal})
    }

    for idx in sorted([idx1, idx2], reverse=True):
        del clusters[idx]

    # Add the new cluster with a temporary key
    clusters[max(clusters.keys(), default=-1) + 1] = new_cluster

    # Reindex clusters keys to be 0...len(clusters)-1
    clusters = {new_idx: clusters[old_idx] for new_idx, old_idx in enumerate(sorted(clusters.keys()))}

    return clusters

## Iterative merging process 

In [186]:
# initialize clusters 
clusters = initialize_clusters(sample, device, k=100)
centroids = torch.stack([clusters[i]['centroid'] for i in range(len(clusters))])

running k-means on cuda..


[running kmeans]: 12it [00:00, 19.21it/s, center_shift=0.000084, iteration=12, tol=0.000100]


In [187]:
# get clusters with shape < 3 
clusters_to_merge = [k for k, v in clusters.items() if v['points'].shape[0] < 3]
print(len(clusters_to_merge), 'clusters with < 3 points')

0 clusters with < 3 points


In [188]:
init_pairs = distance_pairs(centroids)

In [189]:
if len(clusters_to_merge) > 0:
    for k in clusters_to_merge: 
        mergin_candidate = [i for i in init_pairs[0] if i in init_pairs[0]][0]

        clusters = update_clusters(clusters, mergin_candidate[0].item(), mergin_candidate[1].item())
        

In [190]:
len(clusters)

100

In [191]:
# compute pca for each cluster 
for k, v in clusters.items():
    mean, direction, normal = compute_pca(v['points'])
    v['centroid'] = mean
    v['direction'] = direction
    v['normal'] = normal
    v['residual'] = cluster_residuals(v)

In [192]:
for i in range(len(clusters)): 
    clusters[i]['residual'] = cluster_residuals(clusters[i])

In [193]:
centroids = torch.stack([clusters[i]['centroid'] for i in range(len(clusters))])
init_pairs = distance_pairs(centroids)

In [194]:

# compute angles between all pairs of clusters 
angles = [
    pairs_angle(clusters[i.item()], clusters[j.item()])
    for i, j in init_pairs[0]
]

angles = torch.tensor(angles, device=device)

sorted_pairs = sort_pairs_distance_angle(init_pairs[0], init_pairs[1], angles)

merged_resid = [
    merged_residual_diff_percent(clusters[i], clusters[j])
    for i, j in sorted_pairs.tolist()
]


t = compute_threshold(torch.tensor(merged_resid, device=device))


In [195]:
t

tensor(58.6180, device='cuda:0')

In [196]:
t = 10 

In [170]:
import time 

In [197]:
# iterative merging until no more pairs are below threshold 

while True:
    t0 = time.time()
    # 1. Recompute centroids from current clusters
    centroids = torch.stack([clusters[i]['centroid'] for i in range(len(clusters))])
    t1 = time.time() - t0
    # 2. Get all pairs and their distances
    sorted_pairs, sorted_distances = distance_pairs(centroids)
    t2 = time.time() - t1
    # 3. Compute angles for all pairs
    angles = [
        pairs_angle(clusters[i.item()], clusters[j.item()])
        for i, j in sorted_pairs
    ]
    t3 = time.time() - t2
    angles = torch.tensor(angles, device=device)
    # 4. Compute merged residuals for all pairs
    merged_resid = [
        merged_residual_diff_percent(clusters[i.item()], clusters[j.item()])
        for i, j in sorted_pairs
    ]
    t4 = time.time() - t3
    merged_resid = torch.tensor(merged_resid, device=device)
    # 5. Find candidate pairs below threshold
    candidate_pairs_idx = merged_resid < t
    if not torch.any(candidate_pairs_idx):
        print("No more candidate pairs below threshold. Stopping.")
        break
    candidate_pairs = sorted_pairs[candidate_pairs_idx]
    merged_resid_candidates = merged_resid[candidate_pairs_idx]
    t5 = time.time() - t4
    # 6. Find best pair to merge
    merging_candidate_idx = merged_resid_candidates.argmin()
    merging_candidate_pairs = candidate_pairs[merging_candidate_idx]
    t6 = time.time() - t5
    idx1, idx2 = merging_candidate_pairs.tolist()
    # 7. Merge clusters and reindex
    clusters = update_clusters(clusters, idx1, idx2)
    t7 = time.time() - t6
    # 8. Print progress
    print(
        f"Time taken for each step:\n",
        f"Recompute centroids: {t1:.4f}s, ",
        f"Get pairs: {t2:.4f}s, ",
        f"Compute angles: {t3:.4f}s, ",
        f"Compute merged residuals: {t4:.4f}s, ",
        f"Find candidates: {t5:.4f}s, ",
        f"Find best pair: {t6:.4f}s, ",
        f"Merge clusters: {t7:.4f}s" 
    )
    print(f"Merged clusters {idx1} and {idx2}, {len(clusters)} clusters remain.")

Time taken for each step:
 Recompute centroids: 0.0046s,  Get pairs: 1750758688.4683s,  Compute angles: 1.9556s,  Compute merged residuals: 1750758696.2158s,  Find candidates: 2.0050s,  Find best pair: 1750758696.2160s,  Merge clusters: 2.0067s
Merged clusters 64 and 19, 99 clusters remain.
Time taken for each step:
 Recompute centroids: 0.0007s,  Get pairs: 1750758698.2237s,  Compute angles: 1.8701s,  Compute merged residuals: 1750758707.4071s,  Find candidates: 1.9159s,  Find best pair: 1750758707.4072s,  Merge clusters: 1.9169s
Merged clusters 54 and 16, 98 clusters remain.
Time taken for each step:
 Recompute centroids: 0.0007s,  Get pairs: 1750758709.3250s,  Compute angles: 1.8480s,  Compute merged residuals: 1750758718.1239s,  Find candidates: 1.9034s,  Find best pair: 1750758718.1240s,  Merge clusters: 1.9045s
Merged clusters 66 and 41, 97 clusters remain.
Time taken for each step:
 Recompute centroids: 0.0002s,  Get pairs: 1750758720.0313s,  Compute angles: 1.6105s,  Compute me

In [None]:
# create point cloud from clusters 
points = [] 
cluster_idx = [] 

for idx, v in clusters.items():
    points.append(clusters[idx]['points'].cpu().numpy())
    cluster_idx.extend([idx] * v['points'].shape[0]) 

points = np.concatenate(points, axis=0)
cluster_idx = np.array(cluster_idx)

pc = np.hstack((points, cluster_idx[:, None]))



In [None]:
# get point cloud from clusters 


In [201]:
np.savetxt('sample_cluster.txt', pc, fmt='%.6f', delimiter=',')

In [209]:
# cluster into 120 clusters with kmeans 
clusters = initialize_clusters(sample, device, k=32)

running k-means on cuda..


[running kmeans]: 16it [00:00, 71.17it/s, center_shift=0.000071, iteration=16, tol=0.000100]


In [210]:
points = [] 
cluster_idx = [] 

for idx, v in clusters.items():
    points.append(clusters[idx]['points'].cpu().numpy())
    cluster_idx.extend([idx] * v['points'].shape[0]) 

points = np.concatenate(points, axis=0)
cluster_idx = np.array(cluster_idx)

pc = np.hstack((points, cluster_idx[:, None]))

In [211]:
np.savetxt('sample_cluster.txt', pc, fmt='%.6f', delimiter=',')

In [None]:
def split_knn_patches(xyz, mask_ratio=0.7, nsample=32, random=True):
    B, N, C = xyz.shape
    npoint = N // nsample
    device = xyz.device

    num_patches = npoint
    num_mask = int(mask_ratio * num_patches)
    num_vis = num_patches - num_mask

    center_idx = pointnet2_utils.furthest_point_sample(xyz, npoint).long()
    center_pos = index_points(xyz, center_idx)

    if random:
        shuffle_idx = torch.rand(num_patches, device=device).argsort()
        vis_patch_idx, mask_patch_idx = shuffle_idx[:num_vis], shuffle_idx[num_vis:]

        mask_center_idx = center_idx[:, mask_patch_idx]
        vis_center_idx = center_idx[:, vis_patch_idx]
        mask_center_pos = index_points(xyz, mask_center_idx)
        vis_center_pos = index_points(xyz, vis_center_idx)

        all_patch_idx = knn_point(nsample, xyz, center_pos)  # [B, num_patches, nsample]
        mask_idx = all_patch_idx[:, mask_patch_idx]
        vis_idx = all_patch_idx[:, vis_patch_idx]
        mask_pos = index_points(xyz, mask_idx)  # [B, num_mask, nsample, C]
        vis_pos = index_points(xyz, vis_idx)  # [B, num_vis, nsample, C]
    else:
        mask_point_idx = np.random.randint(num_patches)
        mask_point_pos = center_pos[:, mask_point_idx]
        mask_patch_idx = knn_point(num_mask, center_pos, mask_point_pos.unsqueeze(1)).squeeze()
        vis_patch_idx = torch.empty((B, num_vis), device=device, dtype=int)
        for b in range(B):
            idx_all = set(np.arange(num_patches, dtype=int))
            mask_idx = set(mask_patch_idx[b].tolist())
            vis_idx = idx_all - mask_idx
            vis_patch_idx[b] = torch.tensor(list(vis_idx), device=device, dtype=torch.long)

        shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=1).to(device)

        batch_idx = torch.arange(B, device=device).unsqueeze(-1)
        mask_center_idx = center_idx[batch_idx, mask_patch_idx]
        vis_center_idx = center_idx[batch_idx, vis_patch_idx]
        mask_center_pos = index_points(xyz, mask_center_idx)
        vis_center_pos = index_points(xyz, vis_center_idx)

        all_patch_idx = knn_point(nsample, xyz, center_pos)  # [B, num_patches, nsample]
        mask_idx = all_patch_idx[batch_idx, mask_patch_idx]
        vis_idx = all_patch_idx[batch_idx, vis_patch_idx]
        mask_pos = index_points(xyz, mask_idx)  # [B, num_mask, nsample, C]
        vis_pos = index_points(xyz, vis_idx)  # [B, num_vis, nsample, C]

    return mask_pos,  vis_pos,  mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx

In [None]:
def split_kmeans_patches(batch, mask_ratio=0.7, nsample=32):
    B, N, C = batch.shape 

    # calculate number of patches
    num_patches = N // nsample # assuming N is num_points
    num_masked_patches = int(num_patches * mask_ratio) 
    num_vis_patches = num_patches - num_masked_patches

    vis_pos_batch = []
    mask_pos_batch = [] 
    mask_center_pos_batch = [] 
    vis_center_pos_batch = [] 
    mask_patch_idx_batch = []
    vis_patch_idx_batch = []
    # generate patch_idx and shuffle_idx later 

    for i in range(B): 
        pc = batch[i]

        model_kmeans = KMeans(num_clusters=num_patches) 
        results = model_kmeans()

        cl_idx = results.labels  # [N,]
        cl_centroids = results.centers  # [num_patches, C]

        clusters = get_clusters(pc, cl_idx)

        for i, cluster in enumerate(clusters): 
            if cluster.shape[0] < nsample:
                # Randomly select indices (with replacement) to pad the cluster to nsample points
                to_add = nsample - cluster.shape[0]
                rand_idx = torch.randint(0, cluster.shape[0], (to_add,), device=cluster.device)
                cluster = torch.cat([cluster, cluster[rand_idx]], dim=0)
            elif cluster.shape[0] > nsample:
                # If the cluster has more points than nsample, perform fps 
                idx = furthest_point_sample(
                    cluster.unsqueeze(0), nsample
                ).squeeze(0)
                cluster = cluster[idx]
            
            
            clusters[i] = cluster.unsqueeze(0)

        clusters = torch.cat(clusters, dim=0)

In [35]:
def split_kmeans_patches(batch, patch_size=16, mask_ratio=0.2):
    B, N, C = batch.shape 

    # calculate number of patches
    num_patches = N // patch_size # assuming N is num_points
    num_masked_patches = int(num_patches * mask_ratio) 
    num_vis_patches = num_patches - num_masked_patches

    model_kmeans = KMeans(n_clusters=num_patches)
    results = model_kmeans(batch)

    cl_idx = results.labels  # [N,]
    cl_centroids = results.centers  # [num_patches, C]

    return cl_idx, cl_centroids


In [None]:

        clusters = get_clusters(pc, cl_idx)

        for i, cluster in enumerate(clusters): 
            if cluster.shape[0] < nsample:
                # Randomly select indices (with replacement) to pad the cluster to nsample points
                to_add = nsample - cluster.shape[0]
                rand_idx = torch.randint(0, cluster.shape[0], (to_add,), device=cluster.device)
                cluster = torch.cat([cluster, cluster[rand_idx]], dim=0)
            elif cluster.shape[0] > nsample:
                # If the cluster has more points than nsample, perform fps 
                idx = furthest_point_sample(
                    cluster.unsqueeze(0), nsample
                ).squeeze(0)
                cluster = cluster[idx]
            
            
            clusters[i] = cluster.unsqueeze(0)

        clusters = torch.cat(clusters, dim=0)

In [None]:

        # get random clusters 
        shuffle_idx = torch.rand(num_patches, device=device).argsort()
        vis_patch_idx, mask_patch_idx = shuffle_idx[:num_vis_patches], shuffle_idx[num_vis_patches:]

        mask_pos = clusters[mask_patch_idx]
        vis_pos = clusters[vis_patch_idx]

        # get center points 
        mask_center_pos = torch.mean(mask_pos, dim=1)
        vis_center_pos = torch.mean(vis_pos, dim=1)

In [None]:


        # add to batch lists 
        vis_pos_batch.append(vis_pos)
        mask_pos_batch.append(mask_pos)
        mask_center_pos_batch.append(mask_center_pos)
        vis_center_pos_batch.append(vis_center_pos)
        mask_patch_idx_batch.append(mask_patch_idx)
        vis_patch_idx_batch.append(vis_patch_idx)

    mask_pos = torch.stack(mask_pos_batch, dim=0)  # [B, num_mask, nsample, C]
    vis_pos = torch.stack(vis_pos_batch, dim=0)  # [B, num_vis, nsample, C]
    mask_center_pos = torch.stack(mask_center_pos_batch, dim=0)  # [B, num_mask, C]
    vis_center_pos = torch.stack(vis_center_pos_batch, dim=0)  # [B, num_vis, C]
    mask_patch_idx = torch.stack(mask_patch_idx, dim=0)  # [B, num_mask]
    vis_patch_idx = torch.stack(vis_patch_idx, dim=0)  # [B, num_vis]

In [31]:
x = batch

In [32]:
knn_model = KMeans(num_clusters=110) 
results = knn_model(x.unsqueeze(0))
cl_idx = results.labels  # [N,]
cl_centroids = results.centers  # [num_patches, C]

Full batch converged at iteration 54/100 with center shifts = tensor([0., 0., 0., 0.], device='cuda:0').


In [28]:
results.labels.shape

torch.Size([1, 2048])

In [None]:

        # stack batch results 
    mask_pos = torch.stack(mask_pos_batch, dim=0)  # [B, num_mask, nsample, C]
    vis_pos = torch.stack(vis_pos_batch, dim=0)  # [B, num_vis, nsample, C]
    mask_center_pos = torch.stack(mask_center_pos_batch, dim=0)  # [B, num_mask, C]
    vis_center_pos = torch.stack(vis_center_pos_batch, dim=0)  # [B, num_vis, C]
    mask_patch_idx = torch.stack(mask_patch_idx, dim=0)  # [B, num_mask]
    vis_patch_idx = torch.stack(vis_patch_idx, dim=0)  # [B, num_vis]

    # create shuffle index
    shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=1)

    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx

        

In [6]:
from torch.utils.data import DataLoader 

train_loader = DataLoader(
    ds, 
    batch_size=4
)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
batch = next(iter(train_loader))
batch = batch[0].to(device)  # move to device

In [15]:
batch.shape

torch.Size([4, 2048, 3])

In [36]:
clusters = split_kmeans_patches(batch, patch_size=16, mask_ratio=0.2)

OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 55.62 MiB is free. Process 2324173 has 10.49 GiB memory in use. Including non-PyTorch memory, this process has 210.00 MiB memory in use. Of the allocated memory 8.47 MiB is allocated by PyTorch, and 13.53 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [32]:
clusters[0].shape, clusters[1].shape

(torch.Size([4, 2048]), torch.Size([4, 8, 3]))

In [25]:
clusters[0].shape, clusters[1].shape

(torch.Size([4, 2048]), torch.Size([4, 8, 3]))

In [21]:
clusters[0].shape, clusters[1].shape

(torch.Size([4, 2048]), torch.Size([4, 8, 3]))

In [276]:
for i in clusters: 
    print(i.shape)

torch.Size([89, 16, 3])
torch.Size([39, 16, 3])
torch.Size([89, 3])
torch.Size([39, 3])
torch.Size([89])
torch.Size([39])
torch.Size([128])


In [253]:
idx = furthest_point_sample(cluster.unsqueeze(0), 16).squeeze(0)

In [254]:
cluster = cluster[idx]
cluster.shape

torch.Size([16, 3])