# Example 1: 2D Waterflood Tutorial


## Understanding preprocessor.py 

Before running preprocessing, let's understand what the `preprocessor.py` script actually does. This will help you understand the outputs and troubleshoot any issues.

#### Main Classes and Components

**1. `ReservoirPreprocessor` - Main orchestrator class**

```python
class ReservoirPreprocessor:
    def __init__(self, cfg):
        # Set up directories
        self.dataset_dir = get_dataset_dir(cfg)
        self.graphs_dir = os.path.join(self.dataset_dir, "graphs")
        self.partitions_dir = os.path.join(self.dataset_dir, "partitions")
        self.stats_file = os.path.join(self.dataset_dir, "global_stats.json")
```

#### The Preprocessing Pipeline (execute() method)

The preprocessing happens in **5 main steps**:

##### **Step 1: Create Raw Graphs** (unless `skip_graphs=True`)

```python
processor = ReservoirGraphBuilder(self.cfg)
self.generated_files = processor.execute()
```

What happens:
- Reads ECLIPSE binary files (`.INIT`, `.EGRID`, `.UNRST`)
- Extracts grid geometry and cell connections
- Builds graph structure with nodes (cells) and edges (connections)
- Creates training sequences: Input (t-2, t-1, t) → Target (t+1)
- Saves raw graphs to `graphs/` directory

Output: `graph_name_timestep.pt` files (e.g., `CASE_001_000.pt`)

##### **Step 2: Create Partitions from Graphs**

```python
self.create_partitions_from_graphs(graph_file_list=self.graph_file_list)
```

What happens:
- Loads each raw graph
- Uses **METIS** (or fallback: simple sequential partitioning) to divide graph into `num_partitions` subgraphs
- For each partition:
  - Extract inner nodes (the actual partition)
  - Add **halo region** (neighboring nodes) for communication
  - Creates partition object with: `edge_index`, `node_features`, `inner_node` indices
- Saves partitions for each graph

Output: `partitions_graph_name_timestep.pt` files

**Why Partition?**
- Large graphs may not fit in single GPU memory
- Enables multi-GPU training (each GPU handles different partitions)
- Halo regions allow information exchange between partitions

##### **Step 3: Split Samples by Case**

```python
splits = self.split_samples_by_case(
    train_ratio=0.8, val_ratio=0.1, test_ratio=0.1
)
```

What happens:
- Groups all timesteps belonging to the same simulation case
- Randomly assigns entire cases to train/val/test splits
- Ensures all timesteps from one case stay together (prevents data leakage!)
- Moves partition files to appropriate subdirectories: `train/`, `val/`, `test/`

Output: Organized partition files in split-specific directories

##### **Step 4: Compute Global Statistics**

```python
stats = compute_global_statistics(graph_files, self.stats_file)
```

What happens:
- Iterates through all training graphs
- Computes mean and standard deviation for:
  - Node features (PERMX, PORV, PRESSURE, SWAT, etc.)
  - Edge features (transmissibilities)
  - Target features (next-timestep values)
- Saves statistics to `global_stats.json`

Output: `global_stats.json` with normalization parameters

**Why Global Statistics?**
- Neural networks train better with normalized inputs (typically mean=0, std=1)
- Ensures all features are on similar scales
- Allows denormalization of predictions for physical interpretation

##### **Step 5: Save Metadata**

```python
self.save_dataset_metadata()
```

What happens:
- Records all preprocessing configuration
- Saves paths to datasets, partitions, statistics
- Stores partition topology (num_partitions, halo_size)
- Creates metadata for inference stage

Output: `dataset_metadata.json`

#### Key Helper Functions

**Graph Partitioning with METIS:**
```python
cluster_data = pyg.loader.ClusterData(
    graph, num_parts=num_partitions
)
```
- Uses METIS algorithm to minimize edge cuts between partitions
- Balances partition sizes
- Falls back to sequential partitioning if METIS unavailable

**Halo Region Creation:**
```python
part_node, part_edge_index, inner_node_mapping, edge_mask = (
    pyg.utils.k_hop_subgraph(
        part_inner_node,
        num_hops=halo_size,  # Usually 1-3
        edge_index=graph.edge_index,
        num_nodes=graph.num_nodes,
        relabel_nodes=True,
    )
)
```
- Extends partition by `halo_size` layers of neighboring nodes
- Allows information to propagate between partitions during training
- Critical for maintaining accuracy with partitioned graphs

#### Directory Structure After Preprocessing

```
outputs/XMGN_2D_Q5SP_Waterflood/
├── preprocessed_data/
│   ├── graphs/                     # Raw graph files
│   │   ├── CASE_001_000.pt
│   │   ├── CASE_001_001.pt
│   │   └── ...
│   ├── partitions/
│   │   ├── train/                  # Training partitions
│   │   │   ├── partitions_CASE_001_000.pt
│   │   │   └── ...
│   │   ├── val/                    # Validation partitions
│   │   └── test/                   # Test partitions
│   ├── global_stats.json           # Normalization statistics
│   └── dataset_metadata.json       # Preprocessing metadata
```

#### Common Issues and Solutions

**Issue**: "METIS partitioning failed"
- **Solution**: Automatically falls back to simple sequential partitioning
- **Impact**: May result in less balanced partitions, but still functional

**Issue**: "Insufficient samples for train/val/test split"
- **Solution**: Need at least 3 simulation cases for splits
- **Fix**: Increase `num_samples` in config or adjust split ratios

**Issue**: "NaN in global statistics"
- **Cause**: Invalid or missing data in simulation files
- **Fix**: Check simulation outputs for completeness

## Understanding train.py

Now let's understand what happens during training. The `train.py` script handles the entire training pipeline, from loading data to saving checkpoints.

#### Main Classes and Components

**1. `Trainer` - Main training orchestrator class**

```python
class Trainer:
    def __init__(self, cfg, dist, logger):
        # Initialize distributed training
        self.dist = dist
        self.device = dist.device
        
        # Set up dataloaders
        self._initialize_dataloaders(cfg)
        
        # Create model
        self._initialize_model(cfg)
        
        # Set up optimizer and scheduler
        self._initialize_optimizer(cfg)
        
        # Configure loss functions
        self._initialize_loss_functions(cfg)
        
        # Set up early stopping
        self._initialize_early_stopping(cfg)
```

