In [227]:
import sys 
sys.path.append('../')

from util import split_knn_patches 
import torch 
from torch.utils.data import Dataset, DataLoader 
import torch.nn.functional as F

import numpy as np

import h5py 
from pathlib import Path

from pointnet2_ops.pointnet2_utils import furthest_point_sample

import fpsample 

In [250]:
# dataset class for the 50K single tree, env dataset
class SingleTree_Pretrain(Dataset): 
    def __init__(self, num_points=1024):
        super().__init__()
        self.num_points = num_points
        self.data_file = Path('/share/projects/erasmus/hansend/thesis/data/pretraining/ssl_tree_pretraining_dataset.h5')
        
        with h5py.File(self.data_file, 'r') as f:
            self.len = f['data']['instance_xyz'].shape[0]

    def __len__(self):
        return self.len 

    def __getitem__(self, idx): 

        with h5py.File(self.data_file, 'r', swmr=True) as f: 
            instance_xyz = f['data']['instance_xyz'][idx] 
            instance_xyz = instance_xyz.reshape(-1, 3) # -> (N, 3) 
            
            # fps downsample pc 
            if instance_xyz.shape[0] > self.num_points:
                instance_idxs = fpsample.bucket_fps_kdline_sampling(instance_xyz, self.num_points, h=3)
                instance_xyz = instance_xyz[instance_idxs]

            elif instance_xyz.shape[0] < self.num_points: 
                point_diff = self.num_points - instance_xyz.shape[0]
                idxs = np.random.choice(instance_xyz.shape[0], point_diff, replace=True)
                instance_xyz = np.concatenate((instance_xyz, instance_xyz[idxs]), axis=0)

        return instance_xyz

In [229]:
def pad_collate_fn(batch):
    batch = [torch.tensor(b) for b in batch]
    max_len = max(pc.shape[0] for pc in batch)
    padded_b = []
    mask_b = []
    for pc in batch:
        pad_len = max_len - pc.shape[0]
        padded_pc = F.pad(pc, (0, 0, 0, pad_len), value=0.0)
        padded_b.append(padded_pc)
        # Mask: 1 for real points, 0 for padded
        mask = torch.cat([torch.ones(pc.shape[0]), torch.zeros(pad_len)])
        mask_b.append(mask)

    return torch.stack(padded_b), torch.stack(mask_b) # padded batch, mask batch 

In [80]:
def offset_collate_fn(batch): 
    batch = [torch.tensor(b) for b in batch]
    npoints = [pc.shape[0] for pc in batch]
    offset = torch.tensor(npoints).cumsum(0) 

    batch = torch.cat(batch) 

    return batch, offset 

In [270]:
ds = SingleTree_Pretrain()
dl = DataLoader(ds, batch_size=8, shuffle=False, num_workers=4)
batch = next(iter(dl))

In [271]:
batch.shape

torch.Size([8, 1024, 3])

In [114]:
def batch_sample(batch):
    points, offsets = batch 
    B = batch[-1].shape[0]
    lengths = offsets.clone().detach() 
    lengths[1:] = offsets[1:] - offsets[:-1] # calculate npoints per instance
    lengths[0] = offsets[0] 

    pcs = torch.split(points, lengths.tolist())

    return B, pcs

In [115]:
pcs = batch_sample(batch)
    

In [129]:
pc = pcs[-1][0]

In [120]:
# set masking and npoints setting 
num_points = 1024 
masking_ratio = 0.3 
patch_size = 16 

num_patches = num_points // patch_size 
num_masked_patches = int(num_patches * masking_ratio) 
num_vis_patches = num_patches - num_masked_patches 
print(
    f'Num of patches: {num_patches}: Visible: {num_vis_patches}, Masked: {num_masked_patches}'
)

Num of patches: 64: Visible: 45, Masked: 19


In [144]:
# get info pc 
pc_points = pc.shape[0]
pc_points

2538

In [145]:
# generate random T 
if np.random.choice([True, False]): 
    t, axis = np.median(pc[:,0]), 'x'
else: 
    t, axis = np.median(pc[:, 1]), 'y'

In [146]:
# subset the point cloud
if axis == 'x':
    pc_vis = pc[pc[:, 0] > t] 
    pc_masked = pc[pc[:, 0] <= t]
elif axis == 'y': 
    pc_vis = pc[pc[:, 1] <= t] 
    pc_masked = pc[pc[:, 1] > t]

In [178]:
# find patch centers in the subsets 
vis_centers = furthest_point_sample(pc_vis.cuda().unsqueeze(0), num_vis_patches).cpu().squeeze(0)
masked_centers = furthest_point_sample(pc_masked.cuda().unsqueeze(0), num_masked_patches).cpu().squeeze(0)

# select center points 
vis_center_points = pc_vis[vis_centers]
masked_center_points = pc_masked[masked_centers]


In [198]:
def get_kNN_patch_idxs(xyz, center_xyz, k):
    dists = torch.cdist(xyz, center_xyz)

    _, knn_idxs = torch.topk(dists, k, largest=False, dim=0) # shape: (N points, 3)

    return knn_idxs

In [200]:
vis_patch_idxs = get_kNN_patch_idxs(pc_vis, vis_center_points, k=patch_size)
masked_patch_idxs = get_kNN_patch_idxs(pc_masked, masked_center_points, k=patch_size)

In [220]:
vis_patches = [pc_vis[vis_patch_idxs[:, i]] for i in range(vis_patch_idxs.shape[1])]
masked_patches = [pc_masked[masked_patch_idxs[:, i]] for i in range(masked_patch_idxs.shape[1])]

In [221]:
vis_patches_tensor = torch.stack(vis_patches, dim=0)
masked_patches_tensor = torch.stack(masked_patches, dim=0)

