In [None]:
# Cell 1: Import required libraries and set up MPS device
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# Check for MPS availability and set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS acceleration on Apple Silicon!")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using CUDA acceleration!")
else:
    device = torch.device("cpu")
    print("WARNING: CPU only, this will be slow!")

In [None]:
# Cell 2: Load your target images and create target point cloud
def load_and_preprocess_images(image_folder, target_size=(224, 224)):
    """Helper function to load images from folder"""
    # Your existing image loading code here
    return input_images

# Initialize source sphere mesh
src_mesh = ico_sphere(4, device)

# We will learn to deform the source mesh by offsetting its vertices
# The shape of the deform parameters is equal to the total number of vertices in src_mesh
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)

# The optimizer
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)

In [None]:
# Cell 3: Define visualization function
def plot_pointcloud(mesh, title=""):
    # Sample points uniformly from the surface of the mesh.
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter3D(x, z, -y)
    ax.set_xlabel('x')
    ax.set_ylabel('z')
    ax.set_zlabel('y')
    ax.set_title(title)
    ax.view_init(190, 30)
    plt.show()

In [None]:
# Cell 4: Training loop with memory optimizations
# Number of optimization steps (reduced for initial testing)
Niter = 500  # Reduced from 2000
# Weights for different losses
w_edge = 1.0 
w_normal = 0.01 
w_laplacian = 0.1 

# For tracking losses
losses_history = {
    'edge': [],
    'normal': [],
    'laplacian': [],
    'total': []
}

def train_step(src_mesh, deform_verts, optimizer):
    """Single training step with better memory management"""
    optimizer.zero_grad()
    
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    
    # Calculate losses with explicit memory clearing
    loss_edge = mesh_edge_loss(new_src_mesh)
    loss_normal = mesh_normal_consistency(new_src_mesh)
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
    
    # Weighted sum of the losses
    total_loss = (
        loss_edge * w_edge + 
        loss_normal * w_normal + 
        loss_laplacian * w_laplacian
    )
    
    # Store loss values before backward pass
    current_losses = {
        'edge': float(loss_edge.detach().cpu()),
        'normal': float(loss_normal.detach().cpu()),
        'laplacian': float(loss_laplacian.detach().cpu()),
        'total': float(total_loss.detach().cpu())
    }
    
    # Backward pass
    total_loss.backward()
    optimizer.step()
    
    # Clear some memory
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return current_losses, new_src_mesh

try:
    # Main optimization loop
    for i in tqdm(range(Niter)):
        # Training step
        current_losses, current_mesh = train_step(src_mesh, deform_verts, optimizer)
        
        # Store losses
        for loss_name, loss_value in current_losses.items():
            losses_history[loss_name].append(loss_value)
        
        # Periodic visualization (reduced frequency)
        if i % 50 == 0:  # Changed from 250 to 50
            print(f"\nIteration {i}")
            print(f"Total Loss: {current_losses['total']:.6f}")
            
            # Visualization with memory handling
            with torch.no_grad():
                plot_pointcloud(current_mesh, title=f"iter: {i}")
            
            # Force garbage collection
            import gc
            gc.collect()
            
            # Clear GPU memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
except Exception as e:
    print(f"Training interrupted: {str(e)}")
    # Save the latest state even if interrupted
    final_verts, final_faces = current_mesh.get_mesh_verts_faces(0)
    save_obj(f='interrupted_model.obj', verts=final_verts, faces=final_faces)
    raise e

# Plot losses
plt.figure(figsize=(10, 5))
for loss_name, loss_values in losses_history.items():
    plt.plot(loss_values, label=f'{loss_name} loss')
plt.legend()
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.show()