# Rendering for Gaussian Splatting

In this notebook, we will implement a rendering algorithm for [Gaussian Splatting.](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf)

The rendering algorithm we are going to implement takes projected gaussian splats as input and returns an image. To extend the algorithm for 3D gaussian splats, one should additionally implement method for projecting 3D splats onto a camera.

The implementation will be based purely on PyTorch, whereas the common impelemntations rely on custom CUDA/OpenGL code for higher efficiency. This notebook partially adopts code [this](https://github.com/hbb1/torch-splatting) implementation.

## Classes for Camera and 2D Point Cloud

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pylab as plt
import cv2

from IPython.display import clear_output

class Camera(nn.Module):
    def __init__(self, width=256, height=256):
        super().__init__()
        self.image_width = 256
        self.image_height = 256

class PointCloud2D(nn.Module):
    def __init__(self, n_splats, camera):
        super().__init__()
        initial_parameters = self._initialize_parameters(n_splats, camera)
        for parameter_name, tensor in initial_parameters.items():
            self.register_parameter(parameter_name, nn.Parameter(tensor))
        # In this example, we will not be tuning depth
        self.register_buffer('_depths', torch.rand(n_splats))

    def _initialize_parameters(self, n_splats, camera):
        init_parameters = {}
        init_parameters['_means2d'] = torch.stack(
            [camera.image_width * torch.rand(n_splats),
             camera.image_height * torch.rand(n_splats)],
            dim=-1)
        init_parameters['_scaling'] = torch.randn(n_splats, 2) + 15.
        init_parameters['_rotation_angle'] = 2 * torch.pi * torch.rand(n_splats)
        init_parameters['_colors'] = torch.rand(n_splats, 3)
        init_parameters['_opacity'] = torch.randn(n_splats)
        return init_parameters

    def forward(self):
        cos_phi = self._rotation_angle.cos()
        sin_phi = self._rotation_angle.sin()
        rotation_matrix = torch.stack(
            [torch.stack([cos_phi, -sin_phi], dim=-1), 
             torch.stack([sin_phi, cos_phi], dim=-1)],
            dim=-1,
        )
        scaling_dianogal = F.softplus(self._scaling)
        cov2d = torch.einsum(
            'bij,bj,bkj->bik',
            rotation_matrix,
            scaling_dianogal,
            rotation_matrix
        )
        opacity = F.sigmoid(self._opacity)
        output_dict = {
            'means2d': self._means2d,
            'cov2d': cov2d,
            'colors': self._colors,
            'opacity': opacity,
            'depths': self._depths
        }
        return output_dict

Our goal today is to write an algorithm that outputs something similar to `plt.scatter`.

In [None]:
n_splats = 1024
camera = Camera()
pc = PointCloud2D(n_splats, camera)

pc_outputs = pc()
means2d = pc_outputs['means2d'].detach()
colors = pc_outputs['colors'].detach()

plt.figure(figsize=(5, 5))
plt.scatter(
    means2d[:, 0],
    camera.image_height - means2d[:, 1],
    #c=colors,
    s=128
)
plt.xlim(0, camera.image_width)
plt.ylim(0, camera.image_height)
plt.axis('off');

This is the renderer implementation. You goal is to implement methods `get_alpha` and `compose_alpha` for rendering. Additionally, you should implement method `get_raduis` to improve rendering efficiency. The method computes the radius of a circle contatining the Gaussian.

We will denote number of pixels as `P` and number of Gaussians as `G`.

In [None]:
class GaussianRenderer(nn.Module):
    def __init__(self, camera, background_color=None):
        super().__init__()
        self.image_width = camera.image_width
        self.image_height = camera.image_height
        self.register_buffer(
            'pix_coord',
            torch.stack(
                torch.meshgrid(torch.arange(self.image_width),
                               torch.arange(self.image_height),
                               indexing='xy'),
                dim=-1)
        )
        if background_color is None:
            background_color = torch.ones(3)
        self.register_buffer('background_color', background_color)
        self.tile_size = 64

    @torch.no_grad()
    def get_radius(self, cov2d):
        """
        cov2d is a tensor of shape [G, 2, 2] 

        output tensor shape is expected to be [G]
        """
        # TODO
        return 100 * torch.ones_like(cov2d[:, 0, 0])
        
    @torch.no_grad()
    def get_rect(self, means2d, radii):
        rect_min = (means2d - radii[:,None])
        rect_max = (means2d + radii[:,None])
        rect_min[..., 0] = rect_min[..., 0].clip(0, self.image_width - 1.0)
        rect_min[..., 1] = rect_min[..., 1].clip(0, self.image_height - 1.0)
        rect_max[..., 0] = rect_max[..., 0].clip(0, self.image_width - 1.0)
        rect_max[..., 1] = rect_max[..., 1].clip(0, self.image_height - 1.0)
        return rect_min, rect_max
    
    def get_in_mask(self, rect, w, h):
        over_tl = (
            rect[0][..., 0].clip(min=w),
            rect[0][..., 1].clip(min=h)
        )
        over_br = (
            rect[1][..., 0].clip(max=w+self.tile_size-1),
            rect[1][..., 1].clip(max=h+self.tile_size-1)
        )
        in_mask = (over_br[0] > over_tl[0]) & (over_br[1] > over_tl[1]) # gaussian in the tile 
        return in_mask

    def get_alpha(self, x, means2d, cov2d, opacity):
        """
        x is a tensor with shape [P x 2]
        means2d is a tensor with shape [G x 2]
        cov2 is a tensor with shape [G x 2 x 2]
        opacity is a tensor with shape [G]

        output tensor shape is expected to be [B, P]
        """
        # TODO
        return alpha

    def compose_alpha(self, alpha, values):
        """
        alpha is a tensor with shape [B, P]
        values is a tensor with shape [P, C]

        output tensor shapes are expected to be [B, C] and [B]
        """
        # TODO
        return composed_values, accumulated_alpha
        
    def forward(self, camera, means2d, cov2d, colors, opacity, depths):
        radii = self.get_radius(cov2d)
        rect = self.get_rect(means2d, radii)
        render_color = []
        render_alpha = []
        for h in range(0, self.image_height, self.tile_size):
            render_color_h = []
            render_alpha_h = []
            for w in range(0, self.image_width, self.tile_size):
                in_mask = self.get_in_mask(rect, w, h)
                if not torch.any(in_mask):
                    continue

                # cull and sort
                tile_depths, index = torch.sort(depths[in_mask])
                cull_and_sort = lambda x: x[in_mask][index]
                tile_means2d = cull_and_sort(means2d)
                tile_cov2d = cull_and_sort(cov2d) # G 2 2
                tile_opacity = cull_and_sort(opacity)
                tile_colors = cull_and_sort(colors)
                
                tile_coord = self.pix_coord[h:h+self.tile_size, w:w+self.tile_size].flatten(0,-2)
                alpha = self.get_alpha(
                    tile_coord,
                    tile_means2d,
                    tile_cov2d,
                    tile_opacity
                )

                tile_color, acc_alpha = self.compose_alpha(alpha, tile_colors)
                tile_color = tile_color + (1 - acc_alpha) * self.background_color
                tile_depth, _ = self.compose_alpha(alpha, tile_depths)
                
                render_color_h.append(tile_color.reshape(self.tile_size, self.tile_size, -1))
                render_alpha_h.append(acc_alpha.reshape(self.tile_size, self.tile_size, -1))
            render_color.append(torch.cat(render_color_h, dim=1))
            render_alpha.append(torch.cat(render_alpha_h, dim=1))
        render_color = torch.cat(render_color, dim=0)
        render_alpha = torch.cat(render_alpha, dim=0)
        return render_color, render_alpha

In [None]:
n_splats = 1024
camera = Camera()
pc = PointCloud2D(n_splats, camera)
renderer = GaussianRenderer(camera)

pc_outputs = pc()
means2d = pc_outputs['means2d'].detach()
colors = pc_outputs['colors'].detach()

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].scatter(
    means2d[:, 0],
    camera.image_height - means2d[:, 1],
    c=colors,
    s=128
)
ax[0].set_xlim(0, camera.image_width)
ax[0].set_ylim(0, camera.image_height)
ax[0].axis('off')
ax[0].set_title('Matplotlib Scatter Output')

