# PointNet Training & Verification with α,β-CROWN

## LiDAR Point Cloud Classification + Formal Verification

This notebook:
1. **Loads** raw LiDAR frames (~14M points) from your repository
2. **Loads NSGA-III Pareto genomes** for vulnerability-based labeling
3. **Runs QUICK SANITY CHECK** before full training (5 epochs + 3 verification tests)
4. **Trains** PointNet with data augmentation on GPU (50 epochs)
5. **Verifies** robustness properties using α,β-CROWN API

### Key Innovation: NSGA-III Vulnerability Labeling
Labels are computed using **MAX vulnerability across all Pareto-optimal genomes**:
- 5 genomes from NSGA-III optimization (different attack strategies)
- A region is CRITICAL if vulnerable to ANY of these attacks
- Threshold derived from data (median vulnerability)
- Direct link between "what breaks SLAM" and "what PointNet should detect"

### Quick Sanity Check (NEW!)
Before wasting 30+ minutes on full training, we run a quick check:
- Train 5 epochs on 2000 samples (~3 min)
- Check accuracy >= 55%
- Test α,β-CROWN verification on 3 samples
- **FAIL FAST** if something is wrong!

### Architecture: Original PointNet (Qi et al., CVPR 2017)
- Input T-Net (3x3) for spatial alignment
- Feature T-Net (64x64) for feature alignment  
- Point-wise MLP: 3→64→64→64→128→1024 (5 conv layers with BatchNorm)
- Global max pooling
- Classifier MLP: 1024→512→256→2 (with BatchNorm + Dropout)
- **~3.5M parameters**

### Input Format: (N, 1024, 3) - xyz only!
- xyz coordinates (3 channels)
- **Note**: Geometric features (linearity, curvature, density_var, planarity) 
  are used for **labeling only**, NOT as input to the model.
  This ensures the model must learn geometry from raw coordinates!

### Label Computation (NSGA-III vulnerability):
```python
max_vulnerability = max([vuln(genome) for genome in pareto_set])
label = CRITICAL if max_vulnerability >= threshold else NON_CRITICAL
```

### Properties Verified:

**Property 1: Local Robustness (L∞)**
```
∀x' with ||x' - x₀||∞ ≤ ε : f(x') = f(x₀)
```

**Property 2: Safety Property**
```
∀x' with ||x' - x₀||∞ ≤ ε ∧ f(x₀)=CRITICAL : f(x') ≠ NON_CRITICAL
```

## 1. Setup and Installation

In [11]:
# Check GPU
!nvidia-smi

Fri Jan  2 17:52:23 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   40C    P8              9W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# Install dependencies
!pip install torch numpy onnx onnxruntime pyyaml packaging appdirs sortedcontainers path.py -q

# Clone alpha-beta-CROWN repository (for complete_verifier API)
!git clone https://github.com/Verified-Intelligence/alpha-beta-CROWN.git 2>/dev/null || true

# Install auto_LiRPA directly from GitHub
!pip install git+https://github.com/Verified-Intelligence/auto_LiRPA.git --no-deps -q

print("Dependencies installed!")

In [None]:
# Setup paths and imports
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from datetime import datetime
from torch.utils.data import TensorDataset, DataLoader

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Device: {device}")

In [None]:
# Import auto_LiRPA for verification
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm
print("auto_LiRPA imported successfully!")

In [None]:
# Set random seeds for reproducibility
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [None]:
# Configuration
N_POINTS = 1024        # Points per sample (original PointNet)
IN_CHANNELS = 3        # xyz only! Model must learn geometry from raw coordinates
NUM_CLASSES = 2        # CRITICAL vs NON_CRITICAL
INPUT_DIM = N_POINTS * IN_CHANNELS  # 3072

# Training config
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 0.001

# Verification config
EPSILONS = [0.001, 0.003, 0.005, 0.007, 0.01]  # L-inf perturbation budgets
N_VERIFY_SAMPLES = 20  # Samples to verify

print(f"Configuration:")
print(f"  Points per sample: {N_POINTS}")
print(f"  Input channels: {IN_CHANNELS} (xyz only - features used for labeling only)")
print(f"  Input dimension: {INPUT_DIM}")
print(f"  Classes: {NUM_CLASSES}")

## 2. Load Data from GitHub

Data files are stored in the repository using Git LFS in `data/pointnet/`.

In [None]:
# Clone repository and fetch Git LFS data
import os
import shutil

REPO_URL = "https://github.com/francescacraievich/mola-pointnet-verification.git"
REPO_DIR = "/content/mola-pointnet-verification"
RAW_DATA_PATH = f"{REPO_DIR}/data/raw"

# Force re-clone to get latest data (remove old clone if exists)
if os.path.exists(REPO_DIR):
    print(f"Removing old clone at {REPO_DIR}...")
    shutil.rmtree(REPO_DIR)

# Install and setup Git LFS
print("Setting up Git LFS...")
!git lfs install

# Clone with LFS
print("\nCloning repository...")
!git clone {REPO_URL} {REPO_DIR}

# Pull LFS files explicitly
print("\nFetching LFS files...")
%cd {REPO_DIR}
!git lfs pull
%cd /content

print("\nDone!")

# Verify raw data files exist
print("\nChecking data files:")
raw_files = ['frame_sequence.npy', 'frame_sequence.timestamps.npy']
for f in raw_files:
    path = os.path.join(RAW_DATA_PATH, f)
    if os.path.exists(path):
        size = os.path.getsize(path) / 1e6
        print(f"  ✓ {f}: {size:.1f} MB")
    else:
        print(f"  ✗ {f}: NOT FOUND")
        
# Also check if files are LFS pointers (small size = pointer, not actual data)
frame_path = os.path.join(RAW_DATA_PATH, 'frame_sequence.npy')
if os.path.exists(frame_path):
    size = os.path.getsize(frame_path)
    if size < 1000:  # Less than 1KB = probably a pointer
        print(f"\n⚠ WARNING: frame_sequence.npy is only {size} bytes - likely a Git LFS pointer!")
        print("   Running 'git lfs pull' again...")
        %cd {REPO_DIR}
        !git lfs fetch --all
        !git lfs checkout
        %cd /content
        # Check again
        size = os.path.getsize(frame_path)
        print(f"   After LFS fetch: {size / 1e6:.1f} MB")

In [None]:
# Load raw frames and compute features ONCE (cached)
from scipy.spatial import cKDTree

def compute_local_features(points, k=15):
    """
    Compute geometric features for each point in the cloud.
    
    Returns:
        linearity: Edge/line feature strength
        curvature: Surface curvature  
        density_var: Local density variation (scanline vulnerability)
        planarity: How planar the local neighborhood is
    """
    n = len(points)
    xyz = points[:, :3]
    
    # Subsample for speed if too large
    max_points = 50000
    if n > max_points:
        sample_idx = np.random.choice(n, max_points, replace=False)
        xyz_sample = xyz[sample_idx]
    else:
        sample_idx = np.arange(n)
        xyz_sample = xyz
    
    tree = cKDTree(xyz_sample)
    distances, neighbors_idx = tree.query(xyz_sample, k=min(k + 1, len(xyz_sample)))
    
    linearity = np.zeros(len(xyz_sample))
    curvature = np.zeros(len(xyz_sample))
    planarity = np.zeros(len(xyz_sample))
    
    # Compute density variation
    mean_dist = distances[:, 1:].mean(axis=1)
    std_dist = distances[:, 1:].std(axis=1)
    density_var = std_dist / (mean_dist + 1e-10)
    
    # Compute eigenvalue-based features
    for i in range(len(xyz_sample)):
        neighbors = xyz_sample[neighbors_idx[i]]
        centered = neighbors - neighbors.mean(axis=0)
        
        if len(centered) >= 3:
            cov = np.cov(centered.T)
            try:
                eigvals = np.sort(np.linalg.eigvalsh(cov))[::-1]
                total = eigvals.sum() + 1e-10
                linearity[i] = (eigvals[0] - eigvals[1]) / (eigvals[0] + 1e-10)
                curvature[i] = eigvals[2] / total
                planarity[i] = (eigvals[1] - eigvals[2]) / (eigvals[0] + 1e-10)
            except:
                pass
    
    # Map back to full point cloud if subsampled
    if n > max_points:
        _, nearest = tree.query(xyz, k=1)
        return linearity[nearest], curvature[nearest], density_var[nearest], planarity[nearest]
    
    return linearity, curvature, density_var, planarity


