# TILTED paper table generation

Let's start with imports and types...

In [1]:
from __future__ import annotations

from typing import Any, Callable, Dict, Generic, Literal, Tuple, TypeVar
from pathlib import Path
from tqdm.auto import tqdm
import dataclasses
import yaml
import numpy as onp
import itertools

EvalMetricName = Literal["lpips", "mse", "psnr", "ssim"]
TrainMetricName = Literal[
    "iterations_per_second",
    "train/distortion_loss",
    "train/grad_norm",
    "train/interlevel_loss",
    "train/iterations_per_sec",  # Duplicate, oops!
    "train/l12_reg",
    "train/loss",
    "train/mse",
    "train/psnr",
    "train/tv_reg_l1",
    "train/tv_reg_l2",
]

EvalMetrics = Dict[EvalMetricName, float]
TrainMetrics = Dict[TrainMetricName, float]
ConfigDict = Dict[str, Any]

T = TypeVar("T")

@dataclasses.dataclass
class Result(Generic[T]):
    config: ConfigDict
    metrics: T
    path: Path

class Results(Tuple[Result[T], ...], Generic[T]):
    def filter(self, condition: Callable[[Result], bool]) -> Results[T]:
        return Results(filter(condition, self))

    def filter_config_eq(self, config: Dict[str, Any]) -> Results[T]:
        return Results(
            x for x in self
            if all(map(lambda k: x.config[k] == config[k], config.keys()))
        )

    def filter_config_neq(self, config: Dict[str, Any]) -> Results[T]:
        return Results(
            x for x in self
            if all(map(lambda k: x.config[k] != config[k], config.keys()))
        )
    
    def assert_length(self, length: int) -> Results[T]:
        assert len(self) == length, f"Expected length {length}, but got {len(self)}"
        return self


Load our NeRF results...

In [2]:
_output_dirs = [
    Path("./paper_results/"),
]

eval_results = Results[EvalMetrics](
    Result(yaml.load((p / "config_dict.yaml").read_text(), yaml.Loader), yaml.load((p / "eval_metrics_fixed.yaml").read_text(), yaml.Loader), p)
    for p in tqdm(list(itertools.chain(*[d.iterdir() for d in _output_dirs])))
    if (p / "eval_metrics_fixed.yaml").exists()
)
train_results = Results[TrainMetrics](
    Result(yaml.load((p / "config_dict.yaml").read_text(), yaml.Loader), yaml.load((p / "train_metrics.yaml").read_text(), yaml.Loader), p)
    for p in tqdm(list(itertools.chain(*[d.iterdir() for d in _output_dirs])))
    if (p / "train_metrics.yaml").exists()
)

del _output_dirs

  0%|          | 0/1458 [00:00<?, ?it/s]

  0%|          | 0/1458 [00:00<?, ?it/s]

Make some helper functions...

In [3]:
def bold(text: str) -> str:
    return r"\textbf{" + text + "}"

def get_metric_moments(results: Results, metric: str) -> str:
    metric_from_dataset = {}
    for result in results:
        dataset = result.config["dataset_path"]
        if dataset not in metric_from_dataset:
            metric_from_dataset[dataset] = []
        metric_from_dataset[dataset].append(result.metrics[metric])

    means = []
    stderrs = []
    for x in metric_from_dataset.values():
        means.append(onp.mean(x))
        stderrs.append(onp.std(x) / onp.sqrt(len(x) - 1))
        
    mean = onp.mean(means)
    stderr = onp.sqrt(onp.sum(onp.array(stderrs) ** 2)) / len(stderrs)
    return f"{mean:.2f}" + r"{\scriptsize$\pm$" + f"{stderr:.2f}" + r"}"

def get_psnr(results: Results[EvalMetrics]) -> str:
    return(get_metric_moments(results, "psnr"))

def get_train_psnr(results: Results[TrainMetrics]) -> str:
    return(get_metric_moments(results, "train/psnr"))

### How do naive K-Planes / VM compare on axis-aligned vs randomly rotated synthetic data?

In [4]:

def _cell(lego: bool, kplanes: bool, axis_aligned: bool) -> str:
    return get_psnr(
        eval_results.filter_config_eq({
            "dataset_type": "blender",
            "field.primary_channels": 64 if kplanes else 192,
            "primary_transform_count": None,
        }).filter_config_eq({
            "dataset_path": Path("data/nerf_synthetic/lego"),
        } if lego else {}).filter_config_eq({ 
            # Only include axis-aligned scenes.
            "render_config.global_rotate_seed": None,
        } if axis_aligned else {}).filter_config_neq({
            # Randomly rotated scenes only.
            "render_config.global_rotate_seed": None,
        } if not axis_aligned else {}).assert_length(3 if lego else 24)
    )


