In [None]:
%load_ext autoreload
%autoreload 2

# Scribble notebook
The plan is to use this to make first experiments, which will later be turned into a cleaner implementation. 
For now, it is based on https://github.com/thomasantony/splat/blob/master/notes/00_Gaussian_Projection.ipynb 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from scipy import spatial
from typing import List, Dict 
from copy import deepcopy
import warnings 
import random 
from tqdm import tqdm 

from utils.Camera import Camera
from utils.util_gau import load_ply, naive_gaussian, GaussianData
from utils.constants import * 
from pathlib import Path 
import multiprocessing as mp 
import functools
from decouple import config
from dataclasses import dataclass 
from typing import List, Tuple
from utils.Primitives import Gaussian, PrimitiveSet, PrimitiveSubset
from utils.ImageSegmenter import (
    IterativeImageSegmenter, 
    SubImage, 
    cut_image
)

In [None]:
g_width, g_height = 100, 100

In [None]:
gaussians = naive_gaussian()
alpha_mask = np.zeros((g_width, g_height), dtype=bool)
scale_modifier = 1.0       

# Iterate over the gaussians and create Gaussian objects
gaussian_objects = []
for (pos, scale, rot, opacity, sh) in zip(gaussians.xyz, gaussians.scale, gaussians.rot, gaussians.opacity, gaussians.sh):
    gau = Gaussian(pos, scale, rot, opacity, sh)
    gaussian_objects.append(gau)

## Define some classes for optimization

In [None]:
def plot_conics_and_bbs(gaussian_objects, camera: Camera, color: str='blue'):

    ax = plt.gca()

    for g in gaussian_objects: #zip(gaussian_objects, colors):
        assert isinstance(g, Gaussian)
        (conic, bboxsize_cam, bbox_ndc) = g.get_conic_and_bb(
            camera, 
            optimal=True
        )
        if conic is None:
            continue

        A, B, C = conic
        # coordxy is the correct scale to be used with gaussian and is already
        # centered on the gaussian
        coordxy = bboxsize_cam
        x_cam = np.linspace(coordxy[0][0], coordxy[1][0], 100)
        y_cam = np.linspace(coordxy[1][1], coordxy[2][1], 100) # how come the first axis has more than 2 dimensions here?
        X, Y = np.meshgrid(x_cam, y_cam)
        
        # 1-sigma ellipse # actually, I think this is the sqrt(3)-sigma ellipse. 
        # F = A*X**2 + 2*B*X*Y + C*Y**2 - 3.00
        F = np.sqrt(A*X**2 + 2*B*X*Y + C*Y**2) - 3.00 

        bbox_screen = camera.ndc_to_pixel(bbox_ndc)

        # Use bbox offset to position of gaussian in screen coords to position the ellipse
        x_px = np.linspace(bbox_screen[0][0], bbox_screen[1][0], 100)
        y_px = np.linspace(bbox_screen[2][1], bbox_screen[1][1], 100) # again, why this many dimensions?
        X_px, Y_px = np.meshgrid(x_px, y_px)
        F_val = 0.0
        plt.contour(X_px, Y_px, F, [F_val])

        # Plot a rectangle around the gaussian position based on bb
        ul = bbox_screen[0,:2]
        ur = bbox_screen[1,:2]
        lr = bbox_screen[2,:2]
        ll = bbox_screen[3,:2]
        ax.add_patch(plt.Rectangle((ul[0], ul[1]), ur[0] - ul[0], lr[1] - ur[1], fill=False, alpha=1., color=color))


In [None]:
loc_scaler = 1
scale_scaler = .03
rot_scaler = 10
n_gaussians = 30
loc_bias = 0. # np.ones(3,) * -.1

gaussians_debug = [
    Gaussian((np.random.rand(3, )-.5 + loc_bias)*2*loc_scaler, np.random.rand(3)*scale_scaler, np.random.rand(4)*rot_scaler, np.array([1.]), np.array([ 1.772484, -1.772484,  1.772484])) for _ in range(n_gaussians)
]

fig = plt.figure()
ax = plt.gca()
camera = Camera(100, 100)
plot_conics_and_bbs(gaussians_debug, camera)
image = SubImage(np.zeros((100, 100)), (0, 0), (100, 100), camera)
pset = PrimitiveSet(gaussians_debug)
sset = PrimitiveSubset(pset, list(range(len(gaussians_debug))))
segmenter = IterativeImageSegmenter(sset, image, camera, thresh=3)

