In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import datetime
from copy import deepcopy as dc
import random

In [2]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class Siren(nn.Module):
    def __init__(self, architecture, outermost_linear=False,
                 first_omega_0=60, hidden_omega_0=60):
        super().__init__()
        self.architecture = architecture
        in_features = architecture[0]
        out_features = architecture[-1]
        hidden_layers = len(architecture)-2

        self.loss_history = []
        self.optimizer = None
        self.name = "SIREN"
        self.lr_scheduler = None

        self.net = []
        self.net.append(SineLayer(in_features, architecture[1],
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers-1):
            self.net.append(SineLayer(architecture[i+1],architecture[i+2] ,
                                      is_first=False, omega_0=hidden_omega_0))
        self.net.append(SineLayer(architecture[-2],architecture[-2] ,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(architecture[-2], out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / architecture[-2]) / hidden_omega_0,
                                              np.sqrt(6 / architecture[-2]) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(architecture[-2], out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)


    def forward(self, coords):
        #coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output
class SIRELU(nn.Module):
    def __init__(self, architecture, outermost_linear=False,
                first_omega_0=60, hidden_omega_0=60):
        super().__init__()
        self.architecture = architecture
        in_features = architecture[0]
        out_features = architecture[-1]
        hidden_layers = len(architecture)-2

        self.loss_history = []
        self.optimizer = None
        self.name = "SIRELU"
        self.lr_scheduler = None

        self.net = []
        self.net.append(SineLayer(in_features, architecture[1],
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers-1):
            if i % 2 ==0:
                self.net.append(nn.Linear(architecture[i+1],architecture[i+2]))
                self.net.append(nn.ReLU())
            else:
                self.net.append(SineLayer(architecture[i+1],architecture[i+2] ,
                                      is_first=False, omega_0=hidden_omega_0))
        if hidden_layers % 2 == 1:
            self.net.append(nn.Linear(architecture[-2],architecture[-2]))
            self.net.append(nn.ReLU())
        else:
            self.net.append(SineLayer(architecture[-2],architecture[-2] ,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(architecture[-2], out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / architecture[-2]) / hidden_omega_0,
                                              np.sqrt(6 / architecture[-2]) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(architecture[-2], out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)


    def forward(self, coords):
        #coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output

In [4]:
def open_object(file_path):
    vertices = []
    faces = []
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('v '):
                parts = line.strip().split()
                vertex = list(map(float, parts[1:4]))
                vertices.append(vertex)
            elif line.startswith('f '):
                parts = line.strip().split()
                face = [int(part.split('/')[0]) - 1 for part in parts[1:4]]
                faces.append(face)
    return np.array(vertices), np.array(faces)

def create_torch_mesh(vertices, faces, device=None):
    if device is None:
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device)
    faces_tensor = torch.tensor(faces, dtype=torch.int64, device=device)
    return vertices_tensor, faces_tensor

def get_distance_from_contour(vertices, points):
    d = torch.cdist(points, vertices)
    min_distances, _ = torch.min(d, dim=1)
    return min_distances
def get_signed_distance_from_contour(vertices, faces, points):
    """
    Compute signed distance from points to a 3D mesh.
    Positive distances are inside the mesh, negative distances are outside.
    Optimized for runtime performance.
    """
    device = points.device
    batch_size = points.shape[0]
    
    # Compute unsigned distance (closest distance to surface vertices)
    # This is faster than computing distance to triangle faces for large meshes
    d = torch.cdist(points, vertices)
    min_distances, _ = torch.min(d, dim=1)
    
    # Optimized ray casting with precomputed triangle data
    ray_direction = torch.tensor([1.0, 0.0, 0.0], device=device)
    signs = torch.ones(batch_size, device=device)
    
    # Precompute triangle data once
    face_vertices = vertices[faces]  # Shape: [num_faces, 3, 3]
    v0 = face_vertices[:, 0]  # [num_faces, 3]
    v1 = face_vertices[:, 1]  # [num_faces, 3]
    v2 = face_vertices[:, 2]  # [num_faces, 3]
    
    # Precompute edge vectors
    edges1 = v1 - v0  # [num_faces, 3]
    edges2 = v2 - v0  # [num_faces, 3]
    
    # Precompute h vectors for all faces
    h_vectors = torch.cross(ray_direction.unsqueeze(0).expand(len(faces), -1), edges2, dim=1)
    determinants = torch.sum(edges1 * h_vectors, dim=1)
    
    # Filter out nearly parallel triangles once
    valid_mask = torch.abs(determinants) > 1e-8
    valid_faces_indices = torch.where(valid_mask)[0]
    
    if len(valid_faces_indices) == 0:
        # All faces are parallel to ray, assume all points are outside
        return -min_distances
    
    # Extract valid face data
    valid_v0 = v0[valid_faces_indices]
    valid_edges1 = edges1[valid_faces_indices]
    valid_edges2 = edges2[valid_faces_indices]
    valid_h = h_vectors[valid_faces_indices]
    valid_det = determinants[valid_faces_indices]
    inv_det = 1.0 / valid_det
    
    # Process points in batches to balance memory and speed
    batch_size_chunk = min(1000, batch_size)
    
    for start_idx in range(0, batch_size, batch_size_chunk):
        end_idx = min(start_idx + batch_size_chunk, batch_size)
        chunk_points = points[start_idx:end_idx]
        chunk_size = end_idx - start_idx
        
        # Vectorized intersection computation for this chunk
        intersection_counts = torch.zeros(chunk_size, device=device, dtype=torch.int32)
        
        for i, point in enumerate(chunk_points):
            # Compute s vectors for all valid faces
            s_vectors = point.unsqueeze(0) - valid_v0  # [num_valid_faces, 3]
            
            # Compute u coordinates
            u_coords = inv_det * torch.sum(s_vectors * valid_h, dim=1)
            
            # Early filtering for u coordinate bounds
            u_valid_mask = (u_coords >= 0.0) & (u_coords <= 1.0)
            u_valid_indices = torch.where(u_valid_mask)[0]
            
            if len(u_valid_indices) > 0:
                # Compute v coordinates only for valid u
                q_vectors = torch.cross(s_vectors[u_valid_indices], valid_edges1[u_valid_indices], dim=1)
                v_coords = inv_det[u_valid_indices] * torch.sum(ray_direction.unsqueeze(0) * q_vectors, dim=1)
                
                # Check v bounds and triangle constraint
                v_valid_mask = (v_coords >= 0.0) & (u_coords[u_valid_indices] + v_coords <= 1.0)
                v_valid_indices = u_valid_indices[v_valid_mask]
                
                if len(v_valid_indices) > 0:
                    # Compute intersection parameter t
                    final_q = q_vectors[v_valid_mask]
                    t_coords = inv_det[v_valid_indices] * torch.sum(valid_edges2[v_valid_indices] * final_q, dim=1)
                    
                    # Count intersections in front of ray
                    intersection_counts[i] = torch.sum(t_coords > 1e-8).item()
        
        # Determine signs for this chunk
        inside_mask = (intersection_counts % 2) == 1
        signs[start_idx:end_idx] = torch.where(inside_mask, 1.0, -1.0)
    
    return signs * min_distances
def train_model(model, optimizer, num_epoch, lr_scheduler = None):
    geom_model = open_object("./3D_model_data/stanford-bunny.obj")
    vertices, faces = geom_model
    vertices_tensor, faces_tensor = create_torch_mesh(vertices, faces)
    print("Model loaded successfully")
    criterion = nn.L1Loss()
    loss_history = []
    steps_til_summary = num_epoch/10
    for epoch in range(num_epoch):
        optimizer.zero_grad()
        # Dummy input and target for illustration purposes
        input_points = (torch.rand((30000, 3))-0.5)*0.2
        input_points = input_points.cuda()
        target_distances = get_signed_distance_from_contour(vertices_tensor.cuda(), faces_tensor.cuda(), input_points).unsqueeze(1)
        predicted_distances = model(input_points)
        loss = criterion(predicted_distances, target_distances)
        loss.backward()
        optimizer.step()
        if epoch % steps_til_summary == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
        loss_history.append(loss.item())
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        lr_scheduler.step()
    plt.plot(loss_history)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.title('Training Loss Over Time')
    plt.show()
    return model
def train_two_models(model1, model2,num_epoch=100):
    geom_model = open_object("stanford-bunny.obj")
    vertices, faces = geom_model
    vertices_tensor, faces_tensor = create_torch_mesh(vertices, faces)
    print("Model loaded successfully")
    criterion = nn.L1Loss()
    steps_til_summary = num_epoch/10
    for epoch in range(num_epoch):
        model1.optimizer.zero_grad()
        model2.optimizer.zero_grad()
        # Dummy input and target for illustration purposes
        input_points = (torch.rand((30000, 3))-0.5)*0.2
        input_points = input_points.cuda()
        target_distances = get_signed_distance_from_contour(vertices_tensor.cuda(), faces_tensor.cuda(), input_points).unsqueeze(1)
        predicted_distances1 = model(input_points)
        predicted_distances2 = model2(input_points)
        loss1 = criterion(predicted_distances1, target_distances)
        loss2 = criterion(predicted_distances2, target_distances)
        loss1.backward()
        loss2.backward()
        model1.optimizer.step()
        model2.optimizer.step()
        if epoch % steps_til_summary == 0:
            print(f'Epoch {epoch}, Loss1: {loss1.item()}, Loss2: {loss2.item()}')
        model1.loss_history.append(loss1.item())
        model2.loss_history.append(loss2.item())
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        model1.lr_scheduler.step()
        model2.lr_scheduler.step()
def plot2model_loss(model1,model2):
    plt.plot(model1.loss_history)
    plt.plot(model2.loss_history)
    plt.legend(['model1','model2'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.title('Training Loss Over Time')
    plt.show()
    



In [None]:
architecture = [3,256,256,256,256,1]
model = Siren(architecture, outermost_linear=True, first_omega_0=60, hidden_omega_0=60)
print(f"number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
model.optimizer = torch.optim.Adam(lr=1e-6, params=model.parameters())
model.lr_scheduler = torch.optim.lr_scheduler.StepLR(model.optimizer, step_size=100, gamma=0.7)

model2 = SIRELU(architecture=architecture,outermost_linear=True,first_omega_0=60,hidden_omega_0=60)
print(f"number of parameters: {sum(p.numel() for p in model2.parameters() if p.requires_grad)}")
model2.optimizer = torch.optim.Adam(lr=1e-6, params=model2.parameters())
model2.lr_scheduler = torch.optim.lr_scheduler.StepLR(model2.optimizer, step_size=100, gamma=0.7)
print(model2.net)
#

AssertionError: Torch not compiled with CUDA enabled