# Optimize Gaussian Parameters in 2D for Synthetic Data and Image

In this notebook, we'll be optimizing for 2D Gaussians to represent an image, while building up the intuition that generalizes to 3D Gaussian Splatting. 

In Part 1, we'll be targeting a "synthetic" image (originally reconstructed from Gaussians).
In Part 2, we'll be optimizing for an actual target image.

For inspiration, here are a couple of cool examples on shadertoy of what you could do with your 2D Gaussian parameters:
- [https://www.shadertoy.com/view/tflXRB](https://www.shadertoy.com/view/tflXRB)
- [https://www.shadertoy.com/view/dtSfDD](https://www.shadertoy.com/view/dtSfDD)
- [https://www.shadertoy.com/view/MdfGDH](https://www.shadertoy.com/view/MdfGDH)
- [https://www.shadertoy.com/view/4df3D8](https://www.shadertoy.com/view/4df3D8)
- [https://www.shadertoy.com/view/4XXSDN](https://www.shadertoy.com/view/4XXSDN)

For what it's worth, this notebook creates a cool video of the training process. 

*Note: I wrote a modified version of this notebook as part of the 3rd homework assignment for the Georgia Tech course CS8803/4803 CGA: Computer Graphics in the AI Era. You can find more information about the course here: [https://cgai-gatech.vercel.app/](https://cgai-gatech.vercel.app/). You can also find my corresponding guest lecture on 3D Gaussian Splatting on YouTube: [https://youtu.be/MBVmQSA24Yk](https://youtu.be/MBVmQSA24Yk)*    

## Imports

In [None]:
# Core
import numpy as np
import math
import torch
from torch import nn
import random

# Visu
from matplotlib.patches import Ellipse
from matplotlib import pyplot as plt

# file i/o
import pathlib
from pathlib import Path
import json

# For assembling video
import cv2
import glob

# Misc.
from typing import Union
from datetime import datetime
from tqdm import tqdm  # progress bar
TIME_FORMAT = "%Y-%m-%d %H:%M:%S"

In [None]:
#%matplotlib notebook
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Choose between available PyTorch backends. Use GPU if available.
DEVICE = torch.device('cpu')
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")  # Metal backend for Apple devices
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    
print(f"Using {DEVICE} backend (torch.device)")

# Helper classes and functions

You don't need to touch anything here. 
They are included here only to make the notebook self-contained, and include the following functionalities used in throughout the notebook:
- Domain2D: handles the 2D domain and its discretization
- Simplify handling input/output paths and filenames 
- Plotting helper functions
- Very simple Timer class
- Writing out parameters as JSON and assembling video

If you're curious, feel free to dig in, though.

In [None]:
#@title Domain2D
class Domain2D:
    """
    Helper function for handling a 2D domain, discretized on a grid of given resolution.
    """
    x_dim: tuple[float]
    y_dim: tuple[float]
    res_x: float
    res_y: float
    xx: torch.Tensor
    yy: torch.Tensor

    def __init__(self, x_dim, y_dim, res_x, res_y):
        self.x_dim, self.y_dim = x_dim, y_dim
        self.res_x, self.res_y = res_x, res_y

        self.xx, self.yy = torch.meshgrid(
            torch.linspace(self.x_dim[0], self.x_dim[1], self.res_x, device=DEVICE),
            torch.linspace(self.y_dim[0], self.y_dim[1], self.res_y, device=DEVICE),
            indexing='xy'
        )

    def __str__(self):
        return (f"Domain2D: {self.x_dim}x{self.y_dim} "
                f"discretized on a {self.res_x}x{self.res_y} grid.")
    
    def get_extent(self):
        extent = (*self.x_dim, *self.y_dim)
        return extent

In [None]:
#@title Files and folders

CURRENT_SCENE = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
DATA_FOLDER = pathlib.Path("data")
CURRENT_SCENE_FOLDER = DATA_FOLDER / pathlib.Path(CURRENT_SCENE)
LAST_FILE_NAME = 0

def save_fig_to_file(fig):
    # Save plot to file as {curr_date}/{epoch}.jpg
    global LAST_FILE_NAME
    curr_image_file = CURRENT_SCENE_FOLDER / f"{LAST_FILE_NAME}.jpg"
    fig.savefig(curr_image_file.absolute())
    
    LAST_FILE_NAME += 1
    
def initialize_file_names():
    """
    Update current folder name and reset file names.
    We might want to call this when a new optimization run starts.
    """
    global LAST_FILE_NAME, CURRENT_SCENE, CURRENT_SCENE_FOLDER
    
    CURRENT_SCENE = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
    CURRENT_SCENE_FOLDER = DATA_FOLDER / pathlib.Path(CURRENT_SCENE)
    
    CURRENT_SCENE_FOLDER.mkdir(parents=True, exist_ok=True)
    LAST_FILE_NAME = 0
    
initialize_file_names()

In [None]:
#@title Plotting

def plot_tensor_image(
        image,
        title: str = "",
        _plt=plt,
        extent=None,
):
    if isinstance(image, torch.Tensor):
        image = image.cpu().detach().numpy()

    if image.shape[-1] == 3:
        image = np.clip(image, 0, 1)
        # plot (R,G,B) image
        #  Clip image colors
        
    _plt.imshow(
        image,
        extent=extent,  # (left, right, bottom, top)
        origin="lower"
    )

    if _plt == plt:
        _plt.title(title)
    else:
        _plt.title.set_text(title)


def draw_outline(params, _plt=plt):
    centers, sigmas, thetas = params["centers"], params["sigmas"], params["thetas"]
    # Convert to basic numpy array if necessary,
    # bringing tensors to the cpu and detach from computing graph
    if isinstance(centers, torch.Tensor):
        centers = centers.cpu().detach().numpy()
    if isinstance(sigmas, torch.Tensor):
        sigmas = sigmas.cpu().detach().numpy()
    if isinstance(thetas, torch.Tensor):
        thetas = thetas.cpu().detach().numpy()

    if not isinstance(_plt, plt.Axes):
        _plt = _plt.gca()

    # plot centers as red dots
    _plt.scatter(x=centers[:, 0], y=centers[:, 1], c='r', s=1)

    # draw ellipses
    for i, c in enumerate(centers):
        _plt.add_patch(
            Ellipse(
                (c[0], c[1]),
                width=sigmas[i][0] * 3.5,
                height=sigmas[i][1] * 3.5,
                angle=thetas[i] * (180.0 / np.pi),
                edgecolor='red',
                facecolor='none',
                linewidth=1,
                alpha=1.0
            )
        )


def assemble_plot_data(data, outline_params):
    """
    Make the data ready for plotting with the `plot` function.
    Calculates the number of rows and columns to be plotted based on the supplied data.

    :param data: image data, or iterable of image data
    :param outline_params: gaussian parameters used for plotting an overlay over the reconstructed image
    :return: number of columns (int), rows (int), assembled data (np.ndarray) and outline parameters (list)
    """

    def is_iterable(x):
        return isinstance(x, list) or isinstance(x, tuple)

    def is_2d_iterable(x):
        return isinstance(x[0], list) or isinstance(x[0], tuple)

    def get_numpy_data(d):
        # Takes a single piece of plottable data, and makes it a numpy array
        if isinstance(d, torch.Tensor):
            d = d.cpu().detach().numpy()
        d = np.array(d)
        return d

    # Create a 2D array of plottable data [[row_1_1, row_1_2 ...],[row_2_1, row_2_2, ...], ...]
    # Each piece of data is uniformly converted to a numpy array
    new_data = []
    if is_iterable(data):
        if is_2d_iterable(data):
            for i in range(len(data)):
                curr_row = []
                for j in range(len(data[i])):
                    curr_data = get_numpy_data(data[i][j])
                    curr_row.append(curr_data)
                new_data.append(curr_row)
        else:
            single_row = []
            for i in range(len(data)):
                curr_data = get_numpy_data(data[i])
                single_row.append(curr_data)
            new_data.append(single_row)
    else:
        # Single piece of data, but we still create a 2D array
        new_data.append([get_numpy_data(data)])

    data = new_data

    # Calculate number of rows and columns in the plot
    nrows = len(data)
    ncols = len(data[0])

    # Handle outline_params
    if isinstance(outline_params, dict):
        # If plotting only a single piece of data
        assert ncols == nrows == 1, "Non-list outline params is only allowed for plotting a single data."
        outline_params = [[outline_params]]
    if outline_params is not None:
        if not all(p is None for p in outline_params):
            if not isinstance(outline_params[0], list):
                # If not already a 2D array, then
                # wrap outline_params to be a 2D array for [row][col] indexing
                outline_params = [outline_params]
        else:
            outline_params = None

    # We could assert that outline_params should be None, or having the same shape as the data

    return ncols, nrows, data, outline_params


def assemble_titles(title, ncols, nrows):
    """
    :param title: string or list of strings
    :param ncols: number of columns in the plot
    :param nrows: number of rows in the plot
    :return: a 2D list of titles corresponding to a (ncols, nrows) plot.
    """
    if isinstance(title, (tuple, list, dict)):
        if isinstance(title[0], (tuple, list, dict)):
            # Title is a 2D array of titles for each subplot, individually
            assert len(title) == nrows and len(title[0]) == ncols
        else:
            # Wrap 1D list of titles to be a 2D array for a single row
            title = [title]
    else:
        # Same title for each subplot
        title = [[title] * ncols] * nrows

    return title


def plot(
        data: Union[torch.Tensor, list, tuple, np.ndarray],
        title: Union[str, list] = "",
        figsize=(16, 6),
        extent=None,
        domain: Domain2D = None,  # if domain is not None, then it overwrite extent, xx and yy
        outline_params=None,
        show_plot=True,
):
    """
    Main plotting function.
    """
    ncols, nrows, data, outline_params = assemble_plot_data(data, outline_params)
    titles = assemble_titles(title, ncols, nrows)

    # Shape of axs, and existence of dimensions is dependent on number of rows and columns. If figure is (1,1)
    # or either ncols or nrows is 1, then at least 1 of the dimensions will be missing from the axs list.
    fig, axs = plt.subplots(nrows, ncols, figsize=figsize)

    # Make sure that axs is a 2D list of all subplot axes which can be indexed as `axs[row_i][col_j]`.
    if nrows == ncols == 1:
        axs = [[axs]]
    elif nrows == 1:
        axs = [axs]  # axs is a 1D list
    elif ncols == 1:
        axs = [[axs[i]] for i in range(len(axs))]  # axs is a 1D list, but we have to reshape it

    # Set extent, xx and yy from domain if it was supplied
    if domain is not None:
        extent = domain.get_extent()

    for i in range(nrows):
        for j in range(ncols):
            curr_data = data[i][j]
            curr_ax = axs[i][j]
            curr_title = titles[i][j]

            if len(curr_data.shape) > 1:
                # Plot scalar field or image data
                plot_tensor_image(
                    image=curr_data,
                    title=curr_title,
                    _plt=curr_ax,
                    extent=extent,
                )
                if outline_params is not None and outline_params[i][j] is not None:
                    draw_outline(params=outline_params[i][j], _plt=curr_ax)
                curr_ax.set_aspect('equal')
            elif len(curr_data.shape) == 1:
                # Plot 1D data, e.g. loss curve
                curr_ax.plot(curr_data)
                curr_ax.set_xticks(range(0, len(curr_data), math.ceil(len(curr_data) / 8)))
                curr_ax.set_title('Learning curve')
                curr_ax.set_xlabel('Epoch')
                curr_ax.set_ylabel('Loss')

    if show_plot:
        plt.show()

    return fig

In [None]:
#@title Timer
class Timer:
    """
    Super simple utility class for displaying elapsed time.
    """

    def __init__(self):
        self.init_time = datetime.now()
        # Step 1: Print the current time
        print("Current time:", self.init_time.strftime(TIME_FORMAT))

    def print_time(self, reset_time = False):
        # Print the time and the time elapsed since init_time
        current_time = datetime.now()
        elapsed_time = current_time - self.init_time

        print("Current time:", current_time.strftime(TIME_FORMAT))
        print("Time elapsed:", str(elapsed_time))

        # Optionally overwrite initial time
        if reset_time:
            self.init_time = current_time

    def get_elapsed_time(self) -> str:
        delta_time = datetime.now() - self.init_time
        return "{:02}:{:02}:{:02}".format(
                delta_time.seconds // 3600,
                (delta_time.seconds % 3600) // 60,
                (delta_time.seconds % 60) // 1,
                )

In [None]:
#@title Misc.: json i/o and video

def gaussians_from_json_file(filename):
    filepath = pathlib.Path(filename)
    filepath = CURRENT_SCENE_FOLDER / filepath

    json_data = dict()

    with open(filepath, 'r') as file:
        data = json.load(file)
        # TODO check for invalid input data here

        json_data['gaussians'] = dict()
        json_data['gaussians']['alphas']  = data['alphas']
        json_data['gaussians']['centers'] = data['centers']
        json_data['gaussians']['sigmas']  = data['sigmas']
        json_data['gaussians']['thetas']  = data['thetas']

        json_data['N']     = data['N']
        json_data['dims']  = data['dims']
        json_data['x_dim'] = data['x_dim']
        json_data['y_dim'] = data['y_dim']
        json_data['res_x'] = data['res_x']
        json_data['res_y'] = data['res_y']

    print(f"Read gaussian scene from {filepath} with "
          f"N={json_data['N']}, dims={json_data['dims']}, gaussians#={len(json_data['gaussians'])}"
          f"#alphas={len(json_data['gaussians']['alphas'])}, "
          f"#centers={len(json_data['gaussians']['centers'])}, "
          f"#sigmas={len(json_data['gaussians']['sigmas'])}, "
          f"#thetas={len(json_data['gaussians']['thetas'])}")

    return json_data


def gaussians_to_json_file(filename, params, domain):
    """
    filename: relative to the current scene
    """
    filepath = CURRENT_SCENE_FOLDER / filename

    # Convert PyTorch tensors to Python floats (and lists of them)
    alphas_item = [[x.item() for x in a] for a in params['alphas']]
    centers_item = [[x.item() for x in c] for c in params['centers']]
    sigmas_item = [[x.item() for x in s] for s in params['sigmas']]
    # Can't array comprehend over 0D tensor
    thetas_item = [t.item() for t in params['thetas']]

    scene_data = {
        "x_dim": domain.x_dim,
        "y_dim": domain.y_dim,
        "res_x": domain.res_x,
        "res_y": domain.res_y,
        "N": params['alphas'].shape[0],
        "dims": params['alphas'].shape[-1],
        "alphas": alphas_item,
        "centers": centers_item,
        "sigmas": sigmas_item,
        "thetas": thetas_item
    }

    with open(filepath, 'w') as file:
        json.dump(scene_data, file, indent=2)

    print(f"2D Gaussian Scene written to {filepath}.")
    
    
def assemble_video(image_folder: str, output_file: str="video.mp4", fps=10):
    """
    image_folder: absolute path
    output_file: name of output file
    """
    image_files = sorted(glob.glob(f"{image_folder}/*.jpg"), key=lambda x: int(Path(x).stem))
    if not image_files:
        raise ValueError("No .jpg files found in the specified directory.")

    # Read the first image to get the dimensions
    image_0 = cv2.imread(image_files[0])
    height, width, layers = image_0.shape

    # Adjust FPS based on the number of images
    if len(image_files) > 200:
        fps = 30
    elif len(image_files) > 100:
        fps = 20

    # Create Video Writer with proper codec
    fourcc = cv2.VideoWriter_fourcc(*'X264')  # Using H.264 codec
    video = cv2.VideoWriter(output_file, fourcc, fps, (width, height))

    # Write frames to video
    for image_file in image_files:
        video.write(cv2.imread(image_file))

    video.release()
    print(f"Video saved as {output_file}")

## 2D Gaussian Functions

A 2D Gaussian is parameterized by:
- centers $c = [c_x, c_y]^T$
- scales $\sigma = [\sigma_x, \sigma_y]^T$
- scalar rotation $\theta$ (Note: we could also use a unit-length complex number in the spirit of using quaternions in 3D, but a scalar rotation is perfectly fine.) 

Its value at point $\bf{p}$ is given by
$$
f(p) = \text{exp}\left(
  -\frac{1}{2} 
  (p - c)^T 
  \Sigma^{-1}
  (p - c)
\right),
$$

where the covariance matrix $\Sigma$ describes the shape of the Gaussian, and we build it as

$$
\Sigma = 
      RSS^TR^T,
$$

making use of (1) $(AB)^{-1} = B^{-1} A^{-1}$, and (2) $R^{-1} = R^T$ for rotational matrices, we have
$$
\Sigma^{-1} = R(SS^{T})^{-1}R^T.
$$

If we want to write out our 2D Gaussian function explicitly, we have:

$$
f\left(\begin{bmatrix}
    p_x\\
    p_y
\end{bmatrix}\right) 
= \text{exp}\left(
    -\frac{1}{2} 
    \begin{bmatrix}
        p_x - c_x &
        p_y - c_y
    \end{bmatrix}
    \Sigma^{-1}
    \begin{bmatrix}
        p_x - c_x\\
        p_y - c_y
    \end{bmatrix}
\right),
$$ 

where
$$
    \Sigma^{-1} = 
      \begin{bmatrix}
        \cos \theta & -\sin\theta\\
        \sin \theta & \cos \theta
      \end{bmatrix}
      \begin{bmatrix}
        \frac{1}{\sigma_x^2} & 0 \\
        0 & \frac{1}{\sigma_y^2}
      \end{bmatrix}
      \begin{bmatrix}
        \cos \theta & \sin\theta\\
        -\sin \theta & \cos \theta
      \end{bmatrix}.
$$

To reconstruct a 2D image, we add together function of this form, but also multiply them by an "alpha" channel. In the original 3DGS implementation, they use a Gaussian-wise $\alpha \in \mathbb{R}$ and RGB color $[c_r, c_g, c_b] \in \mathbb{R}^3$ (converted from spherical harmonics in 3D, given our current view direction).

To simplify things here, we multiply these together, and store $\boldsymbol{\alpha} = \alpha \cdot [ c_r, c_g, c_b ]$ for each Gaussian.

Thus, we reconstruct each pixel $\textbf{p}$ of our 2D image $\textbf{I}$ by summing together to contribution from N Gaussians as 

$$
I(p) = \sum_{i=1}^{N} \boldsymbol{\alpha}_i f_i(p).
$$

This is implemented in `reconstruct_gaussian_2d`, with the two helper functions `build_sigma_invs` and `build_position_tensor`.

In [None]:
def build_sigma_invs(thetas: torch.Tensor, sigmas: torch.Tensor):
    # Construct batch of rotation matrices R
    cos_thetas = torch.cos(thetas)
    sin_thetas = torch.sin(thetas)
    R = torch.stack([
        torch.stack((cos_thetas, -sin_thetas), dim=-1),
        torch.stack((sin_thetas, cos_thetas), dim=-1)
    ], dim=-2)

    #  Construct inverse scaling matrix squared (inverse covariance matrix)
    #  N diagonal (scaling) matrices represented only by their diagonal part
    diag_inv_squared_sigmas = torch.diag_embed((1.0 / (sigmas ** 2)))

    #  R @ scaling_matrix @ R.T along all gaussians
    #  note: the transpose part could be incorporated into the einsum as ('bik,bkl,bkj->bij', R, S, R),
    #    but this feels more descriptive.
    sigma_invs = torch.einsum("bik,bkl,blj->bij", R, diag_inv_squared_sigmas, R.transpose(1, 2))

    return sigma_invs


def build_position_tensor(xx, yy, centers: torch.Tensor):
    # Match center and grid tensor dimensions for translation
    # Reshape centers: [N, 2] -> [N, 1, 1, 2]
    expanded_centers = centers.view(-1, 1, 1, 2)
    # Expand to match the grid shape [N, res_y, res_x, 2]
    expanded_centers = expanded_centers.expand(-1, *xx.shape, 2)

    # Create a combined grid of x and y coordinates
    # [res_x, res_y, 2] ('xy' indexing)
    xy_grid = torch.stack([xx, yy], dim=-1)
    # [1, res_y, res_x, 2]
    expanded_xy_grid = xy_grid.unsqueeze(0)
    # ([x,y]-[c_ix, c_iy]) for calculating the function values at the given x positions
    pos = expanded_xy_grid - expanded_centers
    
    return pos


def reconstruct_gaussian_2d(params: dict, domain: Domain2D):
    """
    Sample Gaussian functions on a 2D domain, parametrized by their colors (alphas),
    positions (centers), local scales (sigmas) and rotations (thetas). 
    """
    # We expect that these exist in the params dict, and that they are torch.Tensors
    alphas = params['alphas']
    centers = params['centers']
    sigmas = params['sigmas']
    thetas = params['thetas']
    
    sigma_invs = build_sigma_invs(thetas, sigmas)
    pos = build_position_tensor(domain.xx, domain.yy, centers)
    # (pos-c_i)^T @ sigma^{-1} @ (pos-c_i) for all coords and all gaussians
    exponent = torch.einsum('nxyj,nij,nxyi->nxy', pos, sigma_invs, pos)

    # Calculate the scalar Gaussian function f(x) for each Gaussian
    # [N, res_x, res_y]
    f_x = torch.exp(-0.5 * exponent)
    # Calculate linear combination with the alpha coefficients
    # return a value for each pixel (self.x, self.y) as a matrix of values.
    alpha_f_x = torch.einsum("na,nxy->xya", alphas, f_x)

    return alpha_f_x

*A note on batching: after implementing the batched version below, I tried out a simple Python for-loop, which seems to run comparably to the batched version on the GPUs I tested it on. And in some cases, the Python-based for loop ran faster than my batched version. This might have to do something with creating `N` copies of the full grid for each Gaussian, which is slightly wasteful.*

# Part 1: Optimize for Synthetic Ground Truth
## 1.1. Generate Synthetic Ground Truth

Let's test out the above functionalities by reconstructing a 2D Image from our Gaussians.

In [None]:
DOMAIN = Domain2D(
    x_dim=[-5.0, 5.0],  # virtual dimensions
    y_dim=[-5.0, 5.0],
    res_x=100,  # resolution
    res_y=100
)

# Define the parameters in a dictionary.
# Create torch tensors on the selected computing device (cpu/metal/cuda).
target_params = {
    # RGB colors
    'alphas': torch.tensor([
        [0.7, 0.1, 0.2], 
        [0.3, 0.8, 0.1]
    ], device=DEVICE),
    # X,Y coordinates
    'centers': torch.tensor([
        [-3.0, 1.5], 
        [1.5, 2.5]
    ], device=DEVICE),
    # Scales along local (X, Y)
    'sigmas': torch.tensor([
        [2.0, 1.5], 
        [1.5, 0.5]
    ], device=DEVICE),
    # Rotations
    'thetas': torch.tensor([
        0.2, 0.8
    ], device=DEVICE)
}

# Reconstruct the 2D Gaussians on a 2D grid.
# Note: the detach() function is needed, because we just want to generate some static 
#   synthetic target image data. If we didn't detach, then the computation that produced 
#   it would become part of PyTorch's computation graph when we calculate the loss later on. 
#   Or put simply: we just want a plain array of numbers.
target_image_synthetic = reconstruct_gaussian_2d(target_params, domain=DOMAIN).detach()

# Plot the target scene using the helper plotting function
plot(data=target_image_synthetic, title="Synthetic test image", domain=DOMAIN, figsize=(3,3));

## 1.2. Initial random Gaussian parameters for the optimization

Now we run a simple experiment: let's forget that the above "ground truth image" came from our predefined Gaussians. Can we find a set of Gaussian parameters that describe the same image?

Let's generate a random set of Gaussian parameters, and look at what that gives us.

In [None]:
N = 20  # Number of Gaussians

# Collect the parameters we want to optimize in a dictionary.
# Feel free to play around with the random initialization!
params_opt_synthetic = {
    # get random RGB color for N gaussian
    'alphas': nn.Parameter(torch.tensor(
        [[random.random() for _ in range(3)] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # positions: scatter N random gaussians in our domain
    'centers': nn.Parameter(torch.tensor(
        [[random.uniform(*DOMAIN.x_dim), random.uniform(*DOMAIN.y_dim)] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # Scales along local (X, Y). Start out with isotropic gaussian
    'sigmas': nn.Parameter(torch.tensor(
        [[1.5, 1.5] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # Rotations
    'thetas': nn.Parameter(torch.tensor(
        [0.0]*N, # or for random rotation: np.random.rand() * 2 * np.pi
        requires_grad=True, device=DEVICE
    )),
}

In [None]:
# Plotting initial gaussians
initial_image = reconstruct_gaussian_2d(
    params_opt_synthetic, 
    domain=DOMAIN
)

plot(
    data=[target_image_synthetic, initial_image, initial_image],
    title=["Target Image", "Initial Gaussians", ""],
    outline_params=[None, None, params_opt_synthetic],
    extent=DOMAIN.get_extent(),
    figsize=(7,10)
);

## 1.3. Optimization
### 1.3.1 Initialize the optimizer

Now let's tweak our initial Gaussians. We would like the image reconstructed from them to be the same as our ground truth image $\textbf{I}$. Calling our set of Gaussian parameters $\xi$ (containing alphas, centers, sigmas, and thetas), we can quantify this in a scalar-valued loss function:
 
$$
    \mathcal{L}(\textbf{I}, \xi) = 
    \sum_{\text{pixels \textbf{p}} \in \text{\textbf{I}}} 
        ||
            \boldsymbol{I}(\boldsymbol{p}) - \sum_{i}^N \boldsymbol{\alpha}_i f_i(\boldsymbol{p}; \xi_i)
        ||,
$$

where $f_i(\cdot;\xi_i)$ simply denotes that we are using the $i$th Gaussian parameters.

For this notebook, both $L_1$ and $L_2$ losses work fine. In the original 3DGS paper, they add together a D-SSIM loss with a weight of $0.2$ and an $L_1$ loss with a weight of $0.8$ to get their scalar loss.

To find Gaussian parameters $\xi$ that minimize this loss function, we use an Adam optimizer that iteratively steps towards the gradient of the loss w.r.t. the parameters.

In [None]:
params_keys = ['alphas', 'centers', 'sigmas', 'thetas']  # Note: params_opt.keys() is unordered.
params_dict_for_optimizer = [{'params': params_opt_synthetic[p], 'name': p} for p in params_keys]
# Note: this is the same as:
# params_dict_for_optimizer = [
#     {'params': params_opt_synthetic['alphas'], 'name': 'alphas'},
#     {'params': params_opt_synthetic['centers'], 'name': 'centers'},
#     {'params': params_opt_synthetic['sigmas'], 'name': 'sigmas'},
#     {'params': params_opt_synthetic['thetas'], 'name': 'thetas'}
# ]

# We can also set per-parameter-group learning rates beyond the default learning rate below. 
# We can use this for freezing a given parameter group (e.g. leave particles in the same position).
# For more details on setting up and using the optimizer: https://pytorch.org/docs/stable/optim.html
# lr_dict = {'alphas': 0.01, 'centers': 0.2, 'sigmas': 0.0, 'thetas': 0.0}
# for i, param_key in enumerate(params_keys):
#     if param_key in lr_dict:
#         params_dict_for_optimizer[i]['lr'] = lr_dict[param_key]


# Initial learning rate
lr = 0.1

# Initialize an Adam optimizer
optimizer = torch.optim.Adam(params_dict_for_optimizer, lr=lr)

# Define the loss function we want to use in the optimziation loop below
def loss_function(prediction, target):
    # Note: we could add SSIM loss here. For demo purposes, both L1 and L2 seems to be good enough.
    # See e.g.: https://github.com/VainF/pytorch-msssim
    loss = torch.nn.L1Loss()  # L1
    # loss = torch.nn.MSELoss()  # L2
    return loss(prediction, target)

### 1.3.2 Optimization loop

In [None]:
image_save_interval = 1  # Whether to same image
num_epochs = 200
display_plots = False  # Whether to plot in the notebook. Useful for not cluttering the notebook.
save_plots = True  # Whether to write out images into a folder.

if save_plots:
    # Set the output folder to be the current time, and restart naming of the file names
    initialize_file_names()
    print(f"Starting optimization. Outputting results into folder `{CURRENT_SCENE_FOLDER}`.")
else:
    print(f"Starting optimization without saving the results into file.")
    
timer = Timer()

# Keep track of lost history
loss_trajectory = []

epoch_progress_bar = tqdm(range(num_epochs))
for epoch in epoch_progress_bar:
    # Forward pass, i.e. combining Gaussians into an image
    curr_image = reconstruct_gaussian_2d(params_opt_synthetic, domain=DOMAIN)

    # Calculate and save current loss between output and target
    curr_loss = loss_function(curr_image, target_image_synthetic)
    loss_trajectory.append(curr_loss.detach().cpu().numpy())

    # Zero out gradients before running backward pass
    optimizer.zero_grad()

    # Compute gradient of the loss (w.r.t. gaussian parameters)
    curr_loss.backward()

    # Perform an optimization epoch (i.e. update gaussian parameters)
    optimizer.step()

    with torch.no_grad():
        # Note: Densification and Pruning could be done here by removing Gaussians
        #       that are too small/stretched/out of bounds/etc,
        #       and duplicating Gaussians based on their positional gradient.

        # Optionally display/save image
        if display_plots or save_plots:
            if epoch % image_save_interval == 0:
                fig = plot(
                    data=[curr_image, curr_image, target_image_synthetic, loss_trajectory],
                    outline_params=[None, params_opt_synthetic, None, None],
                    title=[
                        f"Optim at epoch {epoch}",
                        f"Gaussians (N={N})",
                        f"Target image",
                        f"Loss history"
                    ],
                    domain=DOMAIN,
                    show_plot=display_plots  # Optionally, don't clutter the notebook with showing the plot here.
                )
            if save_plots:
                # Save plot to file as {curr_date}/{epoch}.jpg
                save_fig_to_file(fig)
            plt.close(fig)  # close the current figure

        # Print out optimization details
        progress_text = (
            f"Step {epoch}, "
            f"Loss: {curr_loss.item()}, "
            f"{timer.get_elapsed_time()}, "
            f"N = {N}"
        )
        
        epoch_progress_bar.set_description(progress_text)

        # You can try experimenting with a learning rate scheduler.
        # In the simplest case, you can decrease the learning rate
        # at some predefined intervals.
        # if epoch % 800 == 0 and epoch > 0:
        #     for param_group in self.optimizer.param_groups:
        #         param_group['lr'] *= 0.5
        #         print(f"lr: {param_group['lr']}")

### 1.3.3 Plotting final result

In [None]:
curr_image = reconstruct_gaussian_2d(params_opt_synthetic, domain=DOMAIN)

# Plotting final result
plot(
    data=[curr_image, curr_image, target_image_synthetic],
    outline_params=[None, params_opt_synthetic, None],
    title=[
        f"Result of optimization",
        f"Gaussians (N={N})",
        f"Target image"
    ],
    domain=DOMAIN,
    figsize=(10,6)
)

print(f"Final params: {params_opt_synthetic}")

In [None]:
# Write out the final params into a json data file:
gaussians_to_json_file("optimized-params-synthetic.json", params=params_opt_synthetic, domain=DOMAIN)

In [None]:
# Note: you can ignore the "Corrupt JPEG data" warnings if you see any. The video should still render properly.
assemble_video(CURRENT_SCENE_FOLDER, str("anim-synthetic.mp4"))

# Part 2: Optimize for Target Image

In the second part of this notebook, we want to find Gaussian parameters that match an actual image. The process will be essentially the same as in Part 1, but you might need to tune some hyperparameters for the best result, such as the number of iterations, learning rate, number of Gaussians, etc. 

## 2.1 Load target image  

In [None]:
def get_input_image_as_torch_tensor(filename: str):
    full_filename = DATA_FOLDER / filename
    # Erase alpha channel ([:,:,0:3])
    image_np = np.array(plt.imread(full_filename)[:, :, 0:3])

    # transpose, and upside-down image to match 'xy' indexing of our meshgrid
    # image_np = np.transpose(image_np, (1, 0, 2))  # Swapping height and width
    image_np = np.flipud(image_np)  # Flipping upside down

    image_torch = torch.tensor(image_np.copy(), requires_grad=False, device=DEVICE)

    return image_torch

In [None]:
# Get the target input image as a torch tensor (requires_grad = False)
# Expected in the `data` folder.
file_name = "mona-lisa.png"
target_image: torch.Tensor = get_input_image_as_torch_tensor(file_name)

# Plot
plot(target_image, "Target Image", (4, 3))

# Same as target image pixel dimensions
RES_X = target_image.shape[1]
RES_Y = target_image.shape[0]
x_to_y_ratio = float(RES_X) / float(RES_Y)

X_DIM = (-5.0, 5.0)
Y_DIM = (X_DIM[0]/x_to_y_ratio, X_DIM[1]/x_to_y_ratio)

DOMAIN = Domain2D(x_dim=X_DIM, y_dim=Y_DIM, res_x=RES_X, res_y=RES_Y)

## 2.2. Initial random Gaussian parameters for the optimization

In [None]:
N = 100  # Number of Gaussians

# Collect the parameters we want to optimize in a dictionary
params_opt_image = {
    # get random RGB color for N gaussian
    'alphas': nn.Parameter(torch.tensor(
        [[random.random() for _ in range(3)] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # positions: scatter N random gaussians in our domain
    'centers': nn.Parameter(torch.tensor(
        [[random.uniform(*DOMAIN.x_dim), random.uniform(*DOMAIN.y_dim)] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # Scales along local (X, Y). Start out with isotropic gaussian
    'sigmas': nn.Parameter(torch.tensor(
        [[0.3, 0.3] for _ in range(N)], 
        requires_grad=True, device=DEVICE
    )),
    # Rotations
    'thetas': nn.Parameter(torch.tensor(
        [0.0]*N, # or for random rotation: np.random.rand() * 2 * np.pi
        requires_grad=True, device=DEVICE
    )),
}

In [None]:
# Plotting initial gaussians
initial_image = reconstruct_gaussian_2d(
    params_opt_image, 
    domain=DOMAIN
)

plot(
    data=[target_image, initial_image, initial_image],
    title=["Target Image", "Initial Gaussians", ""],
    outline_params=[None, None, params_opt_image],
    extent=DOMAIN.get_extent(),
    figsize=(7,10)
);

## 2.3. Optimization

### 2.3.1 Initialize the optimizer

In [None]:
params_keys = ['alphas', 'centers', 'sigmas', 'thetas']  # Note: params_opt.keys() is unordered.
params_dict_for_optimizer = [{'params': params_opt_image[p], 'name': p} for p in params_keys]
# Note: this is the same as:
# params_dict_for_optimizer = [
#     {'params': params_opt_image['alphas'], 'name': 'alphas'},
#     {'params': params_opt_image['centers'], 'name': 'centers'},
#     {'params': params_opt_image['sigmas'], 'name': 'sigmas'},
#     {'params': params_opt_image['thetas'], 'name': 'thetas'}
# ]

# We can also set per-parameter-group learning rates beyond the default learning rate below. 
# We can use this for freezing a given parameter group (e.g. leave particles in the same position).
# For more details on setting up and using the optimizer: https://pytorch.org/docs/stable/optim.html
# lr_dict = {'alphas': 0.01, 'centers': 0.2, 'sigmas': 0.0, 'thetas': 0.0}
# for i, param_key in enumerate(params_keys):
#     if param_key in lr_dict:
#         params_dict_for_optimizer[i]['lr'] = lr_dict[param_key]


# Initial learning rate (default for all parameters unless a per-parameter lr is defined -- see above)
lr = 0.1

# Initialize an Adam optimizer
optimizer = torch.optim.Adam(params_dict_for_optimizer, lr=lr)

def loss_function(prediction, target):
    # Note: we could add SSIM loss here. For demo purposes, both L1 and L2 seems to be good enough.
    # See e.g.: https://github.com/VainF/pytorch-msssim
    loss = torch.nn.L1Loss()  # L1
    # loss = torch.nn.MSELoss()  # L2
    return loss(prediction, target)

### 2.3.2. Optimization loop

In [None]:
# Intervals during optimization
image_save_interval = 1
num_epochs = 50  # number of epochs to optimize for
display_plots = False  # For not cluttering the notebook
save_plots = True  # Writing out images into a folder

if save_plots:
    # Set the output folder to be the current time, and restart naming of the file names
    initialize_file_names()
    print(f"Starting optimization. Outputting results into folder `{CURRENT_SCENE_FOLDER}`.")
else:
    print(f"Starting optimization without saving the results into file.")
    
timer = Timer()

# Keep track of lost history
loss_trajectory = []

epoch_progress_bar = tqdm(range(num_epochs))
for epoch in epoch_progress_bar:
    # Forward pass, i.e. combining Gaussians into an image
    curr_image = reconstruct_gaussian_2d(params_opt_image, domain=DOMAIN)

    # Calculate and save current loss between output and target
    curr_loss = loss_function(curr_image, target_image)
    loss_trajectory.append(curr_loss.detach().cpu().numpy())

    # Zero out gradients before running backward pass
    optimizer.zero_grad()

    # Compute gradient of the loss (w.r.t. gaussian parameters)
    curr_loss.backward()

    # Perform an optimization epoch (i.e. update gaussian parameters)
    optimizer.step()

    with torch.no_grad():
        # Optionally display/save image
        if display_plots or save_plots:
            if epoch % image_save_interval == 0:
                fig = plot(
                    data=[curr_image, curr_image, target_image, loss_trajectory],
                    outline_params=[None, params_opt_image, None, None],
                    title=[
                        f"Optim at epoch {epoch}",
                        f"Gaussians (N={N})",
                        f"Target image",
                        f"Loss history"
                    ],
                    domain=DOMAIN,
                    show_plot=display_plots  # Optionally, don't clutter the notebook with showing the plot here.
                )
            if save_plots:
                # Save plot to file as {curr_date}/{epoch}.jpg
                save_fig_to_file(fig)
            plt.close(fig)  # close the current figure

        # Print out optimization details
        progress_text = (
            f"Step {epoch}, "
            f"Loss: {curr_loss.item()}, "
            f"{timer.get_elapsed_time()}, "
            f"N = {N}"
        )
        
        epoch_progress_bar.set_description(progress_text)

        # You can experiment with changing the learning rate here
        # if epoch % 800 == 0 and epoch > 0:
        #     for param_group in self.optimizer.param_groups:
        #         param_group['lr'] *= 0.5
        #         print(f"lr: {param_group['lr']}")

### 2.3.3. Plotting final result

In [None]:
curr_image = reconstruct_gaussian_2d(params_opt_image, domain=DOMAIN)

# Plotting final result
plot(
    data=[curr_image, curr_image, target_image],
    outline_params=[None, params_opt_image, None],
    title=[
        f"Result of optimization",
        f"Gaussians (N={N})",
        f"Target image"
    ],
    domain=DOMAIN,
    figsize=(10,6)
)

print(f"Final params: {params_opt_image}")

In [None]:
# Write out the final params into a json data file:
gaussians_to_json_file("optimized-params-image.json", params=params_opt_image, domain=DOMAIN)

In [None]:
assemble_video(CURRENT_SCENE_FOLDER, str("anim-image.mp4"))