## 1. Load model


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader

# Define the model class (copy from your training script)
class ParticleDynamicsGNN(MessagePassing):
    """Graph Neural Network for predicting particle dynamics in SPH simulations"""
    
    def __init__(self, in_channels, hidden_channels=64, num_layers=3, dropout=0.1):
        super().__init__(aggr='add')
        
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        
        # Input embedding
        self.input_embedding = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Message passing layers
        self.message_mlps = nn.ModuleList()
        self.update_mlps = nn.ModuleList()
        
        for _ in range(num_layers):
            # Message MLP
            message_mlp = nn.Sequential(
                nn.Linear(2 * hidden_channels, hidden_channels),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_channels, hidden_channels)
            )
            self.message_mlps.append(message_mlp)
            
            # Update MLP
            update_mlp = nn.Sequential(
                nn.Linear(2 * hidden_channels, hidden_channels),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_channels, hidden_channels)
            )
            self.update_mlps.append(update_mlp)
        
        # Output layer for position prediction
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels // 2, 2)  # Predict 2D position
        )
        
        # Skip connections
        self.skip_connections = nn.ModuleList([
            nn.Linear(hidden_channels, hidden_channels) for _ in range(num_layers)
        ])
        
    def forward(self, x, edge_index, batch=None):
        # Input embedding
        h = self.input_embedding(x)
        
        # Store for skip connections
        residual = h
        
        # Message passing layers
        for i in range(self.num_layers):
            h_new = self.propagate(edge_index, x=h, layer_idx=i)
            
            # Skip connection
            if i > 0:
                h_new = h_new + self.skip_connections[i](residual)
            
            h = h_new
            
            # Update residual every 2 layers
            if i % 2 == 1:
                residual = h
        
        # Output layer
        out = self.output_layer(h)
        
        return out
    
    def message(self, x_i, x_j, layer_idx):
        msg_input = torch.cat([x_i, x_j], dim=1)
        msg = self.message_mlps[layer_idx](msg_input)
        return msg
    
    def update(self, aggr_out, x, layer_idx):
        update_input = torch.cat([x, aggr_out], dim=1)
        updated = self.update_mlps[layer_idx](update_input)
        return updated
    
    def propagate(self, edge_index, x, layer_idx):
        row, col = edge_index
        x_i = x[row]
        x_j = x[col]
        
        msg = self.message(x_i, x_j, layer_idx)
        
        aggr_out = torch.zeros_like(x)
        aggr_out.index_add_(0, row, msg)
        
        out = self.update(aggr_out, x, layer_idx)
        return out

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
# device = torch.device('cpu');
print(f"Device in use: {device}")

# Load the checkpoint
checkpoint = torch.load('./complete_gnn_physics_model_1.pth', map_location=device)

# Get model config and create model instance
model_config = checkpoint['model_config']
model = ParticleDynamicsGNN(**model_config)

# Load the state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print(f"Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model config: {model_config}")

Device in use: mps
Model loaded successfully!
Model parameters: 470,466
Model config: {'in_channels': 5, 'hidden_channels': 128, 'num_layers': 4, 'dropout': 0.15}


## 2. Measure inference time

In [20]:
# Create test data from your training pipeline
# First, let's load the same data preparation functions
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
import h5py
import os

def build_neighbor_graph(positions, radius=0.01, max_neighbors=10):
    """Build graph connectivity from particle positions"""
    n_particles = positions.size(0)
    
    # Compute pairwise distances
    pos_i = positions.unsqueeze(1)  # (N, 1, 2)
    pos_j = positions.unsqueeze(0)  # (1, N, 2)
    distances = torch.norm(pos_i - pos_j, dim=2)  # (N, N)
    
    # Find neighbors within radius (excluding self-connections)
    mask = (distances <= radius) & (distances > 0)
    
    # Optional: Limit number of neighbors per particle
    if max_neighbors is not None:
        for i in range(n_particles):
            neighbor_distances = distances[i]
            neighbor_mask = mask[i]
            
            if neighbor_mask.sum() > max_neighbors:
                # Keep only closest neighbors
                neighbor_distances[~neighbor_mask] = float('inf')
                _, closest_idx = neighbor_distances.topk(max_neighbors, largest=False)
                
                # Reset mask for this particle
                mask[i] = False
                mask[i, closest_idx] = True
    
    # Convert to edge_index format
    edges = torch.nonzero(mask, as_tuple=False)
    edge_index = edges.t().contiguous()
    
    return edge_index

