In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from plyfile import PlyData, PlyElement
import tqdm
import sys 
sys.path.append('../')

import radfoam
from radfoam_model.render import TraceRays
from radfoam_model.utils import *

import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.spatial import KDTree
import random

In [None]:
def generate_fixed_color_map(point_cloud):
    """Assigns a fixed random color to each unique point index."""
    unique_indices = np.arange(len(point_cloud))  # Each point has a fixed index
    color_map = {idx: np.random.rand(3) for idx in unique_indices}  # Assign random colors
    return color_map

def triangle_case1(tet, values, points):
    diff = torch.tensor(tet[values>0]).repeat_interleave(3).long()
    rest = tet[values<0]
    new_points = (points[diff]+points[rest])/2.
    new_tri = torch.arange(len(new_points)).reshape(len(new_points)//3,3)
    return new_points, new_tri

def triangle_case2(tet, values, points):
    diff = torch.tensor(tet[values>0]).long()
    rest = tet[values<0]
    p1p2 = (points[diff]+ points[rest])/2.
    p1 = p1p2[::2]
    p2 = p1p2[1::2]
    p3 = (points[diff][::2] + points[rest][1::2])/2.
    p4 = (points[diff][1::2] + points[rest][::2])/2.
    new_points = torch.cat((p1,p2,p3,p4))
    ls = len(p1)
    new_tri = torch.tensor([[0,2*ls,3*ls], [1*ls,3*ls,2*ls]]).repeat(ls,1)
    new_tri += torch.arange(len(p1)).repeat_interleave(2)[:, None]
    return new_points, new_tri

def marching_tetraheadra(tet_adjacency, primal_values, points):
    values = primal_values[tet_adjacency]
    pos = (values>0).sum(1).long()
    # triangle case 
    new_v, new_f = [], []
    cur_ind = 0
    for i in [1,3]:
        if (pos==i).sum()>0:
            new_points, new_tri = triangle_case1(tet_adjacency[pos==i], (-1)**(i//2)*values[pos==i], points)
            new_v.append(new_points)
            new_f.append(cur_ind+new_tri)
            cur_ind += len(new_points)
    new_points, new_tri = triangle_case2(tet_adjacency[pos==2], values[pos==2], points)
    new_v.append(new_points)
    new_f.append(cur_ind+new_tri)
    return torch.cat(new_v).cpu().detach().numpy(), torch.cat(new_f).cpu().detach().numpy()

def tets_to_edges(tets): 
    edges = []
    edges.append(torch.column_stack((tets[:, 0], tets[:, 1])))
    edges.append(torch.column_stack((tets[:, 0], tets[:, 2])))
    edges.append(torch.column_stack((tets[:, 0], tets[:, 3])))
    edges.append(torch.column_stack((tets[:, 1], tets[:, 2])))
    edges.append(torch.column_stack((tets[:, 1], tets[:, 3])))
    edges.append(torch.column_stack((tets[:, 2], tets[:, 3])))
    return torch.cat(edges).view(-1,2)

def apply_lloyd_iterations(primal_points,triangulation,n_iterations=100):
    
    needs_permute = False
    perturbation = 1e-6
    failures = 0
    for _ in range(n_iterations):
        point_adjacency = triangulation.point_adjacency()
        point_adjacency_offsets = triangulation.point_adjacency_offsets()
        new_centroids = radfoam.centroidal_voronoi_tessellation_single_iteration(primal_points, point_adjacency, point_adjacency_offsets)
        is_not_extreme = (torch.norm(primal_points,dim=1,keepdim=True) < 30).any(1)
        primal_points[is_not_extreme,:] = new_centroids[is_not_extreme,:]
        # primal_points = new_centroids
        while True:
            if failures > 25:
                raise RuntimeError("aborted triangulation after 25 attempts")
            try:
                needs_permute = triangulation.rebuild(
                    primal_points, incremental=True
                )
                break
            except radfoam.TriangulationFailedError as e:
                print("caught: ", e)
                perturbation *= 2
                failures += 1
                incremental = False
                with torch.no_grad():
                    primal_points = (
                        primal_points
                        + perturbation * torch.randn_like(primal_points)
                    )
            if needs_permute:
                perm = triangulation.permutation().to(torch.long)
                primal_points = primal_points[perm]
    return primal_points

def visualize_2D_plane(points,min_max=None,color_map=None):
    
    # Take a 2D slice at z = 0
    # Determine the grid bounds based on min/max values of the point cloud
    if min_max is None:
        # x_min, y_min = np.min(points[:, :2], axis=0)
        # x_max, y_max = np.max(points[:, :2], axis=0)
        x_min, y_min = -0,-0
        x_max, y_max = 5, 5
    else:
        x_min,x_max,y_min,y_max = min_max

    # Define grid resolution
    grid_size = 400  # Adjust for desired resolution
    x_vals = np.linspace(x_min, x_max, grid_size)
    y_vals = np.linspace(y_min, y_max, grid_size)

    # Create a 2D grid in the xy-plane at z = 0, now in 3D
    grid_x, grid_y = np.meshgrid(x_vals, y_vals)
    xy_plane = np.column_stack((np.zeros(grid_x.size),grid_x.ravel(), grid_y.ravel()))  # Shape (grid_size^2, 3)

    tree = KDTree(points)  # Use only (x, y) coordinates
    tmp, indices = tree.query(xy_plane)  # Find nearest point indices
    # Generate a colormap where each index has a unique color

    # Generate random colors for each unique index
    unique_indices = np.unique(indices)
    if color_map is None:
        color_map = {idx: np.random.rand(3) for idx in unique_indices}  # Assign random RGB color to each index

    # Convert index assignments to a 2D image
    image = np.zeros((grid_size, grid_size, 3))  # RGB image
    for i, (_, x, y) in enumerate(xy_plane):
        x_idx = np.searchsorted(x_vals, x) - 1  # Convert to image indices
        y_idx = np.searchsorted(y_vals, y) - 1  
        if indices[i] not in color_map:
            color_map[indices[i]] = np.random.rand(3)
        image[y_idx, x_idx] = color_map[indices[i]]
    
    return (x_min,x_max,y_min,y_max), image, color_map

def plot_voronoi_images(image1, image2, title1="Voronoi Diagram Before", title2="CVT",save_filename=None):
    """Plots two Voronoi images side by side in a clean layout."""
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # Two side-by-side plots

    # Display Image 1
    axes[0].imshow(image1)
    axes[0].set_title(title1, fontsize=14, fontweight="bold")
    # axes[0].axis("off")  # Hide axes for a clean look

    # Display Image 2
    axes[1].imshow(image2)
    axes[1].set_title(title2, fontsize=14, fontweight="bold")
    axes[1].axis("off")  

    # Adjust layout for aesthetics
    # plt.tight_layout()
    if save_filename is None:
        plt.show()
    else:
        plt.savefig(save_filename,dpi=300)
        plt.close(fig)


In [None]:
PATH_TO_PRETRAINED = "../../radfoam_original/output/bonsai/scene.ply"
import open3d as o3d

# Load the PLY file
def load_ply(file_path):
    point_cloud = o3d.io.read_point_cloud(file_path)
    return point_cloud

# Example usage
file_path = PATH_TO_PRETRAINED
pcd = load_ply(file_path)

N_lloyd = 100
primal_points = torch.from_numpy(np.asarray(pcd.points)).float().cuda()

triangulation = radfoam.Triangulation(primal_points)
perm = triangulation.permutation().to(torch.long)
primal_points = primal_points[perm]

color_map = generate_fixed_color_map(primal_points)

In [None]:
(x_min,x_max,y_min,y_max), image_voronoi_plane_before, color_map = visualize_2D_plane(primal_points.clone().detach().cpu().numpy(), color_map=color_map)

SAVE_PATH = "2D_slide_cvt_experiment/frames"
os.makedirs(SAVE_PATH,exist_ok=True)

plot_voronoi_images(image_voronoi_plane_before,image_voronoi_plane_before,save_filename=os.path.join(SAVE_PATH,f"frame_{0}.png"))

# Apply lloyd iterations N times. If you want this to run faster comment the plotting and visualization functions.
for frame_number in tqdm(range(1,N_lloyd)):
    primal_points = apply_lloyd_iterations(primal_points,triangulation,n_iterations=1)
    triangulation = radfoam.Triangulation(primal_points)
    perm = triangulation.permutation().to(torch.long)
    primal_points = primal_points[perm]
    new_color_map = {}
    for idx in color_map:
        new_color_map[idx] = color_map[int(perm[idx].item())]
    color_map = new_color_map
    _, image_voronoi_plane_after, color_map= visualize_2D_plane(primal_points.clone().detach().cpu().numpy(),min_max=(x_min,x_max,y_min,y_max),color_map=color_map)   

    plot_voronoi_images(image_voronoi_plane_before,image_voronoi_plane_after,save_filename=os.path.join(SAVE_PATH,f"frame_{frame_number}.png"))

In [None]:
# Function to save the video
SAVE_PATH_VIDEO = "2D_slide_cvt_experiment/video"
os.makedirs(SAVE_PATH_VIDEO,exist_ok=True)
video_path = f"{SAVE_PATH_VIDEO}/output.mp4"
frame_folder = SAVE_PATH
ffmpeg_cmd = f"ffmpeg -framerate 30 -i {frame_folder}/frame_%d.png -c:v libx264 -pix_fmt yuv420p {video_path} -y"
os.system(ffmpeg_cmd)

In [None]:
import IPython.display as display
display.Video(video_path, embed=True)