In [2]:
import sys 
sys.path.append('../')

from util import split_knn_patches 
from util import * 
import torch 
from torch.utils.data import Dataset, DataLoader 
import torch.nn.functional as F

import numpy as np

import h5py 
from pathlib import Path

from pointnet2_ops.pointnet2_utils import furthest_point_sample, ball_query

import fpsample 

## Dataset + DataLoading 

In [2]:
def normalize_pointcloud(pointcloud: np.ndarray) -> np.ndarray: 

    # assert 
    assert pointcloud.shape[1] == 3, "Pointcloud should be of shape (N, 3)"

    # center pointcloud into origin (z-axis too)
    #centroid = np.mean(pointcloud, axis=0)
    #pointcloud = pointcloud - centroid
    
    # farthest distance to origin 
    # normalize to unit sphere 
    m = np.max(np.sqrt(np.sum(pointcloud ** 2, axis=1)))
    pointcloud = pointcloud / m
    return pointcloud

In [3]:
### CHHANGED UNIFORMLY SCALING AND TRANSLATING
def translate_pointcloud(pointcloud, scale_range=[0.6, 1.5], translation_range=[-0.01, 0.01]):
    
    if isinstance(pointcloud, torch.Tensor):
        scale = torch.empty(1).uniform_(scale_range[0], scale_range[1]).item()
        translation = torch.empty(1).uniform_(translation_range[0], translation_range[1]).item()

        translated_pointcloud = (pointcloud * scale + translation).float()
        return translated_pointcloud

    else:
        scale = np.random.uniform(low=scale_range[0], high=scale_range[1])
        translation = np.random.uniform(low=translation_range[0], high=translation_range[1]) 

        translated_pointcloud = np.add(pointcloud*scale, translation).astype('float32')
        return translated_pointcloud

In [4]:
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):

    if isinstance(pointcloud, torch.Tensor):
        N, C = pointcloud.shape
        pointcloud += torch.clamp(sigma * torch.randn(N, C), min=-1 * clip, max=clip)
        return pointcloud
    else:
        N, C = pointcloud.shape
        pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
        return pointcloud

def flip_pointcloud(pointcloud, p=0.5): 
    if np.random.rand() < p: 
        pointcloud[:, :2] = -pointcloud[:, :2]

    return pointcloud

In [5]:
### CHANGED TO RATION AROUND Z-AXIS
def rotate_pointcloud(pointcloud):

    if isinstance(pointcloud, torch.Tensor):
        theta = torch.empty(1).uniform_(0, 2 * torch.pi)
        rotation_matrix = torch.tensor(
            [
                [torch.cos(theta), -torch.sin(theta), 0],
                [torch.sin(theta), torch.cos(theta), 0],
                [0, 0, 1]
            ]
        )
        pointcloud = torch.matmul(pointcloud, rotation_matrix.T)
        return pointcloud
    else:
        theta = np.pi * 2 * np.random.uniform()
        rotation_matrix = np.array(
            [
                [np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]

            ]
        )
        pointcloud = pointcloud.dot(rotation_matrix)  # random rotation z-axis
        
        return pointcloud

In [6]:
def get_kNN_patch_idxs(xyz, center_xyz, k):
    dists = torch.cdist(xyz, center_xyz)

    _, knn_idxs = torch.topk(dists, k, largest=False, dim=0) # shape: (N points, 3)

    return knn_idxs

