In [2]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))
sys.path.append(str(Path().resolve().parent / "submodules/resnet-18-autoencoder/src"))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [3]:
import torch
import matplotlib.pyplot as plt

from data import CIFAR10GaussianSplatsDataset
from utils import (
    noop_collate,
    transform_autoencoder_input,
    transform_autoencoder_output,
)

plt.style.use("../style/main.mpltstyle")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def collect_samples_by_label(
    data, n_samples, test_loader, class_to_index, index_to_class
):
    data = {k: [] for k in class_to_index.keys()}
    all_labels_filled = False

    for batch in test_loader:
        if all_labels_filled:
            break

        for image, index, splat in batch:
            label = index_to_class[index]
            if len(data[label]) < n_samples:
                data[label].append((image, splat))

            all_labels_filled = all(len(v) >= n_samples for v in data.values())
            if all_labels_filled:
                break

    return data


def custom_forward(
    means_model, quats_model, scales_model, opacities_model, colors_model, splat
):
    means_model = means_model.to(DEVICE)
    quats_model = quats_model.to(DEVICE)
    scales_model = scales_model.to(DEVICE)
    opacities_model = opacities_model.to(DEVICE)
    colors_model = colors_model.to(DEVICE)
    splat = splat.to(DEVICE)
    x = transform_autoencoder_input(splat, "dict")
    x_means = x["means"].unsqueeze(0)
    x_quats = x["quats"].unsqueeze(0)
    x_scales = x["scales"].unsqueeze(0)
    x_opacities = x["opacities"].unsqueeze(0)
    x_colors = x["colors"].unsqueeze(0)
    y_means = means_model(x_means).squeeze(0)
    y_quats = quats_model(x_quats).squeeze(0)
    y_scales = scales_model(x_scales).squeeze(0)
    y_opacities = opacities_model(x_opacities).squeeze(0)
    y_colors = colors_model(x_colors).squeeze(0)
    y = {
        "means": y_means,
        "quats": y_quats,
        "scales": y_scales,
        "opacities": y_opacities,
        "colors": y_colors,
    }
    splat = transform_autoencoder_output(y, "dict")
    return splat

In [5]:
N_CLASSES = 10
N_SAMPLES = 1
MODEL_PATH = "../models/final_models/conv_method_3"
CUSTOM_MEANS_MODEL = torch.load(f"{MODEL_PATH}/means_model.pt", map_location=DEVICE)
CUSTOM_QUATS_MODEL = torch.load(f"{MODEL_PATH}/quats_model.pt", map_location=DEVICE)
CUSTOM_SCALES_MODEL = torch.load(f"{MODEL_PATH}/scales_model.pt", map_location=DEVICE)
CUSTOM_OPACITIES_MODEL = torch.load(
    f"{MODEL_PATH}/opacities_model.pt", map_location=DEVICE
)
CUSTOM_COLORS_MODEL = torch.load(f"{MODEL_PATH}/colors_model.pt", map_location=DEVICE)

test_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    test=True,
    init_type="grid",
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=noop_collate,
)

CLASS_TO_INDEX = test_dataset.class_to_index
INDEX_TO_CLASS = {v: k for k, v in CLASS_TO_INDEX.items()}
DATA = collect_samples_by_label(
    test_dataset, N_SAMPLES, test_loader, CLASS_TO_INDEX, INDEX_TO_CLASS
)
CUSTOM_RESULTS = {k: [] for k in CLASS_TO_INDEX.keys()}

In [6]:
for label, samples in DATA.items():
    print(f"Label: {label}")
    for image, splat in samples:
        result = custom_forward(
            CUSTOM_MEANS_MODEL,
            CUSTOM_QUATS_MODEL,
            CUSTOM_SCALES_MODEL,
            CUSTOM_OPACITIES_MODEL,
            CUSTOM_COLORS_MODEL,
            splat,
        )
        for param, value in result.items():
            if param == "Ks" or param == "viewmats":
                continue
            print(
                f"{param}: [{splat[param].min()}, {splat[param].max()}] -> [{result[param].min()}, {result[param].max()}]"
            )
    print("===")


