In [None]:
import cv2
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from gaussian_splatting.colmap import parse_cameras, parse_images, parse_points3d, clean_text
from gaussian_splatting.model import View, GaussianCloud, train
from gaussian_splatting.model.util import create_rasterizer

random.seed(42)

In [None]:
# dataset on which to train
dataset = "hotdog"

In [None]:
# parsing images colmap output
base_path = Path("../data") / dataset
images = {
    image_name: cv2.imread(str(base_path / f"images/{image_name}"))[:, :, ::-1] / 255
    for image_name in os.listdir(base_path / f"images")
}
with open(base_path / "cameras.txt", "r")  as f:
    cameras = parse_cameras(clean_text(f.readlines()))

with open(base_path / "points3D.txt", "r")  as f:
    points3d = parse_points3d(clean_text(f.readlines()))

with open(base_path / "images.txt", "r")  as f:
    images = parse_images(clean_text(f.readlines()), cameras, points3d, images)
    

In [None]:
# creating views
views = [View.from_image(image) for image in images.values()]
views[:4]

In [None]:
# train/test splitting of views
random.shuffle(views)
train_split = 0.7
dataset_size = len(views)
train_dataset = views[:int(dataset_size*train_split)]
test_dataset = views[int(dataset_size*train_split):]

In [None]:
# creating the gaussian cloud
gaussian_cloud = GaussianCloud.from_point_cloud([*points3d.values()]).to("cuda")
gaussian_cloud

In [None]:
# function to easly compare the model output with ground truth
def compare(view: View, gaussian_cloud: GaussianCloud):
    gaussian_cloud.eval()
    model = create_rasterizer(view)
    _img, _ = model(**gaussian_cloud.parameters)
    npimg = _img.cpu().detach().numpy().transpose([1, 2, 0])
    orig_img = view.image
    plt.imshow(np.hstack([npimg, orig_img]))
    plt.show()


In [None]:
# comparing the results before training
compare(test_dataset[1], gaussian_cloud)

In [None]:
# putting the parameters in train mode
gaussian_cloud.train()

# training the model
train_losses, test_losses = train(gaussian_cloud, train_dataset, test_dataset, epochs=300)

In [None]:
# plotting the loss
plt.plot(train_losses)
plt.plot(test_losses)

In [None]:
# comparing the results after trainint
compare(test_dataset[1], gaussian_cloud)

In [None]:
# saving the model for later use
gaussian_cloud.save(f"../models/{dataset}.pkl")