print(
    r"\begin{tabular}{lllll}",
    r"\toprule",
    r"& K-Planes & VM\\",
    r"\cmidrule(lr){2-2} \cmidrule(l){3-3}",
    f"Lego & {_cell(lego=True, kplanes=True, axis_aligned=True)} $\\to$ {_cell(lego=True, kplanes=True, axis_aligned=False)} & {_cell(lego=True, kplanes=False, axis_aligned=True)} $\\to$ {_cell(lego=True, kplanes=False, axis_aligned=False)}" + r"\\",
    f"Average & {_cell(lego=False, kplanes=True, axis_aligned=True)} $\\to$ {_cell(lego=False, kplanes=True, axis_aligned=False)} & {_cell(lego=False, kplanes=False, axis_aligned=True)} $\\to$ {_cell(lego=False, kplanes=False, axis_aligned=False)}" + r"\\",
    r"\bottomrule",
    r"\end{tabular}",
    sep="\n"
)


\begin{tabular}{lllll}
\toprule
& K-Planes & VM\\
\cmidrule(lr){2-2} \cmidrule(l){3-3}
Lego & 35.31{\scriptsize$\pm$0.02} $\to$ 33.29{\scriptsize$\pm$0.11} & 34.24{\scriptsize$\pm$0.04} $\to$ 32.63{\scriptsize$\pm$0.01}\\
Average & 32.12{\scriptsize$\pm$0.02} $\to$ 31.62{\scriptsize$\pm$0.04} & 31.30{\scriptsize$\pm$0.03} $\to$ 30.76{\scriptsize$\pm$0.03}\\
\bottomrule
\end{tabular}


### What impact does incorporating TILTED have on randomly rotated synthetic data?

In [5]:
def _cell(lego: bool, kplanes: bool, tilted: bool) -> str:
    return get_psnr(
        eval_results.filter_config_eq({
            "dataset_type": "blender",
            "field.primary_channels": 64 if kplanes else 192,
            "primary_transform_count": 8 if tilted else None,
            "field.grid_type": "kplane" if kplanes else "vm"
        }).filter_config_eq({
            "dataset_path": Path("data/nerf_synthetic/lego"),
        } if lego else {}).filter_config_neq({
            # For axis-aligned approaches, only include randomly rotated scenes.
            # For TILTED, we include both the axis-aligned and the randomly rotated options.
            "render_config.global_rotate_seed": None,
        } if not tilted else {}).filter(
            lambda result: result.path.name.endswith("-bottleneck") if tilted else True
        ).assert_length(
            (3 if lego else 24) * (2 if tilted else 1)
        )
    )


print(
    r"\begin{tabular}{lllll}",
    r"\toprule",
    r"& K-Planes & w/ TILTED & VM & w/ TILTED \\",
    r"\cmidrule(lr){2-3} \cmidrule(l){4-5}",
    f"Lego & {_cell(True, kplanes=True, tilted=False)} & {bold(_cell(True, kplanes=True, tilted=True))} & {_cell(True, kplanes=False, tilted=False)} & {bold(_cell(True, kplanes=False, tilted=True))}" r"\\",
    f"Average & {_cell(False, kplanes=True, tilted=False)} & {bold(_cell(False, kplanes=True, tilted=True))} & {_cell(False, kplanes=False, tilted=False)} & {bold(_cell(False, kplanes=False, tilted=True))}" r"]\\",
    r"\bottomrule",
    r"\end{tabular}",
    sep="\n"
)

\begin{tabular}{lllll}
\toprule
& K-Planes & w/ TILTED & VM & w/ TILTED \\
\cmidrule(lr){2-3} \cmidrule(l){4-5}
Lego & 33.29{\scriptsize$\pm$0.11} & \textbf{34.35{\scriptsize$\pm$0.07}} & 32.63{\scriptsize$\pm$0.01} & \textbf{33.90{\scriptsize$\pm$0.06}}\\
Average & 31.62{\scriptsize$\pm$0.04} & \textbf{31.91{\scriptsize$\pm$0.04}} & 30.76{\scriptsize$\pm$0.03} & \textbf{31.08{\scriptsize$\pm$0.02}}]\\
\bottomrule
\end{tabular}


### Real-world experiments: how does TILTED compare?

In [6]:
dataset_names = set(
    [
        p.config["dataset_path"].name for p in
        train_results.filter_config_eq({
            "dataset_type": "nerfstudio",
        })
    ]
)

def get_train_psnr(dname: str, channels: int, grid_type: Literal["kplane", "vm"], tilted: bool) -> float:
    return train_results.filter(
        lambda p: p.config["dataset_path"].name == dname
    ).filter_config_eq(
        {
            "primary_transform_count": 8 if tilted else None,
            "field.primary_channels": channels,
            "field.grid_type": grid_type,
        }
    ).assert_length(1)[0].metrics["train/psnr"]

