## Define imports

In [1]:
import torch as to
import torch.nn as nn
import torch.optim as optim
import numpy as np

from plyfile import PlyData

import pandas as pd

device = "cuda"

## Utility functions

In [2]:
def broadcast(gauss_batch, ray_batch):
    
    # Split up gauss_batch
    means, covariances, opacities, normals, reference_normals = gauss_batch
    # Split up ray_batch
    ray_oris, ray_dirs = ray_batch
    
    R = ray_oris.shape[0]
    G = means.shape[0]

    bcast_ray_oris = ray_oris.unsqueeze(1)
    # (N_rays, 1, 3)
    bcast_ray_dirs = ray_dirs.unsqueeze(1)
    # (1, N_gaussians, 3)
    bcast_means = means.unsqueeze(0)
    # (1, N_gaussians, 3)
    bcast_covariances = covariances.unsqueeze(0)
    # (1, N_gaussians)
    bcast_opacities = opacities.unsqueeze(0)
    # (1, N_gaussians, 3)
    bcast_normals = normals.unsqueeze(0)
    # (1, N_gaussians, 3)
    bcast_reference_normals = reference_normals.unsqueeze(0)

    return (
        bcast_means,
        bcast_covariances,
        bcast_opacities,
        bcast_normals,
        bcast_reference_normals,
        bcast_ray_oris,
        bcast_ray_dirs
    )

In [3]:
def generate_sphere_rays(center, radius, n):
    # Generate random angles for spherical coordinates
    theta = to.rand(n, 1) * 2 * to.pi  # Azimuthal angle
    phi = to.rand(n, 1) * to.pi        # Polar angle

    # Spherical to Cartesian conversion
    x = radius * to.sin(phi) * to.cos(theta)
    y = radius * to.sin(phi) * to.sin(theta)
    z = radius * to.cos(phi)

    # Combine into ray origins
    ray_oris = to.hstack((x, y, z))

    # Ray directions pointing outward from the center
    ray_dirs = ray_oris - center
    # Normalise ray dirs
    ray_dirs = ray_dirs / to.linalg.norm(ray_dirs)

    return ray_oris, ray_dirs

In [4]:
def compute_pairwise_great_circle(points, radius=1.0):
    # Normalize points to lie on the unit sphere
    points_normalized = points / points.norm(dim=1, keepdim=True)
    # Compute the pairwise dot product; for unit vectors, this equals cos(theta)
    dot_prod = to.mm(points_normalized, points_normalized.t())
    # Clamp to ensure numerical stability
    dot_prod = to.clamp(dot_prod, -1.0, 1.0)
    # Compute the great circle distance (angle in radians)
    distances = to.acos(dot_prod)
    # Scale by the sphere's radius if needed
    return distances * radius

In [5]:
def compute_graph_laplacian(points, sigma, radius=1.0):
    # Compute pairwise great circle distances
    distances = compute_pairwise_great_circle(points, radius)
    # Create weight matrix using a Gaussian kernel
    W = to.exp(-distances**2 / (2 * sigma**2))
    # Optionally, remove self-loops by zeroing out the diagonal
    W.fill_diagonal_(0)
    # Compute degree matrix
    D = to.diag(W.sum(dim=1))
    # Graph Laplacian: L = D - W
    L = D - W
    return L

## Load data and process Gaussian model.

In [6]:
def quaternion_to_rotation_matrix(quaternions):
    x = quaternions[:, 1]
    y = quaternions[:, 2]
    z = quaternions[:, 3]
    w = quaternions[:, 0]

    xx = x * x
    yy = y * y
    zz = z * z
    xy = x * y
    xz = x * z
    yz = y * z
    xw = x * w
    yw = y * w
    zw = z * w

    n = quaternions.shape[0]
    R = to.empty((n, 3, 3), dtype=quaternions.dtype)

    R[:, 0, 0] = 1 - 2 * (yy + zz)
    R[:, 0, 1] = 2 * (xy - zw)
    R[:, 0, 2] = 2 * (xz + yw)
    R[:, 1, 0] = 2 * (xy + zw)
    R[:, 1, 1] = 1 - 2 * (xx + zz)
    R[:, 1, 2] = 2 * (yz - xw)
    R[:, 2, 0] = 2 * (xz - yw)
    R[:, 2, 1] = 2 * (yz + xw)
    R[:, 2, 2] = 1 - 2 * (xx + yy)

    return R