In [7]:
class SingleTree_Pretrain(Dataset): 
    def __init__(self, num_points=1024):
        super().__init__()
        self.num_points = num_points
        self.data_file = Path('/share/projects/erasmus/hansend/thesis/data/pretraining/ssl_tree_pretraining_dataset.h5')
        self.idx_file = Path('/share/projects/erasmus/hansend/thesis/data/pretraining/ssl_tree_pretraining_dataset_subset_idx.csv')
        
        with open(Path('/share/projects/erasmus/hansend/thesis/data/pretraining/ssl_tree_pretraining_dataset_subset_idx.csv'), 'r') as f: 
            self.idx = f.readlines()[1:] # skip header
            self.idx = [int(i.strip()) for i in self.idx]
                
        self.len = len(self.idx) 

    def __len__(self):
        return self.len 

    def __getitem__(self, idx): 
        with h5py.File(self.data_file, 'r', swmr=True) as f: 
            instance_xyz = f['data']['instance_xyz'][idx] 
            instance_xyz = instance_xyz.reshape(-1, 3) # (N, 3) 
            #instance_nheights = f['data']['instance_nheights'][idx]
            
            # FPS supsampling to num_points
            if instance_xyz.shape[0] > self.num_points:
                instance_idxs = fpsample.bucket_fps_kdline_sampling(instance_xyz, self.num_points, h=3)
                instance_xyz = instance_xyz[instance_idxs]
                #instance_nheights = instance_nheights[instance_idxs]

                # augmentations with jittering as it is not used here before 
                # instance_xyz = rotate_pointcloud(instance_xyz)
                # instance_xyz = translate_pointcloud(instance_xyz)
                # instance_xyz = jitter_pointcloud(instance_xyz)
                # instance_xyz = flip_pointcloud(instance_xyz)
                instance_xyz = normalize_pointcloud(instance_xyz)


            # adding jittered points to num_points
            elif instance_xyz.shape[0] < self.num_points: 
                point_diff = self.num_points - instance_xyz.shape[0]
                idxs = np.random.choice(instance_xyz.shape[0], point_diff, replace=True)
                add_points_xyz = instance_xyz[idxs]
                jitter = np.clip(0.01 * np.random.randn(add_points_xyz.shape[0], add_points_xyz.shape[1]), -0.02, 0.02)
                add_points_xyz += jitter
                instance_xyz = np.concatenate((instance_xyz, add_points_xyz), axis=0)
                #instance_nheights = np.concatenate((instance_nheights, instance_nheights[idxs] + jitter[:, 2]), axis=0) # add jitter to nheights 
            
                # augmentations without jitter as it has beend done before to upsample to num_points 
                # instance_xyz = rotate_pointcloud(instance_xyz)
                # instance_xyz = translate_pointcloud(instance_xyz)
                # instance_xyz = flip_pointcloud(instance_xyz)
                instance_xyz = normalize_pointcloud(instance_xyz)
        

        return instance_xyz #, instance_nheights


In [8]:
# possible pad collate_fn
def pad_collate_fn(batch):
    batch = [torch.tensor(b) for b in batch]
    max_len = max(pc.shape[0] for pc in batch)
    padded_b = []
    mask_b = []
    for pc in batch:
        pad_len = max_len - pc.shape[0]
        padded_pc = F.pad(pc, (0, 0, 0, pad_len), value=0.0)
        padded_b.append(padded_pc)
        # Mask: 1 for real points, 0 for padded
        mask = torch.cat([torch.ones(pc.shape[0]), torch.zeros(pad_len)])
        mask_b.append(mask)

    return torch.stack(padded_b), torch.stack(mask_b) # padded batch, mask batch 

In [9]:
# possible offset_collate_fn
def offset_collate_fn(batch): 
    batch = [torch.tensor(b) for b in batch]
    npoints = [pc.shape[0] for pc in batch]
    offset = torch.tensor(npoints).cumsum(0) 

    batch = torch.cat(batch) 

    return batch, offset 

In [10]:
# possible batch sampling function for collate_fn
def batch_sample(batch):
    points, offsets = batch 
    B = batch[-1].shape[0]
    lengths = offsets.clone().detach() 
    lengths[1:] = offsets[1:] - offsets[:-1] # calculate npoints per instance
    lengths[0] = offsets[0] 

    pcs = torch.split(points, lengths.tolist())

    return B, pcs

In [11]:
ds = SingleTree_Pretrain(num_points=2048)
dl = DataLoader(ds, batch_size=8, shuffle=False, num_workers=4)


FileNotFoundError: [Errno 2] No such file or directory: '/share/projects/erasmus/hansend/thesis/data/pretraining/ssl_tree_pretraining_dataset_subset_idx.csv'

In [26]:
%%timeit 
batch = next(iter(dl))

87 ms ± 2.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [46]:
batch = next(iter(dl))

In [47]:
for j in range(1, 4): 
    batch = next(iter(dl))
    for i in range(batch.shape[0]): 
        instance = batch[i]
        np.savetxt(f'../data/instance_{j}_{i}.txt', instance) 
    

In [None]:
# check instances viruatlly 


## Axis aware patch masking 