# Load frames
print("Loading raw frame sequence...")
frames = np.load(os.path.join(RAW_DATA_PATH, 'frame_sequence.npy'), allow_pickle=True)
print(f"Loaded {len(frames)} frames")

# Count total points
total_points = sum(len(f) for f in frames)
print(f"Total points: {total_points:,}")
print(f"Average points per frame: {total_points // len(frames):,}")

# Pre-compute features for each frame (do this ONCE)
print("\nPre-computing geometric features for each frame...")
print("(This takes a few minutes but only happens once)")

frame_features = []
for i, frame in enumerate(frames):
    if (i + 1) % 10 == 0:
        print(f"  Processing frame {i+1}/{len(frames)}...")
    
    linearity, curvature, density_var, planarity = compute_local_features(frame, k=15)
    
    # Store features alongside frame
    frame_features.append({
        'xyz': frame[:, :3],
        'linearity': linearity,
        'curvature': curvature,
        'density_var': density_var,
        'planarity': planarity,
    })

print(f"\nFeatures computed for all {len(frames)} frames!")

In [None]:
# Load NSGA-III derived weights for labeling
# This creates a direct link between adversarial attack results and PointNet training

# Detect environment: Colab or local
import sys
import os

# Check if running on Colab
ON_COLAB = 'google.colab' in sys.modules

if ON_COLAB:
    # On Colab, use the cloned repo path
    SRC_PATH = REPO_DIR + '/src'
    RUNS_PATH = REPO_DIR + '/runs'
else:
    # Running locally - find the src directory relative to notebook
    # notebooks/ is at the same level as src/
    notebook_dir = os.path.dirname(os.path.abspath('__file__'))
    if os.path.exists('../src'):
        SRC_PATH = '../src'
        RUNS_PATH = '../runs'
    elif os.path.exists('src'):
        SRC_PATH = 'src'
        RUNS_PATH = 'runs'
    else:
        # Try absolute path from notebook location
        SRC_PATH = '/home/francesca/mola-pointnet-verification/src'
        RUNS_PATH = '/home/francesca/mola-pointnet-verification/runs'

print(f"Environment: {'Colab' if ON_COLAB else 'Local'}")
print(f"Source path: {SRC_PATH}")
print(f"Runs path: {RUNS_PATH}")

# Add src to Python path
if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

try:
    from nsga3_integration import get_criticality_weights, get_pareto_front_summary
    print("✓ nsga3_integration module loaded successfully!")
    
    # Get weights - will use fallback if NSGA-III results not found
    CRITICALITY_WEIGHTS = get_criticality_weights(
        nsga3_results_dir=RUNS_PATH if os.path.exists(RUNS_PATH) else None,
        run_id=12,  # Use latest run
        fallback_weights={
            "linearity": 0.0,
            "curvature": 0.15,
            "density_var": 0.25,
            "nonplanarity": 0.60,
        }
    )
    
    # Try to get Pareto front summary for analysis
    if os.path.exists(RUNS_PATH):
        try:
            pareto_summary = get_pareto_front_summary(RUNS_PATH, run_id=12)
            if pareto_summary:
                print(f"\nNSGA-III Pareto Front Summary:")
                print(f"  Solutions: {pareto_summary.get('n_solutions', 'N/A')}")
                print(f"  Best ATE: {pareto_summary.get('best_ate_cm', 'N/A'):.1f} cm")
                print(f"  Baseline ATE: {pareto_summary.get('baseline_ate_cm', 23):.1f} cm")
                print(f"  Critical threshold: {pareto_summary.get('critical_threshold_cm', 1.5):.1f} cm perturbation")
        except Exception as e:
            print(f"Could not load Pareto summary: {e}")
            
except ImportError as e:
    print(f"nsga3_integration module not found: {e}")
    print("Using default weights...")
    CRITICALITY_WEIGHTS = {
        "linearity": 0.0,
        "curvature": 0.15,
        "density_var": 0.25,
        "nonplanarity": 0.60,
    }

print(f"\nCriticality weights for labeling:")
for feat, weight in CRITICALITY_WEIGHTS.items():
    print(f"  {feat}: {weight:.4f}")
print(f"\nThese weights determine which regions are labeled as CRITICAL vs NON_CRITICAL")

In [None]:
# Load Pareto set from NSGA-III results for vulnerability-based labeling
# This must come BEFORE creating datasets!

from nsga3_integration import (
    load_pareto_set,
    compute_max_vulnerability,
    compute_vulnerability_label,
)

# Load Pareto-optimal genomes (5 solutions from NSGA-III)
# RUNS_PATH was defined in previous cell
PARETO_SET = load_pareto_set(RUNS_PATH, run_id=12)

if PARETO_SET is not None:
    print(f"✓ Loaded Pareto set: {PARETO_SET.shape} ({PARETO_SET.shape[0]} genomes)")
    
    # Compute threshold from data (median vulnerability)
    print("\nComputing vulnerability distribution for threshold selection...")
    test_vulns = []
    for i in range(100):
        # Sample random neighborhood
        frame_idx = np.random.randint(0, len(frame_features))
        ff = frame_features[frame_idx]
        seed_idx = np.random.randint(0, len(ff['xyz']))
        
        # Get neighborhood
        tree = cKDTree(ff['xyz'])
        _, neighbor_idx = tree.query(ff['xyz'][seed_idx], k=N_POINTS)
        points = ff['xyz'][neighbor_idx]
        points = points - points.mean(axis=0)  # Center
        
        curvature = ff['curvature'][neighbor_idx]
        linearity = ff['linearity'][neighbor_idx]
        
        vuln = compute_max_vulnerability(points, PARETO_SET, curvature, linearity)
        test_vulns.append(vuln)
    
    VULNERABILITY_THRESHOLD = float(np.median(test_vulns))
    print(f"\nVulnerability statistics:")
    print(f"  Min: {np.min(test_vulns):.4f}")
    print(f"  Max: {np.max(test_vulns):.4f}")
    print(f"  Mean: {np.mean(test_vulns):.4f}")
    print(f"  Median (THRESHOLD): {VULNERABILITY_THRESHOLD:.4f}")
else:
    print("⚠ Pareto set not found - using fallback threshold")
    PARETO_SET = None
    VULNERABILITY_THRESHOLD = 0.4

## 2.5 Create Datasets with NSGA-III Vulnerability Labeling

Now we create the training and test datasets using the Pareto genomes for labeling.

In [None]:
# On-the-fly Dataset class with Data Augmentation
# Uses NSGA-III vulnerability-based labeling (genome-driven, not linear formula!)
from torch.utils.data import Dataset

class LiDAROnTheFlyDataset(Dataset):
    """
    Dataset that samples point cloud groups ON-THE-FLY from raw LiDAR frames.
    
    Each __getitem__ call samples a random local neighborhood from a random frame,
    so the model sees different samples every epoch.
    
    **IMPORTANT**: Returns only xyz (3 channels) as input!
    Geometric features are computed for LABELING only, not as model input.
    
    **LABELING**: Uses NSGA-III Pareto genomes to compute vulnerability.
    A region is CRITICAL if it's vulnerable to ANY of the Pareto-optimal attacks.
    """
    
    def __init__(self, frame_features, pareto_set=None, threshold=0.5, 
                 fallback_weights=None, n_points=1024, samples_per_epoch=10000, 
                 seed=None, augment=False):
        self.frame_features = frame_features
        self.pareto_set = pareto_set
        self.threshold = threshold
        self.fallback_weights = fallback_weights or {
            "linearity": 0.0,
            "curvature": 0.15,
            "density_var": 0.25,
            "nonplanarity": 0.60,
        }
        self.n_points = n_points
        self.samples_per_epoch = samples_per_epoch
        self.seed = seed
        self.augment = augment
        
        # Build KD-trees for each frame (once)
        self.trees = [cKDTree(ff['xyz']) for ff in frame_features]
        
        # Frame weights based on number of points
        self.frame_weights = np.array([len(ff['xyz']) for ff in frame_features])
        self.frame_weights = self.frame_weights / self.frame_weights.sum()
    
    def __len__(self):
        return self.samples_per_epoch
    
    def _augment_xyz(self, xyz):
        """Apply augmentation to xyz coordinates."""
        theta = np.random.uniform(0, 2 * np.pi)
        cos_t, sin_t = np.cos(theta), np.sin(theta)
        rotation = np.array([[cos_t, -sin_t, 0], [sin_t, cos_t, 0], [0, 0, 1]])
        xyz = xyz @ rotation.T
        scale = np.random.uniform(0.9, 1.1)
        xyz = xyz * scale
        jitter = np.random.normal(0, 0.01, size=xyz.shape)
        xyz = xyz + jitter
        return xyz.astype(np.float32)
    
    def __getitem__(self, idx):
        if self.seed is not None:
            np.random.seed(self.seed + idx)
        
        frame_idx = np.random.choice(len(self.frame_features), p=self.frame_weights)
        ff = self.frame_features[frame_idx]
        tree = self.trees[frame_idx]
        xyz = ff['xyz']
        
        seed_idx = np.random.randint(0, len(xyz))
        _, neighbor_idx = tree.query(xyz[seed_idx], k=self.n_points)
        
        if len(neighbor_idx) < self.n_points:
            neighbor_idx = np.pad(neighbor_idx, (0, self.n_points - len(neighbor_idx)), mode='edge')
        
        group_xyz = xyz[neighbor_idx].copy()
        group_xyz = group_xyz - group_xyz.mean(axis=0)
        max_dist = np.abs(group_xyz).max()
        if max_dist > 0:
            group_xyz = group_xyz / max_dist
        
        if self.augment:
            group_xyz = self._augment_xyz(group_xyz)
        
        # Extract features FOR LABELING ONLY
        group_linearity = ff['linearity'][neighbor_idx]
        group_curvature = ff['curvature'][neighbor_idx]
        group_density_var = ff['density_var'][neighbor_idx]
        group_planarity = ff['planarity'][neighbor_idx]
        
        # Random point dropout during training
        if self.augment and np.random.random() < 0.3:
            dropout_ratio = np.random.uniform(0.05, 0.15)
            n_dropout = int(self.n_points * dropout_ratio)
            dropout_idx = np.random.choice(self.n_points, n_dropout, replace=False)
            keep_idx = np.setdiff1d(np.arange(self.n_points), dropout_idx)
            replace_idx = np.random.choice(keep_idx, n_dropout, replace=True)
            group_xyz[dropout_idx] = group_xyz[replace_idx]
        
        group = group_xyz.astype(np.float32)
        
        # Compute label using NSGA-III vulnerability (if available)
        if self.pareto_set is not None:
            label = compute_vulnerability_label(
                group_xyz, self.pareto_set, self.threshold,
                group_curvature, group_linearity
            )
        else:
            # Fallback: use linear formula
            def normalize(f):
                f_min, f_max = f.min(), f.max()
                return (f - f_min) / (f_max - f_min + 1e-6)
            
            score = (
                normalize(group_linearity).mean() * self.fallback_weights.get("linearity", 0.0) +
                normalize(group_curvature).mean() * self.fallback_weights.get("curvature", 0.0) +
                normalize(group_density_var).mean() * self.fallback_weights.get("density_var", 0.0) +
                (1 - normalize(group_planarity).mean()) * self.fallback_weights.get("nonplanarity", 0.0)
            )
            label = 0 if score >= 0.4 else 1
        
        return torch.from_numpy(group), label


# Create datasets
print("Creating on-the-fly datasets with NSGA-III vulnerability labeling...")

if PARETO_SET is not None:
    print(f"Using {PARETO_SET.shape[0]} Pareto-optimal genomes for labeling")
    print(f"Vulnerability threshold: {VULNERABILITY_THRESHOLD:.4f}")
else:
    print("Pareto set not available - using fallback weights")

train_dataset = LiDAROnTheFlyDataset(
    frame_features, pareto_set=PARETO_SET, threshold=VULNERABILITY_THRESHOLD,
    fallback_weights=CRITICALITY_WEIGHTS, n_points=N_POINTS,
    samples_per_epoch=20000, seed=None, augment=True
)

test_dataset = LiDAROnTheFlyDataset(
    frame_features, pareto_set=PARETO_SET, threshold=VULNERABILITY_THRESHOLD,
    fallback_weights=CRITICALITY_WEIGHTS, n_points=N_POINTS,
    samples_per_epoch=4000, seed=42, augment=False
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"\nTrain: {len(train_dataset)} samples/epoch ({len(train_loader)} batches)")
print(f"Test: {len(test_dataset)} samples ({len(test_loader)} batches)")
print(f"Labeling: {'NSGA-III vulnerability (MAX across genomes)' if PARETO_SET is not None else 'Fallback weights'}")

## 2.6 QUICK SANITY CHECK

Before full training (~30 min), we run a quick check (~3 min):
1. Train on 2000 samples for 5 epochs
2. Check accuracy is reasonable (>55%)
3. Test α,β-CROWN verification on 3 samples

If the check fails, we stop and debug before wasting time on full training.

In [None]:
###########################################
# QUICK SANITY CHECK (before full training)
###########################################

print("=" * 60)
print("QUICK SANITY CHECK")
print("=" * 60)
print("This check validates the pipeline before full training.")
print("If it fails, we can fix issues without waiting 30+ minutes.\n")

