In [1]:
from tlspt.datamodules.components.numpy_dataset import NumpyDataset
from tlspt.datamodules.components.base_site import BaseSiteDataset
from tlspt.datamodules.components.octree_dataset import OctreeDataset
from tlspt.datamodules.components.merged_dataset import MergedOctreeDataset

from tlspt.transforms import UniformTLSSampler
from tlspt.models.pointmae.pointmae_seg import PointMAESegmentation
from torch.utils.data import DataLoader
import torch
from matplotlib import pyplot as plt

import numpy as np

In [2]:
seg_dataset = MergedOctreeDataset(
    split_files=['../data/plot_octrees/allen-spain/octrees/____TEST.csv'],
    split='train',
    scales=[2],
    min_points=[512],
    feature_names=['scalar_label'],
    features_to_normalize=None,
    normalize=True,
    transform=UniformTLSSampler(num_points=16384)
    )

seg_dataloader = DataLoader(seg_dataset, batch_size=32, shuffle=True)

[32m2024-11-29 23:53:52.536[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m49[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/____TEST.csv, train, 2): reading splits from ../data/plot_octrees/allen-spain/octrees/____TEST.csv[0m
[32m2024-11-29 23:53:52.855[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m56[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/____TEST.csv, train, 2): looking for 1 folders in ../data/plot_octrees/allen-spain/octrees/[0m
[32m2024-11-29 23:53:52.856[0m | [1mINFO    [0m | [36mtlspt.datamodules.components.base_site[0m:[36m__init__[0m:[36m64[0m - [1mOctreeDataset(../data/plot_octrees/allen-spain/octrees/____TEST.csv, train, 2): found 1 plots for 'train' out of 1 plots defined in split file[0m
[32m2024-11-29 23:53:52.879[0m | [1mINFO    [0m | [36mtlspt.structures.file_octree[0m:[36m__init__[0m:[36m155[0m - [1mI

In [3]:
seg_dataset.prepare_data()

[32m2024-11-29 23:53:53.153[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m95[0m - [1mreading stats from ../data/plot_octrees/allen-spain/octrees/stats/stats_1ca57fe9ad288.pkl[0m
[32m2024-11-29 23:53:53.154[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m96[0m - [1mfor dataset OctreeDataset.{'features_to_normalize': None}.train.1ca57fe9ad288[0m
[32m2024-11-29 23:53:53.222[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m105[0m - [1mmean: tensor([nan])[0m
[32m2024-11-29 23:53:53.223[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m106[0m - [1mstd: tensor([nan])[0m
[32m2024-11-29 23:53:53.223[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m110[0m - [1mtorch.float32[0m
[32m2024-11-29 23:53:53.224[0m | [1mINFO    [0m | [36mtlspt.utils[0m:[36mprepare_data[0m:[36m114[0m - [1mDataset has 1 features named ['scalar_label']. 
 Normalizing None

In [4]:
seg_dataset[0]

{'points': tensor([[ 0.1273,  0.1973, -0.0532],
         [-0.0376, -0.3053, -0.0085],
         [ 0.1211,  0.1789,  0.0530],
         ...,
         [-0.0525,  0.1358,  0.0539],
         [-0.0174, -0.3105, -0.0911],
         [-0.0016,  0.2324, -0.0354]]),
 'features': tensor([[1.],
         [0.],
         [1.],
         ...,
         [1.],
         [0.],
         [1.]]),
 'lengths': 3672,
 'scales': 2}

In [5]:
batch = next(iter(seg_dataloader))

In [6]:
batch['points'].shape
batch['features'].shape
batch['lengths'].shape
batch['scales'].shape

torch.Size([32])

In [7]:
model = PointMAESegmentation(neighbor_alg='ball_query', ball_radius=0.2, scale=2.0)

In [8]:
fix_batch = {'points': batch['points'], 'features': torch.nan_to_num(batch['features']), 'lengths': batch['lengths'], 'scales': batch['scales']}

In [9]:
model(fix_batch)

torch.Size([32, 1152, 1]) torch.Size([32, 1152, 1])


tensor(0.7062, grad_fn=<NllLossBackward0>)

In [10]:
B, N, _ = batch['points'].shape
print(B, N)

32 16384


In [11]:
x_gt = batch['features']
print(x_gt.shape)
x_gt = x_gt.squeeze(-1)
print(x_gt.shape)

torch.Size([32, 16384, 1])
torch.Size([32, 16384])


In [12]:
patches, centers = model.group(
    batch['points'], batch['lengths']
)
print(patches.shape)
print(centers.shape)

torch.Size([32, 64, 32, 3])
torch.Size([32, 64, 3])


In [13]:
patch_embeddings = model.patch_encoder(patches)
print(patch_embeddings.shape)
pos_embeddings = model.pos_encoder(centers)
print(pos_embeddings.shape)

torch.Size([32, 64, 384])
torch.Size([32, 64, 384])


In [14]:
x, feature_list = model.transformer_encoder(patch_embeddings, pos_embeddings, feature_blocks=model.feature_blocks)
print(model.feature_blocks)
print("===")
print(x.shape)
print([f.shape for f in feature_list])

[3, 7, 11]
===
torch.Size([32, 64, 384])
[torch.Size([32, 64, 384]), torch.Size([32, 64, 384]), torch.Size([32, 64, 384])]


In [17]:
feature_tensor = torch.cat(feature_list, dim=2)
feature_tensor = feature_tensor.transpose(1, 2)
print(feature_tensor.shape)

torch.Size([32, 1152, 64])


In [18]:
x_max = torch.max(feature_tensor, dim=2, keepdim=True)[0]
x_avg = torch.mean(feature_tensor, dim=2, keepdim=True)
print(x_max.shape)
print(x_avg.shape)

torch.Size([32, 1152, 1])
torch.Size([32, 1152, 1])


In [19]:
x_max_feature = x_max.expand(-1, -1, N)
x_avg_feature = x_avg.expand(-1, -1, N)
print(x_max_feature.shape)
print(x_avg_feature.shape)

torch.Size([32, 1152, 16384])
torch.Size([32, 1152, 16384])


In [21]:
x_global_feature = torch.cat([x_max_feature, x_avg_feature], dim=1)
print(x_global_feature.shape)

torch.Size([32, 2304, 16384])


In [22]:
f_level_0 = model.propagation_0(
    batch['points'].transpose(-1, -2), centers.transpose(-1, -2), batch['points'].transpose(-1, -2), feature_tensor
)

In [23]:
f_level_0.shape

torch.Size([32, 1024, 16384])

In [27]:
x = torch.cat(
    (f_level_0, x_global_feature), dim=1
)
print(x.shape)

torch.Size([32, 3328, 16384])


In [28]:
x = model.relu(model.bns1(model.convs1(x)))
x = model.dp1(x)
print(x.shape)
x = model.relu(model.bns2(model.convs2(x)))
print(x.shape)
x = model.convs3(x)
print(x.shape)

torch.Size([32, 512, 16384])
torch.Size([32, 256, 16384])
torch.Size([32, 2, 16384])


In [29]:
x_hat = x.transpose(1, 2)
print(x_hat.shape)

torch.Size([32, 16384, 2])


In [30]:
loss = model.get_loss(x_hat, x_gt.long( ))

In [31]:
print(loss)

tensor(0.6843, grad_fn=<NllLossBackward0>)
