In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

%matplotlib notebook  

Main Skript für CGAN 
==============




Einleitung
============
In diesem Skript soll die großteilige Arbeit stattfinden um das Skrpit für die Masterarbeit zu erstellen.
Dabei wird vorerst das Tutorial für ein GAN von Pytorch als Vorlage genutzt. 


benötigte Eigenschaften des CGANs
===============================
Um den Code für die Masterarbeit nützlich zu machen, müssen einige Dinge angepasst werden:
* CGAN für Generierung aus spezifischen Formen
* 3D-Meshes als Input fürs Training

Des weiteren müssen Architektur und Gewichte mit der Zeit angepasst werden



In [2]:
import glob
import os
import sys
import numpy 

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F




from torch.utils.data import Dataset, DataLoader
import torch

from scipy.spatial import Delaunay
from collections import OrderedDict

from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes
from pytorch3d.io import load_obj, save_obj
from pytorch3d.datasets import collate_batched_meshes

import pymeshlab
from preprocess import find_neighbor
from layers import SpatialDescriptor, StructuralDescriptor, MeshConvolution
import open3d as o3d

import trimesh
from scipy.spatial import KDTree


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Data
====


In [4]:
class MeshDataset(torch.utils.data.Dataset):
    def __init__(self, root, max_faces, max_vertices=40):
        """
        Dataset class to preprocess and load meshes for training.
        :param root: Root directory of the dataset (train directory).
        :param max_faces: Maximum number of faces to process per mesh.
        """
        self.root = root
        self.max_faces = max_faces
        self.max_vertices = max_vertices

        # List all mesh files in the train directory
        self.mesh_files = [
            os.path.join(root, file)
            for file in os.listdir(root)
            if file.endswith('.obj') or file.endswith('.npz') or file.endswith('.off')
        ]

    def __getitem__(self, idx):
        path = self.mesh_files[idx]
        file_name = os.path.basename(path)  # Extract '17.obj'
        #file_number = int(os.path.splitext(file_name)[0])
        file_number = 0
        file_number = torch.tensor(file_number)
        
        

        if path.endswith('.stl'):
            vertices, faces = load_stl(path)
        
        elif path.endswith('.npz'):
            # Load preprocessed file
            data = numpy.load(path)
            face = data['faces']
            neighbor_index = data['neighbors']
        else:
            face, neighbor_index, vertices = self.process_mesh(path)
            if face is None:
                return self.__getitem__((idx + 1) % len(self.mesh_files))          
        num_point = len(face)
        num_vertices = vertices.shape[0]
       

        if num_point < self.max_faces or num_vertices < self.max_vertices:
            random_indices_face = numpy.random.randint(0, num_point, self.max_faces - num_point) if num_point < self.max_faces else []
            random_indices_vertices = torch.randint(0, num_vertices, (self.max_vertices - num_vertices,), device=vertices.device) if num_vertices < self.max_vertices else []
            

    # Face padding
            if num_point < self.max_faces:
                fill_face = face[random_indices_face]
                fill_neighbor_index = neighbor_index[random_indices_face]
                face = numpy.concatenate((face, fill_face))
                neighbor_index = numpy.concatenate((neighbor_index, fill_neighbor_index))

    # Vertex padding
            if num_vertices < self.max_vertices:
               random_vertices = vertices[random_indices_vertices]
               vertices = torch.tensor(vertices, dtype=torch.float32)
               random_vertices = torch.tensor(random_vertices, dtype=torch.float32)
               vertices_t = torch.cat([vertices, random_vertices], dim=0)
            else:
                vertices_t = torch.tensor(vertices, dtype=torch.float32)
        else:
             vertices_t = torch.tensor(vertices, dtype=torch.float32)

        #print( pad_size, vertices.size(), random_vertices.size(), vertices_t)
        # Convert to PyTorch tensors
        face = torch.from_numpy(face).float()
        neighbor_index = torch.from_numpy(neighbor_index).long()

        # Extract features
        face = face.permute(1, 0).contiguous()  # (features, num_faces)
        centers, corners, normals = face[:3], face[3:12], face[12:]
        corners = corners - torch.cat([centers, centers, centers], 0)  # Center corners around face centers

        
        curve_score = torch.tensor(comp_curve_score(normals, neighbor_index)).to(device)
        centers = centers.to(device)
        corners = corners.to(device)
        normals = normals.to(device)
        neighbor_index = neighbor_index.to(device)
        
        

        return centers, corners, normals, neighbor_index, vertices_t, file_number, curve_score

    def __len__(self):
        return len(self.mesh_files)

    def process_mesh(self, path):
        """
        Preprocess a single mesh file.
        :param path: Path to the mesh file.
        :return: Preprocessed face and neighbor_index arrays.
        """
        
        ms = pymeshlab.MeshSet()
        ms.load_new_mesh(path)
        mesh = ms.current_mesh()       
        vertices = mesh.vertex_matrix()
        faces = mesh.face_matrix()
        obj_verts = vertices

        if faces.shape[0] > self.max_faces:
            #print(f"Skipping mesh with more than {self.max_faces} faces: {path}")
            #print("num faces: ",faces.shape[0])
            return None, None, None

        # Normalize mesh
        center = (numpy.max(vertices, 0) + numpy.min(vertices, 0)) / 2
        vertices -= center
        max_len = numpy.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2)
        vertices /= numpy.sqrt(max_len)

        
        # Compute face normals
        ms.clear()
        ms.add_mesh(pymeshlab.Mesh(vertices, faces))
        face_normals = ms.current_mesh().face_normal_matrix()

        # Compute face centers and corners
        faces_contain_this_vertex = [set() for _ in range(len(vertices))]
        centers = []
        corners = []
        for i, face in enumerate(faces):
            v1, v2, v3 = face
            x1, y1, z1 = vertices[v1]
            x2, y2, z2 = vertices[v2]
            x3, y3, z3 = vertices[v3]
            centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3])
            corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3])
            faces_contain_this_vertex[v1].add(i)
            faces_contain_this_vertex[v2].add(i)
            faces_contain_this_vertex[v3].add(i)

        # Find neighbors
        neighbors = []
        for i, face in enumerate(faces):
            v1, v2, v3 = face
            n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i)
            n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i)
            n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i)
            neighbors.append([n1, n2, n3])

        # Convert to numpy arrays
        centers = numpy.array(centers)
        corners = numpy.array(corners)
        faces = numpy.concatenate([centers, corners, face_normals], axis=1)
        neighbors = numpy.array(neighbors)

        
        return faces, neighbors, obj_verts