#### The Training Pipeline

##### **Initialization Steps**

**1. Initialize Dataloaders** (`_initialize_dataloaders`)

```python
def _initialize_dataloaders(self, cfg):
    # Load global statistics for normalization
    self.stats = load_stats(self.stats_file)
    
    # Create dataset from partition files
    dataset = GraphDataset(
        file_paths,
        node_mean, node_std,
        edge_mean, edge_std,
        target_mean, target_std
    )
    
    # Create DistributedSampler for multi-GPU
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.training.batch_size,
        sampler=sampler,
        collate_fn=custom_collate_fn
    )
```

What happens:
- Loads partitioned graphs from train/val directories
- Applies normalization using statistics from preprocessing
- Creates distributed sampler for multi-GPU training
- Wraps in DataLoader for batch processing

**2. Initialize Model** (`_initialize_model`)

```python
def _initialize_model(self, cfg):
    # Get feature dimensions from statistics
    input_dim_nodes = len(self.stats["node_features"]["mean"])
    input_dim_edges = len(self.stats["edge_features"]["mean"])
    output_dim = len(cfg.dataset.graph.target_vars.node_features)
    
    # Create MeshGraphNet model
    self.model = MeshGraphNet(
        input_dim_nodes=input_dim_nodes,
        input_dim_edges=input_dim_edges,
        output_dim=output_dim,
        processor_size=cfg.model.num_message_passing_layers,
        hidden_dim_node_encoder=cfg.model.hidden_dim,
        hidden_dim_edge_encoder=cfg.model.hidden_dim,
        hidden_dim_node_decoder=cfg.model.hidden_dim,
        mlp_activation_fn=cfg.model.activation,
        do_concat_trick=cfg.performance.use_concat_trick,
        num_processor_checkpoint_segments=cfg.performance.checkpoint_segments,
    ).to(self.device)
    
    # Wrap for distributed training
    if world_size > 1:
        self.model = DistributedDataParallel(self.model, ...)
```

What happens:
- Creates X-MeshGraphNet (MeshGraphNet) architecture
- **Encoder**: Maps node/edge features to hidden dimension
- **Processor**: Message passing layers (graph convolution)
- **Decoder**: Maps back to target variables
- Wraps with DistributedDataParallel for multi-GPU training

**3. Initialize Optimizer** (`_initialize_optimizer`)

```python
def _initialize_optimizer(self, cfg):
    # AdamW optimizer with weight decay
    self.optimizer = optim.AdamW(
        self.model.parameters(),
        lr=cfg.training.start_lr,
        weight_decay=cfg.training.weight_decay,
        betas=(0.9, 0.99),
        eps=1e-8,
    )
    
    # Cosine annealing learning rate schedule
    self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
        self.optimizer,
        T_max=cfg.training.num_epochs,
        eta_min=cfg.training.end_lr
    )
    
    # Gradient scaler for mixed precision
    self.scaler = GradScaler() if use_cuda else None
```

What happens:
- **AdamW**: Adaptive learning rate with decoupled weight decay (L2 regularization)
- **Cosine Annealing**: Learning rate gradually decreases from `start_lr` to `end_lr`
- **GradScaler**: Enables mixed precision training (FP16/BF16) for speed

**4. Initialize Loss Functions** (`_initialize_loss_functions`)

```python
def _initialize_loss_functions(self, cfg):
    # Load per-variable loss function configuration
    self.loss_functions = cfg.dataset.graph.target_vars.loss_functions
    # Example: ["L2", "L1"] for [PRESSURE, SWAT]
    
    # Create PyTorch loss objects
    self.loss_fn_objects = []
    for loss_func in self.loss_functions:
        if loss_func == "L1":
            self.loss_fn_objects.append(torch.nn.L1Loss())
        elif loss_func == "L2":
            self.loss_fn_objects.append(torch.nn.MSELoss())
        elif loss_func == "Huber":
            self.loss_fn_objects.append(
                torch.nn.HuberLoss(delta=self.huber_delta)
            )
```

What happens:
- Creates separate loss function for each target variable
- **L1 (MAE)**: Mean Absolute Error, good for saturation (bounded [0,1])
- **L2 (MSE)**: Mean Squared Error, good for pressure (unbounded)
- **Huber**: Combines L1 and L2, robust to outliers
- Allows weighting: `total_loss = w1 * loss_pressure + w2 * loss_swat`

##### **Training Loop** (`train()` method)

```python
def train(self):
    for epoch in range(1, num_epochs + 1):
        # Set epoch for distributed sampler
        self.train_sampler.set_epoch(epoch)
        
        # Training step
        train_loss = self.train_epoch()
        
        # Validation step (every validation_freq epochs)
        if epoch % validation_freq == 0:
            val_loss, val_denorm_loss, val_metrics = self.validate_epoch()
            
            # Save best model
            if val_loss < best_val_loss:
                save_checkpoint(**self.bst_ckpt_args, epoch=epoch)
                best_val_loss = val_loss
            
            # Check early stopping
            if self.early_stopping.should_stop():
                break
        
        # Save regular checkpoint
        if epoch % validation_freq == 0:
            save_checkpoint(**self.ckpt_args, epoch=epoch)
        
        # Update learning rate
        self.scheduler.step()
```

**Single Training Epoch** (`train_epoch()`)

```python
def train_epoch(self):
    self.model.train()  # Set to training mode
    total_loss = 0.0
    
    for batch_idx, batch in enumerate(self.train_dataloader):
        partitions_list, labels = batch
        self.optimizer.zero_grad()
        
        # Process each sample in batch
        for partitions in partitions_list:
            # Process each partition
            for partition in partitions:
                # 1. Move data to GPU
                partition = partition.to(self.device)
                
                # 2. Forward pass
                pred = self.model(partition.x, partition.edge_attr, partition)
                
                # 3. Get inner nodes (exclude halo)
                pred_inner = pred[partition.inner_node]
                target_inner = partition.y[partition.inner_node]
                
                # 4. Compute loss
                loss = self.compute_weighted_loss(pred_inner, target_inner)
                loss = loss / (num_partitions * num_samples)
                
                # 5. Backward pass (accumulate gradients)
                if use_cuda:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()
        
        # 6. Update weights (after all partitions)
        if use_cuda:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / num_batches
```

**Key Points:**
1. **Partition Processing**: Each sample may have multiple partitions
2. **Inner Nodes**: Only compute loss on inner nodes (not halo)
3. **Gradient Accumulation**: Accumulate gradients across partitions before updating
4. **Mixed Precision**: Uses GradScaler for FP16/BF16 training

**Single Validation Epoch** (`validate_epoch()`)

```python
def validate_epoch(self):
    self.model.eval()  # Set to evaluation mode
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():  # Disable gradient computation
        for batch in self.val_dataloader:
            partitions_list, labels = batch
            
            for partitions in partitions_list:
                for partition in partitions:
                    # Forward pass
                    pred = self.model(partition.x, partition.edge_attr, partition)
                    pred_inner = pred[partition.inner_node]
                    target_inner = partition.y[partition.inner_node]
                    
                    # Compute loss
                    loss = self.compute_weighted_loss(pred_inner, target_inner)
                    total_loss += loss.item()
                    
                    # Collect for metrics
                    all_predictions.append(pred_inner.cpu().numpy())
                    all_targets.append(target_inner.cpu().numpy())
    
    # Compute per-variable metrics
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    
    metrics = {}
    for i, var_name in enumerate(target_names):
        mae = np.mean(np.abs(all_predictions[:, i] - all_targets[:, i]))
        rmse = np.sqrt(np.mean((all_predictions[:, i] - all_targets[:, i])**2))
        metrics[f"mae_{var_name}"] = mae
        metrics[f"rmse_{var_name}"] = rmse
    
    return avg_loss, metrics
```

**Key Points:**
1. **No Gradients**: Disables gradient computation for speed and memory
2. **Collect Predictions**: Saves all predictions for detailed metrics
3. **Per-Variable Metrics**: Computes MAE and RMSE for each target variable
4. **Denormalization**: Can compute metrics in physical units

#### Loss Computation

**Weighted Multi-Variable Loss:**

```python
def compute_weighted_loss(self, predictions, targets):
    losses_per_var = []
    
    for i, loss_fn in enumerate(self.loss_fn_objects):
        pred_var = predictions[:, i]    # e.g., pressure predictions
        target_var = targets[:, i]       # e.g., pressure targets
        
        # Compute loss for this variable
        loss = loss_fn(pred_var, target_var)
        losses_per_var.append(loss)
    
    # Apply weights and sum
    losses_tensor = torch.stack(losses_per_var)
    weighted_loss = torch.sum(self.target_weights * losses_tensor)
    
    return weighted_loss
```

**Example:**
```python
# Config: weights=[1.0, 1.0], loss_functions=["L2", "L1"]
# PRESSURE: L2 loss with weight 1.0
# SWAT: L1 loss with weight 1.0
total_loss = 1.0 * MSE(pred_pressure, true_pressure) + 
             1.0 * MAE(pred_swat, true_swat)
```

#### Checkpointing

**Two Types of Checkpoints:**

1. **Regular Checkpoints** (every `validation_freq` epochs)
   - Saved to: `checkpoints/`
   - Used for resuming training
   - Contains: model, optimizer, scheduler, epoch

2. **Best Checkpoints** (when validation improves)
   - Saved to: `best_checkpoints/`
   - Used for inference
   - Contains: best model based on validation loss

**Checkpoint Contents:**
```python
{
    'epoch': 42,
    'model_state_dict': ...,
    'optimizer_state_dict': ...,
    'scheduler_state_dict': ...,
    'best_val_loss': 0.0123,
}
```

#### Distributed Training (Multi-GPU)

**Key Components:**

1. **DistributedDataParallel (DDP)**:
   - Replicates model on each GPU
   - Each GPU processes different data
   - Synchronizes gradients across GPUs

2. **DistributedSampler**:
   - Splits dataset across GPUs
   - Each GPU sees unique samples
   - `set_epoch()` ensures different shuffling each epoch

3. **Gradient Synchronization**:
   - After backward pass, gradients are averaged across GPUs
   - Ensures all GPUs update with same gradients

#### MLflow Logging

**Automatic Logging with PhysicsNeMo:**

```python
with LaunchLogger(name_space="train", epoch=epoch) as log:
    train_loss = self.train_epoch()
    log.log_epoch({
        "train_loss": train_loss,
        "learning_rate": lr,
        "best_val_loss": best_val_loss
    })
```

**Logged Metrics:**
- Training loss (per epoch)
- Validation loss (per validation step)
- Learning rate schedule
- Per-variable MAE and RMSE
- Denormalized metrics (physical units)

#### Common Training Issues

**Issue**: "CUDA out of memory"
- **Solutions**:
  - Reduce `batch_size`
  - Increase `num_partitions`
  - Reduce `hidden_dim`
  - Enable gradient checkpointing (`checkpoint_segments`)

**Issue**: "Loss becomes NaN"
- **Causes**:
  - Learning rate too high
  - Invalid normalization statistics
  - Numerical instability
- **Solutions**:
  - Reduce `start_lr`
  - Check global_stats.json for NaN values
  - Use gradient clipping

**Issue**: "Validation loss not improving"
- **Causes**:
  - Overfitting
  - Insufficient model capacity
  - Learning rate too low
- **Solutions**:
  - Reduce `weight_decay`
  - Increase `hidden_dim` or `num_message_passing_layers`
  - Increase `start_lr`
  - Add more training data

**Issue**: "Training very slow"
- **Solutions**:
  - Use multiple GPUs
  - Enable mixed precision training
  - Increase `batch_size` if memory allows
  - Reduce `num_message_passing_layers`

#### Training Output Directory Structure

