In [2]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))
sys.path.append(str(Path().resolve().parent / "submodules/resnet-18-autoencoder/src"))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from models import ResNetAutoencoder, ConvAutoencoder
from gsplat import rasterization_2dgs
from data import CIFAR10GaussianSplatsDataset
from utils import (
    noop_collate,
    transform_autoencoder_input,
    transform_autoencoder_output,
)
from constants import (
    CIFAR10_TRANSFORM,
    CIFAR10_INVERSE_TRANSFORM,
    TENSOR_TRANSFORM,
    PIL_TRANSFORM,
)
from classes.resnet_autoencoder import AE as DefaultResNetAutoencoder

plt.style.use("../style/main.mpltstyle")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def collect_samples_by_label(
    data, n_samples, test_loader, class_to_index, index_to_class
):
    data = {k: [] for k in class_to_index.keys()}
    all_labels_filled = False

    for batch in test_loader:
        if all_labels_filled:
            break

        for image, index, splat in batch:
            label = index_to_class[index]
            if len(data[label]) < n_samples:
                data[label].append((image, splat))

            all_labels_filled = all(len(v) >= n_samples for v in data.values())
            if all_labels_filled:
                break

    return data


def compression_ratio(image: Image.Image, latent_tensor: torch.Tensor):
    image_array = np.array(image)
    image_tensor = torch.tensor(image_array).float()

    input_size = image_tensor.numel()
    latent_size = latent_tensor.numel()

    compression_ratio = input_size / latent_size

    return compression_ratio


def custom_forward(
    means_model, quats_model, scales_model, opacities_model, colors_model, splat
):
    means_model = means_model.to(DEVICE)
    quats_model = quats_model.to(DEVICE)
    scales_model = scales_model.to(DEVICE)
    opacities_model = opacities_model.to(DEVICE)
    colors_model = colors_model.to(DEVICE)
    splat = splat.to(DEVICE)
    x = transform_autoencoder_input(splat, "dict")
    x_means = x["means"].unsqueeze(0)
    x_quats = x["quats"].unsqueeze(0)
    x_scales = x["scales"].unsqueeze(0)
    x_opacities = x["opacities"].unsqueeze(0)
    x_colors = x["colors"].unsqueeze(0)
    y_means = means_model.encoder(x_means)
    y_quats = quats_model.encoder(x_quats)
    y_scales = scales_model.encoder(x_scales)
    y_opacities = opacities_model.encoder(x_opacities)
    y_colors = colors_model.encoder(x_colors)
    y = {
        "means": y_means,
        "quats": y_quats,
        "scales": y_scales,
        "opacities": y_opacities,
        "colors": y_colors,
    }
    return y


def default_forward(model, image):
    x = TENSOR_TRANSFORM(image)
    y = model.encoder(x)
    return y

In [5]:
N_CLASSES = 10
N_SAMPLES = 1
MODEL_PATH = "../models/final_models/conv_method_3"  # TODO: Change path
CUSTOM_MEANS_MODEL = torch.load(f"{MODEL_PATH}/means_model.pt", map_location=DEVICE)
CUSTOM_QUATS_MODEL = torch.load(f"{MODEL_PATH}/quats_model.pt", map_location=DEVICE)
CUSTOM_SCALES_MODEL = torch.load(f"{MODEL_PATH}/scales_model.pt", map_location=DEVICE)
CUSTOM_OPACITIES_MODEL = torch.load(
    f"{MODEL_PATH}/opacities_model.pt", map_location=DEVICE
)
CUSTOM_COLORS_MODEL = torch.load(f"{MODEL_PATH}/colors_model.pt", map_location=DEVICE)

DEFAULT_MODEL = DefaultResNetAutoencoder("light")
DEFAULT_MODEL.load_state_dict(
    torch.load("../models/default_resnet_autoencoder.ckpt")["model_state_dict"]
)

test_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",  # TODO: Change path
    test=True,
    init_type="grid",
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=noop_collate,
)

CLASS_TO_INDEX = test_dataset.class_to_index
INDEX_TO_CLASS = {v: k for k, v in CLASS_TO_INDEX.items()}
DATA = collect_samples_by_label(
    test_dataset, N_SAMPLES, test_loader, CLASS_TO_INDEX, INDEX_TO_CLASS
)
CUSTOM_RESULTS = {k: [] for k in CLASS_TO_INDEX.keys()}
DEFAULT_RESULTS = {k: [] for k in CLASS_TO_INDEX.keys()}

In [6]:
for label, samples in DATA.items():
    for image, splat in samples:
        if (
            CUSTOM_MEANS_MODEL
            and CUSTOM_QUATS_MODEL
            and CUSTOM_SCALES_MODEL
            and CUSTOM_OPACITIES_MODEL
            and CUSTOM_COLORS_MODEL
        ):
            CUSTOM_RESULTS[label].append(
                custom_forward(
                    CUSTOM_MEANS_MODEL,
                    CUSTOM_QUATS_MODEL,
                    CUSTOM_SCALES_MODEL,
                    CUSTOM_OPACITIES_MODEL,
                    CUSTOM_COLORS_MODEL,
                    splat,
                )
            )
        if DEFAULT_MODEL:
            DEFAULT_RESULTS[label].append(default_forward(DEFAULT_MODEL, image))

In [7]:
for label, samples in DATA.items():
    for i, (image, splat) in enumerate(samples):
        default_compression_ratio = compression_ratio(image, DEFAULT_RESULTS[label][i])
        custom_means_compression_ratio = compression_ratio(
            image, CUSTOM_RESULTS[label][i]["means"]
        )
        custom_quats_compression_ratio = compression_ratio(
            image, CUSTOM_RESULTS[label][i]["quats"]
        )
        custom_scales_compression_ratio = compression_ratio(
            image, CUSTOM_RESULTS[label][i]["scales"]
        )
        custom_opacities_compression_ratio = compression_ratio(
            image, CUSTOM_RESULTS[label][i]["opacities"]
        )
        custom_colors_compression_ratio = compression_ratio(
            image, CUSTOM_RESULTS[label][i]["colors"]
        )
        custom_compression_ratio = (
            custom_means_compression_ratio
            + custom_quats_compression_ratio
            + custom_scales_compression_ratio
            + custom_opacities_compression_ratio
            + custom_colors_compression_ratio
        )
        print(
            f"Label: {label}, Sample: {i}, Default: {default_compression_ratio}, Custom: {custom_compression_ratio}"
        )

Label: airplane, Sample: 0, Default: 0.75, Custom: 1.5
Label: automobile, Sample: 0, Default: 0.75, Custom: 1.5
Label: bird, Sample: 0, Default: 0.75, Custom: 1.5
Label: cat, Sample: 0, Default: 0.75, Custom: 1.5
Label: deer, Sample: 0, Default: 0.75, Custom: 1.5
Label: dog, Sample: 0, Default: 0.75, Custom: 1.5
Label: frog, Sample: 0, Default: 0.75, Custom: 1.5
Label: horse, Sample: 0, Default: 0.75, Custom: 1.5
Label: ship, Sample: 0, Default: 0.75, Custom: 1.5
Label: truck, Sample: 0, Default: 0.75, Custom: 1.5
