# GSPLAT training & rasterization investigation

In [None]:
import os
os.chdir("/home/minhtran/Code/gsplat/examples")

In [None]:
from datasets.colmap import Parser
from datasets.colmap_rgba import DatasetRGBA

## DataLoader

In [None]:
import imageio.v2 as imageio

In [None]:
FRAME = "00000000"

In [None]:
data_dir = f"/home/minhtran/Code/data/vocap/minh_2/frames/{FRAME}/train/rgba"
data_factor = 1
normalize = False
test_every = 6
result_dir = f"/home/minhtran/Code/data/vocap/minh_2/gsplat_results/frames/{FRAME}/train/rgba"

In [None]:
parser = Parser(
    data_dir=data_dir,
    factor=data_factor,
    normalize=normalize,
    test_every=test_every,
)

In [None]:
parser.camera_ids

In [None]:
index = 0
img = imageio.imread(parser.image_paths[index])

In [None]:
# Display img
import matplotlib.pyplot as plt
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
cfg.normalize_world_space

# Test dataloader

In [None]:
test_data_dir = "/home/minhtran/Code/data/vocap/minh_2/frames/00000000/test/rgba"

In [None]:
test_parser = Parser(
    data_dir=test_data_dir,
    factor=data_factor,
    normalize=normalize,
    test_every=1,
)

In [None]:
test_dataset = DatasetRGBA(test_parser, split="eval")

In [None]:
test_parser.camera_ids

In [None]:
first_item = test_dataset[0]

In [None]:
len(test_dataset)

# Render a test image

In [None]:
import os

from typing import Tuple, Dict, Optional, Literal
from torch import Tensor

import torch

from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy

from simple_trainer_rgba import Runner

import yaml
from easydict import EasyDict as edict

In [None]:
result_dir = "/home/minhtran/Code/data/vocap/minh_2/gsplat_results/frames/00000000/rgba/mcmc_random_bkgd"

In [None]:
result_cfg = os.path.join(result_dir, "cfg.yml")

In [None]:
# Read result_cfg yml file
import yaml
with open(result_cfg, "r") as f:
    cfg = yaml.load(f, Loader=yaml.UnsafeLoader)

In [None]:
cfg = edict(cfg)

In [None]:
runner = Runner(local_rank=0, world_rank=0, world_size=1, cfg=cfg)

In [None]:
first_item

In [None]:
ckpt_path = os.path.join(result_dir, "ckpts", 'ckpt_29999_rank0.pt')

In [None]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
ckpt = torch.load(ckpt_path)

In [None]:
splats = ckpt['splats']

In [None]:
runner.splats = splats

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=7,
    shuffle=False,
    num_workers=0,
)

In [None]:
data = next(iter(test_loader))

In [None]:
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0  # [1, H, W, 4]
pixels_alpha = pixels[..., 3:]
pixels = pixels[..., :3] * pixels_alpha  # Alpha blend for RGBA
masks = data["mask"].to(device) if "mask" in data else None
height, width = pixels.shape[1:3]

In [None]:
colors, alphas, _ = runner.rasterize_splats(
    camtoworlds=camtoworlds,
    Ks=Ks,
    width=width,
    height=height,
    sh_degree=cfg.sh_degree,
    near_plane=cfg.near_plane,
    far_plane=cfg.far_plane,
    masks=masks,
)

In [None]:
# Concatenate renders and alphas to get RGBA
renders_rgba = torch.cat([colors, alphas], dim=-1)

In [None]:
renders_rgba_np = renders_rgba.cpu().detach().numpy()
renders_rgba_np = (renders_rgba_np * 255.0).astype('uint8')

In [None]:
first_img = renders_rgba_np[0, ...]

In [None]:
# Display first_img
import matplotlib.pyplot as plt
plt.imshow(first_img)
plt.axis('off')
plt.show()

In [None]:
test_parser.camera_ids