def create_graph_data(positions, particle_types, timestep_idx, target_positions=None):
    """Create PyTorch Geometric Data object from particle data"""
    # Node features: [x, y, particle_type_one_hot, timestep_normalized]
    num_types = particle_types.max().item() + 1
    type_one_hot = F.one_hot(particle_types, num_classes=num_types).float()
    
    # Normalize timestep
    timestep_feature = torch.full((len(positions), 1), timestep_idx / 100.0)
    
    # Combine features
    x = torch.cat([positions, type_one_hot, timestep_feature], dim=1)
    
    # Build graph connectivity
    edge_index = build_neighbor_graph(positions, radius=0.1)
    
    # Create data object
    data = Data(x=x, edge_index=edge_index)
    
    if target_positions is not None:
        data.y = target_positions
    
    return data

def load_h5_dataset(filepath, max_episodes=None, skip_timesteps=1):
    """Load SPH dataset from H5 file with memory optimization"""
    with h5py.File(filepath, 'r') as f:
        episodes = sorted(list(f.keys()))
        if max_episodes:
            episodes = episodes[:max_episodes]
        
        print(f"Loading {len(episodes)} episodes from {os.path.basename(filepath)}")
        
        all_positions = []
        all_particle_types = []
        
        for i, episode_id in enumerate(episodes):
            if i % 10 == 0:
                print(f"  Loading episode {i+1}/{len(episodes)}...")
                
            # Load position data with optional timestep skipping
            positions = f[f'{episode_id}/position'][::skip_timesteps]
            particle_types = f[f'{episode_id}/particle_type'][:]
            
            all_positions.append(torch.tensor(positions, dtype=torch.float32))
            all_particle_types.append(torch.tensor(particle_types, dtype=torch.long))
            
        # Stack all episodes
        positions = torch.stack(all_positions)
        particle_types = torch.stack(all_particle_types)
        
        print(f"Final shape - Positions: {positions.shape}, Types: {particle_types.shape}")
        return positions, particle_types

def prepare_training_data(positions, particle_types, sequence_length=5):
    """Prepare sequences of graph data for training"""
    data_list = []
    
    for episode in range(positions.shape[0]):
        episode_types = particle_types[episode]
        
        for t in range(positions.shape[1] - sequence_length):
            # Current state
            current_pos = positions[episode, t]
            
            # Target (next timestep)
            target_pos = positions[episode, t + 1]
            
            # Create graph data
            data = create_graph_data(current_pos, episode_types, t, target_pos)
            data_list.append(data)
    
    return data_list

print("="*60)
print("DATA PREPARATION PHASE")
print("="*60)

# Load test data (same as in your training pipeline)
data_dir = '/Volumes/Meida/01-CodeLab/01-personal-project/GNN/2D_DAM_5740_20kevery100'
test_path = os.path.join(data_dir, 'test.h5')

print("Loading test data...")
test_pos, test_types = load_h5_dataset(test_path, max_episodes=2, skip_timesteps=5)

print("Preparing test data...")
test_data_list = prepare_training_data(test_pos, test_types)
print(f"Created {len(test_data_list)} test samples")

# Create test loader
batch_size = 2  # Same as training
test_loader = DataLoader(test_data_list, batch_size=batch_size, shuffle=False)

print(f"Test loader created with {len(test_loader)} batches")
print(f"Sample test data shape: {test_data_list[0].x.shape}")