```
outputs/XMGN_2D_Q5SP_Waterflood/
├── checkpoints/               # Regular checkpoints
│   ├── checkpoint.epoch_010.pt
│   ├── checkpoint.epoch_020.pt
│   └── ...
├── best_checkpoints/          # Best model
│   └── checkpoint.epoch_042.pt
├── mlruns/                    # MLflow tracking
│   └── 0/
│       └── run_id/
│           ├── metrics/
│           ├── params/
│           └── artifacts/
└── logs/                      # Training logs
```

This training infrastructure provides:
- Distributed multi-GPU training
- Mixed precision for speed
- Flexible loss functions per variable
- Comprehensive metrics and logging
- Checkpoint management and resume
- Early stopping to prevent overtraining

# Example 2: Norne Field

## Understanding preprocessor.py for Norne 

The preprocessing for Norne is fundamentally the same as the 2D example, but with important differences due to scale and complexity.

#### Key Differences for Norne

##### **1. Scale Differences**

```python
# 2D Example
num_nodes ≈ 100-1,000
num_edges ≈ 400-4,000
graph_size ≈ 10-50 MB

# Norne Field
num_nodes ≈ 45,000
num_edges ≈ 220,000 (including ~5,000 NNCs!)
graph_size ≈ 300-500 MB
```

##### **2. Non-Neighbor Connections (NNCs)**

The most important difference: **Faults create NNCs**!

```python
# In ReservoirGraphBuilder (called by preprocessor):
# Regular connections (6-face neighbors in 3D)
for each cell:
    connect to neighbors: [i-1, i+1, j-1, j+1, k-1, k+1]

# PLUS: Non-Neighbor Connections from .INIT file
# Example from Norne:
NNC connections: [(1234, 5678), (2345, 6789), ...]
# Cell 1234 connects to cell 5678 across a fault!
```

**How NNCs are Handled:**

```python
# In graph builder:
# 1. Read regular grid connections
regular_edges = build_6_face_connectivity(grid)

# 2. Read NNC data from .INIT file
nnc_data = read_nnc_transmissibilities(init_file)
# Contains: cell1_index, cell2_index, transmissibility

# 3. Add NNCs as additional edges
nnc_edges = [(cell1, cell2) for cell1, cell2, trans in nnc_data]
all_edges = regular_edges + nnc_edges

# 4. Create edge features
edge_features = []
for edge in all_edges:
    if edge in regular_edges:
        trans = get_transmissibility(edge, direction)  # TRANX, TRANY, or TRANZ
    else:  # NNC
        trans = nnc_transmissibilities[edge]  # TRANNNC
    edge_features.append(trans)
```

**Result**: Graph naturally handles faults as edges!

##### **3. Multi-Phase Complexity**

```python
# 2D Example - 2 phases
dynamic_features = ["PRESSURE", "SWAT"]  # Water and oil (implicit)
target_vars = ["PRESSURE", "SWAT"]

# Norne - 3 phases
dynamic_features = ["PRESSURE", "SOIL", "SWAT", "SGAS"]
target_vars = ["PRESSURE", "SOIL", "SWAT", "SGAS"]
```

**Feature Dimension Growth:**
```python
# 2D Example
static_features: 5 (PERMX, PORV, X, Y, Z)
dynamic_features: 2 variables × 3 timesteps = 6
total_node_features: 5 + 6 = 11

# Norne
static_features: 7 (PERMX, PERMY, PERMZ, PORO, X, Y, Z)
dynamic_features: 5 variables × 3 timesteps = 15
total_node_features: 7 + 15 = 22
```

##### **4. Partitioning Strategy**

For Norne, partitioning is **essential**, not optional:

```python
# Recommended config for Norne:
preprocessing:
  num_partitions: 4-8  # More partitions for large model
  halo_size: 3         # Larger halo for better accuracy
```

**Why More Partitions?**
```python
# Memory estimate (rough):
45,000 nodes × 22 features × 4 bytes = ~4 MB (node features)
220,000 edges × 4 features × 4 bytes = ~3.5 MB (edge features)
45,000 × 4 targets × 4 bytes = ~0.7 MB (targets)
Plus: edge indices, metadata, model activations → ~10-15 MB total

# With 4 partitions:
Per-partition: ~2.5-4 MB (more manageable)
# With halo regions (size=3):
Per-partition: ~3-6 MB (includes overlap)
```

**METIS Partitioning for Norne:**
```python
# METIS tries to:
# 1. Balance partition sizes (each ≈ 45,000/4 = 11,250 nodes)
# 2. Minimize edge cuts between partitions
# 3. Handle irregular connectivity (faults/NNCs)

# Example partition assignment:
partition_map = {
    0: nodes_in_west_region,    # ~11,000 nodes
    1: nodes_in_east_region,    # ~11,500 nodes
    2: nodes_in_north_region,   # ~11,200 nodes
    3: nodes_in_south_region,   # ~11,300 nodes
}

# Saved to: CASE_NAME_partitions.json
{
    "case_name": "NORNE_ATW2013_DOE_0001",
    "num_partitions": 4,
    "num_nodes": 45123,
    "partition_assignment": [1, 1, 1, 2, 2, ...]  # 1-indexed partition ID per node
}
```

##### **5. Processing Time**

```
2D Example: 1000 cases × 0.1s/case = ~2 minutes
Norne: 500 cases × 10s/case = ~80 minutes

Why slower?
- 45x more nodes to process
- Reading larger binary files (200MB vs 5MB per case)
- More complex graph operations
- Larger data to write to disk
```

#### The Preprocessing Pipeline for Norne

The **same 5 steps** as 2D, but with modifications:

##### **Step 1: Create Raw Graphs**

