In [1]:
import json
import math
import os
import time
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import imageio
import nerfview
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import viser
import yaml
from src.datasets.colmap import Dataset, Parser
from src.datasets.traj import (
    generate_interpolated_path,
    generate_ellipse_path_z,
    generate_spiral_path,
)
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from fused_ssim import fused_ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing_extensions import Literal, assert_never
from src.splats.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed

from gsplat.compression import PngCompression
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy


In [2]:
from src.splats.simple_trainer import *

In [None]:
# Config objects we can choose between.
# Each is a tuple of (CLI description, config object).
configs = {
    "default": (
        "Gaussian splatting training using densification heuristics from the original paper.",
        Config(
            strategy=DefaultStrategy(verbose=True),
        ),
    ),
    "mcmc": (
        "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
        Config(
            init_opa=0.5,
            init_scale=0.1,
            opacity_reg=0.01,
            scale_reg=0.01,
            strategy=MCMCStrategy(verbose=True),
        ),
    ),
}
cfg = tyro.extras.overridable_config_cli(configs)
cfg.adjust_steps(cfg.steps_scaler)

# try import extra dependencies
if cfg.compression == "png":
    try:
        import plas
        import torchpq
    except:
        raise ImportError(
            "To use PNG compression, you need to install "
            "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) "
            "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') "
        )

cli(main, cfg, verbose=True)