In [None]:
%load_ext autoreload
%autoreload 2

# Scribble notebook
The plan is to use this to make first experiments, which will later be turned into a cleaner implementation. 
For now, it is based on https://github.com/thomasantony/splat/blob/master/notes/00_Gaussian_Projection.ipynb 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from scipy import spatial
from typing import List, Dict 
from copy import deepcopy
import warnings 
import random 
from tqdm import tqdm 

from utils.Camera import Camera
from utils.util_gau import load_ply, naive_gaussian, GaussianData
from utils.constants import * 
from pathlib import Path 
import multiprocessing as mp 
import functools
from decouple import config
from dataclasses import dataclass 
from typing import List, Tuple
from utils.Primitives import Gaussian, PrimitiveSet, PrimitiveSubset
from utils.ImageSegmenter import (
    IterativeImageSegmenter, 
    SubImage, 
    cut_image
)

## Define some classes for optimization

In [None]:
def plot_conics_and_bbs(gaussian_objects, camera: Camera, color: str='blue'):
    # credit to https://github.com/thomasantony/splat/blob/master/notes/00_Gaussian_Projection.ipynb 
    # Note: there are some things I would still like to change in this function (e.g. I'm pretty sure coordxy does not need four, but rather two, corners)
    ax = plt.gca()

    for g in gaussian_objects: #zip(gaussian_objects, colors):
        assert isinstance(g, Gaussian)
        (conic, bboxsize_cam, bbox_ndc) = g.get_conic_and_bb(
            camera, 
            optimal=True
        )
        if conic is None:
            continue

        A, B, C = conic
        # coordxy is the correct scale to be used with gaussian and is already
        # centered on the gaussian
        coordxy = bboxsize_cam
        x_cam = np.linspace(coordxy[0][0], coordxy[1][0], 100)
        y_cam = np.linspace(coordxy[1][1], coordxy[2][1], 100) # how come the first axis has more than 2 dimensions here?
        X, Y = np.meshgrid(x_cam, y_cam)
        
        # 1-sigma ellipse # actually, I think this is the sqrt(3)-sigma ellipse. 
        # F = A*X**2 + 2*B*X*Y + C*Y**2 - 3.00
        F = np.sqrt(A*X**2 + 2*B*X*Y + C*Y**2) - 3.00 

        bbox_screen = camera.ndc_to_pixel(bbox_ndc)

        # Use bbox offset to position of gaussian in screen coords to position the ellipse
        x_px = np.linspace(bbox_screen[0][0], bbox_screen[1][0], 100)
        y_px = np.linspace(bbox_screen[2][1], bbox_screen[1][1], 100) # again, why this many dimensions?
        X_px, Y_px = np.meshgrid(x_px, y_px)
        F_val = 0.0
        plt.contour(X_px, Y_px, F, [F_val])

        # Plot a rectangle around the gaussian position based on bb
        ul = bbox_screen[0,:2]
        ur = bbox_screen[1,:2]
        lr = bbox_screen[2,:2]
        ll = bbox_screen[3,:2]
        ax.add_patch(plt.Rectangle((ul[0], ul[1]), ur[0] - ul[0], lr[1] - ur[1], fill=False, alpha=1., color=color))


In [None]:
# Test IterativeImageSegmenter
# The idea is to segment an image such that each segment has roughly the same number of Gaussians.
# It's still not perfect (see blog)
# TODO: There seems to be a bug that causes some of the splits to be sub-optimal. Have to investigate
loc_scaler = 1
scale_scaler = .03
rot_scaler = 10
n_gaussians = 30
loc_bias = 0. 

gaussians_debug = [
    Gaussian((np.random.rand(3, )-.5 + loc_bias)*2*loc_scaler, np.random.rand(3)*scale_scaler, np.random.rand(4)*rot_scaler, np.array([1.]), np.array([ 1.772484, -1.772484,  1.772484])) for _ in range(n_gaussians)
]

fig = plt.figure()
ax = plt.gca()
camera = Camera(100, 100)
plot_conics_and_bbs(gaussians_debug, camera)
image = SubImage(np.zeros((100, 100)), (0, 0), (100, 100), camera)
pset = PrimitiveSet(gaussians_debug)
sset = PrimitiveSubset(pset, list(range(len(gaussians_debug))))
segmenter = IterativeImageSegmenter(sset, image, camera, thresh=3)

for i in range(7):
    bx, by = segmenter.cut(i)
    ul, lr = segmenter.cuts[-1]['corners']
    ymin, xmin = ul
    ymax, xmax = lr 
    if bx is None:
        # plot y splitting line
        plt.plot([xmin, xmax], [by]*2)
        plt.text((xmax+xmin)/3*2, by, str(i))
    else:
        plt.plot([bx]*2, [ymin, ymax])
        plt.text(bx, (ymin+ymax)/3*2, str(i))
    