```python
# ReservoirGraphBuilder for Norne:
for case in cases:
    # 1. Read .EGRID (geometry)
    grid = read_eclipse_grid(f"{case}.EGRID")
    # → 46×112×22 = 113,344 total cells
    # → ~45,000 active cells
    # → Corner-point geometry (irregular hexahedra)
    
    # 2. Read .INIT (static properties)
    init_data = read_eclipse_init(f"{case}.INIT")
    # → PERMX, PERMY, PERMZ, PORO
    # → TRANX, TRANY, TRANZ
    # → NNC transmissibilities (TRANNNC)
    
    # 3. Read .UNRST (dynamic properties)
    restart_data = read_eclipse_restart(f"{case}.UNRST")
    # → 64 timesteps
    # → PRESSURE, SOIL, SWAT, SGAS per timestep
    # → ~2.8 million values per case!
    
    # 4. Build graph
    graph = build_graph(
        nodes=active_cells,           # 45,000
        regular_edges=6_face_conn,    # ~215,000
        nnc_edges=fault_connections,  # ~5,000
        node_features=static_props,   # PERM*, PORO, coords
        edge_features=trans_values,   # TRAN*, TRANNNC
        dynamic_features=timestep_data # PRESSURE, S*
    )
    
    # 5. Create training sequences
    for t in range(2, 63):  # 64 timesteps → 62 training samples
        input_graph = create_input(
            static=static_props,
            dynamic_t0=restart_data[t-2],
            dynamic_t1=restart_data[t-1],
            dynamic_t2=restart_data[t]
        )
        target = restart_data[t+1]
        
        save_graph(f"{case}_{t:03d}.pt", input_graph, target)
```

Output: `500 cases × 62 timesteps = 31,000 graph files` (~15 GB total)

##### **Step 2: Create Partitions**

```python
# For each of 31,000 graphs:
graph = torch.load(f"NORNE_case_{timestep}.pt")

# Use METIS to partition into 4 subgraphs
cluster_data = pyg.loader.ClusterData(graph, num_parts=4)

# For each partition:
for part_idx in range(4):
    # Get inner nodes (actual partition)
    inner_nodes = cluster_data.partition.node_perm[
        cluster_data.partition.partptr[part_idx]:
        cluster_data.partition.partptr[part_idx + 1]
    ]
    # → ~11,250 nodes per partition
    
    # Add halo region (3-hop neighbors)
    part_nodes, part_edges, inner_mapping, edge_mask = (
        pyg.utils.k_hop_subgraph(
            inner_nodes,
            num_hops=3,  # halo_size
            edge_index=graph.edge_index,
            num_nodes=graph.num_nodes
        )
    )
    # → ~13,000-14,000 nodes with halo (includes overlap)
    
    # Extract partition data
    partition = Data(
        x=graph.x[part_nodes],              # Node features
        edge_index=part_edges,               # Connectivity
        edge_attr=graph.edge_attr[edge_mask], # Edge features
        y=graph.y[part_nodes],               # Targets
        inner_node=inner_mapping,            # Which nodes are "real"
        part_node=part_nodes                 # Original node indices
    )
    
    partitions.append(partition)

# Save all 4 partitions together
torch.save(partitions, f"partitions_{case}_{timestep}.pt")
```

Output: 31,000 partition files (~200 GB total with redundancy from halos)

##### **Step 3: Split Samples by Case**

```python
# Norne specific:
# 500 cases → 
#   400 training (80%)
#   50 validation (10%)
#   50 testing (10%)

# Each case has 62 timesteps
# → Training: 400 × 62 = 24,800 graphs
# → Validation: 50 × 62 = 3,100 graphs
# → Testing: 50 × 62 = 3,100 graphs
```

##### **Step 4: Compute Global Statistics**

```python
# Critical for Norne due to heterogeneity:

# PERMX statistics:
mean_permx = 145.3 mD (millidarcies)
std_permx = 892.4 mD
min_permx = 0.01 mD
max_permx = 8500 mD

# → Spans 6 orders of magnitude!
# → Log-transform essential: PERMX:LOG10

# After log-transform:
log_permx_mean = 1.84 (log10 mD)
log_permx_std = 0.95
# → Much more normalized distribution
```

**Why Log-Transform Matters:**

```python
# Without log:
normalized_perm = (perm - mean) / std
# Problem: Most values near zero, outliers dominate
# Network struggles to learn

# With log:
log_perm = np.log10(perm + epsilon)  # epsilon=1e-10 to avoid log(0)
normalized_log_perm = (log_perm - log_mean) / log_std
# Benefit: More uniform distribution, easier to learn
```

#### Norne-Specific Preprocessing Challenges

##### **Challenge 1: Inactive Cells**

```python
# Total grid cells: 46 × 112 × 22 = 113,344
# Active cells (in reservoir): ~45,000
# Inactive cells (outside reservoir): ~68,000

# How handled:
# 1. EGRID file contains ACTNUM array
ACTNUM = [1, 1, 0, 0, 1, ...]  # 1=active, 0=inactive

# 2. Only process active cells
active_indices = np.where(ACTNUM == 1)[0]
nodes = cells[active_indices]

# 3. Renumber for graph
# Global index → Graph index mapping
global_to_graph = {global_idx: graph_idx 
                   for graph_idx, global_idx in enumerate(active_indices)}
```

##### **Challenge 2: Memory Management**

```python
# Strategies for handling large graphs:

# 1. Process in chunks
for case_batch in chunks(cases, chunk_size=10):
    process_batch(case_batch)
    torch.cuda.empty_cache()  # Free GPU memory

# 2. Use memory-efficient dtypes
node_features = torch.tensor(features, dtype=torch.float32)  # Not float64
edge_indices = torch.tensor(edges, dtype=torch.long)          # Not int64

# 3. Stream to disk immediately
graph = create_graph(...)
torch.save(graph, filename)
del graph  # Free memory immediately
```

##### **Challenge 3: NNC Edge Ordering**

```python
# NNCs don't follow regular ordering
# Regular edges: predictable (i, j, k) → (i+1, j, k)
# NNCs: arbitrary (1234) → (5678)

# Solution: Store edge type
edge_type = []
for edge in edges:
    if edge in regular_edges:
        if edge[1] - edge[0] == 1:
            edge_type.append(0)  # X-direction
        elif edge[1] - edge[0] == nx:
            edge_type.append(1)  # Y-direction
        elif edge[1] - edge[0] == nx * ny:
            edge_type.append(2)  # Z-direction
    else:
        edge_type.append(3)  # NNC (fault)

# Helps model learn directional vs fault connections differently
```

#### Output Directory Structure for Norne

