In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))
sys.path.append(str(Path().resolve().parent / "submodules/resnet-18-autoencoder/src"))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
import matplotlib.pyplot as plt

from data import CIFAR10GaussianSplatsDataset
from utils import noop_collate, transform_autoencoder_input, transform_autoencoder_output

plt.style.use("../style/main.mpltstyle")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def collect_samples_by_label(data, n_samples, test_loader, class_to_index, index_to_class):
    data = {k: [] for k in class_to_index.keys()}
    all_labels_filled = False  

    for batch in test_loader:
        if all_labels_filled:
            break  

        for image, index, splat in batch:
            label = index_to_class[index]
            if len(data[label]) < n_samples:
                data[label].append((image, splat))

            all_labels_filled = all(len(v) >= n_samples for v in data.values())
            if all_labels_filled:
                break  
            
    return data


def custom_forward(means_model, quats_model, scales_model, opacities_model, colors_model, splat):
    means_model = means_model.to(DEVICE)
    quats_model = quats_model.to(DEVICE)
    scales_model = scales_model.to(DEVICE)
    opacities_model = opacities_model.to(DEVICE)
    colors_model = colors_model.to(DEVICE)
    splat = splat.to(DEVICE)
    x = transform_autoencoder_input(splat, "dict")
    x_means = x["means"].unsqueeze(0)
    x_quats = x["quats"].unsqueeze(0)
    x_scales = x["scales"].unsqueeze(0)
    x_opacities = x["opacities"].unsqueeze(0)
    x_colors = x["colors"].unsqueeze(0)
    y_means = means_model(x_means).squeeze(0)
    y_quats = quats_model(x_quats).squeeze(0)
    y_scales = scales_model(x_scales).squeeze(0)
    y_opacities = opacities_model(x_opacities).squeeze(0)
    y_colors = colors_model(x_colors).squeeze(0)
    y = {
        "means": y_means,
        "quats": y_quats,
        "scales": y_scales,
        "opacities": y_opacities,
        "colors": y_colors,
    }
    splat = transform_autoencoder_output(y, "dict")
    return splat

In [4]:
N_CLASSES = 10  
N_SAMPLES = 1
MODEL_PATH = "../models/final_models/conv_method_3" 
CUSTOM_MEANS_MODEL = torch.load(f"{MODEL_PATH}/means_model.pt", map_location=DEVICE)
CUSTOM_QUATS_MODEL = torch.load(f"{MODEL_PATH}/quats_model.pt", map_location=DEVICE)
CUSTOM_SCALES_MODEL = torch.load(f"{MODEL_PATH}/scales_model.pt", map_location=DEVICE)
CUSTOM_OPACITIES_MODEL = torch.load(f"{MODEL_PATH}/opacities_model.pt", map_location=DEVICE)
CUSTOM_COLORS_MODEL = torch.load(f"{MODEL_PATH}/colors_model.pt", map_location=DEVICE)

test_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS", 
    test=True,
    init_type="grid",
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=noop_collate,
)

CLASS_TO_INDEX = test_dataset.class_to_index
INDEX_TO_CLASS = {v: k for k, v in CLASS_TO_INDEX.items()}
DATA = collect_samples_by_label(test_dataset, N_SAMPLES, test_loader, CLASS_TO_INDEX, INDEX_TO_CLASS)
CUSTOM_RESULTS = {k: [] for k in CLASS_TO_INDEX.keys()}

In [5]:
for label, samples in DATA.items():
    print(f"Label: {label}")
    for image, splat in samples:
        result = custom_forward(CUSTOM_MEANS_MODEL, CUSTOM_QUATS_MODEL, CUSTOM_SCALES_MODEL, CUSTOM_OPACITIES_MODEL, CUSTOM_COLORS_MODEL, splat)
        for param, value in result.items():
            if param == "Ks" or param == "viewmats":
                continue
            print(f"{param}: [{splat[param].min()}, {splat[param].max()}] -> [{result[param].min()}, {result[param].max()}]")
    print("===")

Label: airplane
colors: [-9.533025741577148, 12.163064002990723] -> [-0.9641077518463135, 0.8002510666847229]
means: [-1.0, 1.0] -> [-0.9979274272918701, 0.9987775087356567]
opacities: [-4.73273229598999, 5.293581008911133] -> [-0.9243193864822388, 0.835913360118866]
quats: [-1.6623637676239014, 2.6217942237854004] -> [-0.5537052750587463, 0.5027226805686951]
scales: [-10.672653198242188, 3.9764418601989746] -> [-0.9996795654296875, 1.0]
===
Label: automobile
colors: [-10.232390403747559, 11.857702255249023] -> [-0.9708322286605835, 0.8739020228385925]
means: [-1.0, 1.0] -> [-0.9979274272918701, 0.9987775087356567]
opacities: [-4.8982954025268555, 5.20357608795166] -> [-0.9029794335365295, 0.8699905872344971]
quats: [-1.4995923042297363, 2.910125494003296] -> [-0.6938697099685669, 0.5712622404098511]
scales: [-11.359330177307129, 4.201854705810547] -> [-0.9996407628059387, 1.0]
===
Label: bird
colors: [-7.313357830047607, 8.84171199798584] -> [-0.9591242671012878, 0.816911518573761]
me

In [6]:
train_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)
val_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    val=True,
    init_type="grid",
)
dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset])
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=noop_collate,
)


TEMP_SPLAT_RANGES = {
    "means": {"min": [], "max": []},
    "quats": {"min": [], "max": []},
    "scales": {"min": [], "max": []},
    "opacities": {"min": [], "max": []},
    "colors": {"min": [], "max": []},
}

# Loop though all samples and find min and max for each splat parameter
for batch in dataloader:
    for image, index, splat in batch:
        label = INDEX_TO_CLASS[index]
        for param, value in splat.items():
            if param == "Ks" or param == "viewmats":
                continue
            TEMP_SPLAT_RANGES[param]["min"].append(value.min())
            TEMP_SPLAT_RANGES[param]["max"].append(value.max())

SPLAT_RANGES = {
    "means": {"min": min(TEMP_SPLAT_RANGES["means"]["min"]).item(), "max": max(TEMP_SPLAT_RANGES["means"]["max"]).item()},
    "quats": {"min": min(TEMP_SPLAT_RANGES["quats"]["min"]).item(), "max": max(TEMP_SPLAT_RANGES["quats"]["max"]).item()},
    "scales": {"min": min(TEMP_SPLAT_RANGES["scales"]["min"]).item(), "max": max(TEMP_SPLAT_RANGES["scales"]["max"]).item()},
    "opacities": {"min": min(TEMP_SPLAT_RANGES["opacities"]["min"]).item(), "max": max(TEMP_SPLAT_RANGES["opacities"]["max"]).item()},
    "colors": {"min": min(TEMP_SPLAT_RANGES["colors"]["min"]).item(), "max": max(TEMP_SPLAT_RANGES["colors"]["max"]).item()},
}
SPLAT_RANGES

{'means': {'min': -1.0, 'max': 1.0},
 'quats': {'min': -3.7537035942077637, 'max': 4.574342727661133},
 'scales': {'min': -14.256706237792969, 'max': 6.657063961029053},
 'opacities': {'min': -5.512201309204102, 'max': 7.002721309661865},
 'colors': {'min': -15.537788391113281, 'max': 17.288856506347656}}