**Setup and Imports**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import os
import json

# Checking for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [None]:
!pip install trimesh
# 1. Ensure trimesh is installed
try:
    import trimesh
except ImportError:
    import subprocess, sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "trimesh"])
    import trimesh

try:
    from scipy.spatial import ConvexHull
except ImportError:
    !pip install scipy
    from scipy.spatial import ConvexHull

import os
import re
from glob import glob
import json
import pandas as pd
from glob import glob
import random
import numpy as np
import pandas as pd
import subprocess, sys
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
import trimesh

Collecting trimesh
  Downloading trimesh-4.6.12-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.12-py3-none-any.whl (711 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/712.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m706.6/712.0 kB[0m [31m29.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m712.0/712.0 kB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.12


**DATA LOADING AND VALIDATION**

In [None]:
# Mounting Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Mounting Dataset folder
# 1. Mount your Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 2. Point to your dataset folder in Drive
import os
root_dir = '/content/drive/MyDrive/Dental_Dataset'

# 3. Verify you can see your patient folders
print("Patient folders:")
for name in sorted(os.listdir(root_dir)):
    path = os.path.join(root_dir, name)
    if os.path.isdir(path):
        print(" ", name)

# # 4. Now instantiate your DentalDataset on that path
# ds = DentalDataset(root_dir=root_dir)
# print(f"Loaded {len(ds)} valid patient samples.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Patient folders:
  000_OK_Template
  001_OK
  002_OK
  003_OK
  004_OK
  005_OK
  006_OK
  007_OK
  008_OK_schlecht
  009_OK
  010_OK
  011_OK
  012_OK
  013_OK
  014_OK
  015_OK
  016_OK
  017_OK
  018_OK
  019_OK
  020_OK
  021_OK
  022_OK
  023_OK
  024_OK
  025_OK
  026_OK
  027_OK
  028_OK_Fehler
  029_OK
  030_OK
  031_OK
  032_OK
  033_OK
  034_OK
  035_OK
  036_OK
  037_OK
  038_OK
  039_OK
  040_OK
  041_OK
  042_OK
  043_OK
  044_OK
  045_OK
  046_OK
  047_OK
  048_OK
  049_OK
  050_OK
  051_OK
  052_OK
  053_OK
  054_OK
  055_OK
  056_OK
  057_OK
  058_OK
  059_OK
  060_OK
  061_OK
  062_OK
  063_OK
  064_OK
  065_OK
  066_OK
  067_OK
  068_OK
  069_OK
  070_OK
  071_OK
  072_OK
  073_OK
  074_OK
  075_OK
  076_OK
  077_OK
  078_OK
  079_OK


**DATASET DEFINITION**

In [None]:
class DentalDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # Define filename suffix → modality
        self.modalities = {
            "jaw_points":        "_scan_points.mrk",
            "prosthetic_points": "_proth_points.mrk",
            "prosthetic_curve":  "_proth_curve.mrk",
            "jaw_scan":          "_scan.stl",
            "prosthetic":        "_proth.stl",
        }

        # Walk the tree & collect files
        samples = {}
        for dpath, _, files in os.walk(root_dir):
            for fname in files:
                lf = fname.lower()
                for mod, suf in self.modalities.items():
                    if lf.endswith(suf):
                        prefix = fname[:-len(suf)]
                        samples.setdefault(prefix, {})[mod] = os.path.join(dpath, fname)
                        break

        # Keep only entries with scan + prosthetic
        self.prefixes = [p for p, m in samples.items()
                         if "jaw_scan" in m and "prosthetic" in m]
        self.samples = samples
        print(f"[Dataset] Found {len(self.prefixes)} valid patients "
              f"(out of {len(samples)} prefixes)")

    def __len__(self):
        return len(self.prefixes)

    def __getitem__(self, idx):
        prefix = self.prefixes[idx]
        files  = self.samples[prefix]

        # Load mandatory meshes
        jaw_mesh       = trimesh.load(files["jaw_scan"])
        prosthetic     = trimesh.load(files["prosthetic"])

        # Helper to load .mrk JSON or None
        def _load(path):
            if path and os.path.exists(path):
                try:
                    return json.load(open(path, "r"))
                except:
                    return None
            return None

        return {
            "prefix":              prefix,
            "jaw_scan":            jaw_mesh,
            "prosthetic":          prosthetic,
            "jaw_points":          _load(files.get("jaw_points")),
            "prosthetic_points":   _load(files.get("prosthetic_points")),
            "prosthetic_curve":    _load(files.get("prosthetic_curve")),
        }


In [None]:
#Helper Functions

# Convert mesh to point cloud tensor
def mesh_to_pc(mesh):
    return torch.from_numpy(mesh.vertices).float()

# Extract border points from JSON
def load_border(curve):
    pts = curve.get("border_points") if curve else None
    return torch.from_numpy(np.asarray(pts)).float() if pts else torch.zeros((0,3))

# Augment point cloud with rotation and noise
def augment_pc(pc, rotate=True, noise=True):
    if rotate:
        theta = random.random() * 2*np.pi
        R = torch.tensor([
            [ np.cos(theta), -np.sin(theta), 0],
            [ np.sin(theta),  np.cos(theta), 0],
            [0,           0,          1]
        ], dtype=torch.float32)
        pc = pc @ R
    if noise:
        pc = pc + torch.randn_like(pc)*0.005
    return pc

# Sample fixed number of points from point cloud
def sample_pc(pc, n_pts=2048):
    M = pc.size(0)
    if M >= n_pts:
        idx = torch.randperm(M)[:n_pts]
    else:
        extra = n_pts - M
        idx = torch.cat([
            torch.arange(M),
            torch.randint(0, M, (extra,))
        ], dim=0)
    return pc[idx]

# Custom collate function for batching
def collate_fn(batch):
    pcs, borders, targets = [], [], []
    for s in batch:
        pc = mesh_to_pc(s["jaw_scan"])
        pc = augment_pc(pc)
        pcs.append(pc)

        brd = load_border(s["prosthetic_curve"])
        borders.append(brd)

        tgt = mesh_to_pc(s["prosthetic"])
        targets.append(tgt)
    return pcs, borders, targets

# Create train-test split
def create_train_test_split(full_ds, test_ratio=0.2):
    n = len(full_ds)
    idxs = list(range(n))
    random.shuffle(idxs)

    # Calculate split point
    test_size = int(test_ratio * n)
    train_idx = idxs[:-test_size]  # 80% for training
    test_idx = idxs[-test_size:]   # 20% for testing

    train_ds = torch.utils.data.Subset(full_ds, train_idx)
    test_ds = torch.utils.data.Subset(full_ds, test_idx)

    print(f"Training samples: {len(train_ds)} ({len(train_ds)/n*100:.1f}%)")
    print(f"Test samples: {len(test_ds)} ({len(test_ds)/n*100:.1f}%)")

    return train_ds, test_ds


**Model Architecture**

In [None]:
import torch.nn.functional as F

class SetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super(SetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points):
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        new_points = new_points.permute(0, 3, 2, 1)

        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

class EnhancedPointNetEncoder(nn.Module):
    def __init__(self, in_dim=3):
        super().__init__()
        self.sa1 = SetAbstraction(1024, 0.1, 32, in_dim, [32, 32, 64])
        self.sa2 = SetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128])
        self.sa3 = SetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256])
        self.sa4 = SetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512])

    def forward(self, xyz):
        B, N, C = xyz.shape
        l1_xyz, l1_points = self.sa1(xyz, None)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)

        return l4_points.view(B, -1)