print("\n" + "="*60)
print("DATA PREPARATION COMPLETE")
print("="*60)
print(f"✓ Loaded {len(test_data_list)} test samples")
print(f"✓ Created data loader with {len(test_loader)} batches")
print(f"✓ Sample graph: {test_data_list[0].x.shape[0]} nodes, {test_data_list[0].edge_index.shape[1]} edges")
print(f"✓ Node features: {test_data_list[0].x.shape[1]} dimensions")

DATA PREPARATION PHASE
Loading test data...
Loading 2 episodes from test.h5
  Loading episode 1/2...
Final shape - Positions: torch.Size([2, 81, 5740, 2]), Types: torch.Size([2, 5740])
Preparing test data...
Created 152 test samples
Test loader created with 76 batches
Sample test data shape: torch.Size([5740, 5])

DATA PREPARATION COMPLETE
✓ Loaded 152 test samples
✓ Created data loader with 76 batches
✓ Sample graph: 5740 nodes, 56878 edges
✓ Node features: 5 dimensions


## 3. Collect data phase 

In [16]:
# Add this to your inference.ipynb
def measure_single_timestep_inference(model, test_data_list, device, num_runs=100, warmup_runs=10):
    """
    Measure inference time for a single physics timestep
    
    Args:
        model: Trained GNN model
        test_data_list: List of individual graph samples
        device: Device to run on
        num_runs: Number of inference runs for timing
        warmup_runs: Number of warmup runs (excluded from timing)
    
    Returns:
        Dictionary with single-timestep timing metrics
    """
    model.eval()
    
    # Create single-sample data loader (batch_size=1)
    single_loader = DataLoader(test_data_list, batch_size=1, shuffle=False)
    sample_batch = next(iter(single_loader))
    sample_batch = sample_batch.to(device)
    
    print(f"Single timestep info: {sample_batch.x.shape[0]} nodes, {sample_batch.edge_index.shape[1]} edges")
    print(f"Batch size: {sample_batch.batch.max().item() + 1 if sample_batch.batch is not None else 1}")
    
    # Warmup runs
    print(f"Running {warmup_runs} warmup iterations...")
    with torch.no_grad():
        for _ in range(warmup_runs):
            _ = model(sample_batch.x, sample_batch.edge_index, sample_batch.batch)
    
    # Synchronize for accurate timing
    if device.type == 'cuda':
        torch.cuda.synchronize()
    elif device.type == 'mps':
        torch.mps.synchronize()
    
    # Actual timing runs
    print(f"Running {num_runs} timed iterations...")
    inference_times = []
    
    with torch.no_grad():
        for i in range(num_runs):
            start_time = time.perf_counter()
            
            output = model(sample_batch.x, sample_batch.edge_index, sample_batch.batch)
            
            # Synchronize to ensure computation is complete
            if device.type == 'cuda':
                torch.cuda.synchronize()
            elif device.type == 'mps':
                torch.mps.synchronize()
            
            end_time = time.perf_counter()
            inference_times.append((end_time - start_time) * 1000)  # Convert to ms
            
            if (i + 1) % 20 == 0:
                print(f"  Completed {i + 1}/{num_runs} runs")
    
    # Calculate statistics
    inference_times = np.array(inference_times)
    
    metrics = {
        'mean_time_ms': np.mean(inference_times),
        'std_time_ms': np.std(inference_times),
        'min_time_ms': np.min(inference_times),
        'max_time_ms': np.max(inference_times),
        'median_time_ms': np.median(inference_times),
        'p95_time_ms': np.percentile(inference_times, 95),
        'p99_time_ms': np.percentile(inference_times, 99),
        'timesteps_per_sec': 1000 / np.mean(inference_times),
        'num_nodes': sample_batch.x.shape[0],
        'num_edges': sample_batch.edge_index.shape[1],
        'time_per_node_us': (np.mean(inference_times) * 1000) / sample_batch.x.shape[0],
        'time_per_edge_us': (np.mean(inference_times) * 1000) / sample_batch.edge_index.shape[1]
    }
    
    return metrics, inference_times

