In [None]:
import torch
from torch import nn
from skimage import measure
import open3d as o3d
# import your libraries

from IGR.code.model.network import ImplicitNet

In [None]:
def load_pointcloud(filename):
    pcd = o3d.io.read_point_cloud(filename)
    points = torch.tensor(np.asarray(pcd.points), dtype=torch.float32)
    return points
    
def write_mesh(v,f,filename):
    mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(v),o3d.utility.Vector3iVector(f))
    o3d.io.write_triangle_mesh(filename,mesh)
    
def write_pointcloud(p,filename):
    pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(p))
    o3d.io.write_point_cloud(filename,pc)

# class ImplcitNetwork(nn.Module):
#     def __init__(self):
#         pass

#     def forward(self, x):
#         pass

#     def phase_loss(self, x):
#         pass

class PHASELoss(nn.Module):
    def __init__(self, epsilon=0.01, lambda_val=0.1, mu=0.1, ball_radius=0.01, use_normals=False):
        """
        Args:
            epsilon: Regularization parameter that controls smoothness
            lambda_val: Weight for the reconstruction loss
            mu: Weight for the normal/gradient constraint loss
            ball_radius: Radius of balls around point samples for reconstruction loss
            use_normals: If True, uses provided normals; otherwise enforces unit gradients
        """
        super(PHASELoss, self).__init__()
        self.epsilon = epsilon
        self.lambda_val = lambda_val
        self.mu = mu
        self.ball_radius = ball_radius
        self.use_normals = use_normals
        
    def double_well_potential(self, x):
        return x**2 - 2*torch.abs(x) + 1
    
    def reconstruction_loss(self, u, points, sample_count=50):
        """
        Args:
            u: Neural network representing the signed density
            points: Input point cloud (B x N x 3)
            sample_count: Number of points to sample in each ball
        """
        batch_size = points.shape[0]
        n_points = points.shape[1]
        
        u_values = []

        # Sample points from random subset of input points
        for i in range(sample_count):
            idx = torch.randint(0, n_points, (batch_size,  ))
            selected_points = torch.gather(points, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
            
            # Generate random offsets within ball_radius
            random_offsets = torch.randn_like(selected_points)
            random_offsets = random_offsets / torch.norm(random_offsets, dim=-1, keepdim=True)
            random_offsets = random_offsets * self.ball_radius * torch.rand_like(random_offsets[..., :1])
            
            # Sample points within balls
            sampled_points = selected_points + random_offsets
            
            # Evaluate network at sampled points

            print(u(sampled_points).shape)
            dsalmk

            u_values.append(u(sampled_points))
        
        # Compute average value in each ball
        return torch.mean(torch.abs(torch.stack(u_values, dim = 0)))
    
    def gradient_loss(self, u, w, points, normals=None):
        """
        Computes the gradient constraint loss:
        - If normals are provided, aligns gradients with normals
        - Otherwise, enforces unit gradient norm
        
        Args:
            u: Neural network representing the signed density
            w: Log-transformed function (-sqrt(epsilon) * log(1-|u|) * sign(u))
            points: Input point cloud
            normals: Surface normals (optional)
        """
        points.requires_grad_(True)
        w_val = w(points)
        
        # Compute gradients of w with respect to input points
        grad_outputs = torch.ones_like(w_val)
        gradients = torch.autograd.grad(
            outputs=w_val,
            inputs=points,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        if self.use_normals and normals is not None:
            # Align gradients with provided normals
            return F.l1_loss(gradients, normals)
        else:
            # Enforce unit gradient norm
            gradient_norm = torch.norm(gradients, dim=-1)
            return torch.mean((gradient_norm - 1.0)**2)
    
    def forward(self, model, points, normals=None):
        """
        Computes the complete PHASE loss
        
        Args:
            model: Neural network model for signed density function
            points: Input point cloud
            normals: Surface normals (optional)
        """
        # Get the signed density function values
        u = lambda x: model(x)
        
        # Define the log-transformed function w (the smoothed SDF)
        w = lambda x: -torch.sqrt(self.epsilon) * torch.log(1 - torch.abs(u(x))) * torch.sign(u(x))
        
        # Sample random points in the domain for regularization term
        batch_size = points.shape[0]
        domain_points = torch.rand((batch_size, 1000, 3), device=points.device) * 2 - 1
        
        # Evaluate model on random domain points
        u_domain = u(domain_points)
        
        # Calculate the gradient of u
        domain_points.requires_grad_(True)
        grad_outputs = torch.ones_like(u_domain)
        grad_u = torch.autograd.grad(
            outputs=u_domain, 
            inputs=domain_points,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        # Double-well potential term
        double_well_term = torch.mean(self.double_well_potential(u_domain))
        
        # Gradient regularization term
        grad_term = torch.mean(self.epsilon * torch.sum(grad_u**2, dim=-1))
        
        # Reconstruction loss
        recon_loss = self.reconstruction_loss(u, points)
        
        # Normal/gradient constraint loss
        normal_loss = self.gradient_loss(u, w, points, normals)
        
        # Total loss
        total_loss = grad_term + double_well_term + self.lambda_val * recon_loss + self.mu * normal_loss
        
        return total_loss, {
            'grad_term': grad_term.item(),
            'double_well': double_well_term.item(),
            'reconstruction': recon_loss.item(),
            'normal_constraint': normal_loss.item()
        }

In [None]:
# load point cloud
points = load_pointcloud('bunny.ply')

# instantiate the model and optimizer
model = ImplicitNet(d_in = 3, dims = [ 512, 512, 512, 512, 512, 512, 512, 512 ], skip_in = [4], geometric_init = True)
opt = torch.optim.Adam(
            [
                {
                    "params": model.parameters(),
                    "lr": 0.005,
                    "weight_decay": 0
                },
            ])

In [None]:
def compute_chamfer_distance(pred_points, gt_points):
    """
    Args:
        pred_points (torch.Tensor): Predicted point cloud (N x 3)
        gt_points (torch.Tensor): Ground truth point cloud (M x 3)
    """
    # Ensure inputs are on the same device
    device = pred_points.device
    
    # Compute all pairwise distances
    pred_expanded = pred_points.unsqueeze(1)  # (N, 1, 3)
    gt_expanded = gt_points.unsqueeze(0)      # (1, M, 3)
    
    # Compute squared distances
    dist_matrix = torch.sum((pred_expanded - gt_expanded) ** 2, dim=-1)  # (N, M)
    
    # Compute minimum distances in both directions
    dist_pred_to_gt = torch.min(dist_matrix, dim=1)[0]  # (N,)
    dist_gt_to_pred = torch.min(dist_matrix, dim=0)[0]  # (M,)
    
    # Average the distances (symmetric Chamfer distance)
    chamfer_dist = torch.mean(dist_pred_to_gt) + torch.mean(dist_gt_to_pred)
    
    return chamfer_dist.item()

def sample_mesh_points(mesh_path, n_points=10000):
    """
    Args:
        mesh_path (str): Path to the mesh file
        n_points (int): Number of points to sample    
    """
    mesh = o3d.io.read_triangle_mesh(mesh_path)
    pcd = mesh.sample_points_uniformly(number_of_points=n_points)
    points = torch.tensor(np.asarray(pcd.points), dtype=torch.float32)
    return points

def evaluate_reconstruction(model, gt_mesh_path, resolution=64, bounds=(-1.0, 1.0), n_points=10000):
    """
    Args:
        model: Neural network model for implicit function
        gt_mesh_path (str): Path to ground truth mesh
        resolution (int): Resolution for marching cubes grid
        bounds (tuple): Min and max bounds for the grid
        n_points (int): Number of points to sample for Chamfer distance    
    """
    with torch.no_grad():
        # Create grid for marching cubes
        x = np.linspace(bounds[0], bounds[1], resolution)
        y = np.linspace(bounds[0], bounds[1], resolution)
        z = np.linspace(bounds[0], bounds[1], resolution)
        
        X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
        points = torch.tensor(np.stack([X.flatten(), Y.flatten(), Z.flatten()], axis=1), 
                              dtype=torch.float32)
        
        # Process in batches to avoid memory issues
        batch_size = 10000
        sdf_grid = []
        for i in range(0, points.shape[0], batch_size):
            batch_points = points[i:i+batch_size]
            sdf_batch = model(batch_points).detach().cpu().numpy()
            sdf_grid.append(sdf_batch)
        
        sdf_grid = np.concatenate(sdf_grid, axis=0).reshape(resolution, resolution, resolution)
        
        # Generate mesh using marching cubes
        v, f, _, _ = measure.marching_cubes_lewiner(sdf_grid, 0, gradient_direction='ascent')
        
        # Scale vertices back to original coordinate system
        v = v / (resolution - 1) * (bounds[1] - bounds[0]) + bounds[0]
        
        # First save reconstructed mesh to temporary file
        temp_mesh_path = 'temp_reconstruction.ply'
        write_mesh(v, f, temp_mesh_path)
        
        # Sample points from both meshes
        pred_points = sample_mesh_points(temp_mesh_path, n_points)
        gt_points = sample_mesh_points(gt_mesh_path, n_points)
        
        # Compute Chamfer distance
        chamfer_dist = compute_chamfer_distance(pred_points, gt_points)
        
    return chamfer_dist, v, f

In [None]:
iters=100000
lam, eps, mu = [10, 0.01, 10]
loss_fn = PHASELoss(epsilon=eps, lambda_val=lam, mu=mu, ball_radius=0.001, use_normals=False)

gt_mesh_path = "Preimage_Implicit_DLTaskData/armadillo_10000.xyz"

normals = False

if normals:
    # Load normals if available
    normals = load_pointcloud('bunny_normals.ply')
else:    
    normals = None

model.train()

for i in range(iters):
    # compute loss and train network
    gt_points = sample_mesh_points(gt_mesh_path, n_points=10000)

    # Zero gradients at the start of each iteration
    opt.zero_grad()
    
    # Get a batch of points for training
    # For simplicity, we use the entire point cloud, but you could implement batch sampling here
    points_batch = gt_points.unsqueeze(0)  # Add batch dimension
    
    # Forward pass and compute loss
    if normals is not None:
        loss, loss_components = loss_fn(model, points_batch, normals)
    else:
        loss, loss_components = loss_fn(model, points_batch)
    
    # Backward pass
    loss.backward()
    
    # Update parameters
    opt.step()
    
    # Print progress
    if i %1000 == 0:
        # run evaluation 
        chamfer_dist, v, f = evaluate_reconstruction(model, gt_mesh_path, resolution=64, bounds=(-1.0, 1.0), n_points=10000)

        print(f"Iter {i}/{iters}, Loss: {loss.item():.6f}, "
              f"Grad: {loss_components['grad_term']:.6f}, "
              f"DW: {loss_components['double_well']:.6f}, "
              f"Recon: {loss_components['reconstruction']:.6f}, "
              f"Norm: {loss_components['normal_constraint']:.6f}")

        print(f"Chamfer distance: {chamfer_dist:.6f}")

        # create mesh with marching cubes
        write_mesh(v,f,f'intermediates/mesh_{i}.ply')