def collate_fn1(batch, max_faces=1700):
    centers, corners, normals, neighbor_indices, vertices, file_number, curve_score = zip(*batch)

    # Stapeln der Batches zu Tensors
    centers = torch.stack(centers)
    corners = torch.stack(corners)
    normals = torch.stack(normals)
    neighbor_indices = torch.stack(neighbor_indices)
    vertices = torch.stack(vertices)
    file_numbers = torch.stack(file_number)
    curve_score = torch.stack(curve_score)
    

    return centers, corners, normals, neighbor_indices, vertices, file_numbers, curve_score

def parse_shapes_list(file_path):
    """
    Parse ShapesList.txt to create a dictionary mapping numbers to shape names.
    :param file_path: Path to ShapesList.txt
    :return: Dictionary {int: str}
    """
    shape_dict = {}
    with open(file_path, 'r') as file:
        for line in file:
            # Parse lines in the format "1 = ShapeName"
            if '=' in line:
                number, shape = line.split('=')
                shape_dict[int(number.strip())] = shape.strip()
    return shape_dict

def comp_curve_score(normals, neighbor_index):
    """
    Compute curvature scores for faces based on the angle between neighboring face normals.
    :param normals: Tensor of shape (num_faces, 3) - Normal vectors for each face
    :param neighbor_index: Tensor of shape (num_faces, k_neighbors) - Neighbor indices for each face
    :return: Tensor of shape (num_faces,) - Curvature scores for each face
    """
    # Gather neighboring normals
    normals = normals.permute(1,0)
    neighbor_normals = normals[neighbor_index]  # Shape: (num_faces, k_neighbors, 3)

    # Compute dot product between face normals and their neighbors
    dot_products = torch.sum(normals.unsqueeze(1) * neighbor_normals, dim=-1)  # (num_faces, k_neighbors)

    # Clamp dot products to [-1, 1] to avoid numerical issues with arccos
    dot_products = torch.clamp(dot_products, -1.0, 1.0)

    # Compute angles (in radians) between normals
    angles = torch.acos(dot_products)  # (num_faces, k_neighbors)

    # Average angle across neighbors as the curvature score
    scores = torch.mean(angles, dim=-1)  # (num_faces,)
    scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
    return scores



