In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial import Delaunay
import plotly.graph_objects as go
from plotly.subplots import make_subplots



class OrderPredictor(nn.Module):
    def __init__(self):
        super(OrderPredictor, self).__init__()
        # Input: sphere_center (3), sphere_radius (1), ray_origin (3)
        self.net = nn.Sequential(
            nn.Sigmoid(),
            nn.Linear(3*4 + 9, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, tetra_points, sphere_center, sphere_radius, ray_origin):
        tetra_points=tetra_points.reshape(-1,4,3) - ray_origin.reshape(-1,1,3)
        x = torch.cat([
            tetra_points.reshape(-1,12),
            torch.linalg.norm(tetra_points, dim=2),
            sphere_center - ray_origin,
            torch.linalg.norm(sphere_center - ray_origin, dim=1, keepdim=True),
            sphere_radius
        ], dim=1)
        # return torch.abs(self.net(x))
        return torch.linalg.norm(sphere_center - ray_origin, dim=1, keepdim=True) - sphere_radius

def generate_random_points(num_points):
    """Generate random 3D points."""
    return np.random.rand(num_points, 3)

def compute_circumspheres(points, simplices):
    """Compute circumsphere centers and radii for each tetrahedron."""
    centers = []
    radii = []
    
    for simplex in simplices:
        tetra_points = points[simplex]
        # Create matrix A for circumcenter computation
        # For a tetrahedron, we need three equations
        A = np.zeros((3, 3))
        b = np.zeros(3)
        
        # Set up linear system for circumcenter computation
        for i in range(3):
            A[i] = 2 * (tetra_points[i+1] - tetra_points[0])
            b[i] = np.sum(tetra_points[i+1]**2 - tetra_points[0]**2)
        
        try:
            # Solve for circumcenter
            center = np.linalg.solve(A, b)
            # Compute radius
            radius = np.sqrt(np.sum((tetra_points[0] - center)**2))
            
            centers.append(center)
            radii.append(radius)
        except np.linalg.LinAlgError:
            # Handle degenerate cases
            print("Warning: Degenerate tetrahedron encountered")
            center = np.mean(tetra_points, axis=0)
            radius = np.max(np.linalg.norm(tetra_points - center, axis=1))
            centers.append(center)
            radii.append(radius)
    
    return np.array(centers), np.array(radii)

def compute_ray_tetra_intersections(ray_origin, ray_dir, points, simplices):
    """Compute the true ordering of intersections."""
    distances = []
    
    for simplex in simplices:
        tetra_points = points[simplex]
        # Implement ray-tetrahedron intersection
        # This is a simplified version - you'll want to implement proper intersection
        center = np.mean(tetra_points, axis=0)
        dist = np.dot(center - ray_origin, ray_dir)
        distances.append(dist)
    
    return np.argsort(distances)

def ranking_loss(predictions, true_order):
    """
    Improved ranking loss using all-to-all comparisons with margin.
    For each pair (i,j), if i comes before j in true_order, 
    predictions[i] should be less than predictions[j].
    """
    n = len(predictions)
    margin = 1.0  # Margin for separation between ordered pairs
    
    # Convert true_order to PyTorch tensor if it isn't already
    true_order = torch.tensor(true_order, device=predictions.device)
    
    # Create indices for all pairs
    i_indices, j_indices = torch.triu_indices(n, n, offset=1)
    
    # Get predicted values for all pairs
    pred_i = predictions[i_indices]
    pred_j = predictions[j_indices]
    
    # Get true ordering for all pairs
    true_i = true_order[i_indices]
    true_j = true_order[j_indices]
    
    # Compute mask for pairs where i should be less than j
    should_be_less = (true_i < true_j).float()
    
    # Compute mask for pairs where j should be less than i
    should_be_greater = (true_i > true_j).float()
    
    # Compute losses for both cases
    loss_less = torch.relu(pred_i - pred_j + margin) * should_be_less
    loss_greater = torch.relu(pred_j - pred_i + margin) * should_be_greater
    
    # Sum all losses
    loss = torch.sum(loss_less + loss_greater)# / (pred_i.sum() + pred_j.sum()) * 2
    
    return loss

In [10]:
def generate_batch_data(batch_size, points_per_scene):
    """Vectorized batch data generation."""
    # Generate all points at once (batch_size, points_per_scene, 3)
    all_points = np.random.rand(batch_size, points_per_scene, 3)
    
    # Generate all ray origins and directions at once
    ray_origins = np.random.rand(batch_size, 3)
    ray_dirs = np.random.rand(batch_size, 3)
    ray_dirs /= np.linalg.norm(ray_dirs, axis=1, keepdims=True)
    
    all_centers = []
    all_radii = []
    all_ray_origins = []
    all_true_orders = []
    scene_sizes = []
    all_tetra_points = []
    
    # Still need a loop for Delaunay, but everything else is vectorized
    for i in range(batch_size):
        tri = Delaunay(all_points[i])
        tetra_points = all_points[i][tri.simplices]  # (num_tetras, 4, 3)
        
        # Vectorized circumcenter computation
        p0 = tetra_points[:, 0]  # (num_tetras, 3)
        p1_p0 = tetra_points[:, 1] - p0  # (num_tetras, 3)
        p2_p0 = tetra_points[:, 2] - p0
        p3_p0 = tetra_points[:, 3] - p0
        
        # Stack equations for all tetrahedra at once
        A = np.stack([
            2 * p1_p0,
            2 * p2_p0,
            2 * p3_p0
        ], axis=1)  # (num_tetras, 3, 3)
        
        b = np.sum(tetra_points[:, 1:]**2 - p0[:, np.newaxis]**2, axis=2)  # (num_tetras, 3)
        
        # Solve all systems at once using batched solve
        try:
            centers = np.linalg.solve(A, b)
            radii = np.sqrt(np.sum((p0 - centers)**2, axis=1))
        except np.linalg.LinAlgError:
            # Fallback for degenerate cases
            centers = np.mean(tetra_points, axis=1)
            radii = np.max(np.linalg.norm(
                tetra_points - centers[:, np.newaxis], axis=2), axis=1)
        
        # Compute ray intersections (vectorized)
        ray_origin = ray_origins[i]
        ray_dir = ray_dirs[i]
        
        # Vectorized intersection computation
        # Using mean point for simplicity - replace with proper intersection
        mean_points = np.mean(tetra_points, axis=1)  # (num_tetras, 3)
        vectors_to_center = mean_points - ray_origin
        distances = np.dot(vectors_to_center, ray_dir)
        true_order = np.argsort(distances)
        
        num_tetras = len(centers)
        scene_sizes.append(num_tetras)
        
        all_centers.append(centers)
        all_radii.append(radii)
        all_ray_origins.append(np.tile(ray_origin, (num_tetras, 1)))
        all_true_orders.append(true_order)
        all_tetra_points.append(torch.as_tensor(tetra_points).float())
    
    # Pack everything into tensors
    centers_tensor = torch.FloatTensor(np.concatenate(all_centers, axis=0))
    radii_tensor = torch.FloatTensor(np.concatenate(all_radii, axis=0)).unsqueeze(1)
    ray_origins_tensor = torch.FloatTensor(np.concatenate(all_ray_origins, axis=0))
    scene_sizes = torch.LongTensor(scene_sizes)
    # all_tetra_points = torch.stack(all_tetra_points, dim=0)
    
    return centers_tensor, radii_tensor, ray_origins_tensor, all_true_orders, scene_sizes, all_tetra_points


In [11]:

def create_visualization(tetra_points, true_order, predicted_values, ray_origin=None):
    """
    Create three side-by-side 3D scatter plots showing the tetrahedra with different color schemes.
    
    Args:
        tetra_points: (N, 3) array of vertex coordinates
        true_order: (M,) array of true ordering
        predicted_values: (M,) array of network outputs
        ray_origin: optional (3,) array for ray origin point
    """
    # Create figure with 3 3D subplots
    fig = make_subplots(
        rows=1, cols=3,
        specs=[[{'type': 'scene'}, {'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=('Network Output Values', 'True Order', 'Predicted Order')
    )
    
    predicted_order = np.argsort(predicted_values)
    
    # Function to create edges for a tetrahedron
    def get_tetra_edges(tetra_points):
        # Define edges as pairs of vertex indices
        edges = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
        x, y, z = [], [], []
        
        # Add each edge with a None to break the line between edges
        for start, end in edges:
            x.extend([tetra_points[start,0], tetra_points[end,0], None])
            y.extend([tetra_points[start,1], tetra_points[end,1], None])
            z.extend([tetra_points[start,2], tetra_points[end,2], None])
        return x, y, z
    
    # Function to create a single tetrahedron plot
    def add_tetrahedra(tetra_points, values, colorscale, row, col, value_prefix):
        for i, value in enumerate(values):
            itetra_points = tetra_points.reshape(-1, 4, 3)[i]
            x, y, z = get_tetra_edges(itetra_points)
            
            # Normalized color for this tetrahedron
            norm_value = (value - np.min(values)) / (np.max(values) - np.min(values))
            
            # Add edges
            fig.add_trace(
                go.Scatter3d(
                    x=x, y=y, z=z,
                    mode='lines',
                    line=dict(color=f'rgb({int(255*norm_value)},0,{int(255*(1-norm_value))})', width=2),
                    showlegend=False,
                    hoverinfo='text',
                    text=[f'{value_prefix}: {value}'] * len(x)
                ),
                row=row, col=col
            )
            
            # Add semi-transparent faces
            for face in [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]:
                fig.add_trace(
                    go.Mesh3d(
                        x=itetra_points[list(face) + [face[0]], 0],
                        y=itetra_points[list(face) + [face[0]], 1],
                        z=itetra_points[list(face) + [face[0]], 2],
                        color=f'rgb({int(255*norm_value)},0,{int(255*(1-norm_value))})',
                        opacity=0.2,
                        hoverinfo='text',
                        text=f'{value_prefix}: {value}',
                        showlegend=False
                    ),
                    row=row, col=col
                )
    
    # Add tetrahedra to each subplot
    add_tetrahedra(tetra_points, predicted_values, 'Viridis', 1, 1, 'Value')
    add_tetrahedra(tetra_points, true_order, 'Viridis', 1, 2, 'True Order')
    add_tetrahedra(tetra_points, predicted_order, 'Viridis', 1, 3, 'Predicted Order')
    
    # Add ray origin if provided
    if ray_origin is not None:
        for i in range(1, 4):
            fig.add_trace(
                go.Scatter3d(
                    x=[ray_origin[0]],
                    y=[ray_origin[1]],
                    z=[ray_origin[2]],
                    mode='markers',
                    marker=dict(
                        size=10,
                        color='red',
                        symbol='diamond'
                    ),
                    name='Ray Origin',
                    showlegend=i==1,
                ),
                row=1, col=i
            )
    
    # Update layout
    fig.update_layout(
        height=400,
        width=1200,
        title_text="Training Visualization",
        showlegend=True,
    )
    
    # Make all subplots have the same camera view
    scene_settings = dict(
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        ),
        aspectmode='data'
    )
    
    fig.update_scenes(scene_settings)
    
    return fig

def update_training_visualization(predictions, centers, ray_origin, true_order, tetra_points):
    """
    Create visualization during training.
    """
    # Create and show visualization
    fig = create_visualization(tetra_points, true_order, predictions.squeeze(), ray_origin)
    fig.show()

In [12]:

def train_epoch(model, optimizer, num_batches=100, batch_size=32, points_per_scene=10, visualize=False):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    avg_corrects = []
    
    for batch_idx in range(num_batches):
        # Generate batch data using vectorized operations
        centers_torch, radii_torch, ray_origins_torch, true_orders, scene_sizes, tetra_points_torch = \
            generate_batch_data(batch_size, points_per_scene)
        
        # Forward pass
        tetra_points_cat = torch.cat(tetra_points_torch, dim=0).reshape(-1, 3*4)
        predictions = model(tetra_points_cat, centers_torch, radii_torch, ray_origins_torch)
        
        # Compute loss across all scenes efficiently
        start_idx = 0
        batch_loss = 0
        
        for scene_idx, size in enumerate(scene_sizes):
            scene_preds = predictions.squeeze()[start_idx:start_idx + size]
            scene_true_order = true_orders[scene_idx]
            avg_correct = (torch.tensor(scene_true_order) == torch.argsort(scene_preds)).float().mean()
            avg_corrects.append(avg_correct)
            iloss = ranking_loss(scene_preds, scene_true_order)
            batch_loss += iloss
            
            # Visualize first scene of first batch if requested
            if visualize and batch_idx == 0 and scene_idx == 0:
                scene_centers = centers_torch[start_idx:start_idx + size].cpu().numpy()
                scene_ray_origin = ray_origins_torch[start_idx].cpu().numpy()
                scene_tetra_points = tetra_points_cat[start_idx:start_idx + size].cpu().numpy()
                
                update_training_visualization(
                    predictions[start_idx:start_idx+size].detach().cpu().numpy(),
                    scene_centers, 
                    scene_ray_origin, 
                    scene_true_order,
                    scene_tetra_points
                )
            
            start_idx += size
        
        loss = batch_loss / batch_size
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / num_batches, torch.mean(torch.tensor(avg_corrects))

def main():
    model = OrderPredictor()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    num_epochs = 100
    for epoch in range(num_epochs+1):
        # Visualize every 10 epochs
        visualize = (epoch % 10 == 0)
        avg_loss, accuracy = train_epoch(model, optimizer, visualize=visualize)
        if epoch % 1 == 0:
            print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f} Accuracy: {accuracy}")
    
    return model

if __name__ == "__main__":
    main()


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn