In [None]:
I_GPU = 0

In [None]:
import os
import sys
import numpy as np
import torch
import glob

DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)
torch.cuda.set_device(I_GPU)

In [None]:
n = 64

# Build the 2D image
r_img = torch.repeat_interleave(torch.repeat_interleave(torch.LongTensor([1, 0, 0]).view(3,1,1), n, dim=1), n, dim=2)
g_img = torch.repeat_interleave(torch.repeat_interleave(torch.LongTensor([0, 1, 0]).view(3,1,1), n, dim=1), n, dim=2)
b_img = torch.repeat_interleave(torch.repeat_interleave(torch.LongTensor([0, 0, 1]).view(3,1,1), n, dim=1), n, dim=2)
y_img = torch.repeat_interleave(torch.repeat_interleave(torch.LongTensor([1, 1, 0]).view(3,1,1), n, dim=1), n, dim=2)

img = torch.cat((torch.cat((r_img, g_img), dim=1), torch.cat((b_img, y_img), dim=1)), dim=2).unsqueeze(0)

# from matplotlib import pyplot as plt
# plt.imshow(img[0].permute(1,2,0).numpy() * 255)
# plt.show()

# Build the pixel coordinates
x, y = torch.meshgrid(torch.arange(2*n), torch.arange(2*n))
pixels = torch.cat((x.reshape((2*n)**2, 1), y.reshape((2*n)**2, 1)), dim=1)

# Build the point and image indices
img_idx = torch.zeros((2*n)**2).long()
point_idx = torch.arange((2*n)**2)

from torch_points3d.datasets.multimodal.image import ImageMapping
mappings = ImageMapping.from_dense(point_idx, img_idx, pixels)

from torch_points3d.datasets.multimodal.image import ImageData

images = ImageData(
    path=np.zeros(1, dtype='O'),
    pos=torch.Tensor([[n, n, -1]]),
    opk=torch.ones((1, 3)),
    ref_size=(2*n, 2*n),
    images=img,
    mappings=mappings)

from torch_geometric.data import Data

pos = torch.cat((x.reshape((2*n)**2, 1), y.reshape((2*n)**2, 1), torch.zeros((2*n)**2, 1).long()), dim=1).float()
rgb = img[0, :, pixels[:, 0], pixels[:, 1]].T.float()
y = torch.ones(2*n,2*n).long()
y[:n, :n] = 0
y[n:, :n] = 1
y[:n, n:] = 2
y[n:, n:] = 3
y = y.flatten()
data = Data(pos=pos, rgb=rgb, y=y, mapping_index=point_idx)

from torch_points3d.datasets.multimodal.data import MMData

mm_data = MMData(data, images)

# Build class colors for y mode visualization
class_colors = [
    [255, 0, 0],  # red
    [0, 255, 0],  # green
    [0, 0, 255],  # blue
    [255, 255, 0]]  # yellow

from torch_points3d.visualization import visualize_mm_data
visualize_mm_data(mm_data, color_mode='y', class_colors=class_colors, figsize=700)

In [None]:
# Downscale 3D
from torch_points3d.core.data_transform import GridSampling3D
from torch_points3d.core.data_transform.multimodal.image import SelectMappingFromPointId

mm_sub = mm_data.clone()

# print(mm_sub.data.num_nodes, mm_sub.data.mapping_index)
mm_sub.data = GridSampling3D(2)(mm_sub.data.clone())
# print(mm_sub.data.num_nodes, mm_sub.data.mapping_index)
idx_sampling = mm_sub.data.mapping_index

mm_sub = mm_sub

# Subsample the mappings accordingly
mm_sub.data , mm_sub.images = SelectMappingFromPointId()(mm_sub.data, mm_sub.images)

visualize_mm_data(mm_sub, color_mode='y', class_colors=class_colors, figsize=700)

# print(mm_sub.data.num_nodes, mm_sub.data.mapping_index)
mm_sub.data = GridSampling3D(2*2)(mm_sub.data.clone())
# print(mm_sub.data.num_nodes, mm_sub.data.mapping_index)
idx_sampling = idx_sampling[mm_sub.data.mapping_index]

mm_sub = mm_sub

# Subsample the mappings accordingly
mm_sub.data , mm_sub.images = SelectMappingFromPointId()(mm_sub.data, mm_sub.images)

visualize_mm_data(mm_sub, color_mode='y', class_colors=class_colors, figsize=700)

In [None]:
# Downscale 3D
from torch_points3d.core.data_transform import GridSampling3D
from torch_points3d.core.data_transform.multimodal.image import SelectMappingFromPointId

mm_sub = mm_data.clone()

# print(mm_sub.data.num_nodes, mm_sub.data.point_index)
# mm_sub.data = GridSampling3D(2)(mm_sub.data.clone())
# print(mm_sub.data.num_nodes, mm_sub.data.point_index)

mm_sub = mm_sub[idx_sampling]

# Subsample the mappings accordingly
mm_sub.data , mm_sub.images = SelectMappingFromPointId()(mm_sub.data, mm_sub.images)

visualize_mm_data(mm_sub, color_mode='y', class_colors=class_colors, figsize=700)

In [None]:
# Downscale 3D with 'merge' mode
from torch_points3d.core.data_transform import GridSampling3D
from torch_cluster import grid_cluster
from torch_geometric.nn import voxel_grid
from torch_geometric.nn.pool.consecutive import consecutive_cluster

voxel = 32

mm_sub = mm_data.clone()

# Get the cluster indices for grid sampling
coords = torch.round((mm_sub.data.pos) / voxel)
if "batch" not in mm_sub.data:
    cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
else:
    cluster = voxel_grid(coords, mm_sub.data.batch, 1)
cluster, _ = consecutive_cluster(cluster)

# Actual 3D sampling
mm_sub.data = GridSampling3D(voxel)(mm_sub.data.clone())

# Subsample the mappings accordingly with 'merge' mocde
mm_sub.images.mappings = mm_sub.images.mappings.select_points(cluster, merge=True)
    
# Reset the mapping_index
# Remark : unlike SelectMappingFromPointId, we don't need to search for
# potentially unseen images when 'merge=True', because the subsampling
# index implies that all points are merged into a new one, so no 
# mappings should be lost in the process.
mm_sub.data.mapping_index = torch.arange(mm_sub.data.num_nodes)

visualize_mm_data(mm_sub, color_mode='rgb', class_colors=class_colors, figsize=700)