class GaussianModel:    
    def __init__(self, path):
        # Load in data
        plyfile = PlyData.read(path)
        plydata = plyfile['vertex'].data
        # Covert data into tensors
        df = pd.DataFrame(plydata)
        means_mask = ["x", "y", "z"]
        quaternions_mask = ["rot_0", "rot_1", "rot_2", "rot_3"]
        scales_mask = ["scale_0", "scale_1", "scale_2"]
        opacities_mask = ["opacity"]

        self.means = to.tensor(df[means_mask].values)
        self.quaternions = to.tensor(df[quaternions_mask].values)
        self.scales = to.tensor(df[scales_mask].values)
        self.opacities = to.tensor(df[opacities_mask].values)
        
        # Set base data
        self.n_gaussians = plydata.shape[0]

        self.means = to.tensor(df[means_mask].values)
        self.quaternions = to.tensor(df[quaternions_mask].values)
        self.scales = to.tensor(df[scales_mask].values)
        self.opacities = to.tensor(df[opacities_mask].values)
        
        # Activate opacities
        self.opacities = 1 / (1 + to.exp(-self.opacities))
        # Derive rotation matrix
        self.normalised_quaternions = self.quaternions / to.linalg.norm(self.quaternions)
        self.rotations = quaternion_to_rotation_matrix(self.normalised_quaternions)
        # Derive scale matrix
        self.scales_exp = to.exp(self.scales)
        self.scales_d = to.eye(3)[None, :, :] * (self.scales_exp)[:, :, None]
        self.scales_d **= 2
        self.scales_i_d = to.eye(3)[None, :, :] * (1/self.scales_exp)[:, :, None]
        self.scales_i_d **= 2
        # Derive covariance matrix
        self.rotations_t = self.rotations.transpose(-1,-2)
        self.scales_d_t = self.scales_d.transpose(-1,-2)
        self.covariances = self.rotations @ self.scales_d @ self.rotations_t
        # Derive the normals (use the centroid to flip normals correctly.)
        min_indices = self.scales_exp.argmin(axis=1)
        self.normals = self.rotations[to.arange(self.n_gaussians), :, min_indices]
        self.normals = self.normals / to.linalg.norm(self.normals)
        centroid = self.means.mean(dim=0)
        vectors_to_centroid = centroid - self.means
        dot_products = (vectors_to_centroid * self.normals).sum(dim=1)
        flip_mask = dot_products < 0
        self.normals[flip_mask] = -self.normals[flip_mask]
        self.reference_normals = self.normals

# Projection shit

In [7]:
def evaluate_points(points, gaussian_means, gaussian_inv_covs, gaussian_opacities):
    distance_to_mean = points - gaussian_means
    exponent = -0.5 * (distance_to_mean[:,:,None,:] @ gaussian_inv_covs @ distance_to_mean[..., None])
    evaluations = gaussian_opacities * to.exp(exponent).squeeze(-1)
    return evaluations

In [8]:
def skew_symmetric(v):
    row1 = to.stack([to.zeros_like(v[..., 0]), -v[..., 2], v[..., 1]], dim=-1)
    row2 = to.stack([v[..., 2], to.zeros_like(v[..., 1]), -v[..., 0]], dim=-1)
    row3 = to.stack([-v[..., 1], v[..., 0], to.zeros_like(v[..., 2])], dim=-1)
    K = to.stack([row1, row2, row3], dim=-2)
    return K

In [9]:
def normals_to_rot_matrix(a, b):
    # Given 2 RxNx3 vectors a and b, return an RxNx3x3 rotation matrix
    a_dot_b = (a[:,:,None,:] @ b[..., None]).squeeze(-1).squeeze(-1)
    a_norm = to.linalg.norm(a)
    b_norm = to.linalg.norm(b,dim=2)
    angle = to.acos((a_dot_b / (a_norm * b_norm)))
    v = to.cross(a,b)
    s = to.norm(v,dim=2) * to.sin(angle)
    c = a_dot_b * to.cos(angle) 
    i = to.eye(3).tile(a.shape[0],a.shape[1],1,1)
    v_skew = skew_symmetric(v)
    last_term = 1 / (1 + c)
    return i + v_skew + (v_skew @ v_skew) * last_term[...,None,None]

