In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

In [2]:
import torch
from configs import Config
from torchvision import datasets
from models import GaussianImageTrainer
from constants import CIFAR10_TRANSFORM
from utils import (
    generate_random_splat,
    merge_spherical_harmonics,
    save_gs_data,
    load_gs_data,
    tensor_to_image,
)

In [3]:
NUM_POINTS = 1024  # 32x32
CIFAR10 = datasets.CIFAR10(
    root="../data/CIFAR10/train",
    train=True,
    download=True,
    transform=CIFAR10_TRANSFORM,
)
means, quats, scales, opacities, colors, viewmats, Ks, sh0, shN = generate_random_splat(
    NUM_POINTS
)
image = tensor_to_image(CIFAR10[0][0])
label = CIFAR10[0][1]
splat = torch.nn.ParameterDict(
    {
        "mean": torch.nn.Parameter(means),
        "quat": torch.nn.Parameter(quats),
        "scale": torch.nn.Parameter(scales),
        "opacity": torch.nn.Parameter(opacities),
        "color": torch.nn.Parameter(colors),
        "viewmat": torch.nn.Parameter(viewmats),
        "Ks": torch.nn.Parameter(Ks),
        "sh0": torch.nn.Parameter(sh0),
        "shN": torch.nn.Parameter(shN),
    }
)
print(image, label, splat)

splat = merge_spherical_harmonics(splat)
save_gs_data(image, label, splat, Path("test.pt"))
image, label, splat = load_gs_data(Path("test.pt"))
print(image, label, splat)

Files already downloaded and verified
<PIL.Image.Image image mode=RGB size=32x32 at 0x7ADAC6FF89D0> 6 ParameterDict(
    (Ks): Parameter containing: [torch.FloatTensor of size 3x3]
    (color): Parameter containing: [torch.FloatTensor of size 1024x3]
    (mean): Parameter containing: [torch.FloatTensor of size 1024x3]
    (opacity): Parameter containing: [torch.FloatTensor of size 1024x1]
    (quat): Parameter containing: [torch.FloatTensor of size 1024x4]
    (scale): Parameter containing: [torch.FloatTensor of size 1024x3]
    (sh0): Parameter containing: [torch.FloatTensor of size 1024x1]
    (shN): Parameter containing: [torch.FloatTensor of size 1024x2]
    (viewmat): Parameter containing: [torch.FloatTensor of size 4x4]
)
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x7ADAC701B190> 6 ParameterDict(
    (Ks): Parameter containing: [torch.FloatTensor of size 3x3]
    (color): Parameter containing: [torch.FloatTensor of size 1024x3]
    (mean): Parameter containing:

In [None]:
# Example on how the CIFAR10GS dataset should be created
trainer = GaussianImageTrainer(Config())
for index, (image, label) in enumerate(CIFAR10):
    trainer.reinitialize(Config(image=image))

    splat = trainer.train()
    splat = merge_spherical_harmonics(splat)

    save_gs_data(
        image,
        label,
        splat,
        Path(f".../{'train' if CIFAR10.train else 'test'}/{index}.pt"),
    )
    break