```
outputs/XMGN_Norne_Field/
├── preprocessed_data/
│   ├── graphs/                          # Raw graphs
│   │   ├── NORNE_ATW2013_DOE_0001_000.pt
│   │   ├── NORNE_ATW2013_DOE_0001_001.pt
│   │   └── ... (31,000 files, ~15 GB)
│   │
│   ├── partitions/
│   │   ├── train/                       # 24,800 files
│   │   │   ├── partitions_NORNE_..._000.pt
│   │   │   └── ...
│   │   ├── val/                         # 3,100 files
│   │   └── test/                        # 3,100 files
│   │
│   ├── global_stats.json                # Normalization stats
│   ├── dataset_metadata.json            # Metadata
│   │
│   └── NORNE_ATW2013_DOE_*_partitions.json  # Partition assignments (500 files)
│       # For visualization in ResInsight
```

#### Preprocessing Time Estimates

```
Step 1: Create raw graphs
- 500 cases × 10s/case = ~83 minutes
- Bottleneck: Reading large binary files

Step 2: Create partitions  
- 31,000 graphs × 3s/graph = ~26 hours
- Bottleneck: METIS partitioning

Step 3: Split and organize
- ~5 minutes (mostly file operations)

Step 4: Compute statistics
- ~30 minutes (iterate through all training graphs)

Total: ~28-30 hours for full preprocessing
```

**Optimization Tips:**
```python
# Use multiple workers:
num_preprocess_workers: 8  # Process 8 cases in parallel

# Estimated speedup: ~4-5x
# New total time: ~6-8 hours
```

#### Verification After Preprocessing

```python
# Check key outputs:
print("Graphs created:", len(os.listdir("graphs/")))
# Expected: 31,000

print("Partitions created:", len(os.listdir("partitions/train/")))
# Expected: 24,800

# Check statistics
with open("global_stats.json") as f:
    stats = json.load(f)

print("Node features:", len(stats["node_features"]["mean"]))
# Expected: 22 (7 static + 5 dynamic × 3 timesteps)

print("Edge features:", len(stats["edge_features"]["mean"]))
# Expected: 4 (TRANX, TRANY, TRANZ, TRANNNC)

print("Target features:", len(stats["target_features"]["mean"]))
# Expected: 4 (PRESSURE, SOIL, SWAT, SGAS)

# Check for NaN in statistics
assert not any(np.isnan(stats["node_features"]["mean"]))
assert not any(np.isnan(stats["edge_features"]["mean"]))
```

This preprocessing pipeline transforms 500 complex 3D reservoir simulations with faults into ~31,000 graph structures ready for training!

## Understanding train.py for Norne 

Training for Norne uses the **same training script** as the 2D example, but the scale and complexity require different strategies and configurations.

#### Key Differences for Norne Training

##### **1. Model Capacity Requirements**

```python
# 2D Example Configuration
model:
  num_message_passing_layers: 3
  hidden_dim: 64
  
# Typical model size: ~500K parameters

# Norne Configuration  
model:
  num_message_passing_layers: 4  # Deeper for complex patterns
  hidden_dim: 128                 # Wider for more features
  
# Typical model size: ~2-3M parameters

# Why more capacity needed?
# - More input features (22 vs 11)
# - More output features (4 vs 2)
# - Longer-range interactions (3D geometry)
# - More complex physics (3-phase, faults)
```

##### **2. Memory and Batch Size**

```python
# Memory requirements per sample:

# 2D Example (1 partition):
graph_size = 1,000 nodes × 11 features = ~44 KB
model_activations = ~5 MB
total_per_sample = ~5-10 MB

# Norne (4 partitions with halos):
partition_size = 13,000 nodes × 22 features = ~1.1 MB per partition
4 partitions = ~4.5 MB
model_activations = ~100-200 MB per partition
total_per_sample = ~400-800 MB!

# Consequence:
training:
  batch_size: 1  # Must use batch_size=1 for Norne!
```

**Why batch_size=1 Works:**
```python
# Even with batch_size=1, we have:
# - 4 partitions per sample
# - Multiple GPUs processing in parallel
# → Effective parallelism maintained

# Example with 4 GPUs:
# GPU 0: processes partition 0 of sample
# GPU 1: processes partition 1 of sample
# GPU 2: processes partition 2 of sample
# GPU 3: processes partition 3 of sample
# → All GPUs busy, efficient utilization
```

##### **3. Training Time Estimates**

```
Component                2D Example      Norne Field
─────────────────────────────────────────────────────
Forward pass (1 sample)   0.1s            2-5s
Backward pass (1 sample)  0.2s            5-10s
Total per sample          0.3s            7-15s
─────────────────────────────────────────────────────
Samples per epoch         800             24,800
Time per epoch (1 GPU)    4 min           ~70 hours (!)
Time per epoch (4 GPUs)   2 min           ~20 hours
Time per epoch (8 GPUs)   1 min           ~10 hours
─────────────────────────────────────────────────────
Target epochs             1000            500
Total training (1 GPU)    ~67 hours       ~14,600 hours (!!)
Total training (4 GPUs)   ~33 hours       ~4,200 hours (!!)
Total training (8 GPUs)   ~17 hours       ~2,100 hours (~87 days)
```

**Actual Training Time (with early stopping):**
```
# Typically converges in 100-200 epochs for Norne
# With 8 GPUs:
100 epochs × 10 hours/epoch = ~1,000 hours = ~42 days
# With early stopping patience=30:
# Realistic: 150 epochs = ~63 hours = 2.6 days
```

##### **4. Multi-GPU Training Strategy**

**Distributed Data Parallel (DDP) for Norne:**

```python
# Launch training with 4 GPUs:
# torchrun --nproc_per_node=4 src/train.py --config-name=config_norne

# What happens:
# 1. Model replicated on each GPU
Rank 0 (GPU 0): model_copy_0
Rank 1 (GPU 1): model_copy_1
Rank 2 (GPU 2): model_copy_2
Rank 3 (GPU 3): model_copy_3

# 2. Data distributed across GPUs
DistributedSampler splits 24,800 samples:
Rank 0: samples [0, 4, 8, 12, ...]      # 6,200 samples
Rank 1: samples [1, 5, 9, 13, ...]      # 6,200 samples
Rank 2: samples [2, 6, 10, 14, ...]     # 6,200 samples
Rank 3: samples [3, 7, 11, 15, ...]     # 6,200 samples

# 3. Forward pass (parallel)
Each GPU: processes its assigned samples independently

# 4. Backward pass (synchronized)
Each GPU: computes gradients for its samples
All GPUs: average gradients via AllReduce
Result: synchronized gradients on all GPUs

# 5. Optimizer step (synchronized)
Each GPU: updates model with averaged gradients
Result: all model copies stay in sync
```

