# Notebook 2 - Training DoMINO Model on the Ahmed body surface dataset

In this notebook, we will first provide a detailed explanation of the DoMINO architecture, which is a multi-scale, iterative neural operator designed for modeling large-scale engineering simulations. We will break down the key components of DoMINO, including its use of local geometry representations, multi-scale point convolution kernels, and its efficient handling of complex geometries. Afterward, we will train the model using the **Ahmed body surface dataset**, a widely used dataset in automotive aerodynamics simulations. *As indicated in the previous notebook this dataset was created by the NVIDIA Physics NeMo development team and differs from other similar datasets hosted on cloud platforms like AWS.* 

The DoMINO model is capable of training both volume fields (such as velocity and pressure) and surface fields (including pressure and wall shear stress). However, for the sake of simplicity and educational purposes, this notebook will *focus solely on training the surface fields* using the Ahmed body surface dataset.
*

## Table of Contents
- [DoMINO Architecture](#domino-architecture)
  - [Global Geometry Representation](#global-geometry-representation)
  - [Local Geometry Representation](#local-geometry-representation)
  - [Basis Function Neural Network](#basis-function-neural-network)
  - [Concatenating the Latent Vector with the Local Geometry Encoding](#concatenating-the-latent-vector-with-the-local-geometry-encoding)
  - [Passing Through Additional Neural Network Layers](#passing-through-additional-neural-network-layers)
  - [Aggregation Network](#aggregation-network)
- [Training Process](#training)
  - [Step 1: Define Experiment Parameters and Dependencies](#step-1-define-experiment-parameters-and-dependencies)
    - [Loading Required Libraries](#loading-required-libraries)
    - [Dependencies](#dependencies)
    - [Experiment Parameters and Variables](#experiment-parameters-and-variables)
  - [Step 2: Train the DoMINO Model](#step-2-train-the-domino-model)
    - [Understanding the Training Process](#understanding-the-training-process)
    - [Key Components and Libraries](#key-components-and-libraries)
    - [Important Training Parameters](#important-training-parameters)
    - [Implementation Overview](#implementation-overview)
- [Load Model Checkpoint & Run Inference](#Load-Model-Checkpoint-&-Run-Inference)
- [Visualizing the predicted results](#Visualizing-the-predicted-results)


## DoMINO Architecture
Machine learning (ML) models have been proposed as surrogate models to speed up simulations, but they face limitations, particularly in terms of accuracy, scalability, and generalization to new geometries.
DoMINO, a new ML model designed to address these challenges. DoMINO is a multi-scale, iterative neural operator that uses local geometric information to predict flow fields in large-scale simulations. It is specifically validated for the automotive aerodynamics use case, showcasing its scalability, accuracy, and ability to generalize across different simulation scenarios. Let's walk through the DoMINO architecture step by step, starting from the *Global Geometry Representation, through to the Local Geometry Representation, and then to the Aggregation Network.*

The DoMINO model evaluates solution fields within a computational domain by leveraging geometry representations in STL file format. It encodes global geometry information on a fixed-size grid, defined in the computational domain, through a combination of learnable point convolution kernels, CNNs, and dense networks. Local geometric encoding is extracted, using point convolution kernels, from the global encoding by dynamically constructing local subdomains around sampled points where the solution fields are evaluated. This approach enables the prediction of volume and surface solutions by combining local geometry encoding with basis functions computed for sampled points and their neighboring points.

### Global Geometry Representation:
The Global Geometry Representation refers to the overall shape and structure of the entire object or domain that you are modeling. This representation captures all the geometric details across the entire computational domain.\
Step-by-Step Explanation of Global Geometry Representation:

- **Step 1**: Construct Bounding Boxes
	- A tight-fitting surface bounding box is created around the STL (3D geometry) to hold the geometry.
    - A computational domain bounding box is also defined, which is larger than the surface bounding box to encompass the whole computational domain.
    - Both bounding boxex can be specified in  ```conf.yaml```
- **Step 2**: Project STL Vertices onto Structured Grid
	- The geometric features of the point cloud, such as spatial coordinates, are projected onto an N-dimensional structured grid of resolution m×m×m×f, which is overlaid on the surface bounding box using **learnable point convolution kernels**.
	- The learnable point convolution kernels are created using **differentiable ball query layers**. This means that the method:
    	- Uses a "ball" (a sphere in 3D space) around each point to query or find its neighbors.
    	- The ball query layer is "differentiable," meaning it can be included in the neural network and updated via backpropagation (i.e., during training, the network can learn how to adjust the kernels to improve performance).
    	- The radius of the ball (radius of influence) defines how far around each point we look for neighboring points to include in the convolution. This defines, in fact, how far the geometry can affect the grid. A range of point convolutional kernel sizes can be learned by specifying several radii. Moreover, different kernels are learned to represent information on the surface bounding box and computational domain bounding box. This enables multi-scale learning of geometry encoding by representing both short- and long-range interactions of the surface and flow fields.  The radii of influence are defined as **list** in the ```conf.yaml``` file:
          ```yaml
            volume_radii: [0.1, 0.5]
            surface_radii: [0.05]
          ```
\
          These radius are used in the DoMINO model (```physicsnemo/models/domino/model.py```) to compute two **BQWarp** accordingly:
          

     
```python
        class GeometryRep(nn.Module):
            """Geometry representation from STLs block"""

            def __init__(self, input_features, model_parameters=None):
                super().__init__()
                geometry_rep = model_parameters.geometry_rep

                self.bq_warp = nn.ModuleList()
                self.geo_processors = nn.ModuleList()
                for j, p in enumerate(radii):
                    self.bq_warp.append(
                        BQWarp(
                            input_features=input_features,
                            grid_resolution=model_parameters.interp_res,
                            radius=radii[j],
                        )
                    )
                    self.geo_processors.append(
                        GeoProcessor(
                            input_filters=geometry_rep.geo_conv.base_neurons_out,
                            model_parameters=geometry_rep.geo_processor,
                        )
                    )
```
       
        
- **Step 3**: Use Multi-Resolution Approach for Detailed and Coarse Features
	- The grid resolution in the bounding box determines the level of detail of the geometry: 
    	- Finer resolution captures more detailed features of the geometry.
    	- Coarser resolution captures larger, broader features.
	- A multi-resolution approach is adopted, meaning multiple grids at different resolutions (levels) are maintained to capture both fine and coarse features of the geometry. The number of resolution levels is a parameter that can be adjusted in conf.yaml file as
      ```yaml
      GRID_RESOLUTION = [128, 64, 48]  # Resolution of the interpolation grid 
      ```
  - Currently, the DoMINO model allows specification of a single resolution but this configuration will be provided in a future release.

- **Step 4**: Propagate Geometry Features into the Computational Domain
  - The computational domain is much larger than the surface bounding box, so the geometry information needs to be extended.
	- Geometry features are propagated into the computational domain using two methods: 
    	- As explained in **step 2** Multi-scale **point convolution kernels** project the **geometry information** onto the computational domain grid.
    	- **Features** from the surface grid of the bounding box (i.e., Gs) are propagated into the computational domain grid (i.e., Gc) using **CNN blocks** that contain convolution, pooling, and unpooling layers.
	- The CNN blocks are iterated for a specified number of steps to refine the geometry representation. Currently, the DoMINO model is configured to run a single iteration. An option to change this will be provided in ```conf.yaml``` in a future release.

- **Step 5**: Calculate Signed Distance Function (SDF) and its Gradients
	- Additionally, the Signed Distance Function (SDF) and its gradient components are calculated on the computational domain grid.
	- These SDF and gradient values are added to the learned features, providing additional information about the topology of the geometry (i.e., the geometry's shape, distances to surfaces, etc.).

- **Step 6**: Final Global Geometry Representation
	- The final geometry representation of the STL is formed by combining the learned features from the structured grids at different resolutions in both the bounding box and the computational domain.

Once the computational domain is created for each resolution, the next step would be local geometry representation. 

### Local geometry representation
The Local Geometry Representation focuses on the geometry in the immediate vicinity of a sampled point p (the points in simulation mesh). The idea is to understand how the geometry behaves around a specific point and its neighbors, which can be important for accurate predictions. While the Global Geometry Representation gives the big picture, the Local Geometry Representation zooms in on a small region of interest around each sampled point. The key difference is that local geometry represents a smaller, more detailed portion of the global geometry, typically focusing on the small-scale features close to a point. For each sampled point p, neighboring points are sampled randomly around them to form a computational stencil of points similar to finite volume and element methods. The local geometry representation is learned by drawing a subregion around the computational stencil of
p + 1 points. The size of the subregion are defined as **list** in the ```conf.yaml``` file:

  ```yaml
    geometry_local.volume_radii: [0.05, 0.1]
    geometry_local.surface_radii: [0.05]
  ```

**How Does the Multi-Resolution Global Geometry Affect the Local Geometry Representation?**
 - Coarse resolution: At the coarse resolution, you get a broad view of the object. This can give information about the general shape and large-scale features of the geometry (e.g., the overall shape of the object, major boundaries, etc.). When local geometry is extracted from the coarse resolution, the features are relatively less detailed, and it might capture larger, more general features of the object.
 - Fine resolution: At the fine resolution, you get a detailed view of the geometry, capturing small features such as intricate surface details, small holes, or sharp edges. The local geometry representation derived from the fine resolution will be more detailed and capture smaller variations in the geometry near each sampled point.

**Thus, the global multi-resolution geometry allows the local geometry to be learned at different levels of detail, depending on the resolution of the grid that is used to represent the geometry.** 

### Basis Function Neural Network (Latent Vector):
- Once the local geometry representation is built, it is passed through a Basis Function Neural Network.\
  What happens here:    
- The input features (coordinates, SDF, normal vectors, etc. and their fourier features) for each point in the stencil are fed into the Basis Function Neural Network. This is a fully connected neural network that processes these features.
- The network then computes a latent vector for each point in the stencil. A latent vector is a compressed mathematical representation that encodes the important information about each point’s geometry and position.
- Purpose: The latent vector captures the essential characteristics of each point’s geometry in a compact form, which will be used in later steps for predicting the solution at that point.

### Concatenating the Latent Vector with the Local Geometry Encoding:
- After calculating the latent vector for each point, this vector is concatenated with the local geometry encoding — which includes the previously computed information from the surrounding points and the global geometry.
- Why this is done: Concatenating these two representations allows the network to use both the specific local features of each point and the broader context of the surrounding geometry to make predictions.

### Passing Through Additional Neural Network Layers (Solution Prediction):
- The combined information (latent vector + local geometry encoding) is passed through another set of fully connected layers (a new neural network).
- What happens here: These layers process the combined information and predict a solution vector for each point in the stencil. The solution vector could represent various physical quantities such as temperature, pressure, or other simulation results at the sampled point.
- Purpose: This step produces the predicted solution at each point, based on the local and global geometry.

### Aggregation Network:
Aggregation network is a fully connected neural network with a **DeepONet** like structure, **Local geometry rep is branch net and basis functions are trunk net**, which is used to compute solutions.


- Local Geometry Representation is the Branch Net: The branch network in DeepONet is responsible for processing the local geometry representation (like the shape or physical location of the point and its neighbors in the domain). In this case, the branch net processes the features around the sampled point (i) and its neighbors (j).
- Basis Functions are the Trunk Net: The trunk network in DeepONet processes additional data (like the global features or functions) to help represent the solution space better. Here, the basis functions represent mathematical components that help capture the underlying solution in the computational domain.

- Aggregation network computes the solution field on the sampled point, i and its neighbors j. The solutions are then averaged using an inverse distance weighted interpolation scheme.\
In simpler terms:

- The aggregation network computes the solution values at the sampled point (i) and its surrounding neighbors (j). The solution is a value that corresponds to a field (e.g., temperature, pressure) at these points.
- After computing the solution at the sampled point and its neighbors, the results are combined (averaged) using an inverse distance weighting scheme. This means that points closer to the sampled point (i) contribute more to the final solution than points farther away. The "inverse distance" part means that the influence of a neighbor's solution decreases the farther it is from the sampled point. 

## Training
### **Step 1: Define Experiment Parameters and Dependencies**

The first step in training the DoMINO model on the Ahmed body dataset is to set up our experiment environment and define the necessary parameters. This includes specifying paths to our data, configuring training settings, and ensuring all required libraries are available.

Key components we need to set up:
- Data paths for training and validation sets
- Model hyperparameters and training configurations
- Visualization settings for results
- Required Python libraries for mesh processing and deep learning

#### Loading Required Libraries

Before we proceed with the experiment setup, let's first import all the necessary libraries. These libraries will be used for:
- Deep learning and numerical computations (torch, numpy)
- Progress tracking and visualization (tqdm, matplotlib)

#### Dependencies
Ensure that the required Python libraries are installed:

```bash
pip install numpy torch matplotlib tqdm mlflow torchinfo
```

Let's start by installing mlflow for experiment tracking:

In [None]:
import time
import os
import re
import torch
import torchinfo


import pyvista as pv
from tqdm import tqdm
from pathlib import Path
from types import SimpleNamespace
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf

from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from physicsnemo.distributed import DistributedManager
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
from physicsnemo.utils.sdf import signed_distance_field

from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe
from physicsnemo.models.domino.model import DoMINO
from physicsnemo.utils.domino.utils import *

### Experiment Parameters and Variables

In this section, we define all the necessary parameters and variables for our Ahmed body experiment. These parameters control various aspects of the training process, data processing, and model configuration.

These parameters are carefully chosen based on:
- The physical dimensions of the Ahmed body
- The computational requirements of the DoMINO model
- The desired resolution for accurate flow prediction
- The available computational resources
- The specific requirements of the aerodynamic analysis

The bounding box parameters are particularly important as they define the computational domain for both volume and surface meshes, ensuring we capture all relevant flow features around the Ahmed body.

In [None]:
# Directory and Path Configuration
EXPERIMENT_TAG = 4  # Unique identifier for this experiment run
PROJECT_NAME = "ahmed_body_dataset"  # Name of the project
OUTPUT_DIR = Path(
    f"./outputs/{PROJECT_NAME}/{EXPERIMENT_TAG}"
)  # Directory for experiment outputs
DATA_DIR = Path("./ahmed_body_dataset/")  # Root directory for dataset
PROCESSED_DIR = (
    DATA_DIR / "prepared_surface_data"
)  # Directory for processed surface data
CHECKPOINT_DIR = OUTPUT_DIR / "models"  # Directory for saving model checkpoints
SAVE_PATH = DATA_DIR / "mesh_predictions_surf_final1"  # path to save prediction results

# Ensure directories exist
os.makedirs(PROCESSED_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Physical Variables
VOLUME_VARS = ["p"]  # Volume variables to predict (pressure)
SURFACE_VARS = ["p", "wallShearStress"]  # Surface variables to predict
MODEL_TYPE = "surface"  # Type of model (surface-only prediction)
AIR_DENSITY = 1.205  # Air density in kg/m³
STREAM_VELOCITY = 60

# Training Hyperparameters
NUM_EPOCHS = 10  # Number of training epochs
LR = 0.001  # Learning rate
BATCH_SIZE = 1  # Batch size for training
GRID_RESOLUTION = [128, 64, 48]  # Resolution of the interpolation grid
NUM_SURFACE_NEIGHBORS = 7  # Number of neighbors for surface operations
NORMALIZATION = "min_max_scaling"  # Data normalization method
INTEGRAL_LOSS_SCALING = 0  # Scaling factor for integral loss
NUM_SURF_VARS = 4  # Number of surface variables to predict, 3 for vectore (wallShearStress) and 1 for scalar (p)
CHECKPOINT_INTERVAL = 1  # Save checkpoint every N epochs

# Dataset Paths
DATA_PATHS = {
    "train": "./ahmed_body_dataset/train_prepared_surface_data",
    "val": "./ahmed_body_dataset/validation_prepared_surface_data",
    "test": "./ahmed_body_dataset/test",
}

# Model and Scaling Factor Paths
MODEL_SAVE_DIR = "./outputs/ahmed_body_dataset/4/models"
SURF_SAVE_PATH = "./outputs/ahmed_body_dataset/surface_scaling_factors.npy"

# Bounding Box Configuration for Volume and Surface Meshes
BOUNDING_BOX = SimpleNamespace(
    max=[0.5, 0.6, 0.6],  # Maximum coordinates for volume mesh
    min=[-2.5, -0.5, -0.5],  # Minimum coordinates for volume mesh
)
BOUNDING_BOX_SURF = SimpleNamespace(
    max=[0.01, 0.6, 0.4],  # Maximum coordinates for surface mesh
    min=[-1.5, -0.01, -0.01],  # Minimum coordinates for surface mesh
)

# Set cuDNN benchmark mode
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

In [None]:
def setup_distributed():
    """
    Initialize distributed training environment.

    Returns:
        tuple: (device, rank, world_size)
            - device: torch.device for computation
            - rank: process rank in distributed setup
            - world_size: total number of processes
    """
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # Set up CUDA device
    torch.cuda.set_device(local_rank)
    print("torch.cuda.set_device(local_rank)", torch.cuda.set_device(local_rank))
    print("rank:::", rank)
    print("world_size, local_rank", world_size, local_rank)
    device = torch.device(f"cuda:{local_rank}")

    # Initialize distributed process group if needed
    if world_size > 1:
        torch.distributed.init_process_group(backend="nccl", init_method="env://")

    return device, rank, world_size


def create_model(device, rank, world_size):
    """
    Create and configure DoMINO model with distributed training support.

    Args:
        device (torch.device): Computation device
        rank (int): Process rank
        world_size (int): Total number of processes

    Returns:
        DoMINO: Configured model (wrapped in DistributedDataParallel if world_size > 1)
    """

    # Initialize model with configuration
    model = DoMINO(
        input_features=3,
        output_features_vol=None,
        output_features_surf=NUM_SURF_VARS,
        model_parameters=SimpleNamespace(
            interp_res=GRID_RESOLUTION,
            surface_neighbors=NUM_SURFACE_NEIGHBORS,
            use_surface_normals=True,
            use_only_normals=True,
            encode_parameters=True,
            positional_encoding=False,
            integral_loss_scaling_factor=INTEGRAL_LOSS_SCALING,
            normalization=NORMALIZATION,
            use_sdf_in_basis_func=True,
            geometry_rep=SimpleNamespace(
                base_filters=16,
                geo_conv=SimpleNamespace(
                    base_neurons=32,
                    base_neurons_out=1,
                    radius_short=0.1,
                    radius_long=0.5,
                    hops=1,
                ),
                geo_processor=SimpleNamespace(base_filters=8),
                geo_processor_sdf=SimpleNamespace(base_filters=8),
            ),
            nn_basis_functions=SimpleNamespace(base_layer=512),
            parameter_model=SimpleNamespace(
                base_layer=512, scaling_params=[60.0, 1.226]
            ),
            position_encoder=SimpleNamespace(base_neurons=512),
            geometry_local=SimpleNamespace(
                neighbors_in_radius=64, radius=0.05, base_layer=512
            ),
            aggregation_model=SimpleNamespace(base_layer=512),
            model_type=MODEL_TYPE,
        ),
    ).to(device)

    # Wrap model for distributed training if needed
    if world_size > 1:
        model = DistributedDataParallel(
            model, device_ids=[rank], output_device=rank, find_unused_parameters=True
        )

    return model

In [None]:
def mse_loss_fn(output, target, padded_value=-10):
    """
    Compute masked MSE loss, ignoring padded values.

    Args:
        output (torch.Tensor): Model predictions
        target (torch.Tensor): Ground truth values
        padded_value (float): Value used for padding (default: -10)

    Returns:
        torch.Tensor: Mean squared error loss
    """
    # Move target to same device as output
    target = target.to(output.device)
    # Create mask for non-padded values
    mask = torch.abs(target - padded_value) > 1e-3
    # Compute masked loss
    masked_loss = torch.sum(((output - target) ** 2) * mask) / torch.sum(mask)
    return masked_loss.mean()


def create_dataset(phase):
    """
    Create DoMINO dataset for specified phase (train/val).

    Args:
        phase (str): Dataset phase ('train' or 'val')

    Returns:
        DoMINODataPipe: Configured dataset
    """
    return DoMINODataPipe(
        DATA_PATHS[phase],
        phase=phase,
        grid_resolution=GRID_RESOLUTION,
        surface_variables=SURFACE_VARS,
        normalize_coordinates=True,
        sampling=True,
        sample_in_bbox=True,
        volume_points_sample=4096,  ## 8192 -- original
        surface_points_sample=2048,  ## 4096 -- original
        geom_points_sample=120000,  ## 200000 -- original
        positional_encoding=False,
        surface_factors=np.load(SURF_SAVE_PATH),
        scaling_type=NORMALIZATION,
        model_type=MODEL_TYPE,
        bounding_box_dims=BOUNDING_BOX,
        bounding_box_dims_surf=BOUNDING_BOX_SURF,
        num_surface_neighbors=NUM_SURFACE_NEIGHBORS,
    )


def compute_scaling_factors():
    if MODEL_TYPE == "surface" or MODEL_TYPE == "combined":
        if not os.path.exists(SURF_SAVE_PATH):
            fm_dict = DoMINODataPipe(
                DATA_PATHS["train"],
                phase="train",
                grid_resolution=GRID_RESOLUTION,
                surface_variables=SURFACE_VARS,
                normalize_coordinates=True,
                sampling=True,
                sample_in_bbox=True,
                volume_points_sample=4096,
                surface_points_sample=2048,
                geom_points_sample=120000,
                positional_encoding=False,
                scaling_type=NORMALIZATION,
                model_type=MODEL_TYPE,
                bounding_box_dims=BOUNDING_BOX,
                bounding_box_dims_surf=BOUNDING_BOX_SURF,
                num_surface_neighbors=NUM_SURFACE_NEIGHBORS,
                compute_scaling_factors=True,
            )

            if NORMALIZATION == "min_max_scaling":
                for j in range(len(fm_dict)):
                    d_dict = fm_dict[j]
                    surf_fields = d_dict["surface_fields"]

                    if surf_fields is not None:
                        surf_mean = np.mean(surf_fields, 0)
                        surf_std = np.std(surf_fields, 0)
                        surf_idx = mean_std_sampling(
                            surf_fields, surf_mean, surf_std, tolerance=12.0
                        )
                        surf_fields_sampled = np.delete(surf_fields, surf_idx, axis=0)
                        if j == 0:
                            surf_fields_max = np.amax(surf_fields_sampled, 0)
                            surf_fields_min = np.amin(surf_fields_sampled, 0)
                        else:
                            surf_fields_max1 = np.amax(surf_fields_sampled, 0)
                            surf_fields_min1 = np.amin(surf_fields_sampled, 0)

                            for k in range(surf_fields.shape[-1]):
                                if surf_fields_max1[k] > surf_fields_max[k]:
                                    surf_fields_max[k] = surf_fields_max1[k]

                                if surf_fields_min1[k] < surf_fields_min[k]:
                                    surf_fields_min[k] = surf_fields_min1[k]
                    else:
                        surf_fields_max = 0.0
                        surf_fields_min = 0.0

                    if j > 20:
                        break

                surf_scaling_factors = [surf_fields_max, surf_fields_min]
            np.save(SURF_SAVE_PATH, surf_scaling_factors)


def create_dataloaders(rank, world_size):
    """
    Create train and validation dataloaders with distributed sampling.

    Args:
        rank (int): Process rank
        world_size (int): Total number of processes

    Returns:
        tuple: (train_loader, val_loader, train_sampler, val_sampler)
    """
    # Create datasets
    train_dataset, val_dataset = create_dataset("train"), create_dataset("val")

    # Configure distributed samplers if needed
    train_sampler = (
        DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
        if world_size > 1
        else None
    )
    val_sampler = (
        DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
        if world_size > 1
        else None
    )

    # Create dataloaders
    return (
        DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            sampler=train_sampler,
            shuffle=train_sampler is None,
        ),
        DataLoader(
            val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, shuffle=False
        ),
        train_sampler,
        val_sampler,
    )


def run_epoch(
    train_loader,
    val_loader,
    model,
    optimizer,
    scaler,
    device,
    epoch,
    best_vloss,
    rank,
    world_size,
):
    """
    Run one training epoch with validation.

    Args:
        train_loader (DataLoader): Training data loader
        val_loader (DataLoader): Validation data loader
        model (DoMINO): Model to train
        optimizer (torch.optim.Optimizer): Optimizer
        scaler (GradScaler): Gradient scaler for mixed precision
        device (torch.device): Computation device
        epoch (int): Current epoch number
        best_vloss (float): Best validation loss so far
        rank (int): Process rank
        world_size (int): Total number of processes

    Returns:
        float: Validation loss for this epoch
    """
    # Training phase
    model.train()
    train_loss = 0.0
    pbar = (
        tqdm(train_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}")
        if rank == 0
        else train_loader
    )

    for batch in pbar:
        # Move batch to device
        batch = dict_to_device(batch, device)

        # Forward pass with mixed precision
        with autocast():
            _, pred_surf = model(batch)
            loss = mse_loss_fn(pred_surf, batch["surface_fields"])

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        # Update loss tracking
        train_loss += loss.item()
        if rank == 0:
            pbar.set_postfix(
                {
                    "train_loss": f"{train_loss / (pbar.n + 1):.5e}",
                    "lr": f"{optimizer.param_groups[0]['lr']:.2e}",
                }
            )

    # Compute average training loss
    avg_train_loss = train_loss / len(train_loader)

    # Validation phase
    model.eval()
    with torch.no_grad():
        val_loss = sum(
            mse_loss_fn(
                model(dict_to_device(batch, device))[1],
                batch["surface_fields"].to(device),
            ).item()
            for batch in val_loader
        ) / len(val_loader)

    # Handle distributed training metrics
    if world_size > 1:
        avg_train_loss, val_loss = [
            torch.tensor(v, device=device) for v in [avg_train_loss, val_loss]
        ]
        torch.distributed.all_reduce(avg_train_loss, op=torch.distributed.ReduceOp.SUM)
        torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.SUM)
        avg_train_loss, val_loss = (
            avg_train_loss.item() / world_size,
            val_loss.item() / world_size,
        )

    # Save checkpoints on main process
    if rank == 0:
        if val_loss < best_vloss:
            save_checkpoint(
                os.path.join(MODEL_SAVE_DIR, "best_model"),
                models=model,
                optimizer=optimizer,
                scaler=scaler,
            )

        if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
            save_checkpoint(
                MODEL_SAVE_DIR,
                models=model,
                optimizer=optimizer,
                scaler=scaler,
                epoch=epoch,
            )

    return val_loss


def train(model, device, rank, world_size):
    """
    Function that orchestrates the training process.
    Handles distributed training setup, model creation and training loop.
    """
    compute_scaling_factors()

    # Create output directory on main process
    os.makedirs(MODEL_SAVE_DIR, exist_ok=True) if rank == 0 else None

    # Set up data
    train_loader, val_loader, train_sampler, val_sampler = create_dataloaders(
        rank, world_size
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Initialize learning rate scheduler and gradient scaler
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1, 2], gamma=0.5)
    scaler = GradScaler()

    # Training loop
    best_vloss = float("inf")
    for epoch in range(NUM_EPOCHS):
        if world_size > 1:
            train_sampler.set_epoch(epoch)
        best_vloss = min(
            best_vloss,
            run_epoch(
                train_loader,
                val_loader,
                model,
                optimizer,
                scaler,
                device,
                epoch,
                best_vloss,
                rank,
                world_size,
            ),
        )
        # scheduler.step()

### **Step 2: Train the DoMINO Model**

The fifth step in our workflow focuses on training the DoMINO (Deep Mesh Operator Network) model on our processed CFD data. This step is crucial because:
- It enables the model to learn complex fluid dynamics patterns
- Provides a foundation for accurate flow field predictions
- Allows for efficient inference on new geometries
- Supports distributed training for improved performance

#### Understanding the Training Process

The training process involves several key components:
1. Setting up distributed training environment
2. Creating and configuring datasets and dataloaders
3. Initializing the DoMINO model architecture
4. Implementing training and validation loops
5. Managing model checkpoints and metrics

#### Key Components and Libraries

We'll use the following for training:

- **PyTorch**
   - `torch.distributed`: For distributed training
   - `torch.cuda`: For GPU acceleration
   - `torch.optim`: For optimization algorithms

- **Data Management**
   - Custom dataset classes for CFD data
   - Distributed samplers for efficient data loading
   - Distributed samplers for efficient data loading

#### Important Training Parameters

During the training process, we need to consider:
- Batch size and learning rate
- Number of epochs and validation frequency
- Model architecture parameters
- Loss function configuration
- Checkpointing strategy

#### Implementation Overview

The training is implemented through several key components:

1. **Distributed Setup**
```python
def setup_distributed():
    """Initialize distributed training environment."""
    # Sets up CUDA devices and process groups
```

2. **Model Creation**
```python
def create_model(device, rank, world_size):
    """Create and configure DoMINO model."""
    # Initializes model with specified parameters
```

3. **Training Loop**
```python
def train(model, device, rank, world_size):
    """Orchestrates the training process."""
    # Handles training loop, validation, and checkpointing
```

Let's proceed with implementing these components and training our model:

Lets run the train for few epochs:

In [None]:
# Initialize distributed training
device, rank, world_size = setup_distributed()
# Set up model
model = create_model(device, rank, world_size)
# Run training
train(model, device, rank, world_size)

In [None]:
import os
os._exit(00)