In [None]:
import os
import pathlib
import subprocess
import sys

In [None]:
%%time
%%capture setup_cell_output


def clone_phd_repo(github_token: str) -> None:
    repo_url = f"https://iod-ine:{github_token}@github.com/iod-ine/phd.git"
    subprocess.run(["git", "clone", repo_url])
    sys.path.insert(0, "phd")


if (working_dir := pathlib.Path.cwd().as_posix()).startswith("/Users"):
    ENVIRONMENT = "local"
    data_dir = "../data/interim/synthetic_forest"
    accelerator = "cpu"  # mpu (Apple Silicon GPU) does not work for torch_cluster.fps()
    num_workers = 11

    import dotenv

    assert dotenv.load_dotenv()

elif working_dir.startswith("/kaggle"):
    ENVIRONMENT = "Kaggle"
    data_dir = "data"
    accelerator = "gpu"
    num_workers = 3

    # For Kaggle, the notebook that builds wheels for torch-cluster and torch-scatter
    # needs to be attached as input. It's called "torch-scatter & torch-cluster wheels"
    !pip install --upgrade laspy lazrs mlflow lightning torch_geometric
    !pip install --no-index --find-links /kaggle/input/torch-scatter-torch-cluster-wheels/ torch_scatter torch_cluster

    import kaggle_secrets

    # The notebooks needs to be manually granted access to the secrets through the menu
    secrets = kaggle_secrets.UserSecretsClient()
    os.environ["MLFLOW_TRACKING_URI"] = secrets.get_secret("mlflow-uri")
    os.environ["MLFLOW_TRACKING_USERNAME"] = secrets.get_secret("mlflow-username")
    os.environ["MLFLOW_TRACKING_PASSWORD"] = secrets.get_secret("mlflow-password")
    os.environ["KAGGLE_USERNAME"] = secrets.get_secret("kaggle-username")
    os.environ["KAGGLE_KEY"] = secrets.get_secret("kaggle-key")

    clone_phd_repo(github_token=secrets.get_secret("github-token"))

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

elif working_dir.startswith("/content"):
    ENVIRONMENT = "Colab"
    data_dir = "data"
    accelerator = "gpu"

    !pip install mlflow laspy lazrs lightning torch_geometric torchinfo python-dotenv
    !pip install https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_cluster-1.6.3%2Bpt24cu121-cp310-cp310-linux_x86_64.whl
    !pip install https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl

    from google.colab import userdata

    # The notebooks needs to be manually granted access to the secrets through the menu
    os.environ["MLFLOW_TRACKING_URI"] = userdata.get("mlflow-uri")
    os.environ["MLFLOW_TRACKING_USERNAME"] = userdata.get("mlflow-username")
    os.environ["MLFLOW_TRACKING_PASSWORD"] = userdata.get("mlflow-password")
    os.environ["KAGGLE_USERNAME"] = userdata.get("kaggle-username")
    os.environ["KAGGLE_KEY"] = userdata.get("kaggle-key")

    clone_phd_repo(github_token=userdata.get("github-token"))

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

elif working_dir.startswith("/home/jupyter/work/resources"):
    ENVIRONMENT = "DataSphere"
    data_dir = "../data/interim/synthetic_forest"
    accelerator = "gpu"

    sys.path.insert(0, "phd")

else:
    raise NotImplementedError("Could not determine where the notebook is running.")

In [None]:
import os
import pathlib
import tempfile
from typing import Optional

import dotenv
import laspy
import lightning as L
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
import torch_geometric
import torch_scatter
import torchinfo
from torch import nn

import src.clouds
import src.visualization.clouds
from src.datasets import SyntheticForest
from src.models.pointnet import PointNet2TreeSegmentor

In [None]:
class PointNet2TreeSegmentorModule(L.LightningModule):
    """A PointNet++ tree segmentor lightning module."""

    def __init__(self):
        """Create a new LitPointNet2TreeSegmentor instance."""
        super().__init__()

        self.pointnet = PointNet2TreeSegmentor(num_features=4)

        self.save_hyperparameters()

        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx):  # noqa: ARG002
        """Process a single batch of the training dataset and return the loss."""
        pred = self.pointnet(batch)
        loss = nn.functional.mse_loss(pred.squeeze(), batch.y.float())
        per_batch_max_index, _ = torch_scatter.scatter_max(
            src=batch.y,
            index=batch.batch,
        )
        number_of_trees = (per_batch_max_index + 1).sum()
        self.log("loss/train", loss.item() / number_of_trees)
        return loss

    def validation_step(self, batch, batch_idx):  # noqa: ARG002
        """Process a single batch of the validation dataset and return the loss."""
        pred = self.pointnet(batch)
        loss = nn.functional.mse_loss(pred.squeeze(), batch.y.float())
        per_batch_max_index, _ = torch_scatter.scatter_max(
            src=batch.y,
            index=batch.batch,
        )
        number_of_trees = (per_batch_max_index + 1).sum()
        self.validation_step_outputs.append(loss / number_of_trees)

    def on_validation_epoch_end(self):
        """Process the results of the validation epoch."""
        average_loss = torch.stack(self.validation_step_outputs).mean()
        self.log("loss/val", average_loss)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        """Set up and return the optimizers."""
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=1e-3,
        )
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer,
            step_size=3,
            gamma=0.5,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
        }

