In [None]:
import os
import pathlib
import subprocess
import sys

In [None]:
%%time
%%capture setup_cell_output


def clone_phd_repo(github_token: str) -> None:
    repo_url = f"https://iod-ine:{github_token}@github.com/iod-ine/phd.git"
    subprocess.run(["git", "clone", repo_url])
    sys.path.insert(0, "phd")


if (working_dir := pathlib.Path.cwd().as_posix()).startswith("/kaggle"):
    ENVIRONMENT = "Kaggle"
    data_dir = "data"
    accelerator = "gpu"
    num_workers = 3

    # For Kaggle, the notebook that builds wheels for torch-cluster and torch-scatter
    # needs to be attached as input. It's called "torch-scatter & torch-cluster wheels"
    !pip install --upgrade laspy lazrs mlflow lightning torch_geometric
    !pip install --no-index --find-links /kaggle/input/torch-scatter-torch-cluster-wheels/ torch_scatter torch_cluster

    # It doesn't make sense to download the datasets through the API. Attach them as
    # inputs, and copy all raw files to where they are expected to be to skip the
    # download and save some GPU time.
    !mkdir -p data/raw
    !cp /kaggle/input/tree-detection-lidar-rgb/ortho/*.tif data/raw/
    !cp -r /kaggle/input/uav-point-clouds-of-individual-trees/* data/raw/

    import kaggle_secrets

    # The notebooks needs to be manually granted access to the secrets through the menu
    secrets = kaggle_secrets.UserSecretsClient()
    os.environ["MLFLOW_TRACKING_URI"] = secrets.get_secret("mlflow-uri")
    os.environ["MLFLOW_TRACKING_USERNAME"] = secrets.get_secret("mlflow-username")
    os.environ["MLFLOW_TRACKING_PASSWORD"] = secrets.get_secret("mlflow-password")
    os.environ["KAGGLE_USERNAME"] = secrets.get_secret("kaggle-username")
    os.environ["KAGGLE_KEY"] = secrets.get_secret("kaggle-key")

    clone_phd_repo(github_token=secrets.get_secret("github-token"))

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    sys.path.insert(0, "phd")

    script_path = "phd/src/experiments/sf_rgb_patch.py"

elif working_dir.startswith("/content"):
    ENVIRONMENT = "Colab"
    data_dir = "data"
    accelerator = "gpu"
    num_workers = 2

    !pip install mlflow laspy lazrs rasterio lightning torch_geometric torchinfo python-dotenv
    !pip install https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_cluster-1.6.3%2Bpt24cu121-cp310-cp310-linux_x86_64.whl
    !pip install https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl

    from google.colab import userdata

    # The notebooks needs to be manually granted access to the secrets through the menu
    os.environ["MLFLOW_TRACKING_URI"] = userdata.get("mlflow-uri")
    os.environ["MLFLOW_TRACKING_USERNAME"] = userdata.get("mlflow-username")
    os.environ["MLFLOW_TRACKING_PASSWORD"] = userdata.get("mlflow-password")
    os.environ["KAGGLE_USERNAME"] = userdata.get("kaggle-username")
    os.environ["KAGGLE_KEY"] = userdata.get("kaggle-key")

    clone_phd_repo(github_token=userdata.get("github-token"))

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    sys.path.insert(0, "phd")

    script_path = "phd/src/experiments/sf_rgb_patch.py"

elif working_dir.startswith("/home/jupyter/work/resources"):
    ENVIRONMENT = "DataSphere"
    data_dir = "../data/interim/synthetic_forest"
    accelerator = "gpu"
    num_workers = 27

    sys.path.insert(0, "..")

    script_path = "../src/experiments/sf_rgb_patch.py"

    import torch

    torch.set_float32_matmul_precision("medium")

else:
    raise NotImplementedError("Could not determine where the notebook is running.")

In [None]:
import pathlib
import random
import tempfile

import laspy
import lightning as L
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
import torchinfo

import src.clouds
import src.visualization.clouds
from src.experiments.sf_rgb_patch import (
    PointNet2TreeSegmentorModule,
    SyntheticForestColoredDataModule,
)

In [None]:
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
logger = L.pytorch.loggers.MLFlowLogger(
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    experiment_name=pathlib.Path(script_path).stem,
    tags={
        "environment": ENVIRONMENT,
    },
    log_model=True,
)

In [None]:
if torch.cuda.is_available():
    GPU_AVAILABLE = True
    !nvidia-smi > nvidia_smi.txt
    mlflow.log_artifact(run_id=logger.run_id, local_path="nvidia_smi.txt")

In [None]:
if setup_cell_output.stdout:
    with open("setup.stdout.txt", "w") as f:
        f.write(setup_cell_output.stdout)
    mlflow.log_artifact(run_id=logger.run_id, local_path="setup.stdout.txt")

if setup_cell_output.stderr:
    with open("setup.stderr.txt", "w") as f:
        f.write(setup_cell_output.stderr)
    mlflow.log_artifact(run_id=logger.run_id, local_path="setup.stderr.txt")

In [None]:
model = PointNet2TreeSegmentorModule(
    num_features=3,
    set_abstraction_ratios=(0.5, 0.25),
    set_abstraction_radii=(0.2, 0.4),
    loss=torch.nn.L1Loss(),
    lr=1e-2,
    lr_start_factor=0.1,
    lr_warmup_iters=3,
    lr_decay_iters=20,
    lr_end_factor=0.1,
)

data = SyntheticForestColoredDataModule(
    data_dir=data_dir,
    batch_size=1,
    random_jitter=0.3,
    random_scale_range=(0.9, 1.1),
    random_rotate_degrees_range=(-180, 180),
    height_threshold=2.0,
    patch_width=10,
    patch_height=10,
    patch_overlap=0.5,
    height_dropout_sigmoid_scale=8,
    height_dropout_sigmoid_shift=3,
    height_dropout_sigmoid_seed=None,
    train_samples=150,
    val_samples=50,
    num_workers=num_workers,
)

mlflow.log_params(run_id=logger.run_id, params=model.hparams)
mlflow.log_params(
    run_id=logger.run_id,
    params={
        "loss": model.loss.__class__.__name__,
        "batch_size": data.batch_size,
        "random_jitter": data.random_jitter,
        "random_scale_range": data.random_scale_range,
        "random_rotate_degrees_range": data.random_rotate_degrees_range,
    },
)
mlflow.log_params(run_id=logger.run_id, params=data.dataset_params)
mlflow.log_artifact(run_id=logger.run_id, local_path=script_path)

In [None]:
data.setup("fit")

In [None]:
summary = torchinfo.summary(model)
with tempfile.TemporaryDirectory() as tmp:
    summary_file = f"{tmp}/model_summary.txt"
    with open(summary_file, "w") as f:
        f.write(str(summary))
    mlflow.log_artifact(run_id=logger.run_id, local_path=summary_file)

In [None]:
trainer = L.Trainer(
    fast_dev_run=False,
    accelerator=accelerator,
    max_epochs=200,
    logger=logger,
    log_every_n_steps=5,
    accumulate_grad_batches=1,
    enable_progress_bar=True,
    callbacks=[
        L.pytorch.callbacks.EarlyStopping(
            monitor="loss/val",
            mode="min",
            patience=5,
        ),
        L.pytorch.callbacks.ModelCheckpoint(
            monitor="accuracy/val",
            mode="max",
            dirpath="checkpoints/",
            filename="epoch={epoch}-acc={accuracy/val:.3f}",
            auto_insert_metric_name=False,
            save_weights_only=True,
            save_last=True,
            save_top_k=1,
            every_n_epochs=1,
        ),
        L.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch"),
    ],
)

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=data,
)

In [None]:
model = model.pointnet.to("cuda")
model.eval()
with torch.no_grad():
    for i, example in enumerate(data.val[0::10]):
        example["batch"] = torch.zeros_like(example.y)
        pred = model(example.to("cuda")).cpu().squeeze()

        las = src.clouds.pyg_data_to_las(example.cpu())
        las.add_extra_dim(laspy.ExtraBytesParams(name="pred_raw", type=np.float64))
        las.add_extra_dim(laspy.ExtraBytesParams(name="pred_rounded", type=np.int32))
        las["pred_raw"][:] = pred.numpy()
        las["pred_rounded"][:] = pred.round().numpy()
        las.write(f"example_prediction_{i}.laz")
        mlflow.log_artifact(
            run_id=logger.run_id,
            local_path=f"example_prediction_{i}.laz",
        )

        ax = src.visualization.clouds.scatter_point_cloud_3d(
            las.xyz,
            color=las["pred_raw"],
        )
        logger.experiment.log_figure(
            run_id=logger.run_id,
            figure=ax.figure,
            artifact_file=f"example_prediction_{i}.png",
        )

        fig, ax = plt.subplots(1, 3, figsize=(10, 4), tight_layout=True)
        src.visualization.clouds.scatter_point_cloud_2d(
            las.xyz,
            projection="XY",
            c=las["label"],
            s=2,
            ax=ax[0],
        )
        src.visualization.clouds.scatter_point_cloud_2d(
            las.xyz,
            projection="XY",
            c=las["pred_raw"],
            s=2,
            ax=ax[1],
        )
        src.visualization.clouds.scatter_point_cloud_2d(
            las.xyz,
            projection="XY",
            c=las["pred_rounded"],
            s=2,
            ax=ax[2],
        )
        ax[0].set_title("Labels")
        ax[1].set_title("Raw prediction")
        ax[2].set_title("Rounded prediction")
        logger.experiment.log_figure(
            run_id=logger.run_id,
            figure=fig,
            artifact_file=f"example_prediction_2d_{i}.png",
        )

In [None]:
if ENVIRONMENT == "Kaggle":
    print("Cleaning up.")
    !rm -rf phd data