**Communication Pattern:**

```python
# Every training step:
# 1. Forward pass: No communication (independent)
# 2. Loss computation: No communication
# 3. Backward pass: AllReduce of gradients
#    - Size: ~2-3M parameters × 4 bytes = ~8-12 MB
#    - Time: ~10-50ms depending on interconnect
# 4. Optimizer step: No communication (uses synced gradients)

# Key insight: Communication overhead is small compared to computation!
# Computation: 5-10 seconds
# Communication: 0.01-0.05 seconds
# Efficiency: >99%
```

#### Norne-Specific Training Configurations

##### **Learning Rate Schedule**

```python
# For Norne, use more conservative learning rates:

training:
  num_epochs: 500
  start_lr: 5e-5   # Lower than 2D (1e-4)
  end_lr: 1e-6     # Same as 2D
  weight_decay: 1e-3

# Why lower start_lr?
# - More complex model (easier to destabilize)
# - More features (higher dimensional space)
# - More parameters (larger gradient magnitudes)

# Cosine annealing schedule:
epoch  0: lr = 5e-5
epoch 50: lr = 4.3e-5
epoch 100: lr = 3.2e-5
epoch 250: lr = 1.3e-5
epoch 400: lr = 1.1e-6
epoch 500: lr = 1e-6
```

##### **Loss Function Configuration**

```python
# Norne has 4 target variables, each with its own loss:

dataset:
  graph:
    target_vars:
      node_features: ["PRESSURE", "SOIL", "SWAT", "SGAS"]
      weights: [1.0, 1.0, 1.0, 1.0]  # Equal weighting
      loss_functions: ["L2", "L1", "L1", "L1"]

# Why this choice?
# - PRESSURE: L2 (MSE) - continuous, wide range
# - SOIL: L1 (MAE) - bounded [0,1], sparse in some regions
# - SWAT: L1 (MAE) - bounded [0,1]
# - SGAS: L1 (MAE) - bounded [0,1], often very small values

# Total loss computation:
total_loss = (
    1.0 × MSE(pred_pressure, true_pressure) +
    1.0 × MAE(pred_soil, true_soil) +
    1.0 × MAE(pred_swat, true_swat) +
    1.0 × MAE(pred_sgas, true_sgas)
)
```

**Loss Weighting Strategy:**

You might want to adjust weights based on variable importance:

```python
# Option 1: Equal weights (default)
weights: [1.0, 1.0, 1.0, 1.0]

# Option 2: Emphasize pressure (most important for production)
weights: [2.0, 1.0, 1.0, 1.0]

# Option 3: Normalize by typical ranges
# PRESSURE: 100-400 bar → range ~ 300
# SOIL: 0-1 → range ~ 1
# SWAT: 0-1 → range ~ 1
# SGAS: 0-0.5 → range ~ 0.5
weights: [0.01, 1.0, 1.0, 2.0]  # Inverse of ranges
```

##### **Early Stopping Configuration**

```python
# Norne requires more patience:

training:
  early_stopping:
    patience: 30  # vs 20 for 2D
    min_delta: 1e-6

# Why more patience?
# - Larger model takes longer to converge
# - More data means more epochs to see all samples
# - Validation loss may plateau then improve

# Example training curve:
Epoch 10: val_loss = 0.0156
Epoch 20: val_loss = 0.0121
Epoch 30: val_loss = 0.0098
Epoch 40: val_loss = 0.0092  # Plateau starts
Epoch 50: val_loss = 0.0089
Epoch 60: val_loss = 0.0087
Epoch 70: val_loss = 0.0086  # Still slowly improving
Epoch 80: val_loss = 0.0081  # Breaks through!
# Without patience=30, would have stopped too early
```

##### **Validation Frequency**

```python
# Balance between monitoring and training time:

training:
  validation_freq: 10  # Validate every 10 epochs

# Time impact:
# - Training epoch: 10 hours
# - Validation: 2 hours (50 validation samples)
# - Total per validation: 12 hours

# Alternatives:
validation_freq: 5   # More frequent, slower training
validation_freq: 20  # Less frequent, faster but less monitoring
```

#### Memory Optimization Techniques for Norne

##### **1. Gradient Checkpointing**

```python
performance:
  checkpoint_segments: 4  # More segments = less memory

# How it works:
# Normal: Store all activations during forward pass
# Memory: num_layers × hidden_dim × num_nodes = large!

# With checkpointing:
# - Divide model into 4 segments
# - Only store activations at segment boundaries
# - Recompute intermediate activations during backward

# Trade-off:
# Memory: Reduced by ~50-70%
# Speed: Slowed by ~10-20% (recomputation overhead)
# → Worth it for large models!
```

##### **2. Mixed Precision Training**

```python
# Automatically enabled for CUDA:
self.scaler = GradScaler()

# Forward pass in FP16/BF16:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    pred = self.model(x, edge_attr, graph)

# Backward pass with scaling:
self.scaler.scale(loss).backward()
self.scaler.step(optimizer)
self.scaler.update()

# Benefits for Norne:
# - Memory: 2x reduction (FP16 vs FP32)
# - Speed: 1.5-2x faster (tensor cores)
# - Accuracy: Minimal impact with proper scaling

# Note: BF16 often better than FP16 for reservoir sim
# (wider dynamic range handles pressure variations better)
```

##### **3. Concatenation Trick**

```python
performance:
  use_concat_trick: true

# Standard MeshGraphNet:
# node_update = MLP(node_feat)
# edge_update = MLP(edge_feat)
# Requires: 2 separate MLP passes

# With concat trick:
# combined = concat([node_feat, edge_feat])
# updates = MLP(combined)  # Single pass!
# node_update, edge_update = split(updates)

# Benefits:
# - Fewer MLP calls
# - Better GPU utilization
# - ~10-15% speedup
```