New Data Processing
====





In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mesh = load_objs_as_meshes(["7.obj"], device=device)
point_cloud = sample_points_from_meshes(mesh, 500).squeeze(0).cpu().numpy()

# Convert to Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point_cloud)

pcd.estimate_normals()

# Apply Ball Pivoting Algorithm (BPA)
radii = [0.005, 0.01, 0.02]  # Adjust radii for better reconstruction
mesh_bpa = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
    pcd, o3d.utility.DoubleVector(radii)
)

# Save and visualize
o3d.io.write_triangle_mesh("bpa_triangulated_mesh.obj", mesh_bpa)
o3d.visualization.draw_geometries([mesh_bpa])



Face-Generator
==============





In [5]:
def face_generator(point_clouds):
    """
    Konvertiere eine Batch von Punktwolken in eine Batch von PyTorch3D-Meshes.
    
    Args:
        point_clouds: Tensor der Form (batch_size, num_points, 3), die Punktwolken.
    
    Returns:
        meshes: PyTorch3D Meshes-Objekt mit dem Batch an Meshes.
    """
    #print(point_clouds.size())
   
    batch_size = point_clouds.size(0)
    verts_list = []  # Liste für Vertices (pro Mesh)
    faces_list = []  # Liste für Faces (pro Mesh)
    
    for i in range(batch_size):
        points = point_clouds[i].cpu().detach().numpy()  # Konvertiere zu NumPy für SciPy
        
        # Führe Delaunay-Triangulation durch
        tri = Delaunay(points)

        # Extrahiere Vertices und Faces
        verts = torch.tensor(tri.points, dtype=torch.float32).to(point_clouds.device)  # Shape: (num_vertices, 3)
        faces = torch.tensor(tri.convex_hull, dtype=torch.int64).to(point_clouds.device)  # Shape: (num_faces, 3)

        # Filtere ungültige Faces
        valid_faces = faces[(faces >= 0).all(dim=1)]  # Entferne Faces mit -1

        # Speichere die Vertices und gültigen Faces in den Listen
        verts_list.append(verts)
        faces_list.append(valid_faces)
        

    # Erstelle PyTorch3D-Meshes
    meshes = Meshes(verts=verts_list, faces=faces_list).to(point_clouds.device)
    return {"mesh": meshes, "verts": verts_list, "faces": faces_list}

Shape Extractor
=============

