# Nerfbusters Metrics

In [1]:
import json
from pathlib import Path

import mediapy
import numpy as np

methods_dict = {}
methods_dict["nerfbusters"] = ["gsplat", "nerfiller", "cat3d", "mvinpaint"]
methods_dict["nerfiller"] = ["mask", "gsplat", "nerfiller-no-new-views", "mvinpaint-no-new-views"]

datasets_dict = {
    "nerfbusters": [
        "aloe",
        "art",
        "car",
        "century",
        "flowers",
        "garbage",
        "picnic",
        "pikachu",
        "pipe",
        "plant",
        "roses",
        "table",
    ],
    "nerfiller": ["bear", "billiards", "boot", "cat", "drawing", "dumptruck", "norway", "office", "turtle"],
}

dataset_name = "nerfbusters"
datasets = datasets_dict[dataset_name]
methods = methods_dict[dataset_name]
folder = Path(f"/mnt/home/ethanjohnweber/data/{dataset_name}-renders")
output = Path(f"/mnt/home/ethanjohnweber/data/{dataset_name}")
save_images = False
save_video = False
seconds = 10
num_pairs = 20
# pairs1 = np.random.rand(num_pairs, 1)
# pairs2 = pairs1 + (np.random.rand(num_pairs, 1) / 4.0)
# pairs = np.concatenate([pairs1, pairs2], axis=-1)
# pairs = np.clip(pairs, 0, 0.99)
# pairs1 = np.linspace(0.0, .9, num_pairs)
# pairs2 = np.linspace(0.05, .95, num_pairs)
pairs1 = np.linspace(0.1, 0.8, num_pairs)
pairs2 = np.linspace(0.2, 0.9, num_pairs)
pairs = np.concatenate([pairs1[..., None], pairs2[..., None]], axis=-1)

In [2]:
# pairs

In [3]:
import viser

viser_port = 8890
if "viser_server" not in globals():
    # only run this once per Python process start
    viser_server = viser.ViserServer(port=viser_port)

In [4]:
import random

import cv2

In [5]:
import kornia
import torch
from kornia.feature import LoFTR

device = "cuda:0"
matcher = LoFTR("outdoor").to(device)

In [6]:
def draw_camera(name, c2w, image):
    import viser.transforms as vtf

    R = vtf.SO3.from_matrix(c2w[:3, :3])
    R = R @ vtf.SO3.from_x_radians(np.pi)
    camera_handle = viser_server.scene.add_camera_frustum(
        name=name, fov=1.0, scale=0.2, aspect=1, image=image, wxyz=R.wxyz, position=c2w[:3, 3]
    )

In [7]:
import numpy as np


def compute_angular_error_batch(rotation1, rotation2):
    # https://github.com/jasonyzhang/RayDiffusion/blob/main/ray_diffusion/eval/utils.py#L50
    R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1))
    t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
    theta = np.arccos(np.clip(t, -1, 1))
    return theta * 180 / np.pi

