# NeRFiller metrics

In [1]:
import glob
import os
from pathlib import Path

import mediapy
import numpy as np
import torch

methods = [
    "mvinpaint-no-new-views",
    "nerfiller-no-new-views",
    "mvinpaint-no-new-views-no-normals",
    "nerfiller-no-new-views-no-normals",
]
datasets = ["bear", "billiards", "boot", "cat", "drawing", "dumptruck", "norway", "office", "turtle"]
device = "cuda:0"

In [2]:
from torchmetrics.image import PeakSignalNoiseRatio

psnr = PeakSignalNoiseRatio(data_range=1.0)
from pytorch_msssim import SSIM

ssim = SSIM(data_range=1.0, size_average=True, channel=3)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)

  self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)


In [3]:
psnr = psnr.to(device)
ssim = ssim.to(device)
lpips = lpips.to(device)

In [4]:
from collections import defaultdict

metrics = defaultdict(dict)

In [9]:
for dataset in datasets:
    print(dataset)
    for method in methods:
        print(method)

        psnrs = []
        ssims = []
        lpipss = []

        filename = (
            Path(
                sorted(
                    list(
                        Path(
                            f"/mnt/home/ethanjohnweber/nerfstudio-outputs/{method}/{dataset}/mvinpaint-splatfacto/"
                        ).iterdir()
                    )
                )[-1]
            )
            / "individual-inpaints"
        )
        image_filenames = sorted(glob.glob(str(filename / "*")))
        render_filenames = sorted(
            glob.glob(
                os.path.join(
                    sorted(
                        list(
                            Path(f"/mnt/home/ethanjohnweber/data/nerfiller-renders/{dataset}/{method}").glob(
                                "*dataset-renders"
                            )
                        )
                    )[-1],
                    "train/rgb/*",
                )
            )
        )
        conditoning_indices = torch.linspace(0, len(render_filenames) - 1, 32).long()
        render_filenames = [render_filenames[i] for i in conditoning_indices]
        for idx, (im1f, im2f) in enumerate(zip(image_filenames, render_filenames)):
            if idx != 0:
                continue
            # print(im1f)
            # print(im2f)
            img1 = mediapy.read_image(im1f)
            img1 = mediapy.resize_image(img1, (512, 512))
            img2 = mediapy.read_image(im2f)
            img2 = mediapy.resize_image(img2, (512, 512))
            # mediapy.show_images([img1, img2])
            # break
            img1 = (torch.from_numpy(img1) / 255.0)[None].permute(0, 3, 1, 2).to(device)
            img2 = (torch.from_numpy(img2) / 255.0)[None].permute(0, 3, 1, 2).to(device)

            psnrs.append(psnr(img1, img2).item())
            ssims.append(ssim(img1, img2).item())
            lpipss.append(lpips(img1, img2).item())
        # break
        metrics[dataset][method] = {"psnr": np.mean(psnrs), "ssim": np.mean(ssims), "lpips": np.mean(lpipss)}

bear
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
billiards
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
boot
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
cat
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
drawing
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
dumptruck
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
norway
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
office
mvinpaint-no-new-views
nerfiller-no-new-views
mvinpaint-no-new-views-no-normals
nerfiller-no-new-views-no-normals
turtle
mvinpaint-no-new-views
ne

In [10]:
for method in methods:
    ps = 0
    ss = 0
    lp = 0
    for dataset in datasets:
        ps += metrics[dataset][method]["psnr"]
        ss += metrics[dataset][method]["ssim"]
        lp += metrics[dataset][method]["lpips"]
    ps /= len(datasets)
    ss /= len(datasets)
    lp /= len(datasets)
    print(method, ps, ss, lp)

mvinpaint-no-new-views 29.656557506985134 0.9073409570588006 0.10202627463473214
nerfiller-no-new-views 25.99903678894043 0.8663748635186089 0.20223851419157451
mvinpaint-no-new-views-no-normals 31.161026424831814 0.9369973672760857 0.06258510260118379
nerfiller-no-new-views-no-normals 26.391361448499893 0.8793721596399943 0.19031956626309288