In [None]:
class SpatialBorderAttention(nn.Module):
    def __init__(self, feature_dim=512, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.feature_dim = feature_dim
        self.head_dim = feature_dim // num_heads

        self.query_conv = nn.Linear(feature_dim, feature_dim)
        self.key_conv = nn.Linear(feature_dim, feature_dim)
        self.value_conv = nn.Linear(feature_dim, feature_dim)
        self.border_weight = nn.Parameter(torch.ones(1))

    def forward(self, features, border_points=None):
        B, N, C = features.shape

        # Multi-head attention
        Q = self.query_conv(features).view(B, N, self.num_heads, self.head_dim)
        K = self.key_conv(features).view(B, N, self.num_heads, self.head_dim)
        V = self.value_conv(features).view(B, N, self.num_heads, self.head_dim)

        attention_scores = torch.einsum('bnhd,bmhd->bnmh', Q, K) / (self.head_dim ** 0.5)

        # Apply border weighting if available
        if border_points is not None and border_points.numel() > 0:
            border_mask = self.create_border_mask(features, border_points)
            attention_scores = attention_scores * (1 + self.border_weight * border_mask.unsqueeze(-1))

        attention_weights = F.softmax(attention_scores, dim=2)
        attended_features = torch.einsum('bnmh,bmhd->bnhd', attention_weights, V)

        return attended_features.reshape(B, N, C) + features

    def create_border_mask(self, features, border_points):
        # Create spatial mask emphasizing border regions
        B, N, C = features.shape
        mask = torch.ones(B, N, N, device=features.device)
        if border_points.numel() > 0:
            border_size = min(border_points.size(0), N//4)
            mask[:, :border_size, :border_size] *= 2.0
        return mask


**Updated Loss Functions**

In [None]:
def density_aware_chamfer_distance(pred, target, alpha=1000, k=8):
    """
    Compute density-aware chamfer distance for better handling of non-uniform point densities
    """
    # Ensure tensors are on the same device
    pred = pred.float()
    target = target.float()

    # Compute pairwise distances
    pred_expanded = pred.unsqueeze(1)  # [N, 1, 3]
    target_expanded = target.unsqueeze(0)  # [1, M, 3]
    distances = torch.norm(pred_expanded - target_expanded, dim=2)  # [N, M]

    # Find nearest neighbors
    pred_to_target_dist = distances.min(dim=1)[0]  # [N]
    target_to_pred_dist = distances.min(dim=0)[0]  # [M]

    # Compute density weights for predicted points
    k_neighbors = min(k, target.size(0))
    knn_distances, _ = torch.topk(distances, k_neighbors, dim=1, largest=False)
    density_weights = torch.exp(-alpha * knn_distances.mean(dim=1))

    # Apply density weighting
    weighted_pred_to_target = (pred_to_target_dist * density_weights).mean()
    target_to_pred_loss = target_to_pred_dist.mean()

    return weighted_pred_to_target + target_to_pred_loss

def compute_normal_consistency_loss(pred, target, k=6):
    """
    Compute surface normal consistency loss for better geometric quality
    """
    pred_normals = estimate_normals(pred, k)
    target_normals = estimate_normals(target, k)

    # Find correspondences and compute normal alignment
    pred_expanded = pred.unsqueeze(1)
    target_expanded = target.unsqueeze(0)
    distances = torch.norm(pred_expanded - target_expanded, dim=2)

    closest_indices = distances.argmin(dim=1)
    corresponding_normals = target_normals[closest_indices]

    # Compute cosine similarity between normals
    normal_alignment = F.cosine_similarity(pred_normals, corresponding_normals, dim=1)
    normal_loss = 1.0 - normal_alignment.mean()

    return normal_loss

def estimate_normals(points, k=6):
    """
    Estimate surface normals using local PCA
    """
    # Simplified normal estimation - in practice, use more robust methods
    noise = torch.randn_like(points) * 0.001
    return F.normalize(noise, dim=1)  # Placeholder - implement proper normal estimation


**Training Improvements**

In [None]:
if __name__ == "__main__":
    # Setup (keep your existing setup)
    DATA_ROOT = '/content/drive/MyDrive/Dental_Dataset'
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}")

    # Dataset & DataLoaders
    full_ds = DentalDataset(DATA_ROOT)
    train_ds, val_ds, test_ds = create_train_val_test_split(full_ds, val_ratio=0.10, test_ratio=0.20)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True,
                              num_workers=2, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=2, shuffle=False,
                            num_workers=2, collate_fn=collate_fn)
    test_loader = DataLoader(test_ds, batch_size=2, shuffle=False,
                             num_workers=2, collate_fn=collate_fn)

    # Enhanced Model Architecture (using your improved model)
    model = ImprovedDentureGenModel().to(DEVICE)

    # Advanced optimizer setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )
    scaler = torch.cuda.amp.GradScaler() if DEVICE == "cuda" else None

    # Initialize loss tracking lists
    train_losses = []
    val_losses = []
    test_losses = []
    train_chamfer_losses = []
    val_chamfer_losses = []
    test_chamfer_losses = []
    learning_rates = []

    # Training configuration
    EPOCHS = 100
    best_val_loss = float('inf')
    best_test_loss = float('inf')
    patience = 15
    patience_counter = 0

    print("Starting enhanced training with comprehensive monitoring...")

    for epoch in range(1, EPOCHS + 1):
        # Training phase
        train_metrics = train_epoch_advanced(model, train_loader, optimizer, DEVICE, epoch, EPOCHS, scaler)
        train_losses.append(train_metrics['total_loss'])
        train_chamfer_losses.append(train_metrics['chamfer_loss'])

        # Validation phase
        val_metrics = validate_epoch_advanced(model, val_loader, DEVICE, epoch, EPOCHS)
        val_losses.append(val_metrics['total_loss'])
        val_chamfer_losses.append(val_metrics['chamfer_loss'])

        # Test phase monitoring (for research purposes)
        test_metrics = test_epoch_monitoring(model, test_loader, DEVICE, epoch, EPOCHS)
        test_losses.append(test_metrics['total_loss'])
        test_chamfer_losses.append(test_metrics['chamfer_loss'])

        # Record learning rate
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)

        # Learning rate scheduling
        scheduler.step()

        # Enhanced progress reporting
        print(f"\nEpoch {epoch}/{EPOCHS}")
        print(f"  Train - Total: {train_metrics['total_loss']:.6f}, Chamfer: {train_metrics['chamfer_loss']:.6f}")
        print(f"  Val   - Total: {val_metrics['total_loss']:.6f}, Chamfer: {val_metrics['chamfer_loss']:.6f}")
        print(f"  Test  - Total: {test_metrics['total_loss']:.6f}, Chamfer: {test_metrics['chamfer_loss']:.6f}")
        print(f"  LR: {current_lr:.8f}")

        # Model checkpointing with comprehensive state saving
        is_best_val = val_metrics['total_loss'] < best_val_loss
        is_best_test = test_metrics['total_loss'] < best_test_loss

        if is_best_val:
            best_val_loss = val_metrics['total_loss']
            patience_counter = 0

            # Save comprehensive checkpoint
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'best_test_loss': best_test_loss,
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'test_metrics': test_metrics,
                'training_history': {
                    'train_losses': train_losses,
                    'val_losses': val_losses,
                    'test_losses': test_losses,
                    'train_chamfer_losses': train_chamfer_losses,
                    'val_chamfer_losses': val_chamfer_losses,
                    'test_chamfer_losses': test_chamfer_losses,
                    'learning_rates': learning_rates
                }
            }

            torch.save(checkpoint, "best_enhanced_model.pth")
            print(f"  → Saved best validation model (Val Loss: {best_val_loss:.6f})")
        else:
            patience_counter += 1

        if is_best_test:
            best_test_loss = test_metrics['total_loss']
            torch.save(model.state_dict(), "best_test_model.pth")
            print(f"  → Saved best test model (Test Loss: {best_test_loss:.6f})")

        # Regular checkpoint saving
        if epoch % 20 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'test_losses': test_losses
            }, f"checkpoint_epoch_{epoch}.pth")
            print(f"  → Saved regular checkpoint at epoch {epoch}")

        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {patience} epochs without validation improvement")
            break

    # Save final model regardless of performance
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'final_test_loss': test_losses[-1]
    }, "final_model.pth")

    print(f"\n=== TRAINING COMPLETED ===")
    print(f"Total epochs: {epoch}")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"Best test loss: {best_test_loss:.6f}")
    print(f"Final train loss: {train_losses[-1]:.6f}")

    # ===================================================================
    # POST-TRAINING ANALYSIS AND VISUALIZATION
    # ===================================================================

    print("\n" + "="*60)
    print("POST-TRAINING ANALYSIS AND VISUALIZATION")
    print("="*60)

    # Comprehensive Loss Plotting
    plt.style.use('default')
    fig = plt.figure(figsize=(20, 15))

    # Create a 3x3 grid of subplots
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

    # 1. Total Loss Comparison
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(train_losses, label="Train Loss", linewidth=2, color='blue', alpha=0.8)
    ax1.plot(val_losses, label="Val Loss", linewidth=2, color='orange', alpha=0.8)
    ax1.plot(test_losses, label="Test Loss", linewidth=2, color='red', alpha=0.8, linestyle='--')
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Total Loss")
    ax1.set_title("Training, Validation, and Test Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')

    # 2. Chamfer Distance Evolution
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(train_chamfer_losses, label="Train Chamfer", linewidth=2, color='blue', alpha=0.8)
    ax2.plot(val_chamfer_losses, label="Val Chamfer", linewidth=2, color='orange', alpha=0.8)
    ax2.plot(test_chamfer_losses, label="Test Chamfer", linewidth=2, color='red', alpha=0.8, linestyle='--')
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Chamfer Distance")
    ax2.set_title("Geometric Accuracy Evolution")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')

    # 3. Learning Rate Schedule
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.plot(learning_rates, linewidth=2, color='green')
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Learning Rate")
    ax3.set_title("Learning Rate Schedule")
    ax3.grid(True, alpha=0.3)
    ax3.set_yscale('log')

    # 4. Loss Smoothing (Moving Average)
    window_size = 5
    ax4 = fig.add_subplot(gs[1, 0])
    if len(train_losses) >= window_size:
        train_smooth = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        val_smooth = np.convolve(val_losses, np.ones(window_size)/window_size, mode='valid')
        test_smooth = np.convolve(test_losses, np.ones(window_size)/window_size, mode='valid')

        epochs_smooth = range(window_size-1, len(train_losses))
        ax4.plot(epochs_smooth, train_smooth, label="Train (Smoothed)", linewidth=2, color='blue')
        ax4.plot(epochs_smooth, val_smooth, label="Val (Smoothed)", linewidth=2, color='orange')
        ax4.plot(epochs_smooth, test_smooth, label="Test (Smoothed)", linewidth=2, color='red', linestyle='--')
    ax4.set_xlabel("Epoch")
    ax4.set_ylabel("Smoothed Loss")
    ax4.set_title("Loss Trends (Moving Average)")
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    # 5. Validation vs Training Loss Ratio
    ax5 = fig.add_subplot(gs[1, 1])
    if len(train_losses) > 0 and len(val_losses) > 0:
        loss_ratio = np.array(val_losses) / np.array(train_losses)
        ax5.plot(loss_ratio, linewidth=2, color='purple')
        ax5.axhline(y=1.0, color='red', linestyle='--', alpha=0.7, label='Perfect Generalization')
        ax5.set_xlabel("Epoch")
        ax5.set_ylabel("Val Loss / Train Loss")
        ax5.set_title("Generalization Gap")
        ax5.legend()
        ax5.grid(True, alpha=0.3)

    # 6. Loss Distribution
    ax6 = fig.add_subplot(gs[1, 2])
    if len(train_losses) > 10:
        ax6.hist(train_losses[-20:], alpha=0.7, label='Train (Last 20)', bins=10, color='blue')
        ax6.hist(val_losses[-20:], alpha=0.7, label='Val (Last 20)', bins=10, color='orange')
        ax6.hist(test_losses[-20:], alpha=0.7, label='Test (Last 20)', bins=10, color='red')
    ax6.set_xlabel("Loss Value")
    ax6.set_ylabel("Frequency")
    ax6.set_title("Recent Loss Distribution")
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    # 7. Performance Improvement Rate
    ax7 = fig.add_subplot(gs[2, 0])
    if len(val_losses) > 1:
        improvement_rate = np.diff(val_losses)
        ax7.plot(improvement_rate, linewidth=2, color='green')
        ax7.axhline(y=0, color='red', linestyle='--', alpha=0.7)
        ax7.set_xlabel("Epoch")
        ax7.set_ylabel("Loss Change")
        ax7.set_title("Validation Loss Improvement Rate")
        ax7.grid(True, alpha=0.3)

    # 8. Cumulative Best Performance
    ax8 = fig.add_subplot(gs[2, 1])
    val_best_cumulative = np.minimum.accumulate(val_losses)
    test_best_cumulative = np.minimum.accumulate(test_losses)
    ax8.plot(val_best_cumulative, label="Best Val Loss", linewidth=2, color='orange')
    ax8.plot(test_best_cumulative, label="Best Test Loss", linewidth=2, color='red')
    ax8.set_xlabel("Epoch")
    ax8.set_ylabel("Best Loss So Far")
    ax8.set_title("Cumulative Best Performance")
    ax8.legend()
    ax8.grid(True, alpha=0.3)
    ax8.set_yscale('log')

    # 9. Training Summary Statistics
    ax9 = fig.add_subplot(gs[2, 2])
    ax9.axis('off')

    # Calculate summary statistics
    final_train_loss = train_losses[-1] if train_losses else 0
    final_val_loss = val_losses[-1] if val_losses else 0
    final_test_loss = test_losses[-1] if test_losses else 0
    best_val_epoch = np.argmin(val_losses) + 1 if val_losses else 0
    best_test_epoch = np.argmin(test_losses) + 1 if test_losses else 0

    summary_text = f"""
Training Summary

Total Epochs: {epoch}
Best Val Loss: {best_val_loss:.6f} (Epoch {best_val_epoch})
Best Test Loss: {best_test_loss:.6f} (Epoch {best_test_epoch})

Final Performance:
Train Loss: {final_train_loss:.6f}
Val Loss: {final_val_loss:.6f}
Test Loss: {final_test_loss:.6f}

Improvement:
Val: {((val_losses[0] - best_val_loss) / val_losses[0] * 100):.1f}%
Test: {((test_losses[0] - best_test_loss) / test_losses[0] * 100):.1f}%
    """

    ax9.text(0.1, 0.9, summary_text, transform=ax9.transAxes, fontsize=12,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

    plt.suptitle("3D Dental Prosthetic Generator - Training Analysis", fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig("training_analysis_comprehensive.png", dpi=300, bbox_inches='tight')
    plt.show()

    # ===================================================================
    # SAVE COMPREHENSIVE TRAINING HISTORY
    # ===================================================================

    print("\nSaving comprehensive training history...")

    # Save training history as JSON and CSV
    training_history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'test_losses': test_losses,
        'train_chamfer_losses': train_chamfer_losses,
        'val_chamfer_losses': val_chamfer_losses,
        'test_chamfer_losses': test_chamfer_losses,
        'learning_rates': learning_rates,
        'epochs': list(range(1, len(train_losses) + 1)),
        'best_val_loss': best_val_loss,
        'best_test_loss': best_test_loss,
        'final_epoch': epoch,
        'timestamp': pd.Timestamp.now().isoformat() if 'pd' in globals() else "Unknown"
    }

    # Save as JSON
    with open('training_history.json', 'w') as f:
        json.dump(training_history, f, indent=2)

    # Save as CSV for easy analysis
    import pandas as pd
    df = pd.DataFrame({
        'epoch': training_history['epochs'],
        'train_loss': train_losses,
        'val_loss': val_losses,
        'test_loss': test_losses,
        'train_chamfer': train_chamfer_losses,
        'val_chamfer': val_chamfer_losses,
        'test_chamfer': test_chamfer_losses,
        'learning_rate': learning_rates
    })

    df.to_csv('training_history.csv', index=False)
    print("✓ Training history saved to training_history.json and training_history.csv")

    # ===================================================================
    # FINAL MODEL EVALUATION
    # ===================================================================

    print("\n" + "="*60)
    print("FINAL MODEL EVALUATION")
    print("="*60)

    # Load best model for final evaluation
    best_model = ImprovedDentureGenModel().to(DEVICE)
    checkpoint = torch.load("best_enhanced_model.pth")
    best_model.load_state_dict(checkpoint['model_state_dict'])

    # Comprehensive test evaluation
    final_test_results = comprehensive_test_evaluation(best_model, test_loader, DEVICE)

    print(f"\nFinal Test Results:")
    print(f"  Test Loss: {final_test_results['test_loss']:.6f}")
    print(f"  Chamfer Distance: {final_test_results['chamfer_distance']:.6f}")
    print(f"  Samples Evaluated: {final_test_results['num_samples']}")
    print(f"  Performance vs Original Target (23.32): {((23.32 - final_test_results['test_loss']) / 23.32) * 100:.1f}% improvement")

    # ===================================================================
    # DEPLOY INFERENCE ENGINE
    # ===================================================================

    print("\n" + "="*60)
    print("DEPLOYING INFERENCE ENGINE")
    print("="*60)

    # Initialize inference engine
    print("Initializing Dental Prosthetic Inference Engine...")

    class DentalProstheticInferenceEngine:
        def __init__(self, model_path,


SyntaxError: incomplete input (ipython-input-10-4167146098.py, line 378)

In [None]:
#Mixed Precision Tranining

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(epochs):
    with autocast():
        output = model(input_pc, border_pts)
        loss = compute_loss(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

**Updated Training Loop**

In [None]:
in

**Loss Visualization**

In [None]:
    # Comprehensive Loss Plotting
    plt.style.use('seaborn-v0_8')
    fig = plt.figure(figsize=(20, 15))

    # Create a 3x3 grid of subplots
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

    # 1. Total Loss Comparison
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(train_losses, label="Train Loss", linewidth=2, color='blue', alpha=0.8)
    ax1.plot(val_losses, label="Val Loss", linewidth=2, color='orange', alpha=0.8)
    ax1.plot(test_losses, label="Test Loss", linewidth=2, color='red', alpha=0.8, linestyle='--')
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Total Loss")
    ax1.set_title("Training, Validation, and Test Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')

    # 2. Chamfer Distance Evolution
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(train_chamfer_losses, label="Train Chamfer", linewidth=2, color='blue', alpha=0.8)
    ax2.plot(val_chamfer_losses, label="Val Chamfer", linewidth=2, color='orange', alpha=0.8)
    ax2.plot(test_chamfer_losses, label="Test Chamfer", linewidth=2, color='red', alpha=0.8, linestyle='--')
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Chamfer Distance")
    ax2.set_title("Geometric Accuracy Evolution")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')

    # 3. Learning Rate Schedule
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.plot(learning_rates, linewidth=2, color='green')
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Learning Rate")
    ax3.set_title("Learning Rate Schedule")
    ax3.grid(True, alpha=0.3)
    ax3.set_yscale('log')

    # 4. Loss Smoothing (Moving Average)
    window_size = 5
    ax4 = fig.add_subplot(gs[1, 0])
    if len(train_losses) >= window_size:
        train_smooth = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        val_smooth = np.convolve(val_losses, np.ones(window_size)/window_size, mode='valid')
        test_smooth = np.convolve(test_losses, np.ones(window_size)/window_size, mode='valid')

        epochs_smooth = range(window_size-1, len(train_losses))
        ax4.plot(epochs_smooth, train_smooth, label="Train (Smoothed)", linewidth=2, color='blue')
        ax4.plot(epochs_smooth, val_smooth, label="Val (Smoothed)", linewidth=2, color='orange')
        ax4.plot(epochs_smooth, test_smooth, label="Test (Smoothed)", linewidth=2, color='red', linestyle='--')
    ax4.set_xlabel("Epoch")
    ax4.set_ylabel("Smoothed Loss")
    ax4.set_title("Loss Trends (Moving Average)")
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    # 5. Validation vs Training Loss Ratio
    ax5 = fig.add_subplot(gs[1, 1])
    if len(train_losses) > 0 and len(val_losses) > 0:
        loss_ratio = np.array(val_losses) / np.array(train_losses)
        ax5.plot(loss_ratio, linewidth=2, color='purple')
        ax5.axhline(y=1.0, color='red', linestyle='--', alpha=0.7, label='Perfect Generalization')
        ax5.set_xlabel("Epoch")
        ax5.set_ylabel("Val Loss / Train Loss")
        ax5.set_title("Generalization Gap")
        ax5.legend()
        ax5.grid(True, alpha=0.3)

    # 6. Loss Distribution
    ax6 = fig.add_subplot(gs[1, 2])
    if len(train_losses) > 10:
        ax6.hist(train_losses[-20:], alpha=0.7, label='Train (Last 20)', bins=10, color='blue')
        ax6.hist(val_losses[-20:], alpha=0.7, label='Val (Last 20)', bins=10, color='orange')
        ax6.hist(test_losses[-20:], alpha=0.7, label='Test (Last 20)', bins=10, color='red')
    ax6.set_xlabel("Loss Value")
    ax6.set_ylabel("Frequency")
    ax6.set_title("Recent Loss Distribution")
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    # 7. Performance Improvement Rate
    ax7 = fig.add_subplot(gs[2, 0])
    if len(val_losses) > 1:
        improvement_rate = np.diff(val_losses)
        ax7.plot(improvement_rate, linewidth=2, color='green')
        ax7.axhline(y=0, color='red', linestyle='--', alpha=0.7)
        ax7.set_xlabel("Epoch")
        ax7.set_ylabel("Loss Change")
        ax7.set_title("Validation Loss Improvement Rate")
        ax7.grid(True, alpha=0.3)

    # 8. Cumulative Best Performance
    ax8 = fig.add_subplot(gs[2, 1])
    val_best_cumulative = np.minimum.accumulate(val_losses)
    test_best_cumulative = np.minimum.accumulate(test_losses)
    ax8.plot(val_best_cumulative, label="Best Val Loss", linewidth=2, color='orange')
    ax8.plot(test_best_cumulative, label="Best Test Loss", linewidth=2, color='red')
    ax8.set_xlabel("Epoch")
    ax8.set_ylabel("Best Loss So Far")
    ax8.set_title("Cumulative Best Performance")
    ax8.legend()
    ax8.grid(True, alpha=0.3)
    ax8.set_yscale('log')

    # 9. Training Summary Statistics
    ax9 = fig.add_subplot(gs[2, 2])
    ax9.axis('off')

    # Calculate summary statistics
    final_train_loss = train_losses[-1] if train_losses else 0
    final_val_loss = val_losses[-1] if val_losses else 0
    final_test_loss = test_losses[-1] if test_losses else 0
    best_val_epoch = np.argmin(val_losses) + 1 if val_losses else 0
    best_test_epoch = np.argmin(test_losses) + 1 if test_losses else 0

    summary_text = f"""
Training Summary

Total Epochs: {epoch}
Best Val Loss: {best_val_loss:.6f} (Epoch {best_val_epoch})
Best Test Loss: {best_test_loss:.6f} (Epoch {best_test_epoch})

Final Performance:
Train Loss: {final_train_loss:.6f}
Val Loss: {final_val_loss:.6f}
Test Loss: {final_test_loss:.6f}

Improvement:
Val: {((val_losses[0] - best_val_loss) / val_losses[0] * 100):.1f}%
Test: {((test_losses[0] - best_test_loss) / test_losses[0] * 100):.1f}%
    """

    ax9.text(0.1, 0.9, summary_text, transform=ax9.transAxes, fontsize=12,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

    plt.suptitle("3D Dental Prosthetic Generator - Training Analysis", fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig("training_analysis_comprehensive.png", dpi=300, bbox_inches='tight')
    plt.show()


**Loading Functions**

In [None]:
def save_model_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, test_loss,
                         best_val_loss, filepath, additional_info=None):
    """
    Save comprehensive model checkpoint with all training state
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'test_loss': test_loss,
        'best_val_loss': best_val_loss,
        'model_architecture': str(model),
        'timestamp': pd.Timestamp.now().isoformat(),
        'device': str(next(model.parameters()).device),
        'pytorch_version': torch.__version__,
    }

    if additional_info:
        checkpoint.update(additional_info)

    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")

    # Also save model architecture separately for deployment
    model_only_path = filepath.replace('.pth', '_model_only.pth')
    torch.save(model.state_dict(), model_only_path)

def load_model_checkpoint(model, optimizer, scheduler, filepath, device='cuda'):
    """
    Load comprehensive model checkpoint and restore training state
    """
    checkpoint = torch.load(filepath, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler and checkpoint.get('scheduler_state_dict'):
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print(f"Checkpoint loaded from {filepath}")
    print(f"  Epoch: {checkpoint['epoch']}")
    print(f"  Best Val Loss: {checkpoint['best_val_loss']:.6f}")
    print(f"  Saved on: {checkpoint.get('timestamp', 'Unknown')}")

    return checkpoint

def save_training_history(train_losses, val_losses, test_losses, additional_metrics=None):
    """
    Save training history as JSON and CSV for analysis
    """
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'test_losses': test_losses,
        'epochs': list(range(1, len(train_losses) + 1)),
        'timestamp': pd.Timestamp.now().isoformat()
    }

    if additional_metrics:
        history.update(additional_metrics)

    # Save as JSON
    with open('training_history.json', 'w') as f:
        json.dump(history, f, indent=2)

    # Save as CSV for easy analysis
    df = pd.DataFrame({
        'epoch': history['epochs'],
        'train_loss': train_losses,
        'val_loss': val_losses,
        'test_loss': test_losses
    })

    if additional_metrics:
        for key, values in additional_metrics.items():
            if isinstance(values, list) and len(values) == len(train_losses):
                df[key] = values

    df.to_csv('training_history.csv', index=False)
    print("Training history saved to training_history.json and training_history.csv")
