In [1]:
import torch
from torch.utils.data import DataLoader
from neuralop.models import GINO
from cfd_dataset import *
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

def load_model(model_path, model_config):
    model = GINO(**model_config)
    model.load_state_dict(torch.load(model_path))
    return model

def evaluate_model(model, val_loader, input_geom, latent_queries, device):
    model.eval()
    mse_loss = torch.nn.MSELoss()
    total_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y, output_queries in tqdm(val_loader, desc="Evaluating"):
            x = x.squeeze(0)
            y = y.squeeze(0)
            output_queries = output_queries.squeeze(0)
            x, y = x.to(device), y.to(device)
            output_queries = output_queries.to(device)

            ada_in = torch.randn(1, device=device)
            output = model(x, input_geom, latent_queries, output_queries, ada_in=ada_in)
            
            loss = mse_loss(output, y)
            total_loss += loss.item()

            # Reshape output and y to 3D (8x8x8)
            output = output.view(8, 8, 8).cpu().numpy()
            y = y.view(8, 8, 8).cpu().numpy()

            all_preds.append(output)
            all_targets.append(y)

    avg_loss = total_loss / len(val_loader)
    return avg_loss, np.array(all_preds), np.array(all_targets)

def plot_slices(pred, target, slice_indices, axis=0):
    fig, axes = plt.subplots(len(slice_indices), 2, figsize=(14, 4*len(slice_indices)))
    
    # Determine global min and max for consistent colorbar
    vmin = min(pred.min(), target.min())
    vmax = max(pred.max(), target.max())
    
    for i, idx in enumerate(slice_indices):
        if axis == 0:
            pred_slice = pred[idx, :, :]
            target_slice = target[idx, :, :]
        elif axis == 1:
            pred_slice = pred[:, idx, :]
            target_slice = target[:, idx, :]
        else:
            pred_slice = pred[:, :, idx]
            target_slice = target[:, :, idx]

        im_pred = axes[i, 0].imshow(pred_slice, cmap='viridis', vmin=vmin, vmax=vmax)
        axes[i, 0].set_title(f'Prediction (Slice {idx}, Axis {axis})')
        axes[i, 0].axis('off')
        plt.colorbar(im_pred, ax=axes[i, 0], fraction=0.046, pad=0.04)
        
        im_target = axes[i, 1].imshow(target_slice, cmap='viridis', vmin=vmin, vmax=vmax)
        axes[i, 1].set_title(f'Ground Truth (Slice {idx}, Axis {axis})')
        axes[i, 1].axis('off')
        plt.colorbar(im_target, ax=axes[i, 1], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f'cfd_slices_comparison_axis_{axis}.png', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    # Model configuration
    model_config = {
        'in_channels': 1,
        'out_channels': 1,
        'gno_coord_dim': 3,
        'gno_radius': 0.3,
        'projection_channels': 16,
        'in_gno_mlp_hidden_layers': [16, 16],
        'out_gno_mlp_hidden_layers': [16, 16],
        'in_gno_transform_type': "nonlinear",
        'out_gno_transform_type': "nonlinear",
        'fno_n_modes': (16, 16, 16),
        'fno_hidden_channels': 64,
        'fno_lifting_channels': 16,
        'fno_projection_channels': 16,
        'fno_norm': "ada_in",
    }

    # Load the saved model
    model_path = 'gino_cfd_model.pth'
    model = load_model(model_path, model_config)

    # Set up the validation dataset
    train_dataset, val_dataset = create_datasets('data/', num_train_samples=250, num_val_samples=50, shuffle=True, seed=42)
    val_dataset_wrapped = CFDDatasetWrapper(val_dataset)
    val_loader = DataLoader(val_dataset_wrapped, batch_size=1, shuffle=False, num_workers=1)

    # Get input_geom and latent_queries from the first item of the dataset
    _, _, input_geom, latent_queries, _ = val_dataset[0]

    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Move model and data to device
    model.to(device)
    input_geom = input_geom.to(device)
    latent_queries = latent_queries.to(device)

    # Evaluate the model
    avg_loss, all_preds, all_targets = evaluate_model(model, val_loader, input_geom, latent_queries, device)
    print(f"Average MSE Loss on validation set: {avg_loss:.6f}")

    # Plot slices for each axis
    for axis in range(3):
        slice_indices = [0, 3, 7]  # Plot first, middle, and last slices
        plot_slices(all_preds[0], all_targets[0], slice_indices, axis)

if __name__ == '__main__':
    main()

Using device: cpu


Evaluating: 100%|██████████| 50/50 [00:02<00:00, 17.80it/s]


Average MSE Loss on validation set: 0.019644