Label: airplane
colors: [-9.533025741577148, 12.163064002990723] -> [-31.062448501586914, 26.774341583251953]
means: [-1.0, 1.0] -> [-0.999977707862854, 0.9999849796295166]
opacities: [-4.73273229598999, 5.293581008911133] -> [-9.395462989807129, 2.7971932888031006]
quats: [-1.6623637676239014, 2.6217942237854004] -> [-0.4206662178039551, 1.6715068817138672]
scales: [-10.672653198242188, 3.9764418601989746] -> [-22.754470825195312, 8.203146934509277]
===
Label: automobile
colors: [-10.232390403747559, 11.857702255249023] -> [-30.712785720825195, 26.374431610107422]
means: [-1.0, 1.0] -> [-0.999977707862854, 0.9999849796295166]
opacities: [-4.8982954025268555, 5.20357608795166] -> [-9.535422325134277, 3.6803319454193115]
quats: [-1.4995923042297363, 2.910125494003296] -> [-0.433808296918869, 1.6882603168487549]
scales: [-11.359330177307129, 4.201854705810547] -> [-22.670961380004883, 8.19424057006836]
===
Label: bird
colors: [-7.313357830047607, 8.84171199798584] -> [-30.999523162841797

In [7]:
train_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)
val_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    val=True,
    init_type="grid",
)
dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset])
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=noop_collate,
)


TEMP_SPLAT_RANGES = {
    "means": {"min": [], "max": [], "mean": [], "std": []},
    "quats": {"min": [], "max": [], "mean": [], "std": []},
    "scales": {"min": [], "max": [], "mean": [], "std": []},
    "opacities": {"min": [], "max": [], "mean": [], "std": []},
    "colors": {"min": [], "max": [], "mean": [], "std": []},
}

# Loop though all samples and find min and max for each splat parameter
for batch in dataloader:
    for image, index, splat in batch:
        label = INDEX_TO_CLASS[index]
        for param, value in splat.items():
            if param == "Ks" or param == "viewmats":
                continue
            TEMP_SPLAT_RANGES[param]["min"].append(value.min())
            TEMP_SPLAT_RANGES[param]["max"].append(value.max())
            TEMP_SPLAT_RANGES[param]["mean"].append(value)
            TEMP_SPLAT_RANGES[param]["std"].append(value)

SPLAT_RANGES = {
    "means": {
        "min": min(TEMP_SPLAT_RANGES["means"]["min"]).item(),
        "max": max(TEMP_SPLAT_RANGES["means"]["max"]).item(),
        "mean": torch.stack(TEMP_SPLAT_RANGES["means"]["mean"])
        .mean(dim=(0, 1))
        .tolist(),
        "std": torch.stack(TEMP_SPLAT_RANGES["means"]["std"]).std(dim=(0, 1)).tolist(),
    },
    "quats": {
        "min": min(TEMP_SPLAT_RANGES["quats"]["min"]).item(),
        "max": max(TEMP_SPLAT_RANGES["quats"]["max"]).item(),
        "mean": torch.stack(TEMP_SPLAT_RANGES["quats"]["mean"])
        .mean(dim=(0, 1))
        .tolist(),
        "std": torch.stack(TEMP_SPLAT_RANGES["quats"]["std"]).std(dim=(0, 1)).tolist(),
    },
    "scales": {
        "min": min(TEMP_SPLAT_RANGES["scales"]["min"]).item(),
        "max": max(TEMP_SPLAT_RANGES["scales"]["max"]).item(),
        "mean": torch.stack(TEMP_SPLAT_RANGES["scales"]["mean"])
        .mean(dim=(0, 1))
        .tolist(),
        "std": torch.stack(TEMP_SPLAT_RANGES["scales"]["std"]).std(dim=(0, 1)).tolist(),
    },
    "opacities": {
        "min": min(TEMP_SPLAT_RANGES["opacities"]["min"]).item(),
        "max": max(TEMP_SPLAT_RANGES["opacities"]["max"]).item(),
        "mean": torch.stack(TEMP_SPLAT_RANGES["opacities"]["mean"])
        .mean(dim=(0, 1))
        .tolist(),
        "std": torch.stack(TEMP_SPLAT_RANGES["opacities"]["std"])
        .std(dim=(0, 1))
        .tolist(),
    },
    "colors": {
        "min": min(TEMP_SPLAT_RANGES["colors"]["min"]).item(),
        "max": max(TEMP_SPLAT_RANGES["colors"]["max"]).item(),
        "mean": torch.stack(TEMP_SPLAT_RANGES["colors"]["mean"])
        .mean(dim=(0, 1))
        .tolist(),
        "std": torch.stack(TEMP_SPLAT_RANGES["colors"]["std"]).std(dim=(0, 1)).tolist(),
    },
}
SPLAT_RANGES

