In [199]:
import sys 
sys.path.append('../')

from util_yours import get_kNN_patch_idxs
from data_yours import ALS_50K
import torch 
from torch.utils.data import DataLoader 
import torch.nn.functional as F

import numpy as np 
from pathlib import Path

from pointnet2_ops import pointnet2_utils

import fpsample 

In [218]:
# test dataset 
ds = ALS_50K() 
# example usage 
test_loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=0)


In [219]:
batch = next(iter(test_loader))

In [220]:
batch = batch[0]

In [221]:
def quadrant_masking(batch, masking_ratio=0.2, patch_size=16): 
    B, N, C = batch.shape 

    num_patches = N // patch_size
    num_masked_patches = int(num_patches * masking_ratio)
    num_vis_patches = num_patches - num_masked_patches

    vis_pos, mask_pos = [], []
    vis_center_pos, mask_center_pos = [], []

    for i in range(B): 
        pc = batch[i]
        t_x = torch.median(pc[:,0])
        t_y = torch.median(pc[:, 1])

        # Split into quadrants
        quads = [
            pc[(pc[:, 0] > t_x) & (pc[:, 1] >= t_y)],  # upper right
            pc[(pc[:, 0] <= t_x) & (pc[:, 1] >= t_y)], # upper left
            pc[(pc[:, 0] > t_x) & (pc[:, 1] < t_y)],   # lower right
            pc[(pc[:, 0] <= t_x) & (pc[:, 1] < t_y)]   # lower left
        ]

        # Always select exactly 2 quadrants as masked (1) and 2 as visible (0)
        select_idx = np.zeros(4, dtype=int)
        select_idx[:2] = 1
        np.random.shuffle(select_idx)
        masked_quads = [q for q, s in zip(quads, select_idx) if s == 1 and len(q) >= patch_size] # make sure there is at least 1 patch
        visible_quads = [q for q, s in zip(quads, select_idx) if s == 0 and len(q) >= patch_size]

        # Distribute patches as evenly as possibel
        
        patches_per_masked = [num_masked_patches // 2 + (1 if x < num_masked_patches % 2 else 0) for x in range(2)]
        patches_per_visible = [num_vis_patches // 2 + (1 if x < num_vis_patches % 2 else 0) for x in range(2)] 
        
        # get center points and patches masked
        masked_centers = [] 
        masked_patches = []
        for q, n_p in zip(masked_quads, patches_per_masked):
            # center points
            centers = pointnet2_utils.furthest_point_sample(q.unsqueeze(0), n_p).squeeze(0)
            center_points = q[centers]
            masked_centers.append(center_points)
            # patches
            patches_idx = get_kNN_patch_idxs(q, center_points, k=patch_size)
            patches = [q[patches_idx[:, j]] for j in range(patches_idx.shape[1])]
            masked_patches.extend(patches)  
        # get center points and patches vis 
        vis_centers = [] 
        vis_patches = [] 
        for q, n_p in zip(visible_quads, patches_per_visible):
            # center points
            centers = pointnet2_utils.furthest_point_sample(q.unsqueeze(0), n_p).squeeze(0)
            center_points = q[centers]
            vis_centers.append(center_points)
            # patches
            patches_idx = get_kNN_patch_idxs(q, center_points, k=patch_size)
            patches = [q[patches_idx[:, j]] for j in range(patches_idx.shape[1])]
            vis_patches.extend(patches)
        vis_pos.append(torch.stack(vis_patches, dim=0))
        mask_pos.append(torch.stack(masked_patches, dim=0))
        vis_center_pos.append(torch.cat(vis_centers, dim=0))
        mask_center_pos.append(torch.cat(masked_centers, dim=0))
    
    vis_pos = torch.stack(vis_pos, dim=0)
    mask_pos = torch.stack(mask_pos, dim=0)
    vis_center_pos = torch.stack(vis_center_pos, dim=0)
    mask_center_pos = torch.stack(mask_center_pos, dim=0)

    idx_all = torch.rand(num_patches).argsort()
    vis_patch_idx = idx_all[:num_vis_patches]
    mask_patch_idx = idx_all[num_vis_patches:]
    shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=0)

    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx


In [222]:
mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx = quadrant_masking(batch.cuda(), masking_ratio=0.2, patch_size=16)
  

In [211]:
from tqdm import tqdm

In [212]:
for i, (batch, label) in enumerate(tqdm(test_loader)): 
    batch = batch.cuda() 
    mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx = quadrant_masking(batch, masking_ratio=0.2, patch_size=16) 

  0%|          | 9/3012 [00:16<1:31:37,  1.83s/it]


KeyboardInterrupt: 