for i in range(7):
    bx, by = segmenter.cut(i)
    ul, lr = segmenter.cuts[-1]['corners']
    ymin, xmin = ul
    ymax, xmax = lr 
    if bx is None:
        # plot y splitting line
        plt.plot([xmin, xmax], [by]*2)
        plt.text((xmax+xmin)/3*2, by, str(i))
    else:
        plt.plot([bx]*2, [ymin, ymax])
        plt.text(bx, (ymin+ymax)/3*2, str(i))
    
plt.xlim([0, camera.w])
plt.ylim([0, camera.h])
plt.grid(True)
plt.show()


In [None]:
'''def subroutine(tpl, A, B, C, y_iter, opacity, alphas, bitmap, color, alpha_thresh, w, h):
    """This function was originally made so that I could parallellize plot_opacity, which turned out to be slower due
    to overhead computations, but I still kept this function to "refactor" the plot_opacity function"""
    # time.sleep(random.randint(0, 1000)/1000)
    if alpha_thresh is None: alpha_thresh = np.inf
    x, x_cam = tpl
    x = min([alphas.shape[1]-1, max([x, 0])])
    alphas = alphas[:, x]
    bm = bitmap[:, x]
    if x < 0 or x >= w:
        return bm, alphas
    for y, y_cam in y_iter: 
        if y < 0 or y >= h or alpha_mask[x, y]:
            continue

        # Gaussian is typically calculated as f(x, y) = A * exp(-(a*x^2 + 2*b*x*y + c*y^2))
        power = -(A*x_cam**2 + C*y_cam**2)/2.0 - B * x_cam * y_cam # TODO: can be better by just computing the range for y up front
        if power > 0.0:
            continue

        alpha = opacity * np.exp(power)
        alpha = min(0.99, alpha) 
        # if opacity < 1.0 / 255.0:
        #     continue

        # Set the pixel color to the given color and opacity
        # Do alpha blending using "over" method 
        old_alpha = alphas[y]
        new_alpha = alpha + old_alpha * (1.0 - alpha) # "over" = old_alpha + alpha * (1-old_alpha) >= max(alpha, old_alpha) - can mark that in global object
        alphas[y] = new_alpha
        bm[y, :] = (color[0:3]) * alpha + bm[y, :] * (1.0 - alpha) # TODO: looks wrong, no?
        if alpha_thresh is not None and new_alpha > alpha_thresh: 
            alpha_mask[x, y] = True 
            continue 
    return bm, alphas

def plot_opacity(gaussian: Gaussian, camera: Camera, bitmap: np.ndarray, alphas: np.ndarray, alpha_thresh: float=None, responsible_range=None):
    """Compute the opacity of a gaussian given the camera"""
    shp = bitmap.shape
    w, h = camera.w, camera.h
    conic, bboxsize_cam, bbox_ndc = gaussian.get_conic_and_bb(camera) # different bounding boxes (active areas for gaussian)

    A, B, C = conic # precision matrix is (A, B; B, C)

    h, w = bitmap.shape[:2]
    bbox_screen = camera.ndc_to_pixel(bbox_ndc, w, h)
    
    if np.any(np.isnan(bbox_screen)):
        return

    ul = bbox_screen[0,:2] # Bounding box vertices 
    ur = bbox_screen[1,:2]
    lr = bbox_screen[2,:2]
    ll = bbox_screen[3,:2]
    
    y1 = int(np.floor(ul[1]))
    y2 = int(np.ceil(ll[1]))
    
    x1 = int(np.floor(ul[0]))
    x2 = int(np.ceil(ur[0]))
    nx = x2 - x1
    ny = y2 - y1

    # Extract out inputs for the gaussian
    coordxy = bboxsize_cam
    x_cam_1 = coordxy[0][0]   # ul
    x_cam_2 = coordxy[1][0]   # ur
    y_cam_1 = coordxy[1][1]   # ur (y)
    y_cam_2 = coordxy[2][1]   # lr

    camera_dir = gaussian.pos - camera.position
    camera_dir = camera_dir / np.linalg.norm(camera_dir) # normalized camera viewing direction
    color = gaussian.get_color(camera_dir)
    if responsible_range is not None:
        x_iter = [(x, x_cam) for x, x_cam in zip(range(x1, x2), np.linspace(x_cam_1, x_cam_2, nx)) if x in responsible_range]
    else: x_iter = zip(range(x1, x2), np.linspace(x_cam_1, x_cam_2, nx))
    for tpl in x_iter: # TODO: better to provide this range explicitly or for each x computing the relevant y's
        x, cam = tpl
        x = min(bitmap.shape[1]-1, x)
        tpl = (x, cam)
        bm, al = subroutine(tpl, A, B, C, zip(range(y1, y2), np.linspace(y_cam_1, y_cam_2, ny)), 
                            opacity, alphas, bitmap, color, alpha_thresh, w=w, h=h)
        bitmap[:, x] = bm 
        # print(f'{bm.sum()=}')
        alphas[:, x] = al
    # print(f'{bitmap.max()=}, {bitmap.sum()=}. {alphas.max()=}, {alphas.sum()=}')
    return bitmap, alphas
'''
def plot_opacity_v(gaussian: Gaussian, camera: Camera, bitmap: np.ndarray, alphas: np.ndarray, alpha_thresh: float=None):
    """Vectorized version of the above - we want to loop as little as possible"""
    conic, bboxsize_cam, bbox_ndc = gaussian.get_conic_and_bb(camera, optimal=True) # different bounding boxes (active areas for gaussian)
    # bboxsize_cam, bbox_ndc = gaussian.get_nonopt_bb(camera)
    h, w = bitmap.shape[:2]
    bbox_screen = camera.ndc_to_pixel(bbox_ndc, w, h)
    
    if np.any(np.isnan(bbox_screen)):
        return

    ul = bbox_screen[0,:2] # Bounding box vertices 
    ur = bbox_screen[1,:2]
    ll = bbox_screen[3,:2]
    
    y1 = max([int(np.floor(ul[1])), 0])
    y2 = min([int(np.ceil(ll[1])), bitmap.shape[0]])
    
    x1 = max([int(np.floor(ul[0])), 0])
    x2 = min([int(np.ceil(ur[0])), bitmap.shape[1]])
    nx = x2 - x1
    ny = y2 - y1
    if nx <= 0 or ny <= 0: return bitmap, alphas
    
    # midy = y1+ny//2
    # midx = x1+nx//2
    # # if alpha_thresh is not None and (alphas[y1:y2,x1:x2] >= alpha_thresh).sum() >= nx*ny/3: # hacky, experimenting for now TODO 
    # #     return bitmap, alphas 
    # if alpha_thresh is not None and min([ # check alpha on bb borders (heuristic)
    #     # alphas[y1:y2, x1].max(),
    #     # alphas[y1:y2, x2-1].max(),
    #     # alphas[y1, x1:x2].max(),
    #     # alphas[y2-1, x1:x2].max(),
    #     alphas[midy, x1],
    #     alphas[midy, x2-1],
    #     alphas[y1, midx],
    #     alphas[y2-1, midx],
    #     alphas[midx,midx]
    # ]) >= alpha_thresh: 
    #     return bitmap, alphas
    # conic = gaussian.get_conic(camera)
    
    A, B, C = conic # precision matrix is (A, B; B, C)

    # Extract out inputs for the gaussian
    coordxy = bboxsize_cam
    x_cam_1 = coordxy[0][0]   # ul
    x_cam_2 = coordxy[1][0]   # ur
    y_cam_1 = coordxy[1][1]   # ur (y)
    y_cam_2 = coordxy[2][1]   # lr

    opacity = gaussian.opacity 

    camera_dir = gaussian.pos - camera.position
    camera_dir = camera_dir / np.linalg.norm(camera_dir) # normalized camera viewing direction
    color = gaussian.get_color(camera_dir)
    y_cam, x_cam = np.meshgrid(np.linspace(y_cam_1, y_cam_2, ny), np.linspace(x_cam_1, x_cam_2, nx), indexing='ij')
    power = -(A*x_cam**2 + C*y_cam**2)/2.0 - B * x_cam * y_cam
    # power = np.clip(power, -np.inf, 0.)
    alpha_ = opacity * np.exp(power)
    bitmap[y1:y2, x1:x2] = bitmap[y1:y2, x1:x2] + ((1-alphas[y1:y2, x1:x2])*alpha_).reshape(ny, nx, 1) * color[0:3].reshape(1, 1, -1)
    alphas[y1:y2, x1:x2] = alphas[y1:y2, x1:x2] + (1-alphas[y1:y2, x1:x2]) * alpha_ 
    return bitmap, alphas 

