In [None]:
%cd /scratch_net/biwidl214/ecetin_scratch/GSCodec/notebooks/

In [None]:
import torch
from torch.utils.cpp_extension import load, load_inline, is_ninja_available
import os

os.environ['CUDA_LAUNCH_BLOCKING']='1'
os.environ['TORCH_USE_CUDA_DSA']='1'

In [None]:
%load_ext wurlitzer

In [None]:
print(torch.cuda.is_available())
print(is_ninja_available())

In [None]:
cuda_src = r'''
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#include <vector_types.h>
#include <device_launch_parameters.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}

__global__ void update_means_opacity_sh(
    float* weights, 
    float* mu, 
    float* opacity,
    float* sh,
    int* node_ids, 
    float* new_mu,
    float* new_node_total_weight,
    float* new_opacity,
    float* new_sh,
    int num_gaussians,
    int num_sh_coeffs
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_gaussians) return;

    int node_id = node_ids[idx];
    float weight = weights[idx];

    // Weighted mu for current Gaussian
    float mu_x = mu[idx * 3 + 0];
    float mu_y = mu[idx * 3 + 1];
    float mu_z = mu[idx * 3 + 2];

    atomicAdd(&new_mu[node_id * 3 + 0], weight * mu_x);
    atomicAdd(&new_mu[node_id * 3 + 1], weight * mu_y);
    atomicAdd(&new_mu[node_id * 3 + 2], weight * mu_z);

    // Weighted total weight for current node
    atomicAdd(&new_node_total_weight[node_id], weight);

    // Weighted opacity for current Gaussian
    float op = opacity[idx];
    atomicAdd(&new_opacity[node_id], weight * op);

    // Weighted shs for current Gaussian
    for (int k = 0; k < num_sh_coeffs; k++) {
        atomicAdd(
            &new_sh[node_id * num_sh_coeffs + k], 
            weight * sh[idx * num_sh_coeffs + k]
        );
    }
}

__global__ void divide_by_total_weight(
    float* new_mu,
    float* new_node_total_weight,
    float* new_opacity,
    float* new_sh,
    int num_new_nodes,
    int num_sh_coeffs
) {
    int node_id = blockIdx.x * blockDim.x + threadIdx.x;
    if (node_id >= num_new_nodes) return;

    float total_weight = new_node_total_weight[node_id];

    new_mu[node_id * 3 + 0] /= total_weight;
    new_mu[node_id * 3 + 1] /= total_weight;
    new_mu[node_id * 3 + 2] /= total_weight;

    new_opacity[node_id] /= total_weight;

    for (int k = 0; k < num_sh_coeffs; k++) {
        new_sh[node_id * num_sh_coeffs + k] /= total_weight;
    }
}


__global__ void update_covariances(
    float* weights, 
    float* mu, 
    float* sigma, 
    int* node_ids, 
    float* new_mu, 
    float* new_sigma,
    int num_gaussians
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_gaussians) return;

    int node_id = node_ids[idx];
    float weight = weights[idx];

    // Current mu and new_mu values
    float mu_x = mu[idx * 3 + 0];
    float mu_y = mu[idx * 3 + 1];
    float mu_z = mu[idx * 3 + 2];

    float new_mu_x = new_mu[node_id * 3 + 0];
    float new_mu_y = new_mu[node_id * 3 + 1];
    float new_mu_z = new_mu[node_id * 3 + 2];

    // Difference mu_i - new_mu
    float diff_x = mu_x - new_mu_x;
    float diff_y = mu_y - new_mu_y;
    float diff_z = mu_z - new_mu_z;

    // Sigma indices
    int sigma_base_idx = idx * 6;
    int new_sigma_base_idx = node_id * 6;

    // Update covariance matrix components
    atomicAdd(&new_sigma[new_sigma_base_idx + 0], weight * (sigma[sigma_base_idx + 0] + diff_x * diff_x)); // sxx
    atomicAdd(&new_sigma[new_sigma_base_idx + 1], weight * (sigma[sigma_base_idx + 1] + diff_x * diff_y)); // sxy
    atomicAdd(&new_sigma[new_sigma_base_idx + 2], weight * (sigma[sigma_base_idx + 2] + diff_x * diff_z)); // sxz
    atomicAdd(&new_sigma[new_sigma_base_idx + 3], weight * (sigma[sigma_base_idx + 3] + diff_y * diff_y)); // syy
    atomicAdd(&new_sigma[new_sigma_base_idx + 4], weight * (sigma[sigma_base_idx + 4] + diff_y * diff_z)); // syz
    atomicAdd(&new_sigma[new_sigma_base_idx + 5], weight * (sigma[sigma_base_idx + 5] + diff_z * diff_z)); // szz
}

__global__ void divide_sigma_by_total_weight(
    float* new_node_total_weight,
    float* new_sigma,
    int num_new_nodes
) {
    int node_id = blockIdx.x * blockDim.x + threadIdx.x;
    if (node_id >= num_new_nodes) return;

    float total_weight = new_node_total_weight[node_id];

    int sigma_base_idx = node_id * 6;

    new_sigma[sigma_base_idx + 0] /= total_weight;
    new_sigma[sigma_base_idx + 1] /= total_weight;
    new_sigma[sigma_base_idx + 2] /= total_weight;
    new_sigma[sigma_base_idx + 3] /= total_weight;
    new_sigma[sigma_base_idx + 4] /= total_weight;
    new_sigma[sigma_base_idx + 5] /= total_weight;
}

// Host function to launch kernels
// Returns a tuple of tensors containing the updated means, opacities, spherical harmonics, and covariances
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
aggregate_gaussians(
    torch::Tensor weights, 
    torch::Tensor mu, 
    torch::Tensor opacity,
    torch::Tensor sh,
    torch::Tensor sigma, 
    torch::Tensor node_ids, 
    int num_gaussians,
    int num_new_nodes
) {

    // Assuming mu, sigma, and opacity have the same first dimension size as num_nodes
    // and sh has a second dimension for spherical harmonics coefficients.
    auto options = torch::TensorOptions().dtype(weights.dtype()).device(weights.device());

    int num_sh_coeffs = sh.size(1);

    // Initialize new tensors inside the function
    torch::Tensor new_mu = torch::zeros({num_new_nodes, mu.size(1)}, options);
    torch::Tensor new_node_total_weight = torch::zeros({num_new_nodes}, options);
    torch::Tensor new_opacity = torch::zeros({num_new_nodes}, options);
    torch::Tensor new_sh = torch::zeros({num_new_nodes, num_sh_coeffs}, options);
    torch::Tensor new_sigma = torch::zeros({num_new_nodes, sigma.size(1)}, options);

    // Kernel configurations
    dim3 blockSize(256);

    // Launch kernel to update means
    update_means_opacity_sh<<<((num_gaussians + blockSize.x - 1) / blockSize.x), blockSize>>>(
        weights.data_ptr<float>(), 
        mu.data_ptr<float>(), 
        opacity.data_ptr<float>(),
        sh.data_ptr<float>(),
        node_ids.data_ptr<int>(), 
        new_mu.data_ptr<float>(),
        new_node_total_weight.data_ptr<float>(),
        new_opacity.data_ptr<float>(),
        new_sh.data_ptr<float>(),
        num_gaussians,
        num_sh_coeffs
    );

    // Ensure all updates to new_mu are complete
    cudaDeviceSynchronize();

    // Divide by the total weight to normalize the weights
    divide_by_total_weight<<<((num_gaussians + blockSize.x - 1) / blockSize.x), blockSize>>>(
        new_mu.data_ptr<float>(), 
        new_node_total_weight.data_ptr<float>(), 
        new_opacity.data_ptr<float>(),
        new_sh.data_ptr<float>(),
        num_new_nodes,
        num_sh_coeffs
    );

    // Ensure all updates to new_mu are complete
    cudaDeviceSynchronize();

    // Launch kernel to update covariances
    update_covariances<<<((num_gaussians + blockSize.x - 1) / blockSize.x), blockSize>>>(
        weights.data_ptr<float>(), 
        mu.data_ptr<float>(), 
        sigma.data_ptr<float>(), 
        node_ids.data_ptr<int>(), 
        new_mu.data_ptr<float>(), 
        new_sigma.data_ptr<float>(),
        num_gaussians
    );

    // Ensure all updates to new_sigma are complete
    cudaDeviceSynchronize();

    // Divide the covariance by total weight to normalize the weights
    divide_sigma_by_total_weight<<<((num_gaussians + blockSize.x - 1) / blockSize.x), blockSize>>>(
        new_node_total_weight.data_ptr<float>(), 
        new_sigma.data_ptr<float>(),
        num_new_nodes
    );

    return std::make_tuple(new_mu, new_sigma, new_opacity, new_sh);
}
'''