In [39]:
# set masking and npoints setting 
num_points = 2048
masking_ratio = 0.1
patch_size = 16 

num_patches = num_points // patch_size 
num_masked_patches = int(num_patches * masking_ratio) 
num_vis_patches = num_patches - num_masked_patches 
print(
    f'Num of patches: {num_patches}: Visible: {num_vis_patches}, Masked: {num_masked_patches}'
)

Num of patches: 128: Visible: 116, Masked: 12


In [114]:
pc = ds[2609]

In [115]:
pc = pc[0]

In [117]:
# generate random T 
if np.random.choice([True, False]): 
    t, axis = np.median(pc[:,0]), 'x'
else: 
    t, axis = np.median(pc[:, 1]), 'y'

In [120]:
# to tensor for single instance check 
pc = torch.tensor(pc)

In [121]:
# subset the point cloud
if axis == 'x':
    pc_vis = pc[pc[:, 0] > t] 
    pc_masked = pc[pc[:, 0] <= t]
elif axis == 'y': 
    pc_vis = pc[pc[:, 1] <= t] 
    pc_masked = pc[pc[:, 1] > t]

In [122]:
# find patch centers in the subsets 
vis_centers = furthest_point_sample(pc_vis.cuda().unsqueeze(0), num_vis_patches).cpu().squeeze(0)
masked_centers = furthest_point_sample(pc_masked.cuda().unsqueeze(0), num_masked_patches).cpu().squeeze(0)

# select center points 
vis_center_points = pc_vis[vis_centers]
masked_center_points = pc_masked[masked_centers]


In [123]:
vis_patch_idxs = get_kNN_patch_idxs(pc_vis, vis_center_points, k=patch_size)
masked_patch_idxs = get_kNN_patch_idxs(pc_masked, masked_center_points, k=patch_size)

In [124]:
vis_patches = [pc_vis[vis_patch_idxs[:, i]] for i in range(vis_patch_idxs.shape[1])]
masked_patches = [pc_masked[masked_patch_idxs[:, i]] for i in range(masked_patch_idxs.shape[1])]

In [125]:
vis_patches_tensor = torch.stack(vis_patches, dim=0)
masked_patches_tensor = torch.stack(masked_patches, dim=0)


In [58]:
# mask_pos, vis_pos 
vis_patches_tensor.shape
masked_patches_tensor.shape

torch.Size([12, 16, 3])

In [126]:
# mask_center_pos, vis_center_pos 
vis_center_points.shape 
masked_center_points.shape

torch.Size([12, 3])

In [127]:
# vis_patch_idx, masked_patch_idx
vis_patch_idx = torch.arange(vis_patches_tensor.shape[0])
masked_patch_idx = torch.arange(masked_patches_tensor.shape[0])
num_patches = vis_patches_tensor.shape[0] + masked_patches_tensor.shape[0]
shuffle_idx = torch.stack([torch.randperm(num_patches)])

In [128]:

points_vis = vis_patches_tensor.reshape(-1, 3)
points_masked = masked_patches_tensor.reshape(-1, 3)
patch_labels_vis = np.repeat(np.arange(vis_patches_tensor.shape[0]), vis_patches_tensor.shape[1])
patch_labels_masked = np.repeat(np.arange(masked_patches_tensor.shape[0]), masked_patches_tensor.shape[1])

# Combine all points and labels
all_points = np.vstack([points_vis.numpy(), points_masked.numpy()])
all_patch_labels = np.concatenate([patch_labels_vis, patch_labels_masked])
all_patch_mask_labels = np.concatenate([patch_mask_label_vis, patch_mask_label])

# Optionally, you can also concatenate center points and their labels if needed

data_to_save = np.column_stack([all_points, all_patch_labels, all_patch_mask_labels])

np.savetxt('../data/masked_instance.txt', data_to_save)
np.savetxt('../data/mask_instance.txt', masked_patches_tensor.reshape(-1, 3).numpy())
np.savetxt('../data/vis_instance.txt', vis_patches_tensor.reshape(-1, 3).numpy())
np.savetxt('../data/instance.txt', pc) 

## Custom Masking Logic as function 