# Iterate over the gaussians and create Gaussian objects
# gaussian_objects = []
# for (pos, scale, rot, opacity, sh) in zip(gaussians.xyz, gaussians.scale, gaussians.rot, gaussians.opacity, gaussians.sh):
#     gau = Gaussian(pos, scale, rot, opacity, sh)
#     gaussian_objects.append(gau)
loc_scaler = 1
scale_scaler = .05
rot_scaler = 10
n_gaussians = 20
loc_bias = 0. # np.ones(3,) * -.1

gaussians_debug = [
    Gaussian((np.random.rand(3, )-.5 + loc_bias)*2*loc_scaler, np.random.rand(3)*scale_scaler, np.random.rand(4)*rot_scaler, np.array([1.]), np.array([ 1.772484, -1.772484,  1.772484])) for _ in range(n_gaussians)
]
gaussian_objects = gaussians_debug

(h, w) = (300, 400)
camera = Camera(h, w)
# Get gaussian indices sorted by depth
indices = np.argsort([g.get_depth(camera) for g in gaussian_objects])

# Initialize a bitmap with alpha channel of size w x h

bitmap = np.zeros((h, w, 3), np.float32)
alphas = np.zeros((h, w), np.float32)

alpha_mask = np.zeros((w, h), dtype=bool)

