In [12]:
import requests
import io
from typing import Callable, Sequence

import torch
from torch import nn, optim
import numpy as np
from scipy.spatial import KDTree
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F
from sklearn.metrics import f1_score
import laspy as lp
import pytorch_lightning as pl

np.random.seed(42)

In [13]:
DATA_URLS = (
    "https://github.com/calebbuffa/geometric-deep-learning/raw/main/data/test.laz",
    "https://github.com/calebbuffa/geometric-deep-learning/raw/main/data/train.laz"
)

DATA_URLS_WITH_GEOM = (
    "https://github.com/calebbuffa/geometric-deep-learning/raw/main/data/test_with_geom.laz",
    "https://github.com/calebbuffa/geometric-deep-learning/raw/main/data/train_with_geom.laz"
)


In [80]:
def random_rotate(x: torch.Tensor):
    theta = torch.pi * 2 * torch.rand(1)
    rotation_matrix = torch.tensor(
        [
            [
                torch.cos(theta),
                -torch.sin(theta)
            ],
             [
                  torch.sin(theta),
                  torch.cos(theta)
            ]
        ],
        dtype=x.dtype
    )
    x[..., [0, 1]] = x[..., [0, 1]] @ (rotation_matrix)
    x[..., [6, 7]] = x[..., [6, 7]] @ (rotation_matrix)
    return x

def center(x: np.ndarray, dim: int = 0):
    """
    Center input to 0 mean.
    Parameters
    ----------
    x : torch.Tensor
        Input to center.
    dim : int
        Dimension to extract mean, defaults to 0.
    Returns
    -------
    torch.Tensor
    """
    return x - x.mean(axis=dim, keepdims=True)


def scale(
    x: np.ndarray,
    new_range: tuple[int, int] = (-1, 1),
    eps: float = 1e-20,
    dim: int = 0,
):
    """
    Scale input features to new min/max.
    Parameters
    ----------
    x : torch.Tensor
        Features to scale.
    new_range : tuple[int, int]
        New min/max range, defaults to (-1, 1).
    eps : float
        Value to avoid division by zero.
    dim : int
        Dimension to scale, defaults to 0.
    Returns
    -------
    torch.Tensor
    """
    x_min, x_max = (
        x.min(axis=dim, keepdims=True),
        x.max(axis=dim, keepdims=True),
    )
    x_range = x_max - x_min
    x_range = np.where(x_range <= 0.0, eps, x_range)
    return (x - x_min) / x_range * (new_range[1] - new_range[0]) + new_range[0]


def center_points(xyz: np.ndarray):
    """
    Scale XYZ point cloud coordinates. Sets the min/max Z value to [0, 1] and
    min/max XY values to [-1, 1] centered around the mean.
    Parameters
    ----------
    xyz : np.ndarray
        Array of shape (N, 3).
    Returns
    -------
    np.ndarray
        Scaled coordinates.
    """
    centered = center(xyz)
    xyz[..., -1] = scale(centered[..., -1], new_range=(0, 1))
    xyz[..., :2] = scale(centered[..., :2], new_range=(-1, 1))
    return xyz

def get_data():
    data = {}
    for url in DATA_URLS:
        partition = url.split("/")[-1].split(".")[0]
        resp = requests.get(url)
        pcl = lp.read(io.BytesIO(resp.content))
        xyz = center_points(np.column_stack((pcl.x, pcl.y, pcl.z)))
        rgb = np.column_stack((pcl.red, pcl.green, pcl.blue)) / 65280.0
        data[partition] = {
            "xyz": xyz,
            "rgb": rgb,
            "labels": pcl.classification,
        }
    return data

def get_data_with_geom():
    data = {}
    for url in DATA_URLS_WITH_GEOM:
        partition = url.split("/")[-1].split(".")[0].split("_")[0]
        resp = requests.get(url)
        pcl = lp.read(io.BytesIO(resp.content))
        xyz = center_points(np.column_stack((pcl.x, pcl.y, pcl.z)))
        rgb = np.column_stack((pcl.red, pcl.green, pcl.blue)) / 65280.0
        geom_attrs = []
        for extra_dim_name in pcl.point_format.extra_dimension_names:
            geom_attrs.append(getattr(pcl, extra_dim_name))
        geom_attrs = np.column_stack(geom_attrs)
        data[partition] = {
            "xyz": xyz,
            "rgb": rgb,
            "labels": pcl.classification,
            "geom": geom_attrs
        }
    return data