count = 0
tilted_better_count = 0
tilted_better_than_naive_double_count = 0
tolerance = 0.0

for dname in dataset_names:
    naive = get_train_psnr(dname, 32, 'kplane', False)
    tilted = get_train_psnr(dname, 32, 'kplane', True)
    naive_double = get_train_psnr(dname, 64, 'kplane', False)
    tilted_double = get_train_psnr(dname, 64, 'kplane', True)

    if tilted > naive:
        tilted_better_count += 1
    if tilted_double > naive_double:
        tilted_better_count += 1
    if tilted > naive_double - tolerance:
        tilted_better_than_naive_double_count += 1
    count += 2

    naive = get_train_psnr(dname, 96, 'vm', False)
    tilted = get_train_psnr(dname, 96, 'vm', True)
    naive_double = get_train_psnr(dname, 192, 'vm', False)
    tilted_double = get_train_psnr(dname, 192, 'vm', True)

    if tilted > naive:
        tilted_better_count += 1
    if tilted_double > naive_double:
        tilted_better_count += 1
    if tilted > naive_double - tolerance:
        tilted_better_than_naive_double_count += 1
    count += 2

print("How often does TILTED improve performance?")
print(tilted_better_count / count* 100, "%")
print()

print("How often does adding TILTED outperform _doubling the channel count_?")
print(tilted_better_than_naive_double_count / (count / 2) * 100, "%")
print()

print("How many experiments have we run?")
print(count)
print()

print("What are the runtimes of each method?")

def minutes(it_per_sec: float) -> str:
    sec = (30_000 / it_per_sec)
    return f"{int(sec // 60)}:{int(sec % 60)}"

print("32c", minutes(train_results.filter_config_eq({
    "dataset_type": "nerfstudio",
    "primary_transform_count": None,
    "field.grid_type": "kplane",
    "field.primary_channels": 32
})[0].metrics["train/iterations_per_sec"]))

print("32c TILTED", minutes(train_results.filter_config_eq({
    "dataset_type": "nerfstudio",
    "primary_transform_count": 8,
    "field.grid_type": "kplane",
    "field.primary_channels": 32
})[0].metrics["train/iterations_per_sec"]))

print("64c", minutes(train_results.filter_config_eq({
    "dataset_type": "nerfstudio",
    "primary_transform_count": None,
    "field.grid_type": "kplane",
    "field.primary_channels": 64
})[0].metrics["train/iterations_per_sec"]))

print("64c TILTED", minutes(train_results.filter_config_eq({
    "dataset_type": "nerfstudio",
    "primary_transform_count": 8,
    "field.grid_type": "kplane",
    "field.primary_channels": 64
})[0].metrics["train/iterations_per_sec"]))

How often does TILTED improve performance?
100.0 %

How often does adding TILTED outperform _doubling the channel count_?
63.888888888888886 %

How many experiments have we run?
72

What are the runtimes of each method?
32c 9:16
32c TILTED 11:4
64c 14:46
64c TILTED 17:36


In [7]:
lines = []
deltas = []

kplane_best_count = 0
kplane_tilted_best_count = 0
kplane_double_best_count = 0

vm_best_count = 0
vm_tilted_best_count = 0
vm_double_best_count = 0

for dname in dataset_names:
    kplane = get_train_psnr(dname, 32, 'kplane', False)
    kplane_tilted = get_train_psnr(dname, 32, 'kplane', True)
    kplane_double = get_train_psnr(dname, 64, 'kplane', False)

    vm = get_train_psnr(dname, 96, 'vm', False)
    vm_tilted = get_train_psnr(dname, 96, 'vm', True)
    vm_double = get_train_psnr(dname, 192, 'vm', False)

    if max(kplane, kplane_tilted, kplane_double) == kplane:
        kplane_best_count += 1
    if max(kplane, kplane_tilted, kplane_double) == kplane_tilted:
        kplane_tilted_best_count += 1
    if max(kplane, kplane_tilted, kplane_double) == kplane_double:
        kplane_double_best_count += 1

    if max(vm, vm_tilted, vm_double) == vm:
        vm_best_count += 1
    if max(vm, vm_tilted, vm_double) == vm_tilted:
        vm_tilted_best_count += 1
    if max(vm, vm_tilted, vm_double) == vm_double:
        vm_double_best_count += 1


    lines.append(
        f"{dname.replace('-', ' ').title().ljust(15)}"
        f" & {kplane:.2f}"
        + " / " + (bold if kplane_double > kplane_tilted else lambda x: x)(f"{kplane_double:.2f}")
        + " / " + (bold if kplane_double <= kplane_tilted else lambda x: x)(f"{kplane_tilted:.2f}")
        + f" & {vm:.2f}"
        + " / " + (bold if vm_double > vm_tilted else lambda x: x)(f"{vm_double:.2f}")
        + " / " + (bold if vm_double <= vm_tilted else lambda x: x)(f"{vm_tilted:.2f}")
        + r" \\"
    )
    # lines.append(f"{dname.replace('-', ' ').title().ljust(15)} & {kplane:.2f} & \\textbf{'{'}{kplane_tilted:.2f} (+{kplane_tilted - kplane:.2f}){'}'} & {vm:.2f} & \\textbf{'{'}{vm_tilted:.2f} (+{vm_tilted - vm:.2f}){'}'}")
    deltas.append(kplane_tilted - kplane + vm_tilted - vm)