In [None]:
def center_split_masking(batch, masking_ratio=0.2, patch_size=16): 
    B, N, C = batch.shape 

    # calculate number of patches
    num_patches = N // patch_size # assuming N is num_points
    num_masked_patches = int(num_patches * masking_ratio) 
    num_vis_patches = num_patches - num_masked_patches

    vis_pos = []
    mask_pos = [] 
    mask_center_pos = [] 
    vis_center_pos = [] 
    # generate patch_idx and shuffle_idx later 

    for i in range(B): 
        pc = batch[i]

        # generate random T 
        if np.random.choice([True, False]): 
            t, axis = torch.median(pc[:,0]), 'x'
        else: 
            t, axis = torch.median(pc[:, 1]), 'y'
        
        # subset the point cloud
        if axis == 'x':
            pc_vis = pc[pc[:, 0] > t] 
            pc_masked = pc[pc[:, 0] <= t]
        elif axis == 'y': 
            pc_vis = pc[pc[:, 1] <= t] 
            pc_masked = pc[pc[:, 1] > t]

        # find patch centers in the subsets 
        vis_centers = pointnet2_utils.furthest_point_sample(pc_vis.cuda().unsqueeze(0), num_vis_patches).cpu().squeeze(0)
        masked_centers = pointnet2_utils.furthest_point_sample(pc_masked.cuda().unsqueeze(0), num_masked_patches).cpu().squeeze(0)

        # select center points 
        vis_center_points = pc_vis[vis_centers]
        masked_center_points = pc_masked[masked_centers]

        # get patch idxs
        vis_patch_idxs = get_kNN_patch_idxs(pc_vis, vis_center_points, k=patch_size)
        masked_patch_idxs = get_kNN_patch_idxs(pc_masked, masked_center_points, k=patch_size)

        # get patches 
        vis_patches = [pc_vis[vis_patch_idxs[:, i]] for i in range(vis_patch_idxs.shape[1])]
        masked_patches = [pc_masked[masked_patch_idxs[:, i]] for i in range(masked_patch_idxs.shape[1])]

        vis_patches_tensor = torch.stack(vis_patches, dim=0)
        masked_patches_tensor = torch.stack(masked_patches, dim=0)

        vis_pos.append(vis_patches_tensor)
        vis_center_pos.append(vis_center_points)
        mask_pos.append(masked_patches_tensor)
        mask_center_pos.append(masked_center_points)

    vis_pos = torch.stack(vis_pos, dim=0)
    vis_center_pos = torch.stack(vis_center_pos, dim=0)
    mask_pos = torch.stack(mask_pos, dim=0)
    mask_center_pos = torch.stack(mask_center_pos, dim=0)

    idx_all = torch.rand(num_patches).argsort()
    vis_patch_idx = idx_all[:num_vis_patches]
    mask_patch_idx = idx_all[num_vis_patches:]

    shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=0)

    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx

In [12]:
# split at both axes  x and y, create patches on # opposite sides of the split 
def center_split_masking_2axes(batch, masking_ratio=0.2, patch_size=16): 
    B, N, C = batch.shape 

    # calculate number of patches
    num_patches = N // patch_size # assuming N is num_points
    num_masked_patches = int(num_patches * masking_ratio) 
    num_vis_patches = num_patches - num_masked_patches

    vis_pos = []
    mask_pos = [] 
    mask_center_pos = [] 
    vis_center_pos = [] 
    # generate patch_idx and shuffle_idx later 

    for i in range(B): 
        pc = batch[i]

        t_x = torch.median(pc[:,0])
        t_y = torch.median(pc[:, 1])

        # subset point cloud into quadrants 
        qupper_right = pc[(pc[:, 0] > t_x) & (pc[:, 1] >= t_y)] 
        qupper_left = pc[(pc[:, 0] <= t_x) & (pc[:, 1] >= t_y)]
        qlower_right = pc[(pc[:, 0] > t_x) & (pc[:, 1] < t_y)]
        qlower_left = pc[(pc[:, 0] <= t_x) & (pc[:, 1] < t_y)]

        # quadrants 
        quadrants = [qupper_right, qupper_left, qlower_right, qlower_left]

        # randomly select which quadrant to use masked patches 
        select_idx = np.random.randint(0, 1, size=4) 
        masked_q = [q for q, s in zip(quadrants, select_idx) if s == 1]
        visible_q = [q for q, s in zip(quadrants, select_idx) if s == 0]

        # specify number of patches per quadrant 
        num_patches_per_q_vis = num_vis_patches / 2 
        num_masked_patches_per_q = num_masked_patches / 2

        # for each quadrant find vis_centers 
        qvis_1_centers = pointnet2_utils.furthest_point_sample(visible_q[0].cuda().unsqueeze(0), num_patches_per_q_vis).cpu().squeeze(0)
        qvis_2_centers = pointnet2_utils.furthest_point_sample(visible_q[1].cuda().unsqueeze(0), num_patches_per_q_vis).cpu().squeeze(0)

        # for each quadrant find masked_centers
        qmasked_1_centers = pointnet2_utils.furthest_point_sample(masked_q[0].cuda().unsqueeze(0), num_masked_patches_per_q).cpu().squeeze(0)
        qmasked_2_centers = pointnet2_utils.furthest_point_sample(masked_q[1].cuda().unsqueeze(0), num_masked_patches_per_q).cpu().squeeze(0)

        # select center points
        vis_center_points_1 = visible_q[0][qvis_1_centers]
        vis_center_points_2 = visible_q[1][qvis_2_centers]
        masked_center_points_1 = masked_q[0][qmasked_1_centers]
        masked_center_points_2 = masked_q[1][qmasked_2_centers]
        vis_center_points = torch.cat((vis_center_points_1, vis_center_points_2), dim=0)
        masked_center_points = torch.cat((masked_center_points_1, masked_center_points_2), dim=0)

        # get patch idxs
        vis_patch_idxs_1 = get_kNN_patch_idxs(visible_q[0], vis_center_points_1, k=patch_size)
        vis_patch_idxs_2 = get_kNN_patch_idxs(visible_q[1], vis_center_points_2, k=patch_size)
        masked_patch_idxs_1 = get_kNN_patch_idxs(masked_q[0], masked_center_points_1, k=patch_size)
        masked_patch_idxs_2 = get_kNN_patch_idxs(masked_q[1], masked_center_points_2, k=patch_size)

        # get patches
        vis_patches_1 = [visible_q[0][vis_patch_idxs_1[:, i]] for i in range(vis_patch_idxs_1.shape[1])]
        vis_patches_2 = [visible_q[1][vis_patch_idxs_2[:, i]] for i in range(vis_patch_idxs_2.shape[1])]
        masked_patches_1 = [masked_q[0][masked_patch_idxs_1[:, i]] for i in range(masked_patch_idxs_1.shape[1])]
        masked_patches_2 = [masked_q[1][masked_patch_idxs_2[:, i]] for i in range(masked_patch_idxs_2.shape[1])]

        # get patchess
        vis_patches_tensor = torch.stack(vis_patches_1 + vis_patches_2, dim=0)
        masked_patches_tensor = torch.stack(masked_patches_1 + masked_patches_2, dim=0)

        # append to lists
        vis_pos.append(vis_patches_tensor)
        vis_center_pos.append(vis_center_points)
        mask_pos.append(masked_patches_tensor)
        mask_center_pos.append(masked_center_points)

    vis_pos = torch.stack(vis_pos, dim=0)
    vis_center_pos = torch.stack(vis_center_pos, dim=0)
    mask_pos = torch.stack(mask_pos, dim=0)
    mask_center_pos = torch.stack(mask_center_pos, dim=0)

    idx_all = torch.rand(num_patches).argsort()
    vis_patch_idx = idx_all[:num_vis_patches]
    mask_patch_idx = idx_all[num_vis_patches:]

    shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=0)

    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx

In [None]:
def quadrant_masking(batch, masking_ratio=0.2, patch_size=16): 
    B, N, C = batch.shape 

    num_patches = N // patch_size
    num_masked_patches = int(num_patches * masking_ratio)
    num_vis_patches = num_patches - num_masked_patches

    vis_pos, mask_pos = [], []
    vis_center_pos, mask_center_pos = [], []

    for i in range(B): 
        pc = batch[i]
        t_x = torch.median(pc[:,0])
        t_y = torch.median(pc[:, 1])

        # Split into quadrants
        quads = [
            pc[(pc[:, 0] > t_x) & (pc[:, 1] >= t_y)],  # upper right
            pc[(pc[:, 0] <= t_x) & (pc[:, 1] >= t_y)], # upper left
            pc[(pc[:, 0] > t_x) & (pc[:, 1] < t_y)],   # lower right
            pc[(pc[:, 0] <= t_x) & (pc[:, 1] < t_y)]   # lower left
        ]

        # Always select exactly 2 quadrants as masked (1) and 2 as visible (0)
        select_idx = np.zeros(4, dtype=int)
        select_idx[:2] = 1
        np.random.shuffle(select_idx)
        masked_quads = [q for q, s in zip(quads, select_idx) if s == 1 and len(q) >= patch_size] # make sure there is at least 1 patch
        visible_quads = [q for q, s in zip(quads, select_idx) if s == 0 and len(q) >= patch_size]

        # Distribute patches as evenly as possibel
       
        patches_per_masked = [num_masked_patches // 2 + (1 if x < num_masked_patches % 2 else 0) for x in range(2)] if 2 > 0 else []
        patches_per_visible = [num_vis_patches // 2 + (1 if x < num_vis_patches % 2 else 0) for x in range(2)] if 2 > 0 else []

        # FPS and patch extraction for masked quadrants
        masked_patches, masked_centers = [], []
        for q, n_p in zip(masked_quads, patches_per_masked):
            if n_p > 0 and q.shape[0] >= patch_size:
                centers = furthest_point_sample(q.cuda().unsqueeze(0), n_p).cpu().squeeze(0)
                center_points = q[centers]
                patch_idxs = get_kNN_patch_idxs(q, center_points, k=patch_size)
                patches = [q[patch_idxs[:, j]] for j in range(patch_idxs.shape[1])]
                masked_patches.extend(patches)
                masked_centers.append(center_points)
        if masked_patches:
            mask_pos.append(torch.stack(masked_patches, dim=0))
            mask_center_pos.append(torch.cat(masked_centers, dim=0))
        else:
            mask_pos.append(torch.empty((0, patch_size, C)))
            mask_center_pos.append(torch.empty((0, C)))

        # FPS and patch extraction for visible quadrants
        visible_patches, visible_centers = [], []
        for q, n_p in zip(visible_quads, patches_per_visible):
            if n_p > 0 and q.shape[0] >= patch_size:
                centers = furthest_point_sample(q.cuda().unsqueeze(0), n_p).cpu().squeeze(0)
                center_points = q[centers]
                patch_idxs = get_kNN_patch_idxs(q, center_points, k=patch_size)
                patches = [q[patch_idxs[:, j]] for j in range(patch_idxs.shape[1])]
                visible_patches.extend(patches)
                visible_centers.append(center_points)
        if visible_patches:
            vis_pos.append(torch.stack(visible_patches, dim=0))
            vis_center_pos.append(torch.cat(visible_centers, dim=0))
        else:
            vis_pos.append(torch.empty((0, patch_size, C)))
            vis_center_pos.append(torch.empty((0, C)))

    vis_pos = torch.stack(vis_pos, dim=0)
    vis_center_pos = torch.stack(vis_center_pos, dim=0)
    mask_pos = torch.stack(mask_pos, dim=0)
    mask_center_pos = torch.stack(mask_center_pos, dim=0)

    idx_all = torch.rand(num_patches).argsort()
    vis_patch_idx = idx_all[:num_vis_patches]
    mask_patch_idx = idx_all[num_vis_patches:]
    shuffle_idx = torch.cat((vis_patch_idx, mask_patch_idx), dim=0)

    return mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx

In [19]:
test_batch = torch.rand(8, 2048, 3)
mask_pos, vis_pos, mask_center_pos, vis_center_pos, mask_patch_idx, vis_patch_idx, shuffle_idx = center_split_masking_2axes(test_batch)

In [20]:
mask_pos.shape, vis_pos.shape, mask_center_pos.shape, vis_center_pos.shape, mask_patch_idx.shape, vis_patch_idx.shape, shuffle_idx.shape

(torch.Size([8, 25, 16, 3]),
 torch.Size([8, 103, 16, 3]),
 torch.Size([8, 25, 3]),
 torch.Size([8, 103, 3]),
 torch.Size([25]),
 torch.Size([103]),
 torch.Size([128]))

In [181]:
%%timeit 
masked_batch = center_split_masking(batch[0])

83.1 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