In [None]:
cpp_src = r'''
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include <string>

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
aggregate_gaussians(
    torch::Tensor weights, 
    torch::Tensor mu, 
    torch::Tensor opacity,
    torch::Tensor sh,
    torch::Tensor sigma, 
    torch::Tensor node_ids, 
    int num_gaussians,
    int num_new_nodes
);
'''

In [None]:
module = load_inline(
    cuda_sources=[cuda_src], cpp_sources=[cpp_src], 
    functions=["aggregate_gaussians"],
    build_directory="aggregate",
    extra_cuda_cflags=[],
    verbose=True, name="my_cuda_extension",
)

In [None]:
%cd /scratch_net/biwidl214/ecetin_scratch/GSCodec

In [None]:
from submodules.octree_generation.jit_setup import setup
setup()

In [None]:
from typing import Tuple
import octree_generation

class OctreeGenerator(torch.autograd.Function):
    def __init__(self):
        super(OctreeGenerator, self).__init__()

    @staticmethod
    def forward(ctx, points3D: torch.Tensor, max_depth: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate an octree from a set of 3D points. The octree is axis-aligned and
        the bounding box is computed from the points. The octree is generated with a
        maximum depth of `max_depth`.
        
        Args:
            points3D (torch.Tensor): A tensor of shape (N, 3) representing the 3D points.
            max_depth (int): The maximum depth of the octree.
            
        Returns:
            torch.Tensor: A tensor of shape (N, 10) representing the octant each Gaussian
                belongs to in every level of the octree.
            torch.Tensor: A tensor of shape (N, `max_depth`) representing the node ids 
                of the octree nodes that contain the points.
        """
        aabb_min = torch.min(points3D, dim=0)[0]
        aabb_max = torch.max(points3D, dim=0)[0]

        box_d = aabb_max - aabb_min
        box_min = aabb_min - 0.1 * box_d
        box_max = aabb_max + 0.1 * box_d

        # Returns point_level_bboxes, point_node_assignment
        return octree_generation.generate_octree(points3D, box_min, box_max, max_depth)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None

In [None]:
def generate_octree(points: torch.Tensor, max_depth: int) -> Tuple[torch.Tensor, torch.Tensor]:
    return OctreeGenerator.apply(points, max_depth)

In [None]:
from config.build_config_spaces import ConfigReader

In [None]:
reader = ConfigReader("config/preset_configs/hierarchical_gaussian.yaml")
dataset = reader.dataset_config
pipeline = reader.pipeline_config
optimization = reader.optimization_config

In [None]:
from models.splatting.base_gaussian_model import BaseGaussianModel
from training.base_gaussian_trainer import BaseGaussianTrainer
from utils.general_utils import build_scales_and_quaternions_from_cov

In [None]:
dataset.data_path = dataset.data_path + "/tandt/train"
print(dataset.data_path)

In [None]:
gaussians = BaseGaussianModel(dataset=dataset)
trainer = BaseGaussianTrainer(
    dataset, optimization, pipeline, gaussians, logger=None, 
    checkpoint=None
)
trainer.restore("output/base/train/checkpoints/ckpt_30000.pth")

In [None]:
def assign_unique_values(data):
    # data should be a CUDA tensor of shape (N, 10)
    new_data = data.clone()
    N, C = data.shape
    for i in range(C):
        unique_values, new_indices = torch.unique(data[:, i], sorted=True, return_inverse=True)
        new_data[:, i] = new_indices  # Replace column with new indices

    return new_data

def calculate_weights(scales, opacity):
    radii = 3 * scales
    p = 1.6075
    a, b, c = radii.chunk(3, dim=1)
    # 4 * torch.pi * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3)**(1/p)
    weights = (
        opacity * ((a**p * b**p + a**p * c**p + b**p * c**p)**(1/p))
    )
    return weights

def recursive_aggregation(
    weights: torch.Tensor, 
    mu: torch.Tensor, 
    opacity: torch.Tensor, 
    sh: torch.Tensor, 
    sigma: torch.Tensor, 
    node_ids: torch.Tensor, 
    min_level: int, 
    max_level: int,
    return_all_levels: bool = False
):

    assert min_level <= max_level, "Minimum level should be less than or equal to maximum level"

    current_mu = mu
    current_opacity = opacity
    current_sh = sh
    current_sigma = sigma
    current_weights = weights
    
    # Current mapping between Gaussians and lowest level nodes
    # current_node_ids = node_ids[:, max_level - 1]
    sorted_node_ids = node_ids

    if return_all_levels:
        new_gaussian_mus = torch.empty(0, 3).cuda()
        new_gaussian_scales = torch.empty(0, 3).cuda()
        new_gaussian_rotations = torch.empty(0, 4).cuda()
        new_gaussian_opacities = torch.empty(0, 1).cuda()
        new_gaussian_shs = torch.empty(0, sh.size(1)).cuda()

    for level in range(max_level - 1, min_level-1, -1):
        # Map node ids to dense structure instead of sparse one
        new_node_ids_assigned = sorted_node_ids[:, level].type(torch.int32)
        num_nodes_at_level = new_node_ids_assigned.max() + 1

        # Call a CUDA function to perform aggregation
        new_mu, new_sigma, new_opacity, new_sh, new_total_weights = module.aggregate_gaussians(
            current_weights,
            current_mu,
            current_opacity,
            current_sh,
            current_sigma,
            new_node_ids_assigned.cuda(),
            current_mu.shape[0],
            num_nodes_at_level
        )

        # Debugging: Check if the means are correct
        # uniques, counts = torch.unique(new_node_ids_assigned, return_counts=True)
        # mask = new_node_ids_assigned == uniques[torch.argmax(counts)]
        # print(current_mu[mask].shape)
        # print(current_weights[mask].shape)
        # print((current_mu[mask] * current_weights[mask]).sum(dim=0) / current_weights[mask].sum())
        # print(new_mu[uniques[torch.argmax(counts)]])

        current_mu = new_mu
        current_sigma = new_sigma
        current_opacity = new_opacity
        current_sh = new_sh

        current_scales, current_rotations = build_scales_and_quaternions_from_cov(current_sigma)

        # Append new mu and sigma
        if return_all_levels:
            new_gaussian_mus = torch.cat(
                (new_gaussian_mus, current_mu), dim=0)
            new_gaussian_scales = torch.cat(
                (new_gaussian_scales, current_scales), dim=0)
            new_gaussian_rotations = torch.cat(
                (new_gaussian_rotations, current_rotations), dim=0)
            new_gaussian_opacities = torch.cat(
                (new_gaussian_opacities, current_opacity[:, None]), dim=0)
            new_gaussian_shs = torch.cat(
                (new_gaussian_shs, current_sh), dim=0)

        current_weights = calculate_weights(current_scales, current_opacity[:, None])

        sorted_node_ids = sorted_node_ids[new_node_ids_assigned.sort()[1]]

        # Find where the value changes in the specified column
        unique_indices = torch.cat((torch.tensor([True]).cuda(), sorted_node_ids[1:, -1] != sorted_node_ids[:-1, -1]))
        sorted_node_ids = sorted_node_ids[unique_indices, :-1]

    if return_all_levels:
        return new_gaussian_mus, new_gaussian_scales, new_gaussian_rotations, new_gaussian_opacities, new_gaussian_shs
    else:
        return current_mu, current_scales, current_rotations, current_opacity, current_sh

# Example usage
octree_max_depth = 15
octree_min_depth = 5
mu = gaussians._xyz.data
sigma = gaussians.get_covariance()
opacity = gaussians.get_opacity
sh = gaussians.get_features
num_sh_coeffs = sh.size(1)

scales, quaternions = build_scales_and_quaternions_from_cov(sigma)

weights = calculate_weights(gaussians.get_scaling, opacity)

# 3-sigma Gaussian ellipsoid surface approximation
_, gaussian_node_assignments = generate_octree(
    points=gaussians._xyz, max_depth=octree_max_depth
)

# Node ids were unique for the overall tensor, now we make them unique per depth level
# gaussian_node_assignments = assign_unique_values(gaussian_node_assignments)
unique_per_col_gaussian_node_assignments = assign_unique_values(gaussian_node_assignments)

final_mu, final_scale, final_rotation, final_opacity, final_sh = recursive_aggregation(
    weights, mu, opacity, sh, sigma, unique_per_col_gaussian_node_assignments, 
    octree_min_depth, octree_max_depth, return_all_levels=True
)

In [None]:
print(final_mu.min(), final_mu.max())
print(final_scale.min(), final_scale.max())
print(final_rotation.min(), final_rotation.max())
print(final_opacity.min(), final_opacity.max())
print(final_sh.min(), final_sh.max())

In [None]:
import torch
import plotly.graph_objects as go

# Sample data in PyTorch tensor
points = final_mu[::10].detach().cpu()
print(points.shape[0])

# Convert PyTorch tensor to NumPy arrays
x = points[:, 0].numpy()
y = points[:, 1].numpy()
z = points[:, 2].numpy()

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=1,
        color=z,  # color points by Z value
        colorscale='Viridis',  # choose a colorscale
        opacity=0.8
    )
)])

# Add titles and labels
fig.update_layout(
    title='3D Point Cloud',
    scene=dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis',
        camera=dict(
            eye=dict(x=1, y=1, z=0.5),  # Adjust x, y, z to change the viewpoint
            up=dict(x=0, y=0, z=1),     # Keeping the z-axis pointing up
            center=dict(x=0, y=0, z=0)  # Camera is looking at the origin
        )
    )
)

# Show plot
fig.show()


In [None]:
import torch
import plotly.graph_objects as go

# Sample data in PyTorch tensor
points = gaussians._xyz[::2].detach().cpu()
print(points.shape[0])

# Convert PyTorch tensor to NumPy arrays
x = points[:, 0].numpy()
y = points[:, 1].numpy()
z = points[:, 2].numpy()

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=1,
        color=z,  # color points by Z value
        colorscale='Viridis',  # choose a colorscale
        opacity=0.8
    )
)])

# Add titles and labels
fig.update_layout(
    title='3D Point Cloud',
    scene=dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis',
        camera=dict(
            eye=dict(x=1, y=1, z=0.5),  # Adjust x, y, z to change the viewpoint
            up=dict(x=0, y=0, z=1),     # Keeping the z-axis pointing up
            center=dict(x=0, y=0, z=0)  # Camera is looking at the origin
        )
    )
)

# Show plot
fig.show()


In [1]:
%cd /scratch_net/biwidl214/ecetin_scratch/GSCodec

/scratch_net/biwidl214/ecetin_scratch/GSCodec


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [4]:
import torch
from submodules.gaussian_aggregation.jit_setup import setup as setup_gaussian_aggregation
from submodules.octree_generation.jit_setup import setup as setup_octree_generation

setup_octree_generation()
setup_gaussian_aggregation()

Using /home/ecetin/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
No modifications detected for re-loaded extension module octree_generation, skipping build step...
Loading extension module octree_generation...
Using /home/ecetin/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
No modifications detected for re-loaded extension module gaussian_aggregation, skipping build step...
Loading extension module gaussian_aggregation...


In [5]:
import gaussian_aggregation

class AggregationFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weights, xyzs, opacities, shs, sigmas, node_ids):
        num_gaussians = xyzs.shape[0]
        num_nodes_at_level = node_ids.max() + 1

        new_xyzs, new_sigmas, new_opacities, new_shs, new_node_total_weights = (
            gaussian_aggregation.aggregate_gaussians_forward(
                weights, xyzs, opacities, shs, sigmas, node_ids, num_gaussians, num_nodes_at_level
            )
        )

        # Store for backward
        ctx.save_for_backward(
            weights, xyzs, new_xyzs, opacities, shs, sigmas, node_ids, new_node_total_weights
        )
        ctx.num_gaussians = num_gaussians
        ctx.num_nodes_at_level = num_nodes_at_level

        return new_xyzs, new_sigmas, new_opacities, new_shs

    @staticmethod
    def backward(ctx, grad_new_xyzs, grad_new_sigmas, grad_new_opacities, grad_new_shs):
        # Gradient calculations

        # Retrieve saved tensors
        (weights, mu, new_mu, opacity, sh, sigma, node_ids, node_total_weights) = ctx.saved_tensors

        num_gaussians = ctx.num_gaussians
        num_nodes_at_level = ctx.num_nodes_at_level

        # Calculate gradients for inputs based on grad_outputs
        grad_weights, grad_mu, grad_opacity, grad_sh, grad_sigma = (
            gaussian_aggregation.aggregate_gaussians_backward(
                grad_new_xyzs,
                grad_new_sigmas,
                grad_new_opacities,
                grad_new_shs,
                node_total_weights,
                weights,
                mu,
                new_mu,
                opacity,
                sh,
                sigma,
                node_ids,
                num_gaussians,
                num_nodes_at_level,
            )
        )

        return grad_weights, grad_mu, grad_opacity, grad_sh, grad_sigma, None

In [6]:
from typing import Tuple
import octree_generation

class OctreeGenerator(torch.autograd.Function):
    def __init__(self):
        super(OctreeGenerator, self).__init__()

    @staticmethod
    def forward(ctx, points3D: torch.Tensor, max_depth: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate an octree from a set of 3D points. The octree is axis-aligned and
        the bounding box is computed from the points. The octree is generated with a
        maximum depth of `max_depth`.

        Args:
            points3D (torch.Tensor): A tensor of shape (N, 3) representing the 3D points.
            max_depth (int): The maximum depth of the octree.

        Returns:
            torch.Tensor: A tensor of shape (N, 10) representing the octant each Gaussian
                belongs to in every level of the octree.
            torch.Tensor: A tensor of shape (N, `max_depth`) representing the node ids
                of the octree nodes that contain the points.
        """
        aabb_min = torch.min(points3D, dim=0)[0]
        aabb_max = torch.max(points3D, dim=0)[0]

        box_d = aabb_max - aabb_min
        box_min = aabb_min - 0.1 * box_d
        box_max = aabb_max + 0.1 * box_d

        # Returns point_level_bboxes, point_node_assignment
        return octree_generation.generate_octree(points3D, box_min, box_max, max_depth)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None


def generate_octree(points: torch.Tensor, max_depth: int) -> Tuple[torch.Tensor, torch.Tensor]:
    return OctreeGenerator.apply(points, max_depth)

In [7]:
from utils.general_utils import build_scales_and_quaternions_from_cov, build_covariance_from_scaling_rotation

def calculate_weights(scales, opacity):
    radii = 3 * scales
    p = 1.6075
    a, b, c = radii.chunk(3, dim=1)
    # 4 * torch.pi * ((a**p * b**p + a**p * c**p + b**p * c**p) / 3)**(1/p)
    weights = opacity * ((a**p * b**p + a**p * c**p + b**p * c**p) ** (1 / p))
    return weights


def assign_unique_values(data):
    # data should be a CUDA tensor of shape (N, 10)
    new_data = data.clone()
    N, C = data.shape
    for i in range(C):
        unique_values, new_indices = torch.unique(data[:, i], sorted=True, return_inverse=True)
        new_data[:, i] = new_indices  # Replace column with new indices

    return new_data


def aggregate_gaussians_recursively(
    weights: torch.Tensor,
    mu: torch.Tensor,
    opacity: torch.Tensor,
    sh: torch.Tensor,
    sigma: torch.Tensor,
    node_ids: torch.Tensor,
    min_level: int,
    max_level: int,
    return_all_levels: bool = False,
):

    assert min_level <= max_level, "Minimum level should be less than or equal to maximum level"

    current_mu = mu
    current_opacity = opacity
    current_sh = sh
    current_sigma = sigma
    current_weights = weights
    current_rotations = None
    current_scales = None

    # Current mapping between Gaussians and lowest level nodes
    # current_node_ids = node_ids[:, max_level - 1]
    sorted_node_ids = node_ids

    if return_all_levels:
        new_gaussian_mus = torch.empty(0, 3).cuda()
        # new_gaussian_scales = torch.empty(0, 3).cuda()
        # new_gaussian_rotations = torch.empty(0, 4).cuda()
        new_gaussian_sigmas = torch.empty(0, 6).cuda()
        new_gaussian_opacities = torch.empty(0, 1).cuda()
        new_gaussian_shs = torch.empty(0, sh.size(1) * 3).cuda()

    for level in range(max_level - 1, min_level - 1, -1):
        # Map node ids to dense structure instead of sparse one
        new_node_ids_assigned = sorted_node_ids[:, level].type(torch.int32)

        # Call function to perform aggregation
        new_mu, new_sigma, new_opacity, new_sh = AggregationFunction.apply(
            current_weights,
            current_mu,
            current_opacity,
            current_sh,
            current_sigma,
            new_node_ids_assigned,
        )

        # Debugging: Check if the means are correct
        # uniques, counts = torch.unique(new_node_ids_assigned, return_counts=True)
        # mask = new_node_ids_assigned == uniques[torch.argmax(counts)]
        # print(current_mu[mask].shape)
        # print(current_weights[mask].shape)
        # print((current_mu[mask] * current_weights[mask]).sum(dim=0) / current_weights[mask].sum())
        # print(new_mu[uniques[torch.argmax(counts)]])

        current_mu = new_mu
        current_sigma = new_sigma
        current_opacity = new_opacity
        current_sh = new_sh

        # Return covariance directly:
        current_scales, current_rotations = build_scales_and_quaternions_from_cov(current_sigma)

        # Append new mu and sigma
        if return_all_levels:
            new_gaussian_mus = torch.cat((new_gaussian_mus, current_mu), dim=0)
            new_gaussian_scales = torch.cat((new_gaussian_scales, current_scales), dim=0)
            new_gaussian_rotations = torch.cat((new_gaussian_rotations, current_rotations), dim=0)
            new_gaussian_sigmas = torch.cat((new_gaussian_sigmas, current_sigma), dim=0)
            new_gaussian_opacities = torch.cat(
                (new_gaussian_opacities, current_opacity[:, None]), dim=0
            )
            new_gaussian_shs = torch.cat((new_gaussian_shs, current_sh), dim=0)

        current_weights = calculate_weights(current_scales, current_opacity[:, None])

        sorted_node_ids = sorted_node_ids[new_node_ids_assigned.sort()[1]]

        # Find where the value changes in the specified column
        unique_indices = torch.cat(
            (torch.tensor([True]).cuda(), sorted_node_ids[1:, -1] != sorted_node_ids[:-1, -1])
        )
        sorted_node_ids = sorted_node_ids[unique_indices, :-1]

    # Returning current weights for debugging purposes
    if return_all_levels:
        return (
            new_gaussian_mus,
            new_gaussian_scales,
            new_gaussian_rotations,
            new_gaussian_sigmas,
            new_gaussian_opacities,
            new_gaussian_shs,
        )
    else:
        return (
            current_mu,
            current_scales,
            current_rotations,
            current_sigma,
            current_opacity,
            current_sh,
        )

In [None]:
N = 10
R = 50
xyzs = torch.randn((N, 3)).cuda() * 50
scales = torch.randn((N, 3)).cuda() * 20
rotations = torch.randn((N, 4)).cuda()
opacities = torch.sigmoid(torch.randn((N, 1)).cuda() * 10)
shs = torch.randn((N, 16, 3)).cuda()
sigmas = build_covariance_from_scaling_rotation(scales, 1.0, rotations)
weights = torch.ones(N, 1).cuda()

octree_max_depth = 1

# 3-sigma Gaussian ellipsoid surface approximation
_, gaussian_node_assignments = generate_octree(points=xyzs, max_depth=octree_max_depth)

# Node ids were unique for the overall tensor, now we make them unique per depth level
# gaussian_node_assignments = assign_unique_values(gaussian_node_assignments)
unique_per_col_gaussian_node_assignments = assign_unique_values(
    gaussian_node_assignments
).contiguous()

(
    current_mu,
    current_scales,
    current_rotations,
    current_sigma,
    current_opacity,
    current_sh,
) = aggregate_gaussians_recursively(
    weights,
    xyzs,
    opacities,
    shs,
    sigmas,
    unique_per_col_gaussian_node_assignments,
    0,
    octree_max_depth,
    return_all_levels=False,
)

In [104]:
new_means = torch.empty(0, 3).cuda()

for i in range(unique_per_col_gaussian_node_assignments.unique().shape[0]):

    mean = (xyzs[(unique_per_col_gaussian_node_assignments == i)[:, 0]] * weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]]).sum(0) \
        / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
    
    new_means = torch.cat((new_means, mean[None]), dim=0)