In [10]:
def get_max_responses_and_tvals(
    ray_oris, 
    means, 
    covs,
    ray_dirs, 
    opacities,
    normals,
    old_normals
):
    new_rotations = normals_to_rot_matrix(old_normals, normals)
    new_covs = new_rotations @ covs @ new_rotations.transpose(-2,-1)
    inv_covs = to.linalg.inv(new_covs)
    rg_diff = means - ray_oris
    inv_cov_d = inv_covs @ ray_dirs[..., None]
    numerator = (rg_diff[:,:,None,:] @ inv_cov_d).squeeze(-1)
    denomenator = (ray_dirs[:,:,None,:] @ inv_cov_d).squeeze(-1)
    t_values = numerator / denomenator
    best_positions = (ray_oris + t_values * ray_dirs)
    max_responses = evaluate_points(best_positions, means, inv_covs, opacities)

    return max_responses, t_values

In [11]:
from torch.utils.data import TensorDataset, DataLoader

class GaussianParameters(nn.Module):
    def __init__(self, path):
        super(GaussianParameters, self).__init__()
        self.gaussian_model = GaussianModel(path)
        self.means = nn.Parameter(self.gaussian_model.means)
        self.normals = nn.Parameter(self.gaussian_model.normals)
        ray_oris, ray_dirs = generate_sphere_rays(to.tensor([0.0,0.0,0.0]), 10,100)
        self.ray_oris = ray_oris
        self.ray_dirs = ray_dirs
        self.laplacian = compute_graph_laplacian(ray_oris, 1, 10)

    def forward(self):
        return self.means, self.normals

    def project(self):
        gaussian_dataset = TensorDataset(
            self.means,
            self.gaussian_model.covariances,
            self.gaussian_model.opacities,
            self.normals,
            self.gaussian_model.reference_normals
        )
        rays_dataset = TensorDataset(
            self.ray_oris,
            self.ray_dirs
        )

        gaussian_generator = DataLoader(gaussian_dataset, batch_size=10000)
        ray_generator = DataLoader(rays_dataset, batch_size=10000)

        self.contributions = []
        self.alphas = []
        self.tvals = []
        self.values = []
        for ray_batch in ray_generator:
            self.alphas_along_ray = []
            self.tvals_along_ray = []
            for gauss_batch in gaussian_generator:
                (bcast_means,
                bcast_covariances,
                bcast_opacities,
                bcast_normals,
                bcast_reference_normals,
                bcast_ray_oris,
                bcast_ray_dirs) = broadcast(gauss_batch, ray_batch)

                batch_alphas, batch_tvals = get_max_responses_and_tvals(
                    bcast_ray_oris, 
                    bcast_means, 
                    bcast_covariances, 
                    bcast_ray_dirs, 
                    bcast_opacities,
                    bcast_normals,
                    bcast_reference_normals
                )

                
                self.alphas_along_ray.append(batch_alphas)
                self.tvals_along_ray.append(batch_tvals)
                 
            self.alphas_along_ray = to.cat(self.alphas_along_ray, dim=1)
            self.tvals_along_ray = to.cat(self.tvals_along_ray, dim=1)
            self.alphas.append(self.alphas_along_ray)
            self.tvals.append(self.tvals_along_ray)

            _, sorted_idx = to.sort(self.tvals_along_ray, dim=1)
            sorted_alphas = self.alphas_along_ray.gather(dim=1, index=sorted_idx)
            alphas_compliment = 1 - sorted_alphas        
            transmittance = to.cumprod(alphas_compliment, dim=1)
            shifted = to.ones_like(transmittance)
            # Fill shifted starting from the second column with the values of x's columns 0 to N-2
            shifted[:, 1:] = transmittance[:, :-1]
            # Calculate contribution 
            sorted_contribution = shifted - transmittance
            # Normalise
            norm_factor = to.sum(sorted_contribution, dim=1)[...,None]
            sorted_contribution = sorted_contribution / norm_factor
            # unsort the contribution
            inv_idx = sorted_idx.argsort(dim=1)
            # Reorder contribution back to the original order:
            contribution = sorted_contribution.gather(dim=1, index=inv_idx)
            self.contributions.append(contribution)
            self.values.append(to.sum(contribution * self.alphas_along_ray, dim=1))
        self.values = to.cat(self.values, dim=1)
        self.alphas = to.cat(self.alphas, dim=1)
        self.tvals = to.cat(self.tvals, dim=1)
        self.contributions = to.cat(self.contributions, dim=1)
        self.blended_tvals = to.sum(self.contributions * self.tvals, dim=1)
        return self.blended_tvals
    def harmonic_loss(self):
        projected_values = self.project() 
        loss = projected_values.T @ self.laplacian @ projected_values
        return loss

In [36]:
class RenderContext():
    def __init__(self, width=800, height=600, point_size=0.02):
        # Set up scene, camera, and renderer
        self.width = width
        self.height = height
        self.point_size = point_size
        self.first_update = False
        
        # Initialize scene
        self.scene = Scene(background="#f0f0f0")
        
        # Set up camera
        self.camera = PerspectiveCamera(
            position=[5, 5, 5],
            up=[0, 1, 0],
            aspect=width/height,
            fov=50
        )
        
        # Create renderer
        self.controls = OrbitControls(controlling=self.camera)
        self.renderer = Renderer(
            camera=self.camera,
            scene=self.scene,
            controls=[self.controls],
            width=width,
            height=height,
            antialias=True
        )
        
        # Add lighting
        self.scene.add(AmbientLight(intensity=0.6))
        self.scene.add(DirectionalLight(position=[1, 1, 1], intensity=0.4))
        self.scene.add(DirectionalLight(position=[-1, -1, -1], intensity=0.4))
        
        # Add coordinate axes for reference
        axesHelper = AxesHelper(size=1)
        self.scene.add(axesHelper)
        
        # Create grid for reference
        gridHelper = GridHelper(size=10, divisions=10)
        gridHelper.position = [0, -0.01, 0]  # Slightly below the origin
        self.scene.add(gridHelper)
        
        # Initialize point cloud with dummy data (just one point at origin)
        positions = np.array([[0, 0, 0]], dtype=np.float32)
        colors = np.array([[0.5, 0.5, 0.5]], dtype=np.float32)
        
        # Setup for the point cloud
        self.positions_attr = BufferAttribute(array=positions, normalized=False)
        self.colors_attr = BufferAttribute(array=colors, normalized=False)
        
        point_geo = BufferGeometry(
            attributes={
                'position': self.positions_attr,
                'color': self.colors_attr
            }
        )
        
        point_mat = PointsMaterial(
            vertexColors='VertexColors',
            size=self.point_size,
            sizeAttenuation=True
        )
        
        self.points = Points(geometry=point_geo, material=point_mat)
        self.scene.add(self.points)
        
        # Create a colorbar widget
        self.colorbar_widget = self._create_colorbar_widget()
        
        # Package everything together
        self.canvas = widgets.VBox([
            self.renderer,
            self.colorbar_widget
        ])
    
    def _create_colorbar_widget(self):
        """Create a simple colorbar widget using HTML"""
        # Create a gradient for the colorbar using HTML
        gradient_html = """
        <div style="
            width: 100%; 
            height: 20px; 
            background: linear-gradient(to right, 
                #0d0887, #41049d, #6a00a8, #8f0da4, #b12a90, 
                #cc4778, #e16462, #f2844b, #fca636, #fcce25, #f0f921
            );
            border-radius: 3px;
            margin-top: 5px;
        "></div>
        <div style="
            display: flex;
            justify-content: space-between;
            font-family: Arial;
            font-size: 12px;
            margin-top: 2px;
        ">
            <span>Min</span>
            <span>Max</span>
        </div>
        """
        
        return widgets.HTML(value=gradient_html)
    
    def update(self, positions, values, colormap='plasma'):
        """
        Update the point cloud visualization
        
        Parameters:
        -----------
        positions : numpy.ndarray or torch.Tensor
            Array of 3D positions with shape (N, 3)
        values : numpy.ndarray or torch.Tensor
            Array of scalar values for coloring with shape (N,) or (N, 1)
        colormap : str
            Name of the matplotlib colormap to use
        """
        # Convert to numpy if tensor
        if isinstance(positions, to.Tensor):
            positions = positions.detach().cpu().numpy()
        if isinstance(values, to.Tensor):
            values = values.detach().cpu().numpy()
        
        # Safety check
        if len(positions) == 0:
            return
        
        # Ensure positions are the right shape
        if positions.ndim != 2 or positions.shape[1] != 3:
            raise ValueError(f"Positions must have shape (N, 3), got {positions.shape}")
        
        # Ensure values are the right shape (flatten if needed)
        if values.ndim > 1:
            values = values.flatten()
        
        # Ensure values length matches positions
        if len(values) != len(positions):
            raise ValueError(f"Values length {len(values)} must match positions length {len(positions)}")
            
        # Normalize values for coloring
        min_val = np.min(values)
        max_val = np.max(values)
        norm = mcolors.Normalize(vmin=min_val, vmax=max_val)
        cmap = cm.get_cmap(colormap)
        
        # Generate colors from values using the colormap
        colors = cmap(norm(values))[:, :3].astype(np.float32)
        
        # Update positions and colors in the buffer
        self.positions_attr.array = positions.astype(np.float32)
        self.colors_attr.array = colors
        
        # Mark attributes as needing update
        self.positions_attr.needsUpdate = True
        self.colors_attr.needsUpdate = True
        
        # Update colorbar with new min/max values
        self.colorbar_widget.value = f"""
        <div style="
            width: 100%; 
            height: 20px; 
            background: linear-gradient(to right, 
                #0d0887, #41049d, #6a00a8, #8f0da4, #b12a90, 
                #cc4778, #e16462, #f2844b, #fca636, #fcce25, #f0f921
            );
            border-radius: 3px;
            margin-top: 5px;
        "></div>
        <div style="
            display: flex;
            justify-content: space-between;
            font-family: Arial;
            font-size: 12px;
            margin-top: 2px;
        ">
            <span>{min_val:.4f}</span>
            <span>{max_val:.4f}</span>
        </div>
        """
        
        # Adjust camera if needed (for first update)
        if not self.first_update:
            # Calculate bounding box of points
            min_bounds = np.min(positions, axis=0)
            max_bounds = np.max(positions, axis=0)
            center = (min_bounds + max_bounds) / 2
            
            # Set camera to look at center of points
            self.controls.target = center.tolist()
            
            # Reset camera position relative to the center
            max_dim = np.max(max_bounds - min_bounds)
            camera_dist = max_dim * 2
            self.camera.position = [
                center[0] + camera_dist,
                center[1] + camera_dist,
                center[2] + camera_dist
            ]
            
            self.first_update = True