In [8]:
class ShapeExtractor(nn.Module):
    def __init__(self, cfg):
        super(ShapeExtractor, self).__init__()
        self.spatial_descriptor = SpatialDescriptor()
        self.structural_descriptor = StructuralDescriptor(cfg['structural_descriptor'])
        self.mesh_conv1 = MeshConvolution(cfg['mesh_convolution'], 64, 131, 96, 96)
        self.mesh_conv2 = MeshConvolution(cfg['mesh_convolution'], 96, 96, 96, 64)
        self.fusion_mlp = nn.Sequential(
            nn.Conv1d(160, 96, 1),
            nn.BatchNorm1d(96),
            nn.ReLU(),
        )
        self.concat_mlp = nn.Sequential(
            nn.Conv1d(96 + 96 + 96, 128, 1), 
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
    
    def forward(self, centers, corners, normals, neighbor_index, curvature_scores):
        spatial_fea0 = self.spatial_descriptor(centers)
        structural_fea0 = self.structural_descriptor(corners, normals, neighbor_index)

        spatial_fea1, structural_fea1 = self.mesh_conv1(spatial_fea0, structural_fea0, neighbor_index)
        spatial_fea2, structural_fea2 = self.mesh_conv2(spatial_fea1, structural_fea1, neighbor_index)
        spatial_fea3 = self.fusion_mlp(torch.cat([spatial_fea2, structural_fea2], 1))

        combined_fea = torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)  # b, c, n


        curvature_weight = curvature_scores.unsqueeze(1).expand_as(combined_fea)    
        combined_fea_with_scores = combined_fea 
        

        # Aggregate global features based on combined features
        aggregated_fea = self.concat_mlp(combined_fea_with_scores)  # (batch_size, 512, num_faces)
        global_fea = torch.max(aggregated_fea, dim=2)[0]  # Global feature (batch_size, 512)
        
        return global_fea

Generator
==============


In [9]:
class MeshGenerator(nn.Module):
    def __init__(self, cfg):
        super(MeshGenerator, self).__init__()
        self.shape_extractor = ShapeExtractor(cfg)
        self.vertex_decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, cfg['num_vertices'] * 3)  # Output shape: (batch_size, num_vertices * 3)
        )
    
    def forward(self, centers, corners, normals, neighbor_index, curvature_scores):
        global_fea = self.shape_extractor(centers, corners, normals, neighbor_index, curvature_scores)
        generated_vertices = self.vertex_decoder(global_fea)
        generated_vertices = generated_vertices.view(generated_vertices.size(0), -1, 3)  # Reshape to (batch_size, num_vertices, 3)
        generated_vertices = scale_to_unit(generated_vertices)
        
        return generated_vertices



def scale_to_unit(vertices):
    """
    Scale vertices to the range [-1, 1] while preserving the shape.
    :param vertices: Tensor of shape (batch_size, num_vertices, 3).
    :return: Scaled vertices in the range [-1, 1].
    """
    # Step 1: Center the vertices around the origin
    centroid = torch.mean(vertices, dim=1, keepdim=True)  # Compute the centroid (batch_size, 1, 3)
    centered_vertices = vertices - centroid               # Center the vertices (batch_size, num_vertices, 3)

    # Step 2: Find the maximum absolute value along any axis
    max_abs = torch.amax(torch.abs(centered_vertices), dim=(1, 2), keepdim=True)[0]  # (batch_size, 1, 1)

    # Step 3: Scale the vertices to fit in the range [-1, 1]
    scaled_vertices = centered_vertices / max_abs


    return scaled_vertices

Diskriminator
==============