color, _ = renderer(camera, **pc_outputs)

ax[1].imshow(color.detach().cpu())
ax[1].axis('off')
ax[1].set_title('Gausian Renderer Output');

let's try to train it

In [None]:
img = cv2.imread('dragon_2d.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
plt.imshow(img)
plt.axis('off')

gt_image = torch.from_numpy(img).to('cuda:0') / 256

In [None]:
@torch.no_grad()
def plot_current_outputs(gt_image, image, alpha):
    loss = ((gt_image - image) ** 2).mean()
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    prepare_image = lambda x: x.detach().clamp(0., 1.).cpu()
    ax[0].set_title('Target')
    ax[0].imshow(prepare_image(gt_image))
    ax[0].axis('off')
    ax[1].set_title(f'Gaussian Splats, $L_2$={loss.item():.3f}')
    ax[1].imshow(prepare_image(image))
    ax[1].axis('off')
    ax[2].set_title(r'$\alpha$')
    ax[2].imshow(prepare_image(alpha), cmap='inferno')
    ax[2].axis('off')
    clear_output(wait=True)
    plt.show()  

In [None]:
pc = PointCloud2D(1024, camera)
pc = pc.to('cuda:0')
renderer = GaussianRenderer(camera)
renderer.to('cuda:0')
optim = torch.optim.Adam(pc.parameters(), lr=1e-2)

try:
    for i in range(1024):
        colors, alpha = renderer(camera, **pc())
        loss = ((gt_image - colors) ** 2).mean()
        loss.backward()
        optim.step()
        optim.zero_grad()
        if i % 16 == 0:
            plot_current_outputs(gt_image, colors, alpha)
except KeyboardInterrupt:
    plot_current_outputs(gt_image, colors, alpha)