# 1. Create small datasets for quick check
# Use vulnerability-based labeling if PARETO_SET is available
class QuickCheckDataset(Dataset):
    """Small dataset for quick validation with NSGA-III vulnerability labeling."""
    
    def __init__(self, frame_features, pareto_set, threshold, n_points=1024, 
                 samples_per_epoch=2000, seed=None, augment=False):
        self.frame_features = frame_features
        self.pareto_set = pareto_set
        self.threshold = threshold
        self.n_points = n_points
        self.samples_per_epoch = samples_per_epoch
        self.seed = seed
        self.augment = augment
        
        # Build KD-trees
        self.trees = [cKDTree(ff['xyz']) for ff in frame_features]
        self.frame_weights = np.array([len(ff['xyz']) for ff in frame_features])
        self.frame_weights = self.frame_weights / self.frame_weights.sum()
    
    def __len__(self):
        return self.samples_per_epoch
    
    def __getitem__(self, idx):
        if self.seed is not None:
            np.random.seed(self.seed + idx)
        
        # Sample frame and point
        frame_idx = np.random.choice(len(self.frame_features), p=self.frame_weights)
        ff = self.frame_features[frame_idx]
        tree = self.trees[frame_idx]
        xyz = ff['xyz']
        
        seed_idx = np.random.randint(0, len(xyz))
        _, neighbor_idx = tree.query(xyz[seed_idx], k=self.n_points)
        
        if len(neighbor_idx) < self.n_points:
            neighbor_idx = np.pad(neighbor_idx, (0, self.n_points - len(neighbor_idx)), mode='edge')
        
        # Extract and normalize xyz
        group_xyz = xyz[neighbor_idx].copy()
        group_xyz = group_xyz - group_xyz.mean(axis=0)
        max_dist = np.abs(group_xyz).max()
        if max_dist > 0:
            group_xyz = group_xyz / max_dist
        
        # Extract features for labeling
        curvature = ff['curvature'][neighbor_idx]
        linearity = ff['linearity'][neighbor_idx]
        
        # Compute label using NSGA-III vulnerability
        if self.pareto_set is not None:
            label = compute_vulnerability_label(
                group_xyz, self.pareto_set, self.threshold, curvature, linearity
            )
        else:
            # Fallback: use original linear formula
            density_var = ff['density_var'][neighbor_idx].mean()
            planarity = ff['planarity'][neighbor_idx].mean()
            score = (
                curvature.mean() * CRITICALITY_WEIGHTS.get("curvature", 0.15) +
                density_var * CRITICALITY_WEIGHTS.get("density_var", 0.25) +
                (1 - planarity) * CRITICALITY_WEIGHTS.get("nonplanarity", 0.60)
            )
            label = 0 if score >= 0.4 else 1
        
        return torch.from_numpy(group_xyz.astype(np.float32)), label


# Create quick check datasets
quick_train = QuickCheckDataset(
    frame_features, PARETO_SET, VULNERABILITY_THRESHOLD,
    n_points=N_POINTS, samples_per_epoch=2000, seed=None, augment=True
)
quick_test = QuickCheckDataset(
    frame_features, PARETO_SET, VULNERABILITY_THRESHOLD,
    n_points=N_POINTS, samples_per_epoch=500, seed=42, augment=False
)