#### Monitoring Training for Norne

##### **Key Metrics to Watch**

```python
# 1. Training Loss
# Should decrease smoothly
epoch 10: train_loss = 0.0234
epoch 20: train_loss = 0.0178
epoch 30: train_loss = 0.0145
# Good: Steady decrease
# Bad: Oscillating or NaN

# 2. Validation Loss  
# Should track training loss
epoch 10: val_loss = 0.0241 (train: 0.0234) → gap: 0.0007
epoch 20: val_loss = 0.0185 (train: 0.0178) → gap: 0.0007
# Good: Small gap (< 10%)
# Bad: Growing gap (overfitting)

# 3. Per-Variable Metrics (Denormalized)
# PRESSURE RMSE: 5-15 bar (acceptable for 100-400 bar range)
# SOIL MAE: 0.05-0.10 (5-10% error)
# SWAT MAE: 0.05-0.10
# SGAS MAE: 0.02-0.05

# 4. Learning Rate
# Should follow cosine schedule
epoch 0: lr = 5e-5
epoch 250: lr = 1.3e-5
epoch 500: lr = 1e-6
```

##### **MLflow Metrics Logged**

```python
# Training metrics (every epoch):
- train_loss
- learning_rate
- best_val_loss

# Validation metrics (every validation_freq epochs):
- val_loss
- val_denorm_loss

# Per-variable metrics (normalized):
- val_mae_pressure, val_rmse_pressure
- val_mae_soil, val_rmse_soil
- val_mae_swat, val_rmse_swat
- val_mae_sgas, val_rmse_sgas

# Per-variable metrics (denormalized, physical units):
- val_mae_pressure_denorm, val_rmse_pressure_denorm
- val_mae_soil_denorm, val_rmse_soil_denorm
- val_mae_swat_denorm, val_rmse_swat_denorm
- val_mae_sgas_denorm, val_rmse_sgas_denorm
```

#### Common Training Issues for Norne

##### **Issue 1: CUDA Out of Memory**

```python
# Error: RuntimeError: CUDA out of memory. Tried to allocate X GB

# Solutions (in order of preference):
# 1. Increase num_partitions
num_partitions: 8  # vs 4 (halves memory per partition)

# 2. Increase checkpoint_segments
checkpoint_segments: 6  # vs 4

# 3. Reduce hidden_dim
hidden_dim: 96  # vs 128 (reduces model size)

# 4. Reduce batch_size (if > 1)
batch_size: 1  # Already at minimum for Norne

# 5. Use fewer GPUs (counterintuitive but works)
# Sometimes 4 GPUs with more memory > 8 GPUs with less memory
```

##### **Issue 2: Slow Training**

```python
# If training is slower than expected:

# Check 1: GPU utilization
nvidia-smi
# Should show ~95%+ GPU utilization
# If low: bottleneck is data loading or CPU

# Solution for data loading:
num_preprocess_workers: 8  # More workers
pin_memory: true           # Faster CPU→GPU transfer

# Check 2: Multi-GPU efficiency
# With 4 GPUs, expect ~3.5x speedup (not 4x due to overhead)
# If < 3x: Communication bottleneck

# Solution:
# Use faster interconnect (NVLink > PCIe)
# Or reduce num_partitions (less communication)

# Check 3: Gradient checkpointing overhead
checkpoint_segments: 2  # Reduce if memory allows
# Fewer segments = less recomputation = faster
```

##### **Issue 3: Validation Loss Plateau**

```python
# Validation loss stops improving:

# Cause 1: Learning rate too low
# Solution: Increase start_lr
start_lr: 1e-4  # vs 5e-5

# Cause 2: Model capacity insufficient
# Solution: Increase model size
hidden_dim: 192  # vs 128
num_message_passing_layers: 5  # vs 4

# Cause 3: Overfitting
# Solution: Increase regularization
weight_decay: 5e-3  # vs 1e-3

# Cause 4: Not enough training data
# Solution: Add more simulation cases
```

##### **Issue 4: NaN Loss**

```python
# Loss becomes NaN during training

# Cause 1: Gradient explosion
# Check: Large gradients
for param in model.parameters():
    print(param.grad.max())  # If > 100: problem!

# Solution: Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Cause 2: Invalid normalization statistics
# Check: global_stats.json for NaN or inf values

# Solution: Rerun preprocessing with data validation

# Cause 3: Learning rate too high
# Solution: Reduce start_lr
start_lr: 1e-5  # Very conservative

# Cause 4: Numerical instability in model
# Solution: Use BF16 instead of FP16 (wider dynamic range)
```

#### Checkpoint Management for Norne

```python
# With long training times, checkpoint management is critical:

# Regular checkpoints (every 10 epochs):
checkpoints/
├── checkpoint.epoch_010.pt  # ~50 MB
├── checkpoint.epoch_020.pt
├── checkpoint.epoch_030.pt
└── ...

# Best checkpoint (validation improvements):
best_checkpoints/
└── checkpoint.epoch_087.pt  # Best val_loss

# Checkpoint cleanup strategy:
# - Keep last 3 regular checkpoints (in case of corruption)
# - Keep all best checkpoints (for analysis)
# - Delete old regular checkpoints

# Resume training:
training:
  resume: true  # Automatically loads latest checkpoint

# Load specific checkpoint for inference:
inference:
  checkpoint_path: "outputs/.../best_checkpoints/checkpoint.epoch_087.pt"
```

#### Training Monitoring Dashboard

```bash
# Launch MLflow UI (in separate terminal):
cd outputs/XMGN_Norne_Field
mlflow ui --host 0.0.0.0 --port 5000

# Open browser: http://localhost:5000

# What to monitor:
# 1. Training loss curve (should be smooth decrease)
# 2. Validation loss vs training loss (check gap)
# 3. Per-variable RMSE (check each variable converging)
# 4. Learning rate schedule (verify cosine annealing)
# 5. Training time per epoch (should be consistent)
```

This training infrastructure handles the complexity and scale of Norne while providing comprehensive monitoring and error handling!