new_sigmas = torch.zeros((new_means.shape[0], 6)).cuda()

for i in range(unique_per_col_gaussian_node_assignments.unique().shape[0]):

    diffs = (xyzs[(unique_per_col_gaussian_node_assignments == i)[:, 0]] - new_means[i])
    diffs_x = diffs[:, 0]
    diffs_y = diffs[:, 1]
    diffs_z = diffs[:, 2]

    sigma = torch.zeros(6).cuda()

    for j in range(diffs.shape[0]):
        sigma[0] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 0] + diffs_x[j] * diffs_x[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        sigma[1] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 1] + diffs_x[j] * diffs_y[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        sigma[2] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 2] + diffs_x[j] * diffs_z[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        sigma[3] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 3] + diffs_y[j] * diffs_y[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        sigma[4] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 4] + diffs_y[j] * diffs_z[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        sigma[5] += weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]][j, 0] * \
            (sigmas[j, 5] + diffs_z[j] * diffs_z[j]) / weights[(unique_per_col_gaussian_node_assignments == i)[:, 0]].sum()
        
    new_sigmas[i] = sigma

print(new_means)
print(new_sigmas)

tensor([[-1.1150, -0.9386, -1.2061],
        [ 1.2101, -1.3982, -1.2553],
        [-0.9554,  1.1580, -0.7895],
        [ 0.6118,  0.7821, -1.0802],
        [-0.9496, -1.0083,  0.4498],
        [ 0.6847, -0.9190,  0.4641],
        [-0.7937,  0.9383,  0.4655],
        [ 0.8571,  0.4280,  0.6931]], device='cuda:0')