plt.xlim([0, camera.w])
plt.ylim([0, camera.h])
plt.grid(True)
plt.show()


In [None]:
def plot_opacity_v(gaussian: Gaussian, camera: Camera, bitmap: np.ndarray, alphas: np.ndarray, alpha_thresh: float=None):
    """Vectorized version of the plot_opacity function of the original repo - we want to loop as little as possible for speed"""
    conic, bboxsize_cam, bbox_ndc = gaussian.get_conic_and_bb(camera, optimal=True) # different bounding boxes (active areas for gaussian)
    # bboxsize_cam, bbox_ndc = gaussian.get_nonopt_bb(camera)
    h, w = bitmap.shape[:2]
    bbox_screen = camera.ndc_to_pixel(bbox_ndc, w, h)
    
    if np.any(np.isnan(bbox_screen)):
        return

    ul = bbox_screen[0,:2] # Bounding box vertices 
    ur = bbox_screen[1,:2]
    ll = bbox_screen[3,:2]
    
    y1 = max([int(np.floor(ul[1])), 0])
    y2 = min([int(np.ceil(ll[1])), bitmap.shape[0]])
    
    x1 = max([int(np.floor(ul[0])), 0])
    x2 = min([int(np.ceil(ur[0])), bitmap.shape[1]])
    nx = x2 - x1
    ny = y2 - y1
    if nx <= 0 or ny <= 0: return bitmap, alphas
    
    # TODO: this was an attempt at checking whether we have rendered enough so that the alpha value in a pixel
    # is so high that we do not need to keep rendering deeper Gaussians in that pixel. However, the logic below
    # was slow enough that it did not pay off to implement this
    # TODO: possibly we could be cleverer when checking alpha - e.g. by using information from the foreground.
    # I am working on this

    # midy = y1+ny//2
    # midx = x1+nx//2
    # # if alpha_thresh is not None and (alphas[y1:y2,x1:x2] >= alpha_thresh).sum() >= nx*ny/3: 
    # #     return bitmap, alphas 
    # if alpha_thresh is not None and min([ # check alpha on bb borders (heuristic)
    #     # alphas[y1:y2, x1].max(),
    #     # alphas[y1:y2, x2-1].max(),
    #     # alphas[y1, x1:x2].max(),
    #     # alphas[y2-1, x1:x2].max(),
    #     alphas[midy, x1],
    #     alphas[midy, x2-1],
    #     alphas[y1, midx],
    #     alphas[y2-1, midx],
    #     alphas[midx,midx]
    # ]) >= alpha_thresh: 
    #     return bitmap, alphas
    # conic = gaussian.get_conic(camera)
    
    A, B, C = conic # precision matrix is (A, B; B, C)

    # Extract out inputs for the gaussian
    coordxy = bboxsize_cam
    x_cam_1 = coordxy[0][0]   # ul
    x_cam_2 = coordxy[1][0]   # ur
    y_cam_1 = coordxy[1][1]   # ur (y)
    y_cam_2 = coordxy[2][1]   # lr

    opacity = gaussian.opacity 

    camera_dir = gaussian.pos - camera.position
    camera_dir = camera_dir / np.linalg.norm(camera_dir) # normalized camera viewing direction
    color = gaussian.get_color(camera_dir)
    y_cam, x_cam = np.meshgrid(np.linspace(y_cam_1, y_cam_2, ny), np.linspace(x_cam_1, x_cam_2, nx), indexing='ij')
    power = -(A*x_cam**2 + C*y_cam**2)/2.0 - B * x_cam * y_cam
    # power = np.clip(power, -np.inf, 0.)
    alpha_ = opacity * np.exp(power)
    bitmap[y1:y2, x1:x2] = bitmap[y1:y2, x1:x2] + ((1-alphas[y1:y2, x1:x2])*alpha_).reshape(ny, nx, 1) * color[0:3].reshape(1, 1, -1)
    alphas[y1:y2, x1:x2] = alphas[y1:y2, x1:x2] + (1-alphas[y1:y2, x1:x2]) * alpha_ 
    return bitmap, alphas 

# Plot some Gaussians
loc_scaler = 1
scale_scaler = .05
rot_scaler = 10
n_gaussians = 50
loc_bias = 0. # np.ones(3,) * -.1

gaussians_debug = [
    Gaussian((np.random.rand(3, )-.5 + loc_bias)*2*loc_scaler, np.random.rand(3)*scale_scaler, np.random.rand(4)*rot_scaler, np.array([1.]), np.array([ 1.772484, -1.772484,  1.772484])) for _ in range(n_gaussians)
]
gaussian_objects = gaussians_debug

(h, w) = (300, 400)
camera = Camera(h, w)
# Get gaussian indices sorted by depth
indices = np.argsort([g.get_depth(camera) for g in gaussian_objects]) # TODO: vectorize get_depth