In [8]:
def get_metric(video, cameras, pair, mask_video=None, visualize=False):
    num_frames = len(video)
    idx0 = int(pair[0] * num_frames)
    idx1 = int(pair[1] * num_frames)
    camera0 = cameras[str(idx0)]
    camera1 = cameras[str(idx1)]
    image0 = (torch.from_numpy(video[idx0])[None].permute(0, 3, 1, 2) / 255.0).to(device)
    image1 = (torch.from_numpy(video[idx1])[None].permute(0, 3, 1, 2) / 255.0).to(device)

    if mask_video is not None:
        from mvinpaint.utils.visualization_utils import Colors

        target_color = torch.tensor(Colors.NEON_YELLOW.value).to(device)[None, :, None, None]
        mask_image0 = (torch.from_numpy(mask_video[idx0])[None].permute(0, 3, 1, 2) / 255.0).to(device)
        mask_image1 = (torch.from_numpy(mask_video[idx1])[None].permute(0, 3, 1, 2) / 255.0).to(device)
        mask_image0 = torch.abs(mask_image0 - target_color) < 0.05
        mask_image1 = torch.abs(mask_image1 - target_color) < 0.05
        mask_image0 = mask_image0[:, 0:1] & mask_image0[:, 1:2] & mask_image0[:, 2:3]
        mask_image1 = mask_image1[:, 0:1] & mask_image1[:, 1:2] & mask_image1[:, 2:3]
        # image0 *= mask_image0
        # image1 *= mask_image1

    image0_g = kornia.color.rgb_to_grayscale(image0)
    image1_g = kornia.color.rgb_to_grayscale(image1)
    input_dict = {"image0": image0_g, "image1": image1_g}
    with torch.inference_mode():
        correspondences = matcher(input_dict)

    c2w1 = torch.eye(4)[None]
    c2w1[0, :3, :] = torch.tensor(camera0["camera_to_world"])
    c2w2 = torch.eye(4)[None]
    c2w2[0, :3, :] = torch.tensor(camera1["camera_to_world"])
    if visualize:
        draw_camera("/orig/0", c2w1[0], (image0[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8"))
        draw_camera("/orig/1", c2w2[0], (image1[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8"))
    c2w2 = torch.bmm(torch.linalg.inv(c2w1), c2w2)
    c2w1 = torch.bmm(torch.linalg.inv(c2w1), c2w1)
    if visualize:
        draw_camera("/cameras/0", c2w1[0], (image0[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8"))
        draw_camera("/cameras/1", c2w2[0], (image1[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8"))

    # check if correspondence is in the mask
    keypoints0 = correspondences["keypoints0"].to(device)
    keypoints1 = correspondences["keypoints1"].to(device)
    if mask_video is not None:
        mask_value0 = mask_image0[:, :, keypoints0[..., 1].long(), keypoints0[..., 0].long()].flatten()
        mask_value1 = mask_image1[:, :, keypoints1[..., 1].long(), keypoints1[..., 0].long()].flatten()
        mask_both = (mask_value0 == 1) & (mask_value1 == 1)
        keypoints0 = keypoints0[mask_both]
        keypoints1 = keypoints1[mask_both]

    image = (torch.cat([image0, image1], dim=-1)[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8")
    for idx in range(0, len(keypoints0), 20):
        (x0, y0) = keypoints0[idx]
        (x1, y1) = keypoints1[idx]
        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        cv2.line(image, (int(x0), int(y0)), (int(x1) + 512, int(y1)), color, 2)

    K1 = torch.tensor([[camera0["fx"], 0, camera0["cx"]], [0, camera0["fy"], camera0["cy"]], [0, 0, 1]])[None].to(
        device
    )
    K2 = torch.tensor([[camera1["fx"], 0, camera1["cx"]], [0, camera1["fy"], camera1["cy"]], [0, 0, 1]])[None].to(
        device
    )

    points1 = keypoints0[None].clone()
    points2 = keypoints1[None].clone()
    F_mat = kornia.geometry.epipolar.find_fundamental(points1, points2, weights=correspondences["confidence"][None])
    E_mat = kornia.geometry.epipolar.essential_from_fundamental(F_mat, K1, K2)
    R, t, points = kornia.geometry.epipolar.motion_from_essential_choose_solution(
        E_mat, K1, K2, points1, points2, mask=None
    )

    # R[..., 1:3] *= -1
    # t[..., 1, 0] = -t[..., 1, 0]
    # t[..., 2, 0] = -t[..., 2, 0]

    # temp = torch.tensor([
    #     [1, 0, 0, 0],
    #     [0, -1, 0, 0],
    #     [0, 0, -1, 0],
    #     [0, 0, 0, 1]
    # ])[None].to(device).float()
    # print(temp.shape)
    # print(points.shape)
    # print(temp[0,:3,:3].shape)
    # points = points @ temp[0,:3,:3].permute(1,0)

    gt_trans_norm = torch.linalg.norm(c2w2[0, :3, 3])

    estimated_rel = torch.eye(4)[None].to(device)
    estimated_rel[:, :3, :3] = R
    estimated_rel[:, :3, 3:4] = t * gt_trans_norm
    estimated_rel = torch.inverse(estimated_rel)

    Rx = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]])[None].to(device).float()
    points = points @ Rx[0].permute(1, 0)
    temp = torch.tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])[None].to(device).float()

    prime_0 = (
        torch.tensor(
            [
                [1, 0, 0, 0],
                [0, -1, 0, 0],
                [0, 0, -1, 0],
                [0, 0, 0, 1],
            ]
        )[None]
        .to(device)
        .float()
    )
    prime_1 = torch.bmm(estimated_rel, prime_0)

    prime_0 = torch.bmm(temp, prime_0)
    prime_1 = torch.bmm(temp, prime_1)

    if visualize:
        draw_camera(
            "/cameras/0_prime", prime_0[0].cpu(), (image0[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8")
        )
        draw_camera(
            "/cameras/1_prime", prime_1[0].cpu(), (image1[0].permute(1, 2, 0) * 255).cpu().numpy().astype("uint8")
        )
        viser_server.scene.add_point_cloud(
            "/points",
            points=points.view(-1, 3).cpu().numpy(),
            colors=(1.0, 0, 0),
            point_size=0.03,
            point_shape="circle",
        )
        mediapy.show_image(image)

    error = compute_angular_error_batch(c2w2[:, :3, :3].cpu().numpy(), prime_1[:, :3, :3].cpu().numpy())
    return float(error[0])


# error = get_metric(video, cameras, pairs[1], mask_video=mask_video, visualize=True)
# print(error)

In [9]:
dataset_name

'nerfbusters'

In [None]:
from collections import defaultdict

metrics = defaultdict(dict)
all_video = []
for dataset in datasets:
    print(dataset)
    videos = []
    for method in methods:
        print(method)
        filename = Path(sorted(list((folder / dataset / method).iterdir()))[-1])
        cameras_filename = Path(str(filename).replace(".mp4", ".json"))
        with open(cameras_filename) as f:
            cameras = json.load(f)
        video = mediapy.read_video(filename)
        # if dataset_name == "nerfiller":
        #     filename = Path(sorted(list((folder / dataset / "mask").iterdir()))[-1])
        #     mask_video = mediapy.read_video(filename)
        # else:
        #     mask_video = None
        mask_video = None
        errors = []
        for pair in pairs:
            try:
                error = get_metric(video, cameras, pair, mask_video=mask_video, visualize=False)
                errors.append(error)
            except:
                errors.append(360.0)
        # print(errors)
        errors = np.array(errors)
        metrics[dataset][method] = errors
        if save_images:
            for i in range(len(video)):
                filename = output / f"{dataset}/{method}/image-{i:06d}.jpg"
                print(filename)
                filename.parent.mkdir(parents=True, exist_ok=True)
                mediapy.write_image(filename, video[i])
        if save_video:
            print("todo: save video")
        videos.append(video)
    # mediapy.show_videos(videos, fps=len(cat_video)/seconds)
    cat_video = np.concatenate(videos, axis=2)
    all_video.append(cat_video)
cat_all_video = np.concatenate(all_video, axis=1)

aloe
gsplat
nerfiller
cat3d
mvinpaint
art
gsplat
nerfiller
cat3d
mvinpaint
car
gsplat
nerfiller
cat3d
mvinpaint
century
gsplat
nerfiller
cat3d
mvinpaint
flowers
gsplat
nerfiller
cat3d
mvinpaint
garbage
gsplat
nerfiller
cat3d
mvinpaint
picnic
gsplat
nerfiller
cat3d
mvinpaint
pikachu
gsplat
nerfiller
cat3d
mvinpaint
pipe
gsplat
nerfiller
cat3d
mvinpaint
plant
gsplat
nerfiller
cat3d
mvinpaint
roses
gsplat
nerfiller
cat3d
mvinpaint
table
gsplat
nerfiller
cat3d
mvinpaint


In [None]:
# for dataset in datasets:
#     for method in methods:
#         print(dataset, method, metrics[dataset][method].mean())
#     print()

In [None]:
# thresholds = np.arange(10, 30, 1)
thresholds = np.arange(0, 30, 1)
print(f"thresholds: {thresholds}")
rra_dict = defaultdict(dict)
for dataset in datasets:
    for method in methods:
        for thr in thresholds:
            rra = (metrics[dataset][method] < thr).mean()
            value = rra * (1 / len(metrics))
            try:
                rra_dict[method][thr] += value
            except:
                rra_dict[method][thr] = value

In [None]:
# rra_dict

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

# for method in methods:
colors = ["#3498db", "#f1c40f", "#2ecc71", "#e74c3c"]
for idx, (method, name) in enumerate(
    [("mvinpaint", "Fillerbuster"), ("cat3d", "CAT3D-Imp."), ("gsplat", "GSplat"), ("nerfiller", "NeRFiller")]
):
    # for idx, (method, name) in enumerate([("mvinpaint", "A"), ("cat3d", "B"), ("gsplat", "C"), ("nerfiller", "D")]):
    # for idx, method in enumerate(methods):
    #     print(method)
    #     name = method
    x = sorted(rra_dict[method].keys())
    y = [rra_dict[method][key] for key in x]
    fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(width=6, color=colors[idx])))

fig.update_layout(
    title=dict(text="Relative Rotation Accuracy", x=0.5, y=0.92, xanchor="center"),
    # xaxis_title='Degrees',
    # xaxis_title=dict(
    #     text='Degrees',
    #     x=0.5,y=0.0,
    #     xanchor='center'
    # ),
    yaxis_title="RRA @ Degrees",
    paper_bgcolor="white",
    plot_bgcolor="white",
    font=dict(size=16, family="Helvetica"),
    legend=dict(orientation="v", yanchor="top", xanchor="right", x=1.2, y=1.0),
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor="black"),
    yaxis=dict(
        showgrid=True,
        gridwidth=1,
        gridcolor="black",
        range=[0, 1.01],  # Add some padding
    ),
    margin=dict(l=0, r=0, t=50, b=0),
)

fig.show()
fig.write_image("rra_plot.pdf", width=800, height=300)

In [None]:
# mediapy.show_video(cat_all_video, fps=len(cat_video)/seconds)