<a href="https://colab.research.google.com/github/kampelmuehler/2DGaussianSplatting/blob/main/2DGaussianSplatting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook provides a vanilla implementation of Gaussian splatting in 2D, which allows to reconstruct an image using a fixed number of 2D Gaussian functions.

The implementation is basically following [3DGS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/), but omits any sort of optimization. Thus, training is comparatively slow and resolution as well as number of Gaussian functions is limited by system/accelerator memory.

In the following cell, enter a url to an image you want to approximate as `target_image_url`.



In [None]:
target_image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Carl_Friedrich_Gauss_1840_by_Jensen.jpg/800px-Carl_Friedrich_Gauss_1840_by_Jensen.jpg'
target_image_path = 'target'
!wget {target_image_url} -O {target_image_path}

The following cell holds the main logic implementing basic Gaussian splatting in 2D following the equations from [3DGS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/)

In [4]:
import torch
import torch.nn.functional as F

class GaussianSplatting2d(torch.nn.Module):
    def __init__(self, n_gaussians, width, height, device, image=None):
        """ initialize GaussianSplatting2d
            Args:
                n_gaussians: the number of Gaussians to use to approximate an image.
                width: width of the image to approximate
                height: height of the image to approximate
                device: torch.device to run on
                image: torch.tensor of target image in range [0, 1].
                       if given, initialize colors of Gaussians to closest pixel value in target image.
            Returns:
                None
        """
        super().__init__()
        xs = torch.linspace(0, 1, steps=width)
        ys = torch.linspace(0, 1, steps=height)
        x, y = torch.meshgrid(xs, ys, indexing='xy')
        self.X = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], -1).view(-1, 1, 2, 1).to(device)
        self.width = width
        self.height = height
        self.device = device

        # parameters of Gaussians
        self.scales = torch.nn.Parameter(torch.logit(torch.rand((n_gaussians, 2)) * 0.1), requires_grad=True)
        self.rotation_angles = torch.nn.Parameter(torch.logit(torch.rand((n_gaussians, 1))), requires_grad=True)
        positions = torch.rand((n_gaussians, 2, 1))
        self.positions = torch.nn.Parameter(torch.logit(positions), requires_grad=True)
        self.rgbas = torch.rand((n_gaussians, 4))
        self.rgbas[:, 3] = (self.rgbas[:, 3] + 0.1) / 1.1  # avoid fully transparent initial alphas
        if image is not None:
            positions = torch.round(positions.squeeze() * torch.tensor([[image.shape[1] - 1, image.shape[0] - 1]])).int()
            self.rgbas[:, :3] = image[positions[:, 1], positions[:, 0], :]
        self.rgbas = torch.nn.Parameter(torch.logit(self.rgbas), requires_grad=True)

    def covariance_matrices(self):
        """ Calculate covariance matrices given rotation angles and scale
            parameters. cf. Eq. 6 in 3DGS paper
            Returns:
                tensor containing per Gaussian covariance matrices
        """
        scale_matrices = torch.diag_embed(F.sigmoid(self.scales))
        cosines = torch.cos(F.sigmoid(self.rotation_angles) * torch.pi)
        sines = torch.sin(F.sigmoid(self.rotation_angles) * torch.pi)
        rot_matrices = torch.cat([cosines, -sines, sines, cosines], 1).reshape(-1, 2, 2)
        return rot_matrices @ scale_matrices @ torch.transpose(scale_matrices, -2, -1) @ torch.transpose(rot_matrices, -2, -1)

    def render(self):
        """ Densely evaluate all the Gaussians to generate image
            Returns:
                image tensor generated from the Gaussian parameters
        """
        # calculate the gaussian weights
        x = (self.X - F.sigmoid(self.positions))
        gaussians = torch.exp(-0.5 * torch.transpose(x, -2, -1) @ torch.linalg.solve(self.covariance_matrices(), x))
        # "normalize" individual Gaussians to max 1
        norm_gaussians = gaussians / gaussians.max()
        # calculate alpha_i (cf. Eq. 3 in 3DGS paper)
        alpha_is = norm_gaussians.squeeze() * F.sigmoid(self.rgbas[:, 3])
        # blending (cf. Eq. 3 in 3DGS paper)
        product = torch.cumprod(torch.cat([torch.ones((alpha_is.shape[0], 1), device=self.device), (1 - alpha_is)[..., :-1]], -1), 1).unsqueeze(-1).to(self.device)
        rgb = F.sigmoid(self.rgbas[:, :3]) * alpha_is.unsqueeze(-1) * product
        rgb = rgb.sum(1)
        return rgb.view(self.height, self.width, 3)

The following cell holds the optimization loop and visualization functionality.

In [None]:
import matplotlib.pyplot as plt
import cv2
from IPython.display import clear_output

NUMBER_OF_OPTIMIZATION_STEPS = 1000
NUMBER_OF_GAUSSIANS = 2000
MAX_IMAGE_RESOLUTION = 256

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load image and resize to max MAX_IMAGE_RESOLUTION
image = cv2.imread(target_image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
resize_factor = min(MAX_IMAGE_RESOLUTION / max(image.shape[:2]), 1)
if resize_factor < 1.:
    image = cv2.resize(image, (int(image.shape[1] * resize_factor), int(image.shape[0] * resize_factor)), interpolation=cv2.INTER_AREA)
image = torch.tensor(image).float().to(device)
image = image / 255

# init gaussians and optimizer
gs = GaussianSplatting2d(NUMBER_OF_GAUSSIANS, image.shape[1], image.shape[0], device, image).to(device)
optimizer = torch.optim.Adam(gs.parameters(), lr=.01)

generated_images = []
for i in range(NUMBER_OF_OPTIMIZATION_STEPS):
    optimizer.zero_grad()
    rgb = gs.render()
    loss = F.l1_loss(rgb, image)
    loss.backward()
    optimizer.step()
    # visualize
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(image.cpu().numpy())
    generated_images.append(torch.clip(rgb, 0, 1).cpu().detach().numpy())
    axs[1].imshow(generated_images[-1])
    axs[0].set_title(f'Target ({rgb.shape[0] * rgb.shape[1] / 1000:.01f}k px)')
    axs[1].set_title(f'Step: {i}, loss: {loss.item():.04f} ({NUMBER_OF_GAUSSIANS / 1000:.01f}k Gaussians)')
    axs[0].set_axis_off()
    axs[1].set_axis_off()
    plt.show()
    clear_output(wait=True)


You can use the following code to render the individual images of the optimization stage to create a video of the optimization process, such as [this](https://www.youtube.com/watch?v=DLKLgWZ-BGk). Note that this has not been tested within colab.

You can subsequently use `ffmpeg` to create a video from the exported images. For example:

```ffmpeg -framerate 100 -i %04d.jpg -c:v libx264 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2 2dgaussiansplatting.mp4```

(run inside the directory containing the exported images).

In [6]:
import numpy as np
from pathlib import Path

out_dir = Path('out')
out_dir.mkdir(exist_ok=True)
image = cv2.imread(target_image_path)
resize_factor = min(MAX_IMAGE_RESOLUTION / max(image.shape[:2]), 1)
if resize_factor < 1.:
    image = cv2.resize(image, (int(image.shape[1] * resize_factor), int(image.shape[0] * resize_factor)), interpolation=cv2.INTER_AREA)
for i, img in enumerate(generated_images):
    current_image = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite(str(out_dir / f'{i:04d}.jpg'), cv2.hconcat([image, current_image]))