tensor([[ 0.7775,  0.0474, -0.1793,  0.6820,  0.0602,  0.9855],
        [ 1.6438,  0.5553, -0.3258,  1.4254,  0.1036,  1.6146],
        [ 1.2818,  0.3129, -0.6469,  0.8555, -0.1525,  1.1421],
        [ 1.6533,  0.4284, -0.3207,  1.3396, -0.0149,  1.5847],
        [ 1.3409,  0.2237, -0.4443,  0.9981,  0.0648,  1.4182],
        [ 1.2297,  0.2280, -0.2533,  1.1439,  0.0292,  1.2440],
        [ 1.3946,  0.2321, -0.2262,  1.3414, -0.1972,  1.2725],
        [ 1.3676,  0.2132, -0.4252,  0.8227, -0.0582,  1.1683]],
       device='cuda:0')


In [105]:
sigmas

tensor([[ 1.8007e+00,  2.9974e-01, -8.2606e-01,  9.8344e-01, -2.6561e-01,
          1.0607e+00],
        [ 4.2579e-02,  1.1345e-01,  1.7128e-01,  8.6590e-01,  1.1469e+00,
          1.5522e+00],
        [ 1.7446e-01,  1.3222e-02,  3.2376e-02,  2.0183e-01,  7.8317e-02,
          4.3991e-02],
        [ 5.2014e-01, -1.0774e-01, -1.3292e-01,  3.5611e-02,  1.1307e-01,
          8.4713e-01],
        [ 6.2518e-01,  7.6846e-02, -1.4033e-01,  6.2591e-01, -2.2500e-01,
          3.4635e-01],
        [ 5.5271e-01, -6.5183e-02, -1.0296e-02,  1.8329e-01,  2.2725e-01,
          3.8100e-01],
        [ 4.0767e+00,  1.6006e+00, -2.8158e+00,  1.0559e+00, -1.2280e+00,
          2.7741e+00],
        [ 2.4729e+00,  1.3487e+00, -1.2702e+00,  8.7487e-01, -8.0113e-01,
          9.5965e-01],
        [ 8.3145e-03,  3.0544e-02,  4.7372e-03,  2.2030e-01,  1.0711e-01,
          7.8924e-02],
        [ 3.7630e+00,  1.5963e-01, -4.3913e-01,  3.1914e+00, -6.2918e-01,
          3.0263e+00],
        [ 1.4791e+00,  3.6491e

In [106]:
new_means

tensor([[-1.1150, -0.9386, -1.2061],
        [ 1.2101, -1.3982, -1.2553],
        [-0.9554,  1.1580, -0.7895],
        [ 0.6118,  0.7821, -1.0802],
        [-0.9496, -1.0083,  0.4498],
        [ 0.6847, -0.9190,  0.4641],
        [-0.7937,  0.9383,  0.4655],
        [ 0.8571,  0.4280,  0.6931]], device='cuda:0')

In [111]:
current_sigma

tensor([[ 0.6002,  0.0053,  0.4492,  0.5337,  0.0951,  0.9660],
        [ 1.3554,  0.3313,  0.2212,  1.4141,  0.0587,  0.9869],
        [ 0.6641, -0.0925, -0.1752,  1.1153,  0.0428,  0.5854],
        [ 1.9483,  0.2515,  0.2096,  2.1152, -0.2107,  1.5738],
        [ 0.7485,  0.0105, -0.1144,  1.3327,  0.3460,  1.3714],
        [ 1.2583,  0.0675, -0.3408,  1.1805,  0.0855,  1.5241],
        [ 1.0597,  0.0228, -0.0241,  1.1374, -0.0286,  1.1231],
        [ 1.8036, -0.3376,  0.1089,  1.7728,  0.0638,  2.1481]],
       device='cuda:0')

In [135]:
sigma = torch.zeros(6, 1).cuda()

group_id = 0
grouped_xyzs = xyzs[(unique_per_col_gaussian_node_assignments == group_id)[:, 0]]
grouped_sigmas = sigmas[(unique_per_col_gaussian_node_assignments == group_id)[:, 0]]
grouped_weights = weights[(unique_per_col_gaussian_node_assignments == group_id)[:, 0]]
for i in range(grouped_xyzs.shape[0]):
    sigma[0] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 0] + (grouped_xyzs[i] - new_means[group_id])[0] * (grouped_xyzs[i] - new_means[group_id])[0])
    sigma[1] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 1] + (grouped_xyzs[i] - new_means[group_id])[0] * (grouped_xyzs[i] - new_means[group_id])[1])
    sigma[2] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 2] + (grouped_xyzs[i] - new_means[group_id])[0] * (grouped_xyzs[i] - new_means[group_id])[2])
    sigma[3] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 3] + (grouped_xyzs[i] - new_means[group_id])[1] * (grouped_xyzs[i] - new_means[group_id])[1])
    sigma[4] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 4] + (grouped_xyzs[i] - new_means[group_id])[1] * (grouped_xyzs[i] - new_means[group_id])[2])
    sigma[5] += (grouped_weights[i] / grouped_weights.sum()) * \
        (grouped_sigmas[i, 5] + (grouped_xyzs[i] - new_means[group_id])[2] * (grouped_xyzs[i] - new_means[group_id])[2])

print(sigma.squeeze(-1))

tensor([0.6002, 0.0053, 0.4492, 0.5337, 0.0951, 0.9660], device='cuda:0')


In [None]:
# 2, 5