In [None]:
class SyntheticForestDataModule(L.LightningDataModule):
    """A data module for the synthetic forest dataset."""

    def __init__(
        self,
        data_dir: str,
        batch_size: int,
        random_seed: int = 42,
        train_samples: int = 100,
        val_samples: int = 20,
        test_samples: int = 20,
        trees_per_sample: int = 50,
        height_threshold: float = 2.0,
        dx: float = 5.0,
        dy: float = 5.0,
        xy_noise_mean: float = 0.0,
        xy_noise_std: float = 1.0,
        las_features: Optional[list[str]] = None,
    ):
        """Create a new SyntheticForestDataModule instance."""
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = torch_geometric.transforms.Compose(
            [
                torch_geometric.transforms.NormalizeScale(),
                torch_geometric.transforms.NormalizeFeatures(),
            ]
        )
        self.val_transform = self.transform
        self.dataset_params = {
            "root": self.data_dir,
            "random_seed": random_seed,
            "train_samples": train_samples,
            "val_samples": val_samples,
            "test_samples": test_samples,
            "trees_per_sample": trees_per_sample,
            "height_threshold": height_threshold,
            "dx": dx,
            "dy": dy,
            "xy_noise_mean": xy_noise_mean,
            "xy_noise_std": xy_noise_std,
            "las_features": las_features,
        }

    def prepare_data(self):
        """Prepare the data for setup (download, tokenize, etc.) on one device."""
        SyntheticForest(**self.dataset_params)

    def setup(self, stage: str):
        """Prepare the data for training (split, transform, etc.) on all devices."""
        if stage == "fit":
            self.train = SyntheticForest(
                split="train",
                **self.dataset_params,
                transform=self.transform,
            )
            self.val = SyntheticForest(
                split="val",
                **self.dataset_params,
                transform=self.val_transform,
            )

        if stage == "test":
            raise NotImplementedError()

        if stage == "predict":
            raise NotImplementedError()

    def train_dataloader(self):
        """Set up and return the train data loader."""
        return torch_geometric.loader.DataLoader(
            dataset=self.train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=num_workers,
        )

    def val_dataloader(self):
        """Set up and return the validation data loader."""
        return torch_geometric.loader.DataLoader(
            dataset=self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=num_workers,
        )

    def test_dataloader(self):
        """Set up and return the test data loader."""
        raise NotImplementedError()

    def predict_dataloader(self):
        """Set up and return the prediction data loader."""
        raise NotImplementedError()

In [None]:
logger = L.pytorch.loggers.MLFlowLogger(
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    experiment_name="synthetic_forest_only",
    tags={
        "environment": ENVIRONMENT,
    },
    log_model=True,
)

In [None]:
if torch.cuda.is_available():
    GPU_AVAILABLE = True
    !nvidia-smi > nvidia_smi.txt
    mlflow.log_artifact(run_id=logger.run_id, local_path="nvidia_smi.txt")

In [None]:
if setup_cell_output.stdout:
    with open("setup.stdout.txt", "w") as f:
        f.write(setup_cell_output.stdout)
    mlflow.log_artifact(run_id=logger.run_id, local_path="setup.stdout.txt")

if setup_cell_output.stderr:
    with open("setup.stderr.txt", "w") as f:
        f.write(setup_cell_output.stderr)
    mlflow.log_artifact(run_id=logger.run_id, local_path="setup.stderr.txt")

In [None]:
model = PointNet2TreeSegmentorModule()

data = SyntheticForestDataModule(
    data_dir="data/raw/synthetic_forest",
    batch_size=1,
    trees_per_sample=9,
    random_seed=43,
    dx=4,
    dy=4,
    xy_noise_std=2,
)

mlflow.log_params(run_id=logger.run_id, params=data.dataset_params)

In [None]:
summary = torchinfo.summary(model)
with tempfile.TemporaryDirectory() as tmp:
    summary_file = f"{tmp}/model_summary.txt"
    with open(summary_file, "w") as f:
        f.write(str(summary))
    mlflow.log_artifact(run_id=logger.run_id, local_path=summary_file)

In [None]:
trainer = L.Trainer(
    fast_dev_run=False,
    accelerator=accelerator,
    max_epochs=50,
    logger=logger,
    log_every_n_steps=5,
    accumulate_grad_batches=1,
    enable_progress_bar=True,
    callbacks=[
        L.pytorch.callbacks.EarlyStopping(
            monitor="loss/val",
            mode="min",
            patience=3,
        ),
        L.pytorch.callbacks.ModelCheckpoint(
            monitor="loss/val",
            mode="min",
            dirpath="checkpoints/",
            filename="{epoch}-{loss/val:.2f}",
            save_last=True,
            save_top_k=1,
            every_n_epochs=1,
        ),
        L.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
    ],
)

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=data,
)

In [None]:
example = data.val[0]
example["batch"] = torch.zeros_like(example.y)

model = model.pointnet.to("cuda")
model.eval()
with torch.no_grad():
    pred = model(example.to("cuda"))

las = src.clouds.pyg_data_to_las(example.cpu())
las.add_extra_dim(laspy.ExtraBytesParams(name="pred", type=np.float32))
las["pred"][:] = pred.cpu().squeeze().numpy()
las.write("example_prediction.laz")
mlflow.log_artifact(run_id=logger.run_id, local_path="example_prediction.laz")

In [None]:
ax = src.visualization.clouds.scatter_point_cloud_3d(
    las.xyz,
    color=las["pred"],
)
logger.experiment.log_figure(
    run_id=logger.run_id,
    figure=ax.figure,
    artifact_file="example_prediction.png",
)

In [None]:
if ENVIRONMENT == "Kaggle":
    print("Cleaning up.")
    !rm -rf phd data