# Initialize a bitmap with alpha channel of size w x h

bitmap = np.zeros((h, w, 3), np.float32)
alphas = np.zeros((h, w), np.float32)

plt.figure(figsize=(6,6))
for idx in indices:
    # bitmap, alphas = plot_opacity(gaussian_objects[idx], camera, bitmap, alphas)
    bitmap, alphas = plot_opacity_v(gaussian_objects[idx], camera, bitmap, alphas)
print(f'after execution, {bitmap.max()=}')
# Plot the bitmap
plt.imshow(bitmap, vmin=0, vmax=1.0)

plt.show()

In [None]:
# define utility functions for parallell computation

def helper_primitives(indices: List[int], camera, bitmap, alphas, alpha_thresh):
    for idx in indices:
        bitmap, alphas = plot_opacity_v(gaussian_objects[idx], camera, bitmap, alphas, alpha_thresh)
    return bitmap, alphas 

def plot_model_par_primitives(camera, gaussian_objects: List[Gaussian], alpha_thresh: float=None, n_threads:int=1):
    """Parallellize image rendering by splitting the sorted list of Gaussians and rendering them individually.
    Requires blending the segments later (for now).
    """
    print('Sorting the gaussians by depth')
    indices = np.argsort([gau.get_depth(camera) for gau in gaussian_objects])# [::-1] # fast-ish. probably get_depth is slowing it down (vectorization/parallellization TODO)
    w, h = camera.w, camera.h
    
    print('Plotting with', len(gaussian_objects), 'gaussians')
    bitmap = np.zeros((h, w, 3), np.float32)
    alphas = np.zeros((h, w), np.float32)
    skip = len(indices) // (3*n_threads) # arbitrary choice TODO
    gsns = [indices[i*skip:(i+1)*skip] for i in range(len(indices)//skip + int(len(indices)%skip != 0))]
    with mp.Pool(n_threads) as pool:
        for r in pool.imap(
            functools.partial(
                helper_primitives,
                alpha_thresh=alpha_thresh, 
                camera=camera,# None,
                bitmap=bitmap,
                alphas=alphas
            ), 
            gsns
        ): 
            yield r 




In [None]:
model = load_ply(str(Path(config('MODEL_PATH'))/'debug/point_cloud/iteration_30000/point_cloud.ply'))
from tqdm import tqdm

print('Loading gaussians ...')
gaussian_objects = []
(h, w) = (100, 100)

for (pos, scale, rot, opacity, sh) in tqdm(zip(model.xyz, model.scale, model.rot, model.opacity, model.sh)):
    gaussian_objects.append(Gaussian(pos, scale, rot, opacity, sh))

In [None]:
sampled_poss = np.stack([o.pos for o in gaussian_objects[:1000]], axis=-1)
sampled_poss.mean(axis=-1)

In [None]:
(w, h) = (400, 400)
alpha_thresh = .95

camera = Camera(h, w, position=(-.2, 2., 1.4), target=(-.15, 1.5, .7))# target is pos for one of the gaussians
bitmap_parts = list(tqdm(plot_model_par_primitives(camera, gaussian_objects, alpha_thresh=alpha_thresh, n_threads=20)))

# do alpha blending post-hoc
bitmap = np.zeros(bitmap_parts[0][0].shape)
alpha = np.zeros(bitmap_parts[0][1].shape)
for bm, al in bitmap_parts:# [::-1]: # back to front
    bitmap = bitmap + ((1-alpha)*al).reshape(*alpha.shape, 1)*bm
    alpha = alpha + (1-alpha)*al 

plt.figure(figsize=(6, 6))
plt.imshow(bitmap, vmin=0, vmax=1.0)
plt.show()

In [None]:
from IPython.display import display, clear_output
import time
%matplotlib widget
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot()
# draw as we generate
(w, h) = (400, 400)
alpha_thresh = None
camera = Camera(h, w, position=(-.2, 2., 1.), target=(-.15, 1.5, .6))
# camera = Camera(h, w, position=(-.1, 1., 0), target=(-.15, 1.5, .7)) # target is pos for one of the gaussians
bitmap_parts = plot_model_par_primitives(camera, gaussian_objects, alpha_thresh=alpha_thresh, n_threads=20)

# do alpha blending post-hoc
bitmap = np.zeros((w, h, 3))
alpha = np.zeros((w, h))
# i = 0
for bm, al in plot_model_par_primitives(camera, gaussian_objects, alpha_thresh=alpha_thresh, n_threads=20):
    # print(i, end='\r')
    # i += 1
    bitmap = bitmap + ((1-alpha)*al).reshape(*alpha.shape, 1)*bm
    alpha = alpha + (1-alpha)*al
    ax.cla()
    ax.imshow(bitmap, vmin=0, vmax=1.0)
    display(fig)
    clear_output(wait=True)
    