In [43]:
def train_model(model, num_iterations=10000, lr=0.005):
    optimizer = to.optim.Adam(model.parameters(), lr=lr)
    context = RenderContext()
    display(context.canvas)
    
    for iteration in tqdm(range(num_iterations)):
        optimizer.zero_grad()
        loss = model.harmonic_loss()
        loss.backward()
        optimizer.step()
        
        # Update visualization every few iterations to avoid slowdown
        if iteration % 10 == 0:
            positions = (model.ray_oris + model.blended_tvals * model.ray_dirs).detach().numpy()
            blended_tvals = model.blended_tvals.detach().numpy()
            context.update(positions, blended_tvals)

In [44]:
model = GaussianParameters("point_cloud.ply")
train_model(model=model, num_iterations=10)
print(model.harmonic_loss())

VBox(children=(Renderer(camera=PerspectiveCamera(aspect=1.3333333333333333, position=(5.0, 5.0, 5.0), projecti…

  0%|          | 0/10 [00:00<?, ?it/s]

  cmap = cm.get_cmap(colormap)


tensor([[0.0808]], grad_fn=<MmBackward0>)


In [42]:
# get snapshot : returns current solution using train_model.harmonic_loss()

In [45]:
model = GaussianParameters('point_cloud.ply')
ray_oris, ray_dirs = generate_sphere_rays(to.tensor([0.0,0.0,0.0]), 10.0, 100)
laplacian = compute_graph_laplacian(ray_oris, 1.0, 10.0)
projected_value = model.project()
projected_value.T @ laplacian @ projected_value

tensor([[84.7075]], grad_fn=<MmBackward0>)

In [7]:
import numpy as np

# test_example = true

import fastplotlib as fpl
print(fpl.__version__)
figure = fpl.Figure(size=(700, 560))

data = np.array(
    [[0, 1, 2],
     [3, 4, 5]]
)
image_graphic = figure[0, 0].add_image(data)

figure.show()

0.1.0.a16


AttributeError: module 'fastplotlib' has no attribute 'Figure'