plt.figure(figsize=(6,6))
for idx in indices:
    # bitmap, alphas = plot_opacity(gaussian_objects[idx], camera, bitmap, alphas)
    bitmap, alphas = plot_opacity_v(gaussian_objects[idx], camera, bitmap, alphas)
print(f'after execution, {bitmap.max()=}')
# Plot the bitmap
plt.imshow(bitmap, vmin=0, vmax=1.0)

plt.show()

In [None]:
# define utility functions for parallell computation
def helper(rng, indices, camera, bitmap, alphas, alpha_thresh):
    if not rng: return 
    for idx in tqdm(indices):
        bitmap, alphas = plot_opacity(gaussian_objects[idx], camera, bitmap, alphas, alpha_thresh, responsible_range=rng)
    return np.stack([bitmap[:, x] for x in rng], axis=0)

def plot_model_par(camera, gaussian_objects: List[Gaussian], alpha_thresh: float=None, n_threads:int=1):
    def partition_matrix(mat: np.ndarray):
        s = max([mat.shape[0] // n_threads, 1])
        if s == 1: return [mat[i:i+1] for i in range(mat.shape[0])]
        return [mat[s*i: s*(i+1)] for i in range(n_threads-1)] + [mat[s*(n_threads-1):]]
    print('Sorting the gaussians by depth')
    indices = np.argsort([gau.get_depth(camera) for gau in gaussian_objects]) # fast
    w, h = camera.w, camera.h
    
    print('Plotting with', len(gaussian_objects), 'gaussians')
    bitmap = np.zeros((h, w, 3), np.float32)
    bitmap_parts = partition_matrix(bitmap)
    alphas = np.zeros((h, w), np.float32)
    alphas_parts = partition_matrix(alphas)
    ranges = [range(offset, w, n_threads) for offset in range(n_threads)]
    print(ranges)
    # TODO: the below could possibly be paralellized by splitting the resulting image into tiles and rendering them in parallell
    # (I think they do that in the paper too)
    if n_threads > 1:
        with mp.Pool(n_threads) as pool:
            results = pool.map(
                functools.partial(
                    helper,
                    indices=indices,
                    alpha_thresh=alpha_thresh, 
                    camera=camera,# None,
                    bitmap=bitmap,
                    alphas=alphas
                ), 
                ranges
            )
            # for idx in tqdm(indices): # TODO: this is slow and could potentially be sped up by manipulating the chosen gaussians
            #     bitmap, alphas = plot_opacity(gaussian_objects[idx], camera, bitmap, alphas, alpha_thresh)
    else:
        results = [
            helper(
                rng=ranges[0],
                indices=indices,
                alpha_thresh=alpha_thresh, 
                camera=camera,# None,
                bitmap=bitmap,
                alphas=alphas
            )
        ]
    for res, rng in zip(results, ranges):
        if res is None: continue
        for i, idx in enumerate(rng): bitmap[:, idx] = res[i]

    return bitmap

def helper_primitives(indices: List[int], camera, bitmap, alphas, alpha_thresh):
    alphas += alpha_mask.astype(float)
    for idx in indices:
        bitmap, alphas = plot_opacity_v(gaussian_objects[idx], camera, bitmap, alphas, alpha_thresh)
    # if alpha_thresh is None: return bitmap, alphas 
    # # nonlocal alpha_mask
    # for (i, j), a in np.ndenumerate(alphas):
    #     if a > alpha_thresh: alpha_mask[i, j] = True
    return bitmap, alphas 

def plot_model_par_primitives(camera, gaussian_objects: List[Gaussian], alpha_thresh: float=None, n_threads:int=1):
    print('Sorting the gaussians by depth')
    indices = np.argsort([gau.get_depth(camera) for gau in gaussian_objects])# [::-1] # fast
    w, h = camera.w, camera.h
    
    print('Plotting with', len(gaussian_objects), 'gaussians')
    bitmap = np.zeros((h, w, 3), np.float32)
    alphas = np.zeros((h, w), np.float32)
    skip = len(indices) // (3*n_threads)
    gsns = [indices[i*skip:(i+1)*skip] for i in range(len(indices)//skip + int(len(indices)%skip != 0))]
    with mp.Pool(n_threads) as pool:
        results = list(tqdm(pool.imap(
            functools.partial(
                helper_primitives,
                alpha_thresh=alpha_thresh, 
                camera=camera,# None,
                bitmap=bitmap,
                alphas=alphas
            ), 
            gsns
        ), total=len(gsns)))
    return results


In [None]:
model = load_ply(str(Path(config('MODEL_PATH'))/'debug/point_cloud/iteration_30000/point_cloud.ply'))
from tqdm import tqdm

print('Loading gaussians ...')
gaussian_objects = []
(h, w) = (100, 100)

for (pos, scale, rot, opacity, sh) in tqdm(zip(model.xyz, model.scale, model.rot, model.opacity, model.sh)):
    gaussian_objects.append(Gaussian(pos, scale, rot, opacity, sh))

In [None]:
sampled_poss = np.stack([o.pos for o in gaussian_objects[:1000]], axis=-1)
sampled_poss.mean(axis=-1)

In [None]:
(w, h) = (400, 400)
alpha_thresh = .95
alpha_mask = np.zeros((w, h), dtype=bool)
# target = (-.15, 1.5, .7) 
# # offset = (.05, -.5, -.7)
# norm = np.sqrt(.05**2+.5**2+.7**2)
# offset = np.array([
#     1, 
#     1, 
#     1, 
# ])
# offset = offset / np.linalg.norm(offset)
# position = tuple((t+o for t, o in zip(target, offset)))

# camera = Camera(h, w, position=position, target=target)
camera = Camera(h, w, position=(-.2, 2., 1.4), target=(-.15, 1.5, .7))
# camera = Camera(h, w, position=(-.1, 1., 0), target=(-.15, 1.5, .7)) # target is pos for one of the gaussians
bitmap_parts = plot_model_par_primitives(camera, gaussian_objects, alpha_thresh=alpha_thresh, n_threads=20)

# do alpha blending post-hoc
bitmap = np.zeros(bitmap_parts[0][0].shape)
alpha = np.zeros(bitmap_parts[0][1].shape)
for bm, al in bitmap_parts:# [::-1]: # back to front
    bitmap = bitmap + ((1-alpha)*al).reshape(*alpha.shape, 1)*bm
    alpha = alpha + (1-alpha)*al 

plt.figure(figsize=(6, 6))
plt.imshow(bitmap, vmin=0, vmax=1.0)
plt.show()

In [None]:
def draw_from_unit_view(dir, norm=1.):
    (w, h) = (400, 400)
    alpha_thresh = None
    alpha_mask = np.zeros((w, h), dtype=bool)
    target = (-.15, 1.5, 1) 
    # offset = (.05, -.5, -.7)
    offset = dir / np.linalg.norm(dir) * norm
    position = offset # tuple((t+o for t, o in zip(target, offset)))

    camera = Camera(h, w, position=position, target=target)
    # camera = Camera(h, w, position=(-.2, 2., 1.4), target=(-.15, 1.5, .7))
    # camera = Camera(h, w, position=(-.1, 1., 0), target=(-.15, 1.5, .7)) # target is pos for one of the gaussians
    bitmap_parts = plot_model_par_primitives(camera, gaussian_objects, alpha_thresh=alpha_thresh, n_threads=24)

    # do alpha blending post-hoc
    bitmap = np.zeros(bitmap_parts[0][0].shape)
    alpha = np.zeros(bitmap_parts[0][1].shape)
    for bm, al in bitmap_parts[::-1]: # back to front
        bitmap = bitmap + ((1-alpha)*al).reshape(*alpha.shape, 1)*bm
        alpha = alpha + (1-alpha)*al 

    plt.figure(figsize=(6, 6))
    plt.imshow(bitmap, vmin=0, vmax=1.0)
    plt.show()
# # draw from random directions
# n_trials = 10 
# # d = np.array([.05, -.5, -.7])# np.random.random(3) - .5
# d = np.array([-.1, 1., 0])
# print(f'{d=}')
# for i in range(n_trials):    
#     draw_from_unit_view(d, norm=(1+i)/6)