def compare_batch_vs_single_timestep(model, test_data_list, device):
    """
    Compare batch processing vs single timestep processing
    """
    print("="*60)
    print("BATCH VS SINGLE TIMESTEP COMPARISON")
    print("="*60)
    
    # Test different configurations
    configs = [
        {"batch_size": 1, "description": "Single Timestep"},
        {"batch_size": 2, "description": "Batch of 2"},
        {"batch_size": 4, "description": "Batch of 4"},
        {"batch_size": 8, "description": "Batch of 8"},
    ]
    
    results = {}
    
    for config in configs:
        batch_size = config["batch_size"]
        description = config["description"]
        
        print(f"\n--- {description} (batch_size={batch_size}) ---")
        
        # Create data loader
        loader = DataLoader(test_data_list[:batch_size*5], 
                          batch_size=batch_size, shuffle=False)
        
        # Measure timing
        metrics, _ = measure_inference_time(model, loader, device, num_runs=50, warmup_runs=5)
        
        # Calculate per-timestep metrics
        time_per_timestep = metrics['mean_time_ms'] / batch_size
        timesteps_per_sec = batch_size * 1000 / metrics['mean_time_ms']
        
        result = {
            'batch_size': batch_size,
            'total_time_ms': metrics['mean_time_ms'],
            'time_per_timestep_ms': time_per_timestep,
            'timesteps_per_sec': timesteps_per_sec,
            'nodes_per_timestep': metrics['num_nodes'] // batch_size,
            'edges_per_timestep': metrics['num_edges'] // batch_size,
        }
        
        results[batch_size] = result
        
        print(f"  Total time: {result['total_time_ms']:.2f} ms")
        print(f"  Time per timestep: {result['time_per_timestep_ms']:.2f} ms")
        print(f"  Timesteps per second: {result['timesteps_per_sec']:.2f}")
        print(f"  Nodes per timestep: {result['nodes_per_timestep']}")
        print(f"  Edges per timestep: {result['edges_per_timestep']}")
    
    return results

def simulate_realtime_performance(model, test_data_list, device, target_fps=60):
    """
    Simulate real-time performance requirements
    """
    print("="*60)
    print("REAL-TIME SIMULATION ANALYSIS")
    print("="*60)
    
    # Measure single timestep performance
    single_metrics, _ = measure_single_timestep_inference(
        model, test_data_list, device, num_runs=100
    )
    
    timestep_time_ms = single_metrics['mean_time_ms']
    max_fps = 1000 / timestep_time_ms
    
    print(f"\nReal-time Performance Analysis:")
    print(f"  Single timestep time: {timestep_time_ms:.2f} ms")
    print(f"  Maximum FPS: {max_fps:.1f}")
    print(f"  Target FPS: {target_fps}")
    
    if max_fps >= target_fps:
        print(f"  ✅ Can achieve {target_fps} FPS (with {max_fps - target_fps:.1f} FPS headroom)")
        time_budget_ms = 1000 / target_fps
        utilization = (timestep_time_ms / time_budget_ms) * 100
        print(f"  ⏱️  GPU utilization: {utilization:.1f}%")
    else:
        print(f"  ❌ Cannot achieve {target_fps} FPS")
        print(f"  🔧 Need {target_fps / max_fps:.1f}x speedup")
    
    # Multi-step prediction analysis
    steps_ahead = [1, 5, 10, 20, 50]
    print(f"\nMulti-step Prediction Times:")
    for steps in steps_ahead:
        total_time = timestep_time_ms * steps
        print(f"  {steps:2d} steps ahead: {total_time:6.1f} ms ({total_time/1000:.2f} seconds)")
    
    return {
        'timestep_time_ms': timestep_time_ms,
        'max_fps': max_fps,
        'can_achieve_target': max_fps >= target_fps,
        'speedup_needed': target_fps / max_fps if max_fps < target_fps else 1.0
    }