In [10]:
class MeshDiscriminator(nn.Module):
    def __init__(self, cfg):
        super(MeshDiscriminator, self).__init__()
        
        # Convolutional layers to extract local features
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=32, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)

        # Batch normalization for stability
        self.bn1 = nn.BatchNorm1d(32)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(128)

        # Fully connected layers for classification
        self.fc1 = nn.Linear(128 * 64, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        # Expecting input shape: (batch_size, num_vertices, 3)
        
        x = x.permute(0, 2, 1)  # Change to (batch_size, 3, num_vertices) for Conv1d
    
        # Apply convolutional layers
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        

        # Apply adaptive pooling to ensure a fixed output length (64 in this case)
        x = F.adaptive_avg_pool1d(x, 64)
        

        # Flatten for fully connected layers
        x = x.view(x.shape[0], -1)
        

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Sigmoid for binary classification

        return x.squeeze()

Loss-Functions
==============

Alle speziellen Loss_Functions


In [11]:

def chamfer_distance1(x, y):
    x_exp = x.unsqueeze(2)  # (batch_size, num_points, 1, 3)
    y_exp = y.unsqueeze(1)  # (batch_size, 1, num_points, 3)
    dist = torch.norm(x_exp - y_exp, dim=-1)  # Pairwise distances
    min_dist_x, _ = torch.min(dist, dim=2)
    min_dist_y, _ = torch.min(dist, dim=1)
    return torch.mean(min_dist_x) + torch.mean(min_dist_y)

def compute_vertex_density(vertices, k=10):
    tree = KDTree(vertices)
    density_scores = numpy.zeros(len(vertices))

    for i, vertex in enumerate(vertices):
        dists, idx = tree.query(vertex, k=k+1)  # +1, weil der Punkt selbst enthalten ist
        density_scores[i] = k / (numpy.mean(dists[1:]) + 1e-8)  # Entfernung nutzen, aber den Punkt selbst ausschließen

    return density_scores

def curve_loss(generated_vertices, corners, curvature_scores):
    """
    Penalize the generator for not placing vertices in high-curvature regions.
    :param generated_vertices: Tensor of shape (batch_size, num_generated_vertices, 3)
    :param corners: Tensor of shape (batch_size, num_faces, 9) - Real corners
    :param curvature_scores: Tensor of shape (batch_size, num_faces) - Real curvature scores
    :return: Loss scalar
    """
    corners = corners.permute(0,2,1)
    # Compute distances between generated vertices and corners
    generated_vertices = generated_vertices.unsqueeze(2)  # (batch_size, num_generated_vertices, 1, 3)
    corners = corners.view(corners.size(0), corners.size(1), 3, 3)  # (batch_size, num_faces, 3, 3)
    distances = torch.norm(generated_vertices - corners, dim=-1)  # (batch_size, num_generated_vertices, num_faces, 3)

    # Weight distances by curvature scores
    weighted_distances = distances * curvature_scores.unsqueeze(1).unsqueeze(-1)  # (batch_size, num_generated_vertices, num_faces, 3)

    # Minimize the weighted distances
    return weighted_distances.mean()
    
def map_vertices_to_faces(generated_vertices, face_centers):
    """
    Findet für jedes generierte Vertex das nächstgelegene Face anhand seiner Mittelpunkte.

    generated_vertices: (N, 3) numpy array der generierten Vertex-Positionen
    face_centers: (M, 3) numpy array mit Mittelpunkten der echten Faces
    
    Return:
    vertex_face_map: (N,) numpy array mit dem Index des nächstgelegenen Faces für jedes generierte Vertex
    """
    face_centers = face_centers.T
    tree = KDTree(face_centers)  # KDTree für schnellen Nachbarschaftsvergleich
    _, closest_face_indices = tree.query(generated_vertices)  # Findet das nächstgelegene Face für jedes Vertex

    return closest_face_indices
    
def compute_face_vertex_density(generated_vertices, face_centers, num_faces):
    """
    Berechnet die Vertex-Dichte für jedes echte Face basierend auf den generierten Vertices.

    generated_vertices: (N, 3) numpy array mit generierten Vertex-Positionen
    face_centers: (M, 3) numpy array mit Mittelpunkten der echten Faces
    num_faces: Anzahl der echten Faces

    Return:
    face_density: (M,) numpy array mit der Anzahl der generierten Vertices pro Face
    """
    face_density = numpy.zeros(num_faces)

    # Bestimme, welches Face zu jedem generierten Vertex gehört
    closest_faces = map_vertices_to_faces(generated_vertices, face_centers)

    # Zähle, wie viele Vertices pro Face liegen
    for face_idx in closest_faces:
        face_density[face_idx] += 1
    
    # Normalisiere die Dichte auf einen Bereich von 0 bis 1
    face_density = (face_density - numpy.min(face_density)) / (numpy.max(face_density) - numpy.min(face_density) + 1e-8)

    
    return face_density

def curvature_density_loss(generated_vertices, face_centers, real_curvature_scores):
    """
    Bestraft falsche Vertex-Verteilungen basierend auf dem echten Curvature Score der Faces.

    generated_vertices: (B, N, 3) Tensor mit den generierten Vertex-Positionen
    face_centers: (B, M, 3) Tensor mit den Mittelpunkten der echten Faces
    real_curvature_scores: (B, M) Tensor mit echten Curvature Scores der Faces
    """


    batch_size = generated_vertices.shape[0]
    loss = 0.0

    for i in range(batch_size):
        # Konvertiere in NumPy für Berechnungen mit KDTree
        verts = generated_vertices[i].detach().cpu().numpy()
        centers = face_centers[i].detach().cpu().numpy()
        curv_scores = real_curvature_scores[i]

        # Berechne die Face-Dichte basierend auf den generierten Vertices
        face_vertex_density = compute_face_vertex_density(verts, centers, centers.shape[1])

        # Konvertiere in PyTorch Tensor
        face_vertex_density = torch.tensor(face_vertex_density, dtype=torch.float32, device=generated_vertices.device)

        # Fehler = Differenz zwischen generierter Vertex-Dichte und echtem Curvature Score
        density_diff = torch.abs(face_vertex_density - curv_scores)

        # Mean Squared Error als Bestrafung
        loss += torch.mean(density_diff ** 2)

    
    return loss / batch_size  # Durchschnitt über den Batch nehmen

Training
==============

In [21]:
latent_dim = 128   # Dimension des latenten Vektors
num_epochs = 1000# Anzahl der Trainings-Epochen
batch_size = 23 # Größe des Batches
N = 1024
lr_g = 1e-4
lr_d = 1e-5
feature_dim = 512
num_verts = 850
num_faces = 1700


# num_vertices sollte sich so wenig wie möglich von realen Objekten unterscheiden
#num_vertices = vetices in Sphere, für Testing
# Initialisierung von Generator und Diskriminator
cfg = {
        'structural_descriptor': {
            'num_kernel': 64,
            'sigma': 0.2
        },
        'mesh_convolution': {
            'aggregation_method': 'Max'
        },
        'num_vertices': num_verts # Number of vertices in the input mesh
    }
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = MeshGenerator(cfg)
discriminator = MeshDiscriminator(cfg)

generator = generator.to(device)
discriminator = discriminator.to(device)

# Optimizer für Generator und Diskriminator
optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))

def weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)



meshDataset = MeshDataset(root ="/newModelData/goodTopo", max_faces=num_faces, max_vertices=num_verts)
meshDataloader = DataLoader(dataset=meshDataset, batch_size=batch_size,shuffle="true", collate_fn=collate_fn1)
# Beispiel für reale Daten (Dummy)

shapes_dict = parse_shapes_list("ShapesList.txt")

print("Starting training...")


# Trainingsloop
for epoch in range(num_epochs):
    
    #for i, real_meshes in enumerate(dataloader):
    centers, corners, normals, neighbor_index, vertices, file_numbers, curve_score = next(iter(meshDataloader))    

    centers = centers.to(dtype=torch.float32)
    corners = corners.to(dtype=torch.float32)
    normals = normals.to(dtype=torch.float32)

    # -----------------
    # Trainieren des Diskriminators
    # -----------------
    optimizer_D.zero_grad()
   
 

    vertices = vertices.to(device) 
    real_pred = discriminator(vertices)
    real_labels = torch.ones_like((real_pred)).to(device)  # Label 1 für reale Daten
    real_loss = F.binary_cross_entropy(real_pred, real_labels).to(device)
    
    # Schritt 2: Generierte Daten durch den Diskriminator
    
    generated_verts = generator(centers, corners, normals, neighbor_index, curve_score)  # Generiertes Mesh
    generated_verts = generated_verts.to(device) 
    fake_pred = discriminator(generated_verts)
    fake_labels = torch.zeros_like(fake_pred).to(device)  # Label 0 für generierte Daten
    fake_loss = F.binary_cross_entropy(fake_pred, fake_labels).to(device)
    chamfer_loss = chamfer_distance1(generated_verts, vertices)
    chamfer_loss_d = fake_loss + chamfer_loss 

   # print(chamfer_loss_d, fake_loss)

    # Diskriminator-Loss berechnen und Schritt machen
    adv_d = real_loss + fake_loss
    d_loss = real_loss + fake_loss + chamfer_loss_d 
    d_loss.backward()
    optimizer_D.step()
    
    # -----------------
    # Trainieren des Generators
    # -----------------
    optimizer_G.zero_grad()

    generated_verts = generator(centers, corners, normals, neighbor_index, curve_score).to(device)   
    generated_meshes = face_generator(generated_verts)

    real_meshes_cd = face_generator(vertices)
     
    # Diskriminator-Output für das generierte Mesh
    generated_verts = generated_verts.to(device) 
    fake_pred = discriminator(generated_verts)
    
    adv_loss = F.binary_cross_entropy(fake_pred, torch.ones_like(fake_pred)).to(device)
    
    fake_point_clouds = sample_points_from_meshes(generated_meshes["mesh"], num_samples=1024).to(device)
    real_point_clouds = sample_points_from_meshes(real_meshes_cd["mesh"],num_samples=1024).to(device)

    curve_loss = curvature_density_loss(generated_verts, centers, curve_score)
    chamfer_loss = chamfer_distance1(generated_verts,vertices)  

    g_loss = adv_loss + chamfer_loss + curve_loss
 
    g_loss.backward()
    optimizer_G.step()


    
    print(f"Epoch [{epoch+1}/{num_epochs}]  Loss D: {d_loss.item():.4f}, Loss Adv_d: {adv_d.item():.4f} ------ Loss G: {g_loss.item():.4f}, Curve Loss: {curve_loss.item():.4f}, Chamfer Loss: {chamfer_loss.item():.4f}")

  state_dict = torch.load("MeshNet_best_9192.pkl")
  curve_score = torch.tensor(comp_curve_score(normals, neighbor_index)).to(device)


