In [2]:
from tlspt.datamodules.components.numpy_dataset import NumpyDataset
from tlspt.datamodules.components.base_site import BaseSiteDataset
from tlspt.datamodules.components.octree_dataset import OctreeDataset
from tlspt.transforms import TLSSampler, UniformDownsample
from tlspt.models.pointmae.pointmae import PointMAE
from torch.utils.data import DataLoader
import torch
from matplotlib import pyplot as plt

import numpy as np

In [5]:
site_dataset = OctreeDataset(split_file='../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv',
                                 split='train',
                                 feature_names=None,
                                 features_to_normalize=['red', 'green', 'blue'],
                                 scale=1.5,
                                 transform=UniformDownsample(num_points=8192))

site_dataloader = DataLoader(site_dataset, batch_size=8, shuffle=True)

[32m2024-11-24 15:03:10.029[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m49[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): reading splits from ../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv[0m
[32m2024-11-24 15:03:10.112[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m56[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): looking for 14 folders in ../data/plot_octrees/allen-spain/octrees/[0m
[32m2024-11-24 15:03:10.334[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m64[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): found 14 plots for 'tr

In [6]:
site_dataset.prepare_data(force_compute=False)

[32m2024-11-24 15:03:15.157[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m45[0m - [1mNo features to normalize[0m


In [7]:
site_dataset[0]

{'points': tensor([[-0.0490, -0.7468, -0.1046],
         [ 0.0396, -0.2320,  0.0827],
         [ 0.3818,  0.4459, -0.0212],
         ...,
         [-0.1105, -0.6332,  0.0235],
         [ 0.1133, -0.2267,  0.0674],
         [-0.1596, -0.9737, -0.1474]]),
 'lengths': 3594}

In [9]:
site_dataset[0]['points'].shape

torch.Size([8192, 3])

In [11]:
model = PointMAE(neighbor_alg='ball_query', ball_radius=0.2)

In [12]:
batch = next(iter(site_dataloader))

In [13]:
type(batch)

dict

In [14]:
batch['points'].shape

torch.Size([8, 8192, 3])

In [15]:
test_out = model(batch)

In [17]:
patches, centers = model.group(batch['points'], batch['lengths'])

In [18]:
patches.shape

torch.Size([8, 64, 32, 3])

In [19]:
centers.shape

torch.Size([8, 64, 3])

In [21]:
masked_idx, unmasked_idx = model.mask_generator(
    centers
)  # Ge

In [29]:
print(masked_idx)
print(masked_idx.shape)

tensor([[ 7, 56, 46, 38,  5, 35, 49,  3, 58, 16, 40, 29, 12, 11,  6, 21, 23, 43,
         47, 22, 19, 10, 53, 24,  8, 36, 52, 59, 28, 39, 60, 26, 63, 37, 41, 55,
         27,  4],
        [30, 26, 54,  7, 62, 22, 17,  3, 12, 20, 13, 49, 28, 18, 11, 43, 16, 34,
         60, 53, 47, 46, 24, 23, 63, 39, 15, 19, 52, 35,  8, 38, 10, 25, 50, 41,
         36, 58],
        [20, 58, 10, 39, 36, 11, 15, 44, 61, 22,  6, 21, 52,  5, 28, 45, 26, 43,
         23,  3, 14, 53, 38, 56, 50, 25,  4, 12, 51, 63, 37, 29, 48, 17, 47, 54,
         41,  7],
        [36, 38, 11, 22, 29, 46, 26, 42, 28, 48,  5, 53, 40,  3,  1, 27, 20, 15,
         49, 60, 55, 23,  4,  6, 30, 59, 45, 54, 61, 52, 41, 33, 51, 58, 21, 44,
         43, 12],
        [60, 10, 39, 47, 15, 22,  6, 62,  8, 28, 19,  5, 45, 38, 33, 13, 11, 23,
          9, 36, 58, 61, 50, 49,  0, 53, 27, 16, 63, 59, 12, 46, 51, 52, 29, 30,
         44, 25],
        [39, 12, 56, 13, 31,  2, 51, 17,  3,  0, 53, 20, 32, 43, 62, 24,  8, 27,
         16, 60, 28

In [30]:
print(unmasked_idx)
print(unmasked_idx.shape)

tensor([[34, 44, 31,  9, 14,  2, 45, 15, 30, 20, 62, 61, 17, 48, 51, 42, 57, 50,
         33, 18,  0, 25, 54,  1, 32, 13],
        [ 6, 59, 33, 31, 48,  1, 55, 56, 57,  4, 32, 42, 27,  0, 21, 37,  2, 40,
          9, 45, 29, 44,  5, 51, 61, 14],
        [ 0,  9, 33, 59,  1, 62, 31, 49, 30,  8, 27, 13,  2, 42, 16, 60, 18, 55,
         32, 24, 46, 35, 19, 57, 40, 34],
        [37, 50, 25, 32,  8, 47, 56,  7,  0, 62, 18,  9, 24, 57, 16, 13,  2, 31,
         14, 63, 35, 34, 19, 39, 17, 10],
        [42, 18,  3, 31, 54, 17, 40, 35, 21, 41,  4, 14,  7, 34, 55, 48,  1,  2,
         57, 56, 24, 37, 20, 43, 26, 32],
        [18, 21, 61, 48, 57, 54, 40, 19, 58, 14, 35, 42, 37,  5,  1, 34, 38, 55,
         45,  9, 25, 29, 63, 33, 11, 47],
        [21, 26, 46,  3, 62, 41, 12,  4, 33, 17, 44, 13, 18, 48, 14, 36, 52, 11,
         22, 32, 10, 29,  2, 47,  7, 40],
        [ 3, 44, 24, 12, 36, 50, 30, 19, 39, 60, 52,  0, 15, 20, 62, 26, 32, 10,
         58,  6, 37, 33, 42, 56, 18, 25]])
torch.Size([8, 

In [31]:
x_vis, vis_pos_embeddings = model.forward_encoder(
    patches, centers, unmasked_idx
)  # x_vis: (batch, centers, transformer_dim), mask: (batch, centers)

In [32]:
x_vis.shape

torch.Size([8, 26, 384])

In [33]:
vis_pos_embeddings.shape

torch.Size([8, 26, 384])

In [34]:
from tlspt.models.utils import get_at_index
masked_centers = get_at_index(
    centers, masked_idx
)  # Masked centers. (batch, m*centers, 3)
masked_pos_embeddings = model.pos_encoder(
    masked_centers
)  # batch, m*centers, transformer_dim
print(masked_centers.shape)
print(masked_pos_embeddings.shape)

torch.Size([8, 38, 3])
torch.Size([8, 38, 384])


In [37]:
B, N, _ = masked_pos_embeddings.shape
mask_tokens = model.mask_token.expand(B, N, -1)
print(mask_tokens.shape)

torch.Size([8, 38, 384])


In [38]:
x_full = torch.cat((x_vis, mask_tokens), dim=1)
print(x_full.shape)

torch.Size([8, 64, 384])


In [39]:
full_pos_embeddings = torch.cat((vis_pos_embeddings, masked_pos_embeddings), dim=1)
print(full_pos_embeddings.shape)

torch.Size([8, 64, 384])


In [40]:
x_hat = model.forward_decoder(x_full, full_pos_embeddings, N)
print(x_hat.shape)

torch.Size([8, 38, 32, 3])


In [41]:
x_gt = get_at_index(patches, masked_idx)
print(x_gt.shape)

torch.Size([8, 38, 32, 3])


In [42]:
loss = model.get_loss(x_hat, x_gt)

In [43]:
loss

tensor(0.9476, grad_fn=<DivBackward0>)