{'means': {'min': -1.0,
  'max': 1.0,
  'mean': [3.073364496231079e-08, 0.0, 0.0],
  'std': [0.5956833958625793, 0.5956833958625793, 0.0]},
 'quats': {'min': -3.7537035942077637,
  'max': 4.574342727661133,
  'mean': [0.49086257815361023,
   0.5112597346305847,
   0.49628859758377075,
   0.48551133275032043],
  'std': [0.4739896357059479,
   0.5310168266296387,
   0.5126069188117981,
   0.4742636978626251]},
 'scales': {'min': -14.256706237792969,
  'max': 6.657063961029053,
  'mean': [-1.8529794216156006, -1.8229098320007324, -2.4987568855285645],
  'std': [1.4585336446762085, 1.509230375289917, 0.5868567228317261]},
 'opacities': {'min': -5.512201309204102,
  'max': 7.002721309661865,
  'mean': -3.367799997329712,
  'std': 1.417393684387207},
 'colors': {'min': -15.537788391113281,
  'max': 17.288856506347656,
  'mean': [[0.6861101388931274, 0.6411524415016174, 0.6533592343330383],
   [-0.034583691507577896, -0.02748388983309269, -0.015395736321806908],
   [0.7585484981536865, 0.7298

In [14]:
TEMP_SPLAT_RANGES = {
    "means": {"min_std": [], "max_std": []},
    "quats": {"min_std": [], "max_std": []},
    "scales": {"min_std": [], "max_std": []},
    "opacities": {"min_std": [], "max_std": []},
    "colors": {"min_std": [], "max_std": []},
}

# Loop though all samples and find min and max for each splat parameter
for batch in dataloader:
    for image, index, splat in batch:
        label = INDEX_TO_CLASS[index]
        for param, value in splat.items():
            if param == "Ks" or param == "viewmats":
                continue
            mean = torch.tensor(SPLAT_RANGES[param]["mean"])
            std = torch.tensor(SPLAT_RANGES[param]["std"])
            std[std == 0] = 1
            value = (value - mean) / std
            TEMP_SPLAT_RANGES[param]["min_std"].append(value.min())
            TEMP_SPLAT_RANGES[param]["max_std"].append(value.max())

STD_SPLAT_RANGES = {
    "means": {
        "min_std": min(TEMP_SPLAT_RANGES["means"]["min_std"]).item(),
        "max_std": max(TEMP_SPLAT_RANGES["means"]["max_std"]).item(),
    },
    "quats": {
        "min_std": min(TEMP_SPLAT_RANGES["quats"]["min_std"]).item(),
        "max_std": max(TEMP_SPLAT_RANGES["quats"]["max_std"]).item(),
    },
    "scales": {
        "min_std": min(TEMP_SPLAT_RANGES["scales"]["min_std"]).item(),
        "max_std": max(TEMP_SPLAT_RANGES["scales"]["max_std"]).item(),
    },
    "opacities": {
        "min_std": min(TEMP_SPLAT_RANGES["opacities"]["min_std"]).item(),
        "max_std": max(TEMP_SPLAT_RANGES["opacities"]["max_std"]).item(),
    },
    "colors": {
        "min_std": min(TEMP_SPLAT_RANGES["colors"]["min_std"]).item(),
        "max_std": max(TEMP_SPLAT_RANGES["colors"]["max_std"]).item(),
    },
}
STD_SPLAT_RANGES

{'means': {'min_std': -1.6787440776824951, 'max_std': 1.6787440776824951},
 'quats': {'min_std': -8.814098358154297, 'max_std': 8.55053424835205},
 'scales': {'min_std': -8.383855819702148, 'max_std': 5.618740558624268},
 'opacities': {'min_std': -1.5129185914993286, 'max_std': 7.31661319732666},
 'colors': {'min_std': -8.328888893127441, 'max_std': 7.9224629402160645}}

In [None]:
for param in SPLAT_RANGES.keys():
    SPLAT_RANGES[param]["mean"] = torch.tensor(
        SPLAT_RANGES[param]["mean"], dtype=torch.float32
    )
    SPLAT_RANGES[param]["std"] = torch.tensor(
        SPLAT_RANGES[param]["std"], dtype=torch.float32
    )
SPLAT_RANGES

{'means': {'min': -1.0,
  'max': 1.0,
  'mean': tensor([3.0734e-08, 0.0000e+00, 0.0000e+00]),
  'std': tensor([0.5957, 0.5957, 0.0000])},
 'quats': {'min': -3.7537035942077637,
  'max': 4.574342727661133,
  'mean': tensor([0.4909, 0.5113, 0.4963, 0.4855]),
  'std': tensor([0.4740, 0.5310, 0.5126, 0.4743])},
 'scales': {'min': -14.256706237792969,
  'max': 6.657063961029053,
  'mean': tensor([-1.8530, -1.8229, -2.4988]),
  'std': tensor([1.4585, 1.5092, 0.5869])},
 'opacities': {'min': -5.512201309204102,
  'max': 7.002721309661865,
  'mean': tensor(-3.3678),
  'std': tensor(1.4174)},
 'colors': {'min': -15.537788391113281,
  'max': 17.288856506347656,
  'mean': tensor([[ 0.6861,  0.6412,  0.6534],
          [-0.0346, -0.0275, -0.0154],
          [ 0.7585,  0.7298,  0.6635],
          [-0.0100, -0.0208, -0.0225]]),
  'std': tensor([[2.4589, 2.3819, 2.4028],
          [1.9165, 1.8434, 1.8344],
          [2.1524, 2.0766, 2.0886],
          [1.9149, 1.8357, 1.8274]])}}

In [None]:
# How it was ... Check output!
for label, samples in DATA.items():
    print(f"Label: {label}")
    for image, splat in samples:
        result = custom_forward(
            CUSTOM_MEANS_MODEL,
            CUSTOM_QUATS_MODEL,
            CUSTOM_SCALES_MODEL,
            CUSTOM_OPACITIES_MODEL,
            CUSTOM_COLORS_MODEL,
            splat,
        )
        for param, value in result.items():
            if param == "Ks" or param == "viewmats":
                continue
            print(
                f"{param}: [{splat[param].min()}, {splat[param].max()}] -> [{result[param].min()}, {result[param].max()}]"
            )
    print("===")

Label: airplane
colors: [-9.533025741577148, 12.163064002990723] -> [-0.9641077518463135, 0.8002510666847229]
means: [-1.0, 1.0] -> [-0.9979274272918701, 0.9987775087356567]
opacities: [-4.73273229598999, 5.293581008911133] -> [-0.9243193864822388, 0.835913360118866]
quats: [-1.6623637676239014, 2.6217942237854004] -> [-0.5537052750587463, 0.5027226805686951]
scales: [-10.672653198242188, 3.9764418601989746] -> [-0.9996795654296875, 1.0]
===
Label: automobile
colors: [-10.232390403747559, 11.857702255249023] -> [-0.9708322286605835, 0.8739020228385925]
means: [-1.0, 1.0] -> [-0.9979274272918701, 0.9987775087356567]
opacities: [-4.8982954025268555, 5.20357608795166] -> [-0.9029794335365295, 0.8699905872344971]
quats: [-1.4995923042297363, 2.910125494003296] -> [-0.6938697099685669, 0.5712622404098511]
scales: [-11.359330177307129, 4.201854705810547] -> [-0.9996407628059387, 1.0]
===
Label: bird
colors: [-7.313357830047607, 8.84171199798584] -> [-0.9591242671012878, 0.816911518573761]
me