quick_train_loader = DataLoader(quick_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
quick_test_loader = DataLoader(quick_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Quick train: {len(quick_train)} samples")
print(f"Quick test: {len(quick_test)} samples")

# Check class distribution
train_labels = [quick_train[i][1] for i in range(100)]
print(f"\nLabel distribution (first 100): {sum(train_labels)} NON_CRITICAL, {100 - sum(train_labels)} CRITICAL")

In [None]:
# 2. Quick training (5 epochs)
print("\n" + "-" * 60)
print("Step 1: Quick Training (5 epochs)")
print("-" * 60)

quick_model = PointNetForVerification(
    num_points=N_POINTS,
    num_classes=NUM_CLASSES,
    use_tnet=True,
    feature_transform=True,
    in_channels=IN_CHANNELS,
).to(device)

quick_criterion = nn.CrossEntropyLoss()
quick_optimizer = torch.optim.Adam(quick_model.parameters(), lr=LEARNING_RATE)

for epoch in range(5):
    quick_model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    
    for batch_data, batch_labels in quick_train_loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        quick_optimizer.zero_grad()
        outputs = quick_model(batch_data)
        loss = quick_criterion(outputs, batch_labels)
        loss.backward()
        quick_optimizer.step()
        
        train_loss += loss.item() * batch_data.size(0)
        _, predicted = outputs.max(1)
        train_correct += predicted.eq(batch_labels).sum().item()
        train_total += batch_data.size(0)
    
    train_acc = 100.0 * train_correct / train_total
    print(f"  Epoch {epoch+1}/5: loss={train_loss/train_total:.4f}, acc={train_acc:.1f}%")

# Evaluate on quick test set
quick_model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
    for batch_data, batch_labels in quick_test_loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        outputs = quick_model(batch_data)
        _, predicted = outputs.max(1)
        test_correct += predicted.eq(batch_labels).sum().item()
        test_total += batch_data.size(0)

quick_acc = 100.0 * test_correct / test_total
print(f"\n  Quick test accuracy: {quick_acc:.1f}%")

In [None]:
# 3. Check accuracy threshold
print("\n" + "-" * 60)
print("Step 2: Accuracy Check")
print("-" * 60)

MIN_ACCURACY = 55.0  # Minimum acceptable accuracy

if quick_acc < MIN_ACCURACY:
    print(f"❌ FAIL: Accuracy {quick_acc:.1f}% is below threshold {MIN_ACCURACY}%")
    print("\n   Possible issues:")
    print("   - Labels may be incorrect or too noisy")
    print("   - NSGA-III vulnerability threshold may need adjustment")
    print("   - Model may need different hyperparameters")
    raise ValueError(f"Quick check failed - accuracy too low ({quick_acc:.1f}%)")
else:
    print(f"✓ PASS: Accuracy {quick_acc:.1f}% >= {MIN_ACCURACY}%")

In [None]:
# 4. Test α,β-CROWN verification on 3 samples
print("\n" + "-" * 60)
print("Step 3: α,β-CROWN Verification Test")
print("-" * 60)

# Create verification model (remove dropout)
quick_verify_model = PointNetVerify(
    num_points=N_POINTS,
    num_classes=NUM_CLASSES,
    in_channels=IN_CHANNELS,
    use_tnet=True,
    feature_transform=True,
)
quick_verify_model = transfer_weights_to_verify(quick_model.cpu(), quick_verify_model)
quick_verify_model.eval()

# Get 3 test samples
quick_test_samples = [quick_test[i][0].numpy() for i in range(3)]
quick_test_labels = [quick_test[i][1] for i in range(3)]

# Test verification
QUICK_EPSILON = 0.005  # Test epsilon
verification_ok = True
verification_errors = []

print(f"Testing with ε = {QUICK_EPSILON}")
print()

for i in range(3):
    sample = quick_test_samples[i]
    label = quick_test_labels[i]
    label_str = "CRITICAL" if label == 0 else "NON_CRITICAL"
    
    try:
        result = verify_robustness_lirpa(quick_verify_model, sample, label, QUICK_EPSILON)
        if result['verified']:
            status = f"✓ VERIFIED (margin={result['margin']:.4f})"
        else:
            status = f"✗ NOT VERIFIED (margin={result['margin']:.4f})"
        print(f"  Sample {i} ({label_str}): {status}")
        
    except Exception as e:
        error_msg = str(e)
        verification_errors.append(error_msg)
        print(f"  Sample {i} ({label_str}): ❌ ERROR")
        print(f"    {error_msg[:80]}...")
        verification_ok = False

if verification_errors:
    print(f"\n❌ FAIL: Verification encountered {len(verification_errors)} error(s)")
    print("\n   Possible issues:")
    print("   - Model architecture may be incompatible with auto_LiRPA")
    print("   - Batch normalization issues with batch_size=1")
    print("   - T-Net torch.bmm operation may need special handling")
    raise ValueError("Quick check failed - verification error")
else:
    print("\n✓ PASS: Verification works correctly")

In [None]:
# 5. Quick Check Summary
print("\n" + "=" * 60)
print("✓ QUICK CHECK PASSED!")
print("=" * 60)
print(f"  - Training: 5 epochs completed")
print(f"  - Accuracy: {quick_acc:.1f}% (threshold: {MIN_ACCURACY}%)")
print(f"  - Verification: α,β-CROWN works on 3 samples")
print(f"  - Labeling: NSGA-III vulnerability-based" if PARETO_SET is not None else "  - Labeling: Fallback weights")
print()
print("Proceeding with full training (50 epochs)...")
print("=" * 60)

# Clean up quick check objects to free memory
del quick_model, quick_verify_model, quick_train, quick_test
del quick_train_loader, quick_test_loader
torch.cuda.empty_cache() if torch.cuda.is_available() else None

## 3. PointNet Model (Original Architecture)

Based on Qi et al., "PointNet: Deep Learning on Point Sets" (CVPR 2017)

Architecture:
- **T-Net 3x3**: Spatial transformer for xyz coordinates
- **T-Net 64x64**: Feature transformer after first MLP
- **Point MLP**: 3→64→64→[feat_trans]→64→128→1024 with BatchNorm
- **Global MaxPool**: Symmetric aggregation
- **Classifier**: 1024→512→256→2 with BatchNorm + Dropout(0.3)

**Input**: (batch, 1024, 3) - xyz coordinates only!
The model must learn to classify critical vs non-critical regions
purely from the geometric structure of the point cloud.

In [ ]:
class TNet(nn.Module):
    """
    T-Net: Spatial Transformer Network for PointNet.
    Predicts a k x k transformation matrix.
    
    Architecture (original):
    - Conv1d: k → 64 → 128 → 1024
    - MaxPool
    - FC: 1024 → 512 → 256 → k*k
    - All with BatchNorm
    """
    def __init__(self, k=3):
        super().__init__()
        self.k = k
        
        # Shared MLP (implemented as Conv1d)
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        
        # FC layers
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.bn_fc2 = nn.BatchNorm1d(256)
        
        # Initialize to identity
        self.fc3.weight.data.zero_()
        self.fc3.bias.data.copy_(torch.eye(k).view(-1))
    
    def forward(self, x):
        """
        Args:
            x: (batch, k, n_points)
        Returns:
            transform: (batch, k, k) transformation matrix
        """
        batch_size = x.shape[0]
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Max pool over points
        x = torch.max(x, dim=2)[0]  # (batch, 1024)
        
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.fc3(x)
        
        # Reshape to k x k matrix
        x = x.view(batch_size, self.k, self.k)
        
        return x


class PointNetForVerification(nn.Module):
    """
    PointNet for Verification - IDENTICAL to original PointNet architecture.
    
    Based on: Qi et al., "PointNet: Deep Learning on Point Sets" (CVPR 2017)
    
    Features:
    - Input T-Net (3x3) only transforms xyz, not extra features
    - Feature T-Net (64x64) after conv2
    - 5 conv layers: in_channels→64→64→64→128→1024
    - BatchNorm on all layers
    - Dropout(0.3) in classifier
    
    Input format: (batch, n_points, in_channels)
    - in_channels=7: xyz(3) + features(4)
    """
    def __init__(
        self,
        num_points=1024,
        num_classes=2,
        use_tnet=True,
        feature_transform=True,
        in_channels=7,
    ):
        super().__init__()
        
        self.num_points = num_points
        self.num_classes = num_classes
        self.use_tnet = use_tnet
        self.feature_transform = feature_transform
        self.in_channels = in_channels
        self.input_dim = num_points * in_channels
        
        # Input T-Net (3x3) - only for xyz
        if use_tnet:
            self.input_tnet = TNet(k=3)
        
        # Feature T-Net (64x64)
        if feature_transform:
            self.feat_tnet = TNet(k=64)
        
        # Point-wise MLP (5 conv layers)
        self.conv1 = nn.Conv1d(in_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)
        
        # Classifier MLP
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.bn_fc2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=0.3)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Handle flattened input
        if x.dim() == 2:
            x = x.view(batch_size, self.num_points, self.in_channels)
        
        # Separate xyz and extra features
        if self.in_channels > 3:
            xyz = x[:, :, :3]  # (batch, n_points, 3)
            extra_features = x[:, :, 3:]  # (batch, n_points, in_channels-3)
        else:
            xyz = x
            extra_features = None
        
        # Input T-Net (only on xyz)
        if self.use_tnet:
            xyz_t = xyz.transpose(1, 2)  # (batch, 3, n_points)
            input_trans = self.input_tnet(xyz_t)  # (batch, 3, 3)
            xyz = torch.bmm(xyz, input_trans)  # (batch, n_points, 3)
        
        # Recombine
        if extra_features is not None:
            x = torch.cat([xyz, extra_features], dim=2)
        else:
            x = xyz
        
        # Point-wise MLP
        x = x.transpose(1, 2)  # (batch, in_channels, n_points)
        x = F.relu(self.bn1(self.conv1(x)))  # (batch, 64, n_points)
        x = F.relu(self.bn2(self.conv2(x)))  # (batch, 64, n_points)
        
        # Feature T-Net
        if self.feature_transform:
            feat_trans = self.feat_tnet(x)  # (batch, 64, 64)
            x = x.transpose(1, 2)  # (batch, n_points, 64)
            x = torch.bmm(x, feat_trans)  # (batch, n_points, 64)
            x = x.transpose(1, 2)  # (batch, 64, n_points)
        
        x = F.relu(self.bn3(self.conv3(x)))  # (batch, 64, n_points)
        x = F.relu(self.bn4(self.conv4(x)))  # (batch, 128, n_points)
        x = F.relu(self.bn5(self.conv5(x)))  # (batch, 1024, n_points)
        
        # Global max pooling
        x = torch.max(x, dim=2)[0]  # (batch, 1024)
        
        # Classifier
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x
    
    def get_transforms(self, x):
        """Return input and feature transforms for regularization."""
        batch_size = x.shape[0]
        
        if x.dim() == 2:
            x = x.view(batch_size, self.num_points, self.in_channels)
        
        if self.in_channels > 3:
            xyz = x[:, :, :3]
        else:
            xyz = x
        
        input_trans = None
        feat_trans = None
        
        if self.use_tnet:
            xyz_t = xyz.transpose(1, 2)
            input_trans = self.input_tnet(xyz_t)
            xyz = torch.bmm(xyz, input_trans)
        
        if self.in_channels > 3:
            x = torch.cat([xyz, x[:, :, 3:]], dim=2)
        else:
            x = xyz
        
        x = x.transpose(1, 2)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        
        if self.feature_transform:
            feat_trans = self.feat_tnet(x)
        
        return input_trans, feat_trans


# Create model
model = PointNetForVerification(
    num_points=N_POINTS,
    num_classes=NUM_CLASSES,
    use_tnet=True,
    feature_transform=True,
    in_channels=IN_CHANNELS,
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"PointNet parameters: {n_params:,}")
print(f"Expected: ~3.5M parameters")

## 4. Training

In [ ]:
def feature_transform_regularizer(trans):
    """Regularization loss for feature transform to be close to orthogonal."""
    d = trans.size()[1]
    I = torch.eye(d, device=trans.device).unsqueeze(0)
    loss = torch.mean(torch.norm(I - torch.bmm(trans, trans.transpose(2, 1)), dim=(1, 2)))
    return loss


def train_epoch(model, loader, criterion, optimizer):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_data, batch_labels in loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        
        # Add feature transform regularization
        if model.feature_transform:
            _, feat_trans = model.get_transforms(batch_data)
            if feat_trans is not None:
                loss = loss + 0.001 * feature_transform_regularizer(feat_trans)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch_data.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(batch_labels).sum().item()
        total += batch_data.size(0)
    
    return total_loss / total, 100.0 * correct / total


def evaluate(model, loader):
    """Evaluate model on dataset."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_data, batch_labels in loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            
            outputs = model(batch_data)
            _, predicted = outputs.max(1)
            correct += predicted.eq(batch_labels).sum().item()
            total += batch_data.size(0)
    
    return 100.0 * correct / total

In [None]:
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

best_acc = 0
history = {'train_loss': [], 'train_acc': [], 'test_acc': []}

print("="*60)
print("Training PointNet")
print("="*60)

for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    test_acc = evaluate(model, test_loader)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_acc'].append(test_acc)
    
    if test_acc > best_acc:
        best_acc = test_acc
        # Save best model
        torch.save({
            'model_state_dict': model.state_dict(),
            'n_points': N_POINTS,
            'num_classes': NUM_CLASSES,
            'in_channels': IN_CHANNELS,
            'use_tnet': True,
            'feature_transform': True,
            'test_accuracy': best_acc,
        }, 'pointnet_best.pth')
    
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{EPOCHS} | Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.1f}% | Test Acc: {test_acc:.1f}% | Best: {best_acc:.1f}%")

print(f"\nTraining complete! Best accuracy: {best_acc:.2f}%")

In [ ]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(history['train_loss'], 'b-', label='Train Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], 'b-', label='Train Acc')
axes[1].plot(history['test_acc'], 'r-', label='Test Acc')
axes[1].axhline(y=best_acc, color='g', linestyle='--', label=f'Best: {best_acc:.1f}%')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

In [None]:
# Load best model for verification
checkpoint = torch.load('pointnet_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded best model with test accuracy: {checkpoint['test_accuracy']:.2f}%")

<cell_type>markdown</cell_type>## 5. Verification with auto_LiRPA

Using auto_LiRPA API directly (native PyTorch support for Conv1d, MaxPool, etc.)

**IMPORTANT**: The original PointNet with T-Net uses `torch.bmm` (batch matrix multiplication) 
which is not supported by auto_LiRPA's bound propagation. We create a simplified model 
WITHOUT T-Net specifically for verification.

In [None]:
# Create a verification-friendly model WITH T-Net but WITHOUT Dropout
# 
# auto_LiRPA limitations:
# 1. Dropout + BatchNorm1d with batch_size=1 → NOT supported
# 2. BoundReduceMax with perturbed indices → requires fixed_reducemax_index=True
# 3. torch.bmm → Should work (maps to BoundMatMul)
#
# We keep T-Net but remove Dropout for verification compatibility.

class PointNetVerify(nn.Module):
    """
    PointNet for α,β-CROWN Verification - WITH T-Net, WITHOUT Dropout.
    
    This is the SAME architecture as PointNetForVerification but without Dropout,
    which is not supported by auto_LiRPA with BatchNorm1d and batch_size=1.
    
    Architecture (original PointNet):
    - Input T-Net (3x3) for spatial alignment
    - Feature T-Net (64x64) after conv2
    - Point-wise MLP: 7→64→64→64→128→1024 with BatchNorm
    - Global MaxPool
    - Classifier: 1024→512→256→2 with BatchNorm (NO Dropout)
    """
    def __init__(self, num_points=1024, num_classes=2, in_channels=7,
                 use_tnet=True, feature_transform=True):
        super().__init__()
        self.num_points = num_points
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.use_tnet = use_tnet
        self.feature_transform = feature_transform
        self.input_dim = num_points * in_channels
        
        # Input T-Net (3x3) - only for xyz
        if use_tnet:
            self.input_tnet = TNet(k=3)
        
        # Feature T-Net (64x64)
        if feature_transform:
            self.feat_tnet = TNet(k=64)
        
        # Point-wise MLP (same as original)
        self.conv1 = nn.Conv1d(in_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)
        
        # Classifier MLP (NO Dropout for verification compatibility)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.bn_fc2 = nn.BatchNorm1d(256)
        # NO self.dropout - removed for auto_LiRPA compatibility
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Handle flattened input
        if x.dim() == 2:
            x = x.view(batch_size, self.num_points, self.in_channels)
        
        # Separate xyz and extra features
        if self.in_channels > 3:
            xyz = x[:, :, :3]  # (batch, n_points, 3)
            extra_features = x[:, :, 3:]  # (batch, n_points, in_channels-3)
        else:
            xyz = x
            extra_features = None
        
        # Input T-Net (only on xyz)
        if self.use_tnet:
            xyz_t = xyz.transpose(1, 2)  # (batch, 3, n_points)
            input_trans = self.input_tnet(xyz_t)  # (batch, 3, 3)
            xyz = torch.bmm(xyz, input_trans)  # (batch, n_points, 3)
        
        # Recombine
        if extra_features is not None:
            x = torch.cat([xyz, extra_features], dim=2)
        else:
            x = xyz
        
        # Point-wise MLP
        x = x.transpose(1, 2)  # (batch, in_channels, n_points)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        
        # Feature T-Net
        if self.feature_transform:
            feat_trans = self.feat_tnet(x)  # (batch, 64, 64)
            x = x.transpose(1, 2)  # (batch, n_points, 64)
            x = torch.bmm(x, feat_trans)  # (batch, n_points, 64)
            x = x.transpose(1, 2)  # (batch, 64, n_points)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        
        # Global max pooling
        x = torch.max(x, dim=2)[0]  # (batch, 1024)
        
        # Classifier (NO Dropout)
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.fc3(x)
        
        return x


def transfer_weights_to_verify(full_model, verify_model):
    """
    Transfer weights from full PointNet (with Dropout) to verification model (no Dropout).
    All layers are identical except Dropout is removed.
    """
    # Copy T-Net weights if present
    if hasattr(full_model, 'input_tnet') and hasattr(verify_model, 'input_tnet'):
        verify_model.input_tnet.load_state_dict(full_model.input_tnet.state_dict())
    
    if hasattr(full_model, 'feat_tnet') and hasattr(verify_model, 'feat_tnet'):
        verify_model.feat_tnet.load_state_dict(full_model.feat_tnet.state_dict())
    
    # Copy conv and bn layers
    for i in range(1, 6):
        getattr(verify_model, f'conv{i}').load_state_dict(
            getattr(full_model, f'conv{i}').state_dict()
        )
        getattr(verify_model, f'bn{i}').load_state_dict(
            getattr(full_model, f'bn{i}').state_dict()
        )
    
    # Copy fc layers
    for i in range(1, 4):
        getattr(verify_model, f'fc{i}').load_state_dict(
            getattr(full_model, f'fc{i}').state_dict()
        )
    
    # Copy fc batchnorm
    verify_model.bn_fc1.load_state_dict(full_model.bn_fc1.state_dict())
    verify_model.bn_fc2.load_state_dict(full_model.bn_fc2.state_dict())
    
    return verify_model


# Create verification model WITH T-Net
verify_model = PointNetVerify(
    num_points=N_POINTS,
    num_classes=NUM_CLASSES,
    in_channels=IN_CHANNELS,
    use_tnet=True,           # Keep T-Net!
    feature_transform=True   # Keep feature transform!
)

# Transfer weights from trained model
verify_model = transfer_weights_to_verify(model, verify_model)
verify_model.eval()

# Compare accuracy
verify_model_gpu = verify_model.to(device)
verify_acc = evaluate(verify_model_gpu, test_loader)
print(f"Original model (with Dropout) accuracy: {best_acc:.2f}%")
print(f"Verification model (no Dropout) accuracy: {verify_acc:.2f}%")
print(f"Accuracy difference: {abs(best_acc - verify_acc):.2f}%")
print("\nNote: T-Net is KEPT in verification model!")
print("      Only Dropout is removed (not needed in eval mode anyway).")

In [None]:
# Verification functions using α,β-CROWN
# Key settings for PointNet verification:
# 1. fixed_reducemax_index=True: Assume max indices don't change (valid for small ε)
# 2. method='CROWN' or 'alpha-CROWN': Use CROWN bounds, not just IBP

def verify_robustness_lirpa(model, sample, label, epsilon, method='CROWN'):
    """
    Verify local robustness using α,β-CROWN.
    
    Property: ∀x' with ||x' - x₀||∞ ≤ ε : f(x') = f(x₀)
    
    Args:
        model: PyTorch model (PointNetVerify)
        sample: Input sample (n_points, in_channels)
        label: Ground truth label (0=CRITICAL, 1=NON_CRITICAL)
        epsilon: L∞ perturbation budget
        method: Verification method ('CROWN', 'alpha-CROWN', 'CROWN-Optimized')
    
    Returns:
        dict with 'verified', 'margin', 'lb', 'ub'
    """
    model.eval()
    model_cpu = model.cpu()
    
    # Prepare input
    sample_tensor = torch.FloatTensor(sample).unsqueeze(0)  # (1, n_points, in_channels)
    
    # Create bounded model with CRITICAL options for max pooling
    # fixed_reducemax_index=True assumes the argmax indices don't change under perturbation
    # This is a sound assumption for small epsilon values
    bounded_model = BoundedModule(
        model_cpu, 
        sample_tensor, 
        device='cpu',
        bound_opts={
            'fixed_reducemax_index': True,  # Required for CROWN with max pooling
        }
    )
    
    # Define perturbation
    ptb = PerturbationLpNorm(norm=float('inf'), eps=epsilon)
    bounded_input = BoundedTensor(sample_tensor, ptb)
    
    # Compute bounds using CROWN (α,β-CROWN)
    lb, ub = bounded_model.compute_bounds(x=(bounded_input,), method=method)
    
    # Check if correct class is always highest
    # Margin = lower_bound(correct_class) - upper_bound(other_class)
    if label == 0:
        margin = lb[0, 0] - ub[0, 1]
    else:
        margin = lb[0, 1] - ub[0, 0]
    
    verified = margin.item() > 0
    
    return {
        'verified': verified,
        'margin': margin.item(),
        'lb': lb.detach().numpy(),
        'ub': ub.detach().numpy(),
        'method': method
    }


def verify_safety_lirpa(model, sample, epsilon, method='CROWN'):
    """
    Verify safety property using α,β-CROWN.
    
    Property: For CRITICAL samples, no perturbation causes NON_CRITICAL classification.
    
    Args:
        model: PyTorch model (PointNetVerify)
        sample: Input sample (n_points, in_channels)
        epsilon: L∞ perturbation budget
        method: Verification method ('CROWN', 'alpha-CROWN', 'CROWN-Optimized')
    
    Returns:
        dict with 'verified', 'margin', status info
    """
    model.eval()
    model_cpu = model.cpu()
    
    # Prepare input
    sample_tensor = torch.FloatTensor(sample).unsqueeze(0)
    
    # First check prediction
    with torch.no_grad():
        output = model_cpu(sample_tensor)
        pred = output.argmax(dim=1).item()
        confidence = torch.softmax(output, dim=1)[0]
    
    # Only verify if predicted as CRITICAL (class 0)
    if pred != 0:
        return {
            'verified': False,
            'status': 'skipped_wrong_prediction',
            'original_prediction': pred,
            'confidence': confidence.numpy()
        }
    
    # Create bounded model with options for max pooling
    bounded_model = BoundedModule(
        model_cpu, 
        sample_tensor, 
        device='cpu',
        bound_opts={
            'fixed_reducemax_index': True,  # Required for CROWN with max pooling
        }
    )
    
    # Define perturbation
    ptb = PerturbationLpNorm(norm=float('inf'), eps=epsilon)
    bounded_input = BoundedTensor(sample_tensor, ptb)
    
    # Compute bounds using CROWN
    lb, ub = bounded_model.compute_bounds(x=(bounded_input,), method=method)
    
    # Safety: CRITICAL (class 0) should always have higher score than NON_CRITICAL (class 1)
    # If lb[CRITICAL] > ub[NON_CRITICAL], the sample is safe
    margin = lb[0, 0] - ub[0, 1]
    verified = margin.item() > 0
    
    return {
        'verified': verified,
        'margin': margin.item(),
        'lb': lb.detach().numpy(),
        'ub': ub.detach().numpy(),
        'method': method,
        'original_prediction': pred,
        'confidence': confidence.numpy()
    }


print("Verification functions defined with α,β-CROWN support!")
print("\nKey settings for PointNet verification:")
print("  - fixed_reducemax_index=True: Assumes max indices stable under perturbation")
print("  - method='CROWN': Uses CROWN bound propagation (not just IBP)")
print("\nAvailable methods:")
print("  - 'CROWN': Standard CROWN bounds (fast)")
print("  - 'alpha-CROWN': Optimized CROWN with learnable α (tighter, slower)")
print("  - 'CROWN-Optimized': Same as alpha-CROWN")

In [None]:
# Generate fixed verification samples from test dataset
# We need numpy arrays for verification, so we extract them once

print("Generating fixed verification samples...")

# Create a fixed set of samples for verification (with seed for reproducibility)
# Uses the same NSGA-III vulnerability labeling as training
verify_dataset = LiDAROnTheFlyDataset(
    frame_features,
    pareto_set=PARETO_SET,
    threshold=VULNERABILITY_THRESHOLD,
    fallback_weights=CRITICALITY_WEIGHTS,
    n_points=N_POINTS,
    samples_per_epoch=N_VERIFY_SAMPLES * 2,  # Extra samples to ensure enough of each class
    seed=12345,  # Fixed seed for verification
    augment=False
)

# Extract samples and labels as numpy arrays
test_groups = []
test_labels = []
for i in range(len(verify_dataset)):
    sample, label = verify_dataset[i]
    test_groups.append(sample.numpy())
    test_labels.append(label)

test_groups = np.array(test_groups)
test_labels = np.array(test_labels)

print(f"Verification samples: {len(test_groups)}")
print(f"  CRITICAL (0): {sum(test_labels == 0)}")
print(f"  NON_CRITICAL (1): {sum(test_labels == 1)}")
print(f"\nLabeling: {'NSGA-III vulnerability-based' if PARETO_SET is not None else 'Fallback weights'}")

In [None]:
# Property 1: Local Robustness Verification with α,β-CROWN
# Using PointNetVerify (no T-Net, no Dropout) with CROWN method

print("="*70)
print("PROPERTY 1: LOCAL ROBUSTNESS (L∞) with α,β-CROWN")
print("Verifying: ∀x' with ||x' - x₀||∞ ≤ ε : f(x') = f(x₀)")
print("Model: PointNetVerify (no T-Net, no Dropout)")
print("Method: CROWN (backward bound propagation)")
print("="*70)

# Verification method to use
VERIFY_METHOD = 'CROWN'  # Options: 'CROWN', 'alpha-CROWN', 'CROWN-Optimized'

robustness_results = {}
errors_log = []  # Track full error messages

for eps in EPSILONS:
    print(f"\nε = {eps}")
    print("-"*40)
    
    verified_count = 0
    total = 0
    
    for i in range(min(N_VERIFY_SAMPLES, len(test_groups))):
        sample = test_groups[i]
        label = int(test_labels[i])
        label_str = "CRITICAL" if label == 0 else "NON_CRITICAL"
        
        try:
            # Use verify_model (PointNetVerify) with CROWN method
            result = verify_robustness_lirpa(verify_model, sample, label, eps, method=VERIFY_METHOD)
            if result['verified']:
                verified_count += 1
                status = f"✓ VERIFIED (margin={result['margin']:.4f})"
            else:
                status = f"✗ NOT VERIFIED (margin={result['margin']:.4f})"
            total += 1
        except Exception as e:
            error_msg = str(e)
            errors_log.append(f"Sample {i}, eps={eps}: {error_msg}")
            status = f"⚠ ERROR: {error_msg[:60]}..."
        
        print(f"  Sample {i:3d} ({label_str:12}): {status}")
    
    robustness_results[str(eps)] = {
        'epsilon': eps,
        'verified': verified_count,
        'total': total,
        'verified_pct': 100 * verified_count / total if total > 0 else 0,
        'method': VERIFY_METHOD
    }
    
    print(f"\n  Summary: {verified_count}/{total} verified ({robustness_results[str(eps)]['verified_pct']:.1f}%)")

# Print any errors encountered
if errors_log:
    print("\n" + "="*70)
    print("ERRORS ENCOUNTERED:")
    print("="*70)
    for err in errors_log[:5]:  # Show first 5 errors
        print(f"  {err}")
    if len(errors_log) > 5:
        print(f"  ... and {len(errors_log) - 5} more errors")

In [None]:
# Property 2: Safety Verification with α,β-CROWN
# Using PointNetVerify (no T-Net, no Dropout) with CROWN method

print("="*70)
print("PROPERTY 2: SAFETY PROPERTY with α,β-CROWN")
print("Verifying: For CRITICAL samples, never misclassified as NON_CRITICAL")
print("Model: PointNetVerify (no T-Net, no Dropout)")
print("Method: CROWN (backward bound propagation)")
print("="*70)

# Get CRITICAL samples (label=0)
critical_indices = np.where(test_labels == 0)[0]
n_critical = min(N_VERIFY_SAMPLES, len(critical_indices))

safety_results = {}
safety_errors_log = []

for eps in EPSILONS:
    print(f"\nε = {eps}")
    print("-"*40)
    
    verified_count = 0
    skipped_count = 0
    total = 0
    
    for i, idx in enumerate(critical_indices[:n_critical]):
        sample = test_groups[idx]
        
        try:
            # Use verify_model (PointNetVerify) with CROWN method
            result = verify_safety_lirpa(verify_model, sample, eps, method=VERIFY_METHOD)
            
            if result.get('status') == 'skipped_wrong_prediction':
                skipped_count += 1
                status = f"⊘ SKIPPED (model predicts NON_CRITICAL)"
            elif result['verified']:
                verified_count += 1
                status = f"✓ SAFE (margin={result['margin']:.4f})"
                total += 1
            else:
                status = f"✗ UNSAFE (margin={result['margin']:.4f})"
                total += 1
        except Exception as e:
            error_msg = str(e)
            safety_errors_log.append(f"Sample {idx}, eps={eps}: {error_msg}")
            status = f"⚠ ERROR: {error_msg[:60]}..."
        
        print(f"  Sample {idx:3d} (CRITICAL): {status}")
    
    safety_results[str(eps)] = {
        'epsilon': eps,
        'verified': verified_count,
        'total': total,
        'skipped': skipped_count,
        'verified_pct': 100 * verified_count / total if total > 0 else 0,
        'method': VERIFY_METHOD
    }
    
    print(f"\n  Summary: {verified_count}/{total} safe ({safety_results[str(eps)]['verified_pct']:.1f}%), {skipped_count} skipped")

# Print any errors encountered
if safety_errors_log:
    print("\n" + "="*70)
    print("ERRORS ENCOUNTERED:")
    print("="*70)
    for err in safety_errors_log[:5]:
        print(f"  {err}")
    if len(safety_errors_log) > 5:
        print(f"  ... and {len(safety_errors_log) - 5} more errors")

## 6. Results Summary

In [None]:
# Results Summary
print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)

# Property 1 Table
print("\n### PROPERTY 1: LOCAL ROBUSTNESS ###")
print(f"{'Epsilon':>10} | {'Verified':>10} | {'Total':>10} | {'Verified %':>12}")
print("-"*50)
for eps_str, r in robustness_results.items():
    print(f"{float(eps_str):>10.4f} | {r['verified']:>10} | {r['total']:>10} | {r['verified_pct']:>10.1f}%")

# Property 2 Table
print(f"\n### PROPERTY 2: SAFETY (CRITICAL -> never NON_CRITICAL) ###")
print(f"{'Epsilon':>10} | {'Safe':>10} | {'Total':>10} | {'Safe %':>12} | {'Skipped':>10}")
print("-"*65)
for eps_str, r in safety_results.items():
    print(f"{float(eps_str):>10.4f} | {r['verified']:>10} | {r['total']:>10} | {r['verified_pct']:>10.1f}% | {r['skipped']:>10}")

In [None]:
# Visualization
import matplotlib.pyplot as plt

eps_values = [float(e) for e in robustness_results.keys()]
robustness_pct = [r['verified_pct'] for r in robustness_results.values()]
safety_pct = [r['verified_pct'] for r in safety_results.values()]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Robustness
ax1 = axes[0]
ax1.plot(eps_values, robustness_pct, 'b-o', linewidth=2, markersize=8)
ax1.axhline(y=50, color='r', linestyle='--', label='50% threshold')
ax1.set_xlabel('Perturbation (ε)', fontsize=12)
ax1.set_ylabel('Verified (%)', fontsize=12)
ax1.set_title('Property 1: Local Robustness', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 105])

# Plot 2: Safety
ax2 = axes[1]
colors = ['green' if p == 100 else ('orange' if p > 50 else 'red') for p in safety_pct]
ax2.bar(range(len(eps_values)), safety_pct, color=colors, alpha=0.7)
ax2.set_xticks(range(len(eps_values)))
ax2.set_xticklabels([f'{e:.3f}' for e in eps_values])
ax2.axhline(y=100, color='green', linestyle='-', linewidth=2, label='Safety verified')
ax2.set_xlabel('Perturbation (ε)', fontsize=12)
ax2.set_ylabel('Safe Samples (%)', fontsize=12)
ax2.set_title('Property 2: Safety (CRITICAL → never NON_CRITICAL)', fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_ylim([0, 105])

plt.tight_layout()
plt.savefig('verification_results.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Save results locally
final_results = {
    "timestamp": datetime.now().isoformat(),
    "model_trained": {
        "type": "PointNetForVerification",
        "n_points": N_POINTS,
        "in_channels": IN_CHANNELS,
        "num_classes": NUM_CLASSES,
        "use_tnet": True,
        "feature_transform": True,
        "parameters": n_params,
        "test_accuracy": best_acc,
    },
    "model_verified": {
        "type": "PointNetVerify (with T-Net, no Dropout)",
        "use_tnet": True,
        "feature_transform": True,
        "note": "Dropout removed for auto_LiRPA compatibility (Dropout+BatchNorm1d with batch_size=1 not supported)",
        "test_accuracy": verify_acc,
    },
    "verification_method": f"α,β-CROWN ({VERIFY_METHOD})",
    "verification_settings": {
        "method": VERIFY_METHOD,
        "fixed_reducemax_index": True,
        "note": "Assumes max indices stable under small perturbations"
    },
    "n_verify_samples": N_VERIFY_SAMPLES,
    "property1_robustness": robustness_results,
    "property2_safety": safety_results,
}

with open('verification_results.json', 'w') as f:
    json.dump(final_results, f, indent=2)

print("Results saved locally:")
print("  - pointnet_best.pth (model checkpoint)")
print("  - verification_results.json")
print("  - verification_results.png")
print("  - training_curves.png")
print(f"\nVerification performed using α,β-CROWN with method='{VERIFY_METHOD}'")
print(f"  Trained model accuracy: {best_acc:.2f}%")
print(f"  Verified model accuracy: {verify_acc:.2f}%")
print(f"\nVerification model keeps T-Net and feature transform!")
print("Only Dropout was removed for compatibility.")
print("\nRun the next cell to download all files to your computer.")

In [None]:
# Download all files to your computer
from google.colab import files

print("Downloading files...")
files.download('pointnet_best.pth')
files.download('verification_results.json')
files.download('verification_results.png')
files.download('training_curves.png')
print("Done! Check your Downloads folder.")