print(r"\begin{tabular}{lcc}")
print(r"\toprule")
print(r"Dataset & K-Plane / 2x / TILTED & VM / 2x / TILTED\\")
print(r"\cmidrule(r){1-1} \cmidrule(lr){2-2} \cmidrule(l){3-3}")
for i in sorted(range(len(lines)), key=lambda i: -deltas[i]):
    print(lines[i])
print(r"\cmidrule(r){1-1} \cmidrule(lr){2-2} \cmidrule(l){3-3}")
print(f"Best # & {kplane_best_count} / {kplane_double_best_count} / {kplane_tilted_best_count} & {vm_best_count} / {vm_double_best_count} / {vm_tilted_best_count}" + r"\\")
print(r"\bottomrule")
print(r"\end{tabular}")

\begin{tabular}{lcc}
\toprule
Dataset & K-Plane / 2x / TILTED & VM / 2x / TILTED\\
\cmidrule(r){1-1} \cmidrule(lr){2-2} \cmidrule(l){3-3}
Kitchen         & 25.95 / 26.91 / \textbf{27.12} & 25.63 / 26.54 / \textbf{26.90} \\
Floating Tree   & 24.58 / \textbf{25.17} / 25.06 & 24.03 / 24.70 / \textbf{25.04} \\
Poster          & 33.14 / 33.71 / \textbf{33.79} & 32.84 / 33.49 / \textbf{33.61} \\
Redwoods2       & 23.55 / 24.08 / \textbf{24.12} & 23.22 / 23.81 / \textbf{23.85} \\
Stump           & 26.82 / \textbf{27.29} / 27.28 & 26.33 / 26.83 / \textbf{26.97} \\
Vegetation      & 21.62 / 22.10 / \textbf{22.10} & 21.11 / 21.55 / \textbf{21.73} \\
Bww_Entrance    & 24.64 / \textbf{25.06} / 24.95 & 24.22 / 24.75 / \textbf{24.80} \\
Library         & 25.24 / 25.68 / \textbf{25.78} & 25.50 / 25.78 / \textbf{25.84} \\
Storefront      & 29.71 / \textbf{30.12} / 29.87 & 29.15 / 29.77 / \textbf{29.87} \\
Dozer           & 22.37 / \textbf{22.88} / 22.69 & 21.91 / \textbf{22.46} / 22.40 \\
Egypt       

### Ablations

Let's compare:
- 4 transform, 8 transform TILTED
- 2-phase registration vs direct optimization

We'll use Lego, since that's where the ``correct'' alignment might be the most obvious.

In [18]:
print(
    r"\begin{tabular}{lll}",
    r"\toprule",
    sep="\n"
)
print(r"& 8 transforms & 4 transforms\\")
print(r"\cmidrule(lr){2-2} \cmidrule(l){3-3}")
for two_phase_enabled in (True, False):
    print("Two phase" if two_phase_enabled else "w/o", end=" ")
    for primary_transform_count in (8, 4):
        print(
            "&",
            get_psnr(
                eval_results.filter_config_eq(
                    {
                        "primary_transform_count": primary_transform_count,
                        "field.grid_type": "kplane",
                        "field.primary_channels": 64,
                    }
                ).filter_config_neq({
                    # Since TILTED is transform invariant, we can include both rotated and axis-aligned scenes.
                    # "render_config.global_rotate_seed": None,
                }).filter(
                    lambda p: p.config["dataset_path"].name == "lego"
                ).filter(
                    lambda p: p.path.name.endswith("-bottleneck") if two_phase_enabled else not p.path.name.endswith("-bottleneck")
                ).assert_length(6)
            ),
            end=" "
        )
    print(r" \\")

print(
    r"\bottomrule",
    r"\end{tabular}",
    sep="\n"
)

\begin{tabular}{lll}
\toprule
& 8 transforms & 4 transforms\\
\midrule\\
Two phase & 34.35{\scriptsize$\pm$0.07} & 34.19{\scriptsize$\pm$0.22}  \\
w/o & 33.95{\scriptsize$\pm$0.15} & 33.83{\scriptsize$\pm$0.08}  \\
\bottomrule
\end{tabular}