# Run the single timestep benchmarks
print("\n" + "="*60)
print("SINGLE TIMESTEP PERFORMANCE ANALYSIS")
print("="*60)

# 1. Pure single timestep measurement
print("\n1. MEASURING TRUE SINGLE TIMESTEP PERFORMANCE")
print("-" * 50)
single_metrics, single_times = measure_single_timestep_inference(
    model, test_data_list, device, num_runs=100
)

print(f"\nSingle Timestep Statistics:")
print(f"  Mean: {single_metrics['mean_time_ms']:.3f} ms")
print(f"  Std:  {single_metrics['std_time_ms']:.3f} ms")
print(f"  Min:  {single_metrics['min_time_ms']:.3f} ms")
print(f"  Max:  {single_metrics['max_time_ms']:.3f} ms")
print(f"  P95:  {single_metrics['p95_time_ms']:.3f} ms")
print(f"  P99:  {single_metrics['p99_time_ms']:.3f} ms")

print(f"\nSingle Timestep Throughput:")
print(f"  Timesteps/sec: {single_metrics['timesteps_per_sec']:.1f}")
print(f"  Time per node: {single_metrics['time_per_node_us']:.3f} μs")
print(f"  Time per edge: {single_metrics['time_per_edge_us']:.3f} μs")

# 2. Compare batch sizes
print("\n2. BATCH SIZE EFFICIENCY COMPARISON")
print("-" * 50)
batch_comparison = compare_batch_vs_single_timestep(model, test_data_list, device)

# 3. Real-time simulation analysis
print("\n3. REAL-TIME SIMULATION FEASIBILITY")
print("-" * 50)
realtime_analysis = simulate_realtime_performance(model, test_data_list, device, target_fps=60)

# 4. Summary
print("\n" + "="*60)
print("PERFORMANCE SUMMARY")
print("="*60)
print(f"✨ Single timestep inference: {single_metrics['mean_time_ms']:.1f} ms")
print(f"🚀 Maximum simulation speed: {single_metrics['timesteps_per_sec']:.1f} FPS")
print(f"🎯 Real-time feasibility (60 FPS): {'✅ YES' if realtime_analysis['can_achieve_target'] else '❌ NO'}")

if not realtime_analysis['can_achieve_target']:
    print(f"🔧 Speedup needed: {realtime_analysis['speedup_needed']:.1f}x")

print(f"📊 Nodes per timestep: {single_metrics['num_nodes']:,}")
print(f"📊 Edges per timestep: {single_metrics['num_edges']:,}")
print(f"💾 Time per node: {single_metrics['time_per_node_us']:.3f} μs")
print(f"💾 Time per edge: {single_metrics['time_per_edge_us']:.3f} μs")


SINGLE TIMESTEP PERFORMANCE ANALYSIS

1. MEASURING TRUE SINGLE TIMESTEP PERFORMANCE
--------------------------------------------------
Single timestep info: 5740 nodes, 108451 edges
Batch size: 1
Running 10 warmup iterations...
Running 100 timed iterations...
  Completed 20/100 runs
  Completed 40/100 runs
  Completed 60/100 runs
  Completed 80/100 runs
  Completed 100/100 runs

Single Timestep Statistics:
  Mean: 225.482 ms
  Std:  3.426 ms
  Min:  218.577 ms
  Max:  238.319 ms
  P95:  230.781 ms
  P99:  233.359 ms

Single Timestep Throughput:
  Timesteps/sec: 4.4
  Time per node: 39.283 μs
  Time per edge: 2.079 μs

2. BATCH SIZE EFFICIENCY COMPARISON
--------------------------------------------------
BATCH VS SINGLE TIMESTEP COMPARISON

--- Single Timestep (batch_size=1) ---
Batch info: 5740 nodes, 108451 edges
Running 5 warmup iterations...
Running 50 timed iterations...
  Completed 20/50 runs
  Completed 40/50 runs
  Total time: 226.00 ms
  Time per timestep: 226.00 ms
  Timestep