Starting training...
Epoch [1/1000]  Loss D: 2.5879, Loss Adv_d: 1.3779 ------ Loss G: 1.2780, Curve Loss: 0.1379, Chamfer Loss: 0.4722
Epoch [2/1000]  Loss D: 2.4757, Loss Adv_d: 1.3730 ------ Loss G: 1.1926, Curve Loss: 0.1343, Chamfer Loss: 0.3795
Epoch [3/1000]  Loss D: 2.4456, Loss Adv_d: 1.3636 ------ Loss G: 1.1954, Curve Loss: 0.1364, Chamfer Loss: 0.3706
Epoch [4/1000]  Loss D: 2.4262, Loss Adv_d: 1.3575 ------ Loss G: 1.2010, Curve Loss: 0.1352, Chamfer Loss: 0.3673
Epoch [5/1000]  Loss D: 2.3929, Loss Adv_d: 1.3512 ------ Loss G: 1.1924, Curve Loss: 0.1342, Chamfer Loss: 0.3500
Epoch [6/1000]  Loss D: 2.3736, Loss Adv_d: 1.3447 ------ Loss G: 1.2002, Curve Loss: 0.1371, Chamfer Loss: 0.3463
Epoch [7/1000]  Loss D: 2.3503, Loss Adv_d: 1.3392 ------ Loss G: 1.1982, Curve Loss: 0.1353, Chamfer Loss: 0.3376
Epoch [8/1000]  Loss D: 2.3446, Loss Adv_d: 1.3324 ------ Loss G: 1.2141, Curve Loss: 0.1339, Chamfer Loss: 0.3470
Epoch [9/1000]  Loss D: 2.3085, Loss Adv_d: 1.3265 ------ L

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (File Number: 0)
Exporting Point Cloud for Shape: None (Fil

In [None]:
print(vertices[1])