class TNet(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.k = in_channels
        self.mlp1 = nn.Sequential(nn.Conv1d(self.k, 64, 1), nn.ReLU())
        self.mlp2 = nn.Sequential(nn.Conv1d(64, 128, 1), nn.ReLU())
        self.mlp3 = nn.Sequential(nn.Conv1d(128, 1024, 1), nn.ReLU())
        self.mlp4 = nn.Sequential(nn.Conv1d(1024, 512, 1), nn.ReLU())
        self.mlp5 = nn.Sequential(nn.Conv1d(512, 256, 1), nn.ReLU())
        self.t_mlp = nn.Conv1d(256, self.k**2, 1)

    def forward(self, x):
        B = x.shape[0]
        x = self.mlp1(x)
        x = self.mlp2(x)
        x = self.mlp3(x)
        x = torch.max(x, -1, keepdims=True)[0]  # B, 1024

        x = self.mlp4(x)
        x = self.mlp5(x)
        x = self.t_mlp(x).reshape(-1, self.k, self.k)  # B, K, K

        ident = torch.eye(self.k, device=x.device).unsqueeze(0).repeat(B, 1, 1)
        return x + ident

class PointNet(pl.LightningModule):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        lr: float = 1e-3,
        global_feature: bool=False,
        feature_transform: bool=False,
        xyz_transform: bool=False,
    ):
        super().__init__()
        self._global_feature = global_feature
        self._transform_feature = feature_transform
        self._transform_xyz = xyz_transform
        self.lr = lr

        self.xyz_transform = TNet(3) if xyz_transform else nn.Identity()
        self.feature_transform = TNet(64) if feature_transform else nn.Identity()
        self.mlp1 = nn.Sequential(
            nn.Conv1d(in_channels, 64, 1), nn.BatchNorm1d(64), nn.ReLU()
        )
        self.mlp2 = nn.Sequential(nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU())
        self.mlp3 = nn.Sequential(nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU())
        self.mlp4 = nn.Sequential(nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU())
        self.mlp5 = nn.Sequential(nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU())
        self.dropout = nn.Dropout()
        self.head = nn.Sequential(
            nn.Conv1d(1088, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv1d(128, out_channels, 1),
        )

        self.scores = []

    def forward_base(self, x):
        N = x.shape[-1]
        if self._transform_xyz:
            xyz_trans = self.xyz_transform(x[:, :3])
            x[:, :3] = (x[:, :3].mT @ xyz_trans).mT
        else:
            xyz_trans = None

        x = self.mlp1(x)
        x = self.mlp2(x)
        if self._transform_feature:
            feat_trans = self.feature_transform(x)
            x = (x.mT @ feat_trans).mT
        else:
            feat_trans = None

        x_local = self.mlp3(x)

        x = self.mlp4(x_local)
        x = self.mlp5(x)
        x = self.dropout(x)

        x_global = x.max(dim=-1, keepdim=True).values

        if self._global_feature:
            return x_global.squeeze(-1), xyz_trans, feat_trans

        x = torch.cat((x_local, x_global.repeat(1, 1, N)), dim=1)
        return x, xyz_trans, feat_trans

    def forward(self, x):
        x, xyz_trans, feat_trans = self.forward_base(x.mT)
        x = self.head(x)
        return x.mT, xyz_trans, feat_trans

    def _common_step(self, batch, batch_idx, partition):
        x, y = batch
        logits, _, feat_trans = self.forward(x.float())
        if feat_trans is not None:
            reg = self._feature_transform_regularizer(feat_trans.mT) * 0.001
        else:
            reg = 0.0
        loss = self.loss_fn(logits.contiguous(), y) + reg
        pred = self._activation(logits)
        return loss, logits, pred

    def _feature_transform_regularizer(self, feat_trans):
        D = feat_trans.shape[2]
        eye = torch.eye(D, device=self.device)[None, :, :]
        loss = torch.norm(eye - (feat_trans @ feat_trans.mT), dim=(1, 2))
        return loss.mean()

    def _activation(self, y_hat):
        probs = torch.sigmoid(y_hat.squeeze(-1))
        return torch.where(probs > 0.5, 1, 0)

    def training_step(self, train_batch, batch_idx):
        loss, logits, pred = self._common_step(train_batch, batch_idx, "train")
        return {"loss": loss, "logits": logits, "pred": pred}

    def validation_step(self, val_batch, batch_idx):
        loss, logits, pred = self._common_step(val_batch, batch_idx, "val")
        self.scores.append(
            f1_score(
                val_batch[1].flatten().cpu().detach().numpy(),
                pred.flatten().cpu().detach().numpy(),
                zero_division=0.0,
            )
        )
        return {"loss": loss, "logits": logits, "pred": pred}

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)

    def on_validation_end(self):
        print(f"F1 Score: {round(sum(self.scores) / len(self.scores), 4)}")
        self.scores = []

    def loss_fn(self, y_hat, y):
        y_hat = y_hat.squeeze(-1)
        return F.binary_cross_entropy_with_logits(
            y_hat, y.type(y_hat.dtype)
        )

class DemoModel(pl.LightningModule):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_dims: Sequence[int],
        dropout: float = 0.0,
        lr: float = 0.001,
        kernel_size: int = 1,
    ):
        super().__init__()
        self.lr = lr
        self.model = self._init_layers(
            in_channels,
            out_channels,
            hidden_dims,
            kernel_size,
            dropout
        )
        self.scores = []

    def forward(self, x):
        return self.model(x.mT).mT

    def _common_step(self, batch, batch_idx, partition):
        x, y = batch
        logits = self.forward(x.float())
        loss = self.loss_fn(logits.contiguous(), y)
        pred = self._activation(logits)
        return loss, logits, pred

    def training_step(self, train_batch, batch_idx):
        loss, logits, pred = self._common_step(train_batch, batch_idx, "train")
        return {"loss": loss, "logits": logits, "pred": pred}

    def validation_step(self, val_batch, batch_idx):
        loss, logits, pred = self._common_step(val_batch, batch_idx, "val")
        self.scores.append(
            f1_score(
                val_batch[1].flatten().cpu().detach().numpy(),
                pred.flatten().cpu().detach().numpy(),
                zero_division=0.0,
            )
        )
        return {"loss": loss, "logits": logits, "pred": pred}

    def _activation(self, y_hat):
        probs = torch.sigmoid(y_hat.squeeze(-1))
        return torch.where(probs > 0.5, 1, 0)

    def loss_fn(self, y_hat, y):
        y_hat = y_hat.squeeze(-1)
        return F.binary_cross_entropy_with_logits(
            y_hat, y.type(y_hat.dtype)
        )

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)

    def on_validation_end(self):
        print(f"F1 Score: {round(sum(self.scores) / len(self.scores), 4)}")
        self.scores = []

    def _init_layers(
        self,
        input_channels,
        output_channels,
        hidden_dimensions,
        kernel_size,
        dropout_prob
    ):
        layers = []
        in_channels = input_channels
        num_layers = len(hidden_dimensions)

        for i in range(num_layers):
            out_channels = hidden_dimensions[i]
            layers.append(
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size=kernel_size,
                    padding="same",
                    bias=False
                )
            )
            layers.append(nn.BatchNorm1d(out_channels))
            layers.append(nn.ReLU())
            in_channels = out_channels
        layers.extend(
            [
                nn.Dropout(dropout_prob),
                nn.Conv1d(
                    in_channels,
                    output_channels,
                    kernel_size=1,
                    padding="same"
                ),
            ]
        )

        return nn.Sequential(*layers)

class DemoDataset(Dataset):
    def __init__(
        self,
        tree: KDTree,
        data: dict,
        points_per_batch: int=5000,
        transform: Callable = lambda x: x,
        geometric_features: bool = False,
    ):
        super().__init__()
        self.tree = tree
        self.data = data
        self.points_per_batch = points_per_batch
        self.n_points = tree.data.shape[0]
        self.transform = transform
        self.with_geom = geometric_features

    def __getitem__(self, idx):
        query_point = self.tree.data[idx] + np.random.uniform(-.01, .01, 3)
        dists, idxs = self.tree.query(query_point, k=self.points_per_batch)
        xyz_global = torch.from_numpy(self.data["xyz"][idxs])
        xyz_local = torch.from_numpy(center_points(self.data["xyz"][idxs]))
        rgb = torch.from_numpy(self.data["rgb"][idxs])
        y = torch.from_numpy(self.data["labels"][idxs])
        if self.with_geom:
            geom = torch.from_numpy(self.data["geom"][idxs])
            x = torch.cat((xyz_global, rgb, xyz_local, geom), dim=1)
        else:
            x = torch.cat((xyz_global, rgb, xyz_local), dim=1)
        indexes = torch.randperm(y.shape[0]) # check for permutation invariance
        return self.transform(x[indexes]), y[indexes]


class DemoDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 1,
        points_per_batch: int = 5000,
        transform: Callable = lambda x: x,
        geometric_features: bool = False,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transform
        self.points_per_batch = points_per_batch
        self.with_geom = geometric_features

    def setup(self, stage: str):
        if self.with_geom:
            data = get_data_with_geom()
        else:
            data = get_data()
        train_tree = KDTree(data["train"]["xyz"])
        val_tree = KDTree(data["test"]["xyz"])
        self.train_ds = DemoDataset(
            train_tree,
            data["train"],
            points_per_batch=self.points_per_batch,
            transform=self.transform,
            geometric_features=self.with_geom,
        )
        self.val_ds = DemoDataset(
            val_tree,
            data["test"],
            points_per_batch=self.points_per_batch,
            transform=self.transform,
            geometric_features=self.with_geom,
        )

    def train_dataloader(self):
        idxs = np.arange(0, self.train_ds.n_points)
        np.random.shuffle(idxs)
        n_points = self.train_ds.n_points // self.train_ds.points_per_batch
        sampler = SubsetRandomSampler(idxs[:n_points].tolist())
        return DataLoader(
            self.train_ds, batch_size=self.batch_size, sampler=sampler
        )

    def val_dataloader(self):
        idxs = np.arange(0, self.val_ds.n_points)
        np.random.shuffle(idxs)
        n_points = self.val_ds.n_points // self.val_ds.points_per_batch
        sampler = SubsetRandomSampler(idxs[:n_points].tolist())
        return DataLoader(
            self.val_ds, batch_size=self.batch_size, sampler=sampler
        )

In [62]:
model_kwargs = dict(
    hidden_dims=[
        64,
        128,
        256,
        512,
        1024,
        512,
        256,
        128,
    ],
    dropout=0.2,
    lr=1e-2,
)

dm_kwargs = dict(
    batch_size=1,
    points_per_batch=10_000,
)

In [63]:
trainer = pl.Trainer(max_epochs=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Violate Symmetry
1. Train model with `kernel_size=3`
2. Train model with `kernel_size=1`

In [81]:
model = DemoModel(
    in_channels=9,
    out_channels=1,
    kernel_size=3,
    **model_kwargs,
)

dm = DemoDataModule(
    transform=lambda x: x,
    geometric_features=False,
    **dm_kwargs
)

In [82]:
trainer.fit(model=model, datamodule=dm)

c:\Users\cal11713\AppData\Local\mambaforge\envs\gdl\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:630: Checkpoint directory c:\Users\cal11713\projects\geometric-deep-learning\notebooks\lightning_logs\version_2\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | xyz_transform     | TNet       | 799 K 
1 | feature_transform | Identity   | 0     
2 | mlp1              | Sequential | 768   
3 | mlp2              | Sequential | 4.3 K 
4 | mlp3              | Sequential | 4.3 K 
5 | mlp4              | Sequential | 8.6 K 
6 | mlp5              | Sequential | 134 K 
7 | dropout           | Dropout    | 0     
8 | head              | Sequential | 723 K 
-------------------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.700     Total estimated model params size (MB)


                                                  

c:\Users\cal11713\AppData\Local\mambaforge\envs\gdl\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


F1 Score: 0.6448                                          


c:\Users\cal11713\AppData\Local\mambaforge\envs\gdl\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
c:\Users\cal11713\AppData\Local\mambaforge\envs\gdl\lib\site-packages\pytorch_lightning\loops\fit_loop.py:293: The number of training batches (21) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=10` reached.


## Fix Stability Under Distortion
1. Train model with `transform=random_rotate`

In [None]:
model = DemoModel(
    in_channels=9,
    out_channels=1,
    kernel_size=1,
    **model_kwargs,
)
dm = DemoDataModule(
    transform=random_rotate,
    geometric_features=False,
    **dm_kwargs
)

In [None]:
trainer.fit(model=model, datamodule=dm)

## Fix Multiscale Feature Representations
1. Train model with `geometric_features=True` and `in_channels=21`

model = DemoModel(
    in_channels=9 + 12,
    out_channels=1,
    kernel_size=1,
    **model_kwargs,
)
dm = DemoDataModule(
    transform=random_rotate,
    geometric_features=True,
    **dm_kwargs
)

In [None]:
trainer.fit(model=model, datamodule=dm)