In [1]:
from tlspt.datamodules.components.numpy_dataset import NumpyDataset
from tlspt.datamodules.components.base_site import BaseSiteDataset
from tlspt.datamodules.components.octree_dataset import OctreeDataset
from tlspt.transforms import TLSSampler
from tlspt.models.pointmae.pointmae import PointMAE
from torch.utils.data import DataLoader
import torch
from matplotlib import pyplot as plt

import numpy as np

In [2]:
site_dataset = OctreeDataset(split_file='../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv',
                                 split='train',
                                 feature_names=None,
                                 features_to_normalize=['red', 'green', 'blue'],
                                 scale=1.5,
                                 transform=TLSSampler(uniform_points=8192, farthest_points=1024))

site_dataloader = DataLoader(site_dataset, batch_size=8, shuffle=True)

[32m2024-11-17 18:25:47.397[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m49[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): reading splits from ../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv[0m
[32m2024-11-17 18:25:47.419[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m56[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): looking for 14 folders in ../data/plot_octrees/allen-spain/octrees/[0m
[32m2024-11-17 18:25:47.437[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m64[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/allen-spain-plot_splits-tr0.7-val0.15-te0.15_seed0.csv, train, 1.5): found 14 plots for 'tr

In [3]:
site_dataset.prepare_data(force_compute=False)

[32m2024-11-17 18:25:48.931[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m45[0m - [1mNo features to normalize[0m


In [4]:
site_dataset[0]

{'points': tensor([[ 0.4989, -0.4872, -0.0186],
         [-0.5862,  0.5312,  0.0910],
         [ 0.4152,  0.4868,  0.0368],
         ...,
         [ 0.4989, -0.4872, -0.0186],
         [ 0.4989, -0.4872, -0.0186],
         [ 0.4989, -0.4872, -0.0186]])}

In [5]:
site_dataset[0]

{'points': tensor([[ 0.1640,  0.2004, -0.0191],
         [ 1.0000, -0.5764,  0.0342],
         [-0.2895, -0.5446,  0.1007],
         ...,
         [ 0.1640,  0.2004, -0.0191],
         [ 0.1640,  0.2004, -0.0191],
         [ 0.1640,  0.2004, -0.0191]])}

In [6]:
model = PointMAE()

In [7]:
batch = next(iter(site_dataloader))

In [8]:
type(batch)

dict

In [9]:
batch['points'].shape

torch.Size([8, 1024, 3])

In [10]:
test_out = model(batch)

In [11]:
patches, centers = model.group(batch['points'])

In [12]:
patches.shape

torch.Size([8, 64, 32, 3])

In [13]:
centers.shape

torch.Size([8, 64, 3])

In [14]:
x_vis, mask, vis_pos_embeddings = model.forward_encoder(
    patches, centers
)

In [15]:
x_vis.shape

torch.Size([8, 26, 384])

In [16]:
mask.shape

torch.Size([8, 64])

In [17]:
vis_pos_embeddings.shape

torch.Size([8, 26, 384])

In [18]:
from tlspt.models.utils import get_masked, get_unmasked
masked_centers = get_masked(centers, mask)
masked_pos_embeddings = model.pos_encoder(masked_centers)

In [19]:
masked_centers.shape

torch.Size([8, 38, 3])

In [20]:
masked_pos_embeddings.shape

torch.Size([8, 38, 384])

In [21]:
B, N, _ = masked_pos_embeddings.shape
mask_tokens = model.mask_token.expand(B, N, -1)

In [22]:
mask_tokens = model.mask_token.expand(B, N, -1)

In [23]:
x_full = torch.cat((x_vis, mask_tokens), dim=1)

In [24]:
x_full.shape

torch.Size([8, 64, 384])

In [25]:
full_pos_embeddings = torch.cat((vis_pos_embeddings, masked_pos_embeddings), dim=1)

In [26]:
full_pos_embeddings.shape

torch.Size([8, 64, 384])

In [27]:
x_hat = model.forward_decoder(x_full, full_pos_embeddings, N)

In [28]:
x_hat.shape

torch.Size([8, 38, 32, 3])

In [29]:
x_gt = get_masked(patches, mask)

In [30]:
x_gt.shape

torch.Size([8, 38, 32, 3])

In [31]:
loss = model.get_loss(x_hat, x_gt)

In [34]:
loss

(tensor(1.0146, grad_fn=<DivBackward0>), None)