# Flow Field Reconstruction - Inference Results Analysis

This notebook analyzes the reconstructed flow fields from various models (FLRNet, MLP, POD) using inference results stored in the checkpoints directory. The analysis includes comprehensive visualization with enhanced color maps for better visual interpretation.

## Overview
- Load pre-computed inference results from different models
- Compare reconstruction accuracy across various conditions
- Generate publication-quality plots with optimized color schemes
- Analyze effects of sensor count, layout, noise, and Reynolds number

## Import Required Libraries

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import os
import glob
from pathlib import Path
import seaborn as sns

# Set up improved plotting parameters similar to PDF style
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['lines.linewidth'] = 1.0
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['xtick.major.size'] = 3
plt.rcParams['ytick.major.size'] = 3

# Define enhanced color palettes
COLORS = {
    'flrnet': '#2E86AB',      # Ocean blue
    'flrnet_fourier': '#A23B72', # Deep magenta  
    'mlp': '#F18F01',         # Golden orange
    'pod': '#C73E1D',         # Crimson red
    'ground_truth': '#3A5F0B' # Forest green
}

# Enhanced colormap for field visualization
def create_enhanced_colormap():
    """Create enhanced colormap similar to PDF file"""
    # Use a more scientific colormap similar to the PDF
    colors = ['#000428', '#004e92', '#009ffd', '#00d2ff', '#ffffff', '#ffff00', '#ff6600', '#ff0000', '#8b0000']
    n_bins = 256
    cmap = LinearSegmentedColormap.from_list('enhanced_field', colors, N=n_bins)
    return cmap

# Enhanced diverging colormap for error visualization  
def create_error_colormap():
    """Create enhanced diverging colormap for error visualization"""
    # Blue-white-red diverging colormap
    colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7', '#fddbc7', '#f4a582', '#d6604d', '#b2182b']
    n_bins = 256
    cmap = LinearSegmentedColormap.from_list('enhanced_error', colors, N=n_bins)
    return cmap

print("Libraries imported successfully!")
print("Enhanced color schemes initialized.")

Libraries imported successfully!
Enhanced color schemes initialized.


## Load Reconstructed Field Data

Load pre-computed inference results from the checkpoints directory. The results include reconstructed fields from different models and configurations.

In [None]:
# Define paths
inference_results_path = r"E:\Research\Physics-informed-machine-learning\flow_field_recon_parc\checkpoints\inference_results"
data_path = r"E:\Research\Data\NavierStokes\test"

# Check if inference results directory exists
if not os.path.exists(inference_results_path):
    print(f"Warning: Inference results directory not found at {inference_results_path}")
    print("Available checkpoints:")
    checkpoints_path = r"E:\Research\Physics-informed-machine-learning\flow_field_recon_parc\checkpoints"
    if os.path.exists(checkpoints_path):
        for item in os.listdir(checkpoints_path):
            print(f"  - {item}")

def load_inference_results():
    """Load all inference results from the checkpoints directory"""
    results = {}
    
    # Look for .npz files in the inference_results directory and subdirectories
    search_patterns = [
        os.path.join(inference_results_path, "*.npz"),
        os.path.join(inference_results_path, "**", "*.npz"),
        # Also check main checkpoints directory
        r"E:\Research\Physics-informed-machine-learning\flow_field_recon_parc\checkpoints\*.npz"
    ]
    
    found_files = []
    for pattern in search_patterns:
        found_files.extend(glob.glob(pattern, recursive=True))
    
    print(f"Found {len(found_files)} .npz files:")
    for file_path in found_files:
        print(f"  - {os.path.basename(file_path)}")
        
        try:
            # Load the file
            data = np.load(file_path)
            file_name = os.path.basename(file_path).replace('.npz', '')
            
            # Extract key information from filename based on new convention
            if file_name.startswith('inference_'):
                # FLRNet models: inference_{layout}_{no_of_sensor}_{config}
                parts = file_name.split('_')
                if len(parts) >= 4:
                    config = '_'.join(parts[3:])  # Handle multi-part configs like 'no_fourier_no_percep'
                    if config == 'standard':
                        model_type = 'flrnet_fourier_percep'  # With fourier and perceptual
                    elif config == 'fourier':
                        model_type = 'flrnet_fourier'  # Fourier only
                    elif config == 'no_fourier':
                        model_type = 'flrnet_percep'  # Perceptual only
                    elif config == 'no_fourier_no_percep':
                        model_type = 'flrnet_standard'  # Vanilla VAE
                    else:
                        model_type = 'flrnet_unknown'
                else:
                    model_type = 'flrnet_unknown'
            elif file_name.startswith('mlp_'):
                model_type = 'mlp'
            elif file_name.startswith('pod_'):
                model_type = 'pod'
            else:
                model_type = 'unknown'
            
            # Store data with metadata
            results[file_name] = {
                'data': data,
                'model_type': model_type,
                'file_path': file_path
            }
            
            print(f"    Loaded: {file_name} ({model_type}) - Keys: {list(data.keys())}")
            
        except Exception as e:
            print(f"    Error loading {file_path}: {e}")
    
    return results

# Load all inference results
inference_data = load_inference_results()

print(f"\nSuccessfully loaded {len(inference_data)} inference result files.")

: 

: 

## Data Preprocessing and Normalization

Organize and preprocess the loaded data for consistent analysis across different models and configurations.

In [None]:
def categorize_flrnet_model(filename):
    """Categorize FLRNet models based on filename config pattern"""
    # Parse filename: inference_{layout}_{no_of_sensor}_{config}
    parts = filename.split('_')
    if len(parts) >= 4:
        config = '_'.join(parts[3:])  # Handle multi-part configs
        
        if config == 'standard':
            return 'flrnet_fourier_percep'  # With fourier and perceptual
        elif config == 'fourier':
            return 'flrnet_fourier'  # Fourier only
        elif config == 'no_fourier':
            return 'flrnet_percep'  # Perceptual only
        elif config == 'no_fourier_no_percep':
            return 'flrnet_standard'  # Vanilla VAE
    
    return 'flrnet_unknown'  # Fallback

def load_and_organize_data():
    """Load all .npz files and organize them by model type"""
    
    # Search patterns for finding .npz files
    search_patterns = [
        os.path.join(inference_results_path, "*.npz"),
        os.path.join(inference_results_path, "**", "*.npz"),
    ]
    
    # Find all .npz files
    all_files = []
    for pattern in search_patterns:
        all_files.extend(glob.glob(pattern, recursive=True))
    
    # Initialize organized data structure
    organized_data = {
        'flrnet_fourier_percep': {},  # standard config
        'flrnet_fourier': {},         # fourier config
        'flrnet_percep': {},          # no_fourier config
        'flrnet_standard': {},        # no_fourier_no_percep config (vanilla VAE)
        'mlp': {},
        'pod': {},
        'ground_truth': {}
    }
    
    # Ground truth data
    gt_data = None
    gt_info = {}
    
    # Process each file
    for file_path in all_files:
        try:
            filename = os.path.basename(file_path).replace('.npz', '')
            data = np.load(file_path)
            
            # Determine model category based on filename pattern
            if filename.startswith('inference_'):
                category = categorize_flrnet_model(filename)
            elif filename.startswith('mlp_'):
                category = 'mlp'
            elif filename.startswith('pod_'):
                category = 'pod'
            else:
                continue  # Skip unknown files
            
            # Extract model features from filename
            info = extract_model_info(filename)
            
            # Store in organized structure
            organized_data[category][filename] = {
                'data': data,
                'info': info,
                'file_path': file_path
            }
            
            # Extract ground truth from first available file
            if gt_data is None and 'targets' in data:
                gt_data = data['targets']
                gt_info = info.copy()
                gt_info['source'] = filename
                organized_data['ground_truth']['targets'] = {
                    'data': {'targets': gt_data},
                    'info': gt_info,
                    'file_path': file_path
                }
                print(f"✓ Ground truth extracted from {filename}")
            
        except Exception as e:
            print(f"✗ Error loading {file_path}: {e}")
    
    return organized_data, gt_data, gt_info

def extract_model_info(filename):
    """Extract metadata from filename based on naming convention"""
    info = {
        'filename': filename,
        'variant': 'unknown',
        'num_sensors': 0,
        'sensor_layout': 'unknown',
        'has_fourier': False,
        'has_perceptual': False,
        'config': 'unknown'
    }
    
    parts = filename.split('_')
    
    if filename.startswith('inference_'):
        # Pattern: inference_{layout}_{no_of_sensor}_{config}
        if len(parts) >= 4:
            layout = parts[1]
            num_sensors = int(parts[2]) if parts[2].isdigit() else 0
            config = '_'.join(parts[3:])
            
            info['sensor_layout'] = layout
            info['num_sensors'] = num_sensors
            info['config'] = config
            
            # Determine features based on config
            if config == 'standard':
                info['variant'] = 'fourier_percep'
                info['has_fourier'] = True
                info['has_perceptual'] = True
            elif config == 'fourier':
                info['variant'] = 'fourier'
                info['has_fourier'] = True
                info['has_perceptual'] = False
            elif config == 'no_fourier':
                info['variant'] = 'percep'
                info['has_fourier'] = False
                info['has_perceptual'] = True
            elif config == 'no_fourier_no_percep':
                info['variant'] = 'standard'
                info['has_fourier'] = False
                info['has_perceptual'] = False
                
    elif filename.startswith(('mlp_', 'pod_')):
        # Pattern: {model}_{layout}_{no_of_sensor}_standard
        if len(parts) >= 4:
            layout = parts[1]
            num_sensors = int(parts[2]) if parts[2].isdigit() else 0
            config = parts[3]
            
            info['sensor_layout'] = layout
            info['num_sensors'] = num_sensors
            info['config'] = config
            info['variant'] = filename.split('_')[0]  # 'mlp' or 'pod'
    
    return info

# Load and organize all data
organized_data, gt_data, gt_info = load_and_organize_data()

# Print detailed summary
print("Detailed Data Organization Summary:")
print("=" * 60)
print()

# Count files in each category
category_counts = {category: len(files) for category, files in organized_data.items()}

print("FLRNet Variants:")
print(f"  • FLRNet (Fourier + Perceptual) [standard]: {category_counts['flrnet_fourier_percep']} files")
print(f"  • FLRNet (Fourier only) [fourier]: {category_counts['flrnet_fourier']} files")
print(f"  • FLRNet (Perceptual only) [no_fourier]: {category_counts['flrnet_percep']} files")
print(f"  • FLRNet (Vanilla VAE) [no_fourier_no_percep]: {category_counts['flrnet_standard']} files")
print()
print("Other Models:")
print(f"  • MLP: {category_counts['mlp']} files")
print(f"  • POD: {category_counts['pod']} files")
print()
print("Ground Truth:")
print(f"  • Ground Truth: {category_counts['ground_truth']} extracted")
print()

# Show detailed breakdown for each category
for category, files in organized_data.items():
    if files and category != 'ground_truth':
        print(f"{category.upper().replace('_', ' ')}:")
        for filename, model_data in files.items():
            info = model_data['info']
            data = model_data['data']
            pred_shape = data['predictions'].shape if 'predictions' in data else "N/A"
            targets_shape = data['targets'].shape if 'targets' in data else "N/A"
            
            print(f"  - {filename}")
            print(f"    Config: {info['config']}, Variant: {info['variant']}")
            print(f"    Sensors: {info['num_sensors']}, Layout: {info['sensor_layout']}")
            print(f"    Fourier: {info['has_fourier']}, Perceptual: {info['has_perceptual']}")
            print(f"    Predictions: {pred_shape}, Targets: {targets_shape}")
        print()

if gt_data is not None:
    print("GROUND TRUTH:")
    print(f"  - Source: {gt_info['source']}")
    print(f"  - Shape: {gt_data.shape}")
else:
    print("⚠ WARNING: No ground truth data found!")

print()
print("=" * 60)
print(f"TOTAL FILES ORGANIZED: {sum(category_counts.values()) - category_counts['ground_truth']}")
print("=" * 60)

: 

: 

## Load Ground Truth Data for Comparison

Load the corresponding ground truth flow field data for error calculation and comparative analysis.

In [None]:
def load_ground_truth_from_inference_data():
    """Load ground truth data from embedded 'targets' in inference files"""
    
    print("Loading ground truth data from inference files:")
    
    # Get ground truth from the first available inference file (all should have the same targets)
    gt_data = None
    sample_file = None
    
    for file_name, file_info in inference_data.items():
        data = file_info['data']
        if 'targets' in data:
            gt_data = data['targets']
            sample_file = file_name
            break
    
    if gt_data is not None:
        print(f"  ✓ Ground truth loaded from: {sample_file}")
        print(f"  ✓ Ground truth shape: {gt_data.shape}")
        
        # Reshape data to have time as first dimension: (time, height, width, channels)
        # Original shape: (batch, height, width, time) -> (time, height, width, batch)
        if len(gt_data.shape) == 4:
            gt_data = np.transpose(gt_data, (3, 1, 2, 0))  # Move time from last to first
            print(f"  ✓ Reshaped to: {gt_data.shape} (time, height, width, channels)")
        
        # Calculate normalization parameters from the ground truth data itself
        min_val = np.min(gt_data)
        max_val = np.max(gt_data)
        
        print(f"  ✓ Ground truth range: [{min_val:.6f}, {max_val:.6f}]")
        
        return gt_data, min_val, max_val
    else:
        print("  ✗ No ground truth data found in inference files!")
        return None, None, None

# Load ground truth from embedded data
gt_data, min_val, max_val = load_ground_truth_from_inference_data()

if gt_data is not None:
    # The data is already in the correct format, no additional normalization needed
    print(f"Ground truth data successfully loaded: {gt_data.shape}")
    print(f"Value range: [{np.min(gt_data):.6f}, {np.max(gt_data):.6f}]")
else:
    print("Warning: No ground truth data available for analysis")
    min_val, max_val = 0.0, 1.0  # Default values

## Calculate Error Metrics

Compute comprehensive error metrics between ground truth and reconstructed fields for quantitative analysis.

In [None]:
def calculate_error_metrics(gt_data, pred_data):
    """Calculate comprehensive error metrics"""
    # Mean Absolute Error
    mae = np.mean(np.abs(gt_data - pred_data))
    
    # Root Mean Square Error
    rmse = np.sqrt(np.mean((gt_data - pred_data)**2))
    
    # Mean Absolute Percentage Error
    mape = np.mean(np.abs((gt_data - pred_data) / (gt_data + 1e-8))) * 100
    
    # Normalized MAE
    mae_normalized = mae / (np.max(gt_data) - np.min(gt_data))
    
    return {
        'mae': mae,
        'rmse': rmse, 
        'mape': mape,
        'mae_normalized': mae_normalized
    }
def process_inference_results_for_analysis():
    """Process inference results and calculate error metrics using embedded targets"""
    model_results = {
        'flrnet_fourier_percep': [],  # standard config
        'flrnet_fourier': [],         # fourier config
        'flrnet_percep': [],          # no_fourier config
        'flrnet_standard': [],        # no_fourier_no_percep config
        'mlp': [],
        'pod': []
    }
    
    error_metrics = {
        'flrnet_fourier_percep': [],
        'flrnet_fourier': [],
        'flrnet_percep': [],
        'flrnet_standard': [],
        'mlp': [],
        'pod': []
    }
    
    configs = {
        'flrnet_fourier_percep': [],
        'flrnet_fourier': [],
        'flrnet_percep': [],
        'flrnet_standard': [],
        'mlp': [],
        'pod': []
    }
    
    if gt_data is None:
        print("Warning: No ground truth data available for error calculation")
        return model_results, error_metrics, configs
    
    print("Processing inference results for analysis:")
    print("=" * 50)
    
    for model_type in model_results.keys():
        if model_type in organized_data and organized_data[model_type]:
            print(f"\n{model_type.upper()}:")
            
            for file_name, file_info in organized_data[model_type].items():
                data = file_info['data']
                info = file_info['info']
                
                # Use predictions and targets from the same file
                if 'predictions' in data and 'targets' in data:
                    predictions = data['predictions']
                    targets = data['targets']
                    
                    # Reshape predictions and targets to have time as first dimension if needed
                    if len(predictions.shape) == 4 and predictions.shape[3] > predictions.shape[0]:
                        predictions = np.transpose(predictions, (3, 1, 2, 0))
                        targets = np.transpose(targets, (3, 1, 2, 0))
                    
                    # Ensure consistent shapes
                    if predictions.shape != targets.shape:
                        print(f"  Warning: Shape mismatch for {file_name}")
                        print(f"    Predictions: {predictions.shape}")
                        print(f"    Targets: {targets.shape}")
                        continue
                    
                    # Calculate error metrics using predictions and targets directly
                    metrics = calculate_error_metrics(targets, predictions)
                    
                    # Store results
                    model_results[model_type].append(predictions)
                    error_metrics[model_type].append(metrics)
                    configs[model_type].append(info)
                    
                    print(f"  ✓ {file_name}")
                    print(f"    Config: {info['config']}")
                    print(f"    MAE: {metrics['mae']:.6f}")
                    print(f"    RMSE: {metrics['rmse']:.6f}")
                    
                else:
                    print(f"  ✗ {file_name}: Missing 'predictions' or 'targets' key")
                    print(f"    Available keys: {list(data.keys())}")
        else:
            print(f"\n{model_type.upper()}: No data found")
    
    return model_results, error_metrics, configs

# Process all inference results
model_results, error_metrics, configs = process_inference_results_for_analysis()

print(f"\nSummary of processed results:")
for model_type in model_results.keys():
    count = len(model_results[model_type])
    print(f"  {model_type}: {count} files processed")

## Enhanced Plotting Functions

Define improved plotting functions with enhanced color schemes and better visual presentation.

In [None]:
def create_enhanced_comparison_plot(x_data, y_data_dict, title, xlabel, ylabel, 
                                   figsize=(10, 7), save_name=None):
    """Create enhanced comparison plot with improved styling"""
    fig, ax1 = plt.subplots(figsize=figsize)
    
    # Enhanced color scheme
    colors = [COLORS['flrnet'], COLORS['flrnet_fourier'], COLORS['mlp'], COLORS['pod']]
    markers = ['o', 's', '^', 'd']
    linestyles = ['-', '--', '-.', ':']
    
    # Plot data
    for i, (model_name, y_data) in enumerate(y_data_dict.items()):
        if len(y_data) > 0:
            ax1.plot(x_data[:len(y_data)], y_data, 
                    label=model_name.replace('_', ' ').title(), 
                    color=colors[i % len(colors)], 
                    linestyle=linestyles[i % len(linestyles)],
                    linewidth=2.5, 
                    marker=markers[i % len(markers)],
                    markersize=8,
                    markerfacecolor='white',
                    markeredgewidth=2)
    
    # Styling
    ax1.set_title(title, fontsize=18, fontweight='bold', pad=20)
    ax1.set_xlabel(xlabel, fontsize=14, fontweight='bold')
    ax1.set_ylabel(ylabel, fontsize=14, fontweight='bold')
    ax1.legend(loc='upper right', fontsize=12, frameon=True, fancybox=True, shadow=True)
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.tick_params(axis='both', which='major', labelsize=12)
    
    # Add secondary y-axis for percentage
    ax2 = ax1.twinx()
    for i, (model_name, y_data) in enumerate(y_data_dict.items()):
        if len(y_data) > 0:
            y_percentage = np.array(y_data) / max_val * 100
            ax2.plot(x_data[:len(y_data)], y_percentage, 
                    color=colors[i % len(colors)], 
                    linestyle=linestyles[i % len(linestyles)],
                    linewidth=2.5, 
                    marker=markers[i % len(markers)],
                    markersize=8,
                    markerfacecolor='white',
                    markeredgewidth=2,
                    alpha=0.7)
    
    ax2.set_ylabel('Mean Absolute Percentage Error (%)', fontsize=14, fontweight='bold')
    ax2.tick_params(axis='y', labelcolor='black', labelsize=12)
    
    plt.tight_layout()
    
    if save_name:
        plt.savefig(f"{save_name}.png", dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig(f"{save_name}.pdf", dpi=300, bbox_inches='tight', facecolor='white')
    
    plt.show()


## Effect of Number of Sensors Analysis

Analyze how reconstruction accuracy varies with the number of sensors (8, 16, 32) using enhanced visualizations.

In [None]:
# Debug: Check organized data structure
print("Debug: Organized data structure:")
print("=" * 50)
for model_type, model_data in organized_data.items():
    print(f"\n{model_type}: {len(model_data) if model_data else 0} files")
    if model_data:
        for file_name in list(model_data.keys())[:3]:  # Show first 3 files
            print(f"  - {file_name}")

print("\nDebug: Original inference data:")
print("=" * 50)
for file_name in list(inference_data.keys())[:5]:  # Show first 5 files
    file_info = inference_data[file_name]
    print(f"{file_name}:")
    print(f"  Model type: {file_info['model_type']}")
    print(f"  Data keys: {list(file_info['data'].keys())}")

# Now proceed with the analysis using the correct data structure
print("\n\nAnalyzing sensor count effects:")
print("=" * 50)

# Collect sensor count data from organized structure
sensor_count_results = {}

for model_type, model_data in organized_data.items():
    if model_type == 'ground_truth' or not model_data:
        continue
        
    for file_name, file_info in model_data.items():
        info = file_info['info']  # This is the info dictionary
        sensor_count = info['num_sensors']  # Use num_sensors instead of config['sensor_count']
        layout = info['sensor_layout']  # Use sensor_layout instead of config['layout']
        
        if sensor_count:
            if sensor_count not in sensor_count_results:
                sensor_count_results[sensor_count] = {}
            if model_type not in sensor_count_results[sensor_count]:
                sensor_count_results[sensor_count][model_type] = {}
            if layout not in sensor_count_results[sensor_count][model_type]:
                sensor_count_results[sensor_count][model_type][layout] = []
            
            # Calculate basic metrics
            predictions = file_info['data']['predictions']
            targets = file_info['data']['targets']
            
            # Calculate MSE
            mae = np.abs((predictions - targets))
            
            sensor_count_results[sensor_count][model_type][layout].append({
                'file': file_name,
                'mae': mae,
                'predictions': predictions,
                'targets': targets
            })

# Display sensor count analysis
for sensor_count in sorted(sensor_count_results.keys()):
    print(f"\nSensor count: {sensor_count}")
    for model_type, model_results in sensor_count_results[sensor_count].items():
        for layout, layout_results in model_results.items():
            avg_mae = np.mean([r['mae'] for r in layout_results])
            print(f"  {model_type} ({layout}): {len(layout_results)} files, Avg MSE: {avg_mae:.6f}")

print(f"\nTotal configurations found: {sum(len(model_data) for model_data in organized_data.values() if model_data)}")

## Effect of Sensor Layout Analysis

Compare different sensor layouts (Random, Circular, Edge) and their impact on reconstruction quality.

In [None]:
# Create sensor count comparison plots
import matplotlib.pyplot as plt

# Prepare data for visualization
sensor_counts = sorted(sensor_count_results.keys())
model_types = ['flrnet', 'flrnet_fourier', 'mlp', 'pod']
layouts = ['random', 'circular', 'edge']

# Create comprehensive sensor count analysis plot
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Model Performance Analysis Across Different Configurations', fontsize=16, fontweight='bold')

# Plot 1: MSE by Sensor Count (averaged across layouts)
ax1 = axes[0, 0]
for model_type in model_types:
    sensor_maes = []
    for sensor_count in sensor_counts:
        if sensor_count in sensor_count_results and model_type in sensor_count_results[sensor_count]:
            # Average across all layouts for this sensor count and model
            all_maes = []
            for layout in sensor_count_results[sensor_count][model_type]:
                layout_maes = [r['mae'] for r in sensor_count_results[sensor_count][model_type][layout]]
                all_maes.extend(layout_maes)
            if all_maes:
                sensor_maes.append(np.mean(all_maes))
            else:
                sensor_maes.append(np.nan)
        else:
            sensor_maes.append(np.nan)
    
    # Remove NaN values for plotting
    valid_indices = ~np.isnan(sensor_maes)
    if np.any(valid_indices):
        ax1.plot(np.array(sensor_counts)[valid_indices], np.array(sensor_maes)[valid_indices], 
                'o-', linewidth=2, markersize=8, label=model_type.upper())

ax1.set_xlabel('Number of Sensors', fontsize=12)
ax1.set_ylabel('Mean Squared Error', fontsize=12)
ax1.set_title('Model Performance vs Sensor Count', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Plot 2: Layout comparison for 32 sensors
ax2 = axes[0, 1]
if 32 in sensor_count_results:
    layouts_available = []
    model_maes = {model: [] for model in model_types}
    
    for layout in layouts:
        layout_found = False
        for model_type in model_types:
            if model_type in sensor_count_results[32] and layout in sensor_count_results[32][model_type]:
                layout_maes = [r['mae'] for r in sensor_count_results[32][model_type][layout]]
                model_maes[model_type].append(np.mean(layout_maes))
                layout_found = True
            else:
                model_maes[model_type].append(np.nan)
        if layout_found:
            layouts_available.append(layout)
    
    x = np.arange(len(layouts_available))
    width = 0.2
    
    for i, model_type in enumerate(model_types):
        valid_maes = [model_maes[model_type][j] for j in range(len(layouts_available)) 
                     if not np.isnan(model_maes[model_type][j])]
        valid_layouts = [layouts_available[j] for j in range(len(layouts_available)) 
                        if not np.isnan(model_maes[model_type][j])]
        
        if valid_maes:
            ax2.bar(np.arange(len(valid_layouts)) + i*width, valid_maes, width, 
                   label=model_type.upper(), alpha=0.8)
    
    ax2.set_xlabel('Sensor Layout', fontsize=12)
    ax2.set_ylabel('Mean Squared Error', fontsize=12)
    ax2.set_title('Layout Comparison (32 Sensors)', fontsize=14)
    ax2.set_xticks(x + width*1.5)
    ax2.set_xticklabels(layouts_available)
    ax2.legend()
    ax2.grid(True, alpha=0.3)

# Plot 3: Model variant comparison
ax3 = axes[1, 0]
model_variant_data = {}
for sensor_count in sensor_count_results:
    for model_type in sensor_count_results[sensor_count]:
        for layout in sensor_count_results[sensor_count][model_type]:
            for result in sensor_count_results[sensor_count][model_type][layout]:
                key = f"{model_type}_{layout}_{sensor_count}"
                if key not in model_variant_data:
                    model_variant_data[key] = []
                model_variant_data[key].append(result['mae'])

# Group by main model types
model_groups = {'FLRNet': [], 'MLP': [], 'POD': []}
for key, maes in model_variant_data.items():
    if 'flrnet' in key:
        model_groups['FLRNet'].extend(maes)
    elif 'mlp' in key:
        model_groups['MLP'].extend(maes)
    elif 'pod' in key:
        model_groups['POD'].extend(maes)

box_data = [model_groups[model] for model in ['FLRNet', 'MLP', 'POD'] if model_groups[model]]
box_labels = [model for model in ['FLRNet', 'MLP', 'POD'] if model_groups[model]]

if box_data:
    bp = ax3.boxplot(box_data, labels=box_labels, patch_artist=True)
    colors = ['lightblue', 'lightgreen', 'lightcoral']
    for patch, color in zip(bp['boxes'], colors[:len(bp['boxes'])]):
        patch.set_facecolor(color)

ax3.set_ylabel('Mean Squared Error', fontsize=12)
ax3.set_title('Model Type Comparison', fontsize=14)
ax3.grid(True, alpha=0.3)
ax3.set_yscale('log')

# Plot 4: Performance improvement trend
ax4 = axes[1, 1]
for model_type in model_types:
    improvements = []
    sensor_counts_valid = []
    
    for i, sensor_count in enumerate(sorted(sensor_counts)):
        if sensor_count in sensor_count_results and model_type in sensor_count_results[sensor_count]:
            all_maes = []
            for layout in sensor_count_results[sensor_count][model_type]:
                layout_maes = [r['mse'] for r in sensor_count_results[sensor_count][model_type][layout]]
                all_maes.extend(layout_maes)
            if all_maes:
                avg_mse = np.mean(all_maes)
                if i == 0:
                    improvements.append(0)  # Baseline
                else:
                    # Calculate improvement percentage
                    baseline_mse = np.mean([
                        r['mse'] for layout in sensor_count_results[sorted(sensor_counts)[0]][model_type]
                        for r in sensor_count_results[sorted(sensor_counts)[0]][model_type][layout]
                    ]) if sorted(sensor_counts)[0] in sensor_count_results and model_type in sensor_count_results[sorted(sensor_counts)[0]] else avg_mse
                    improvement = (baseline_mse - avg_mse) / baseline_mse * 100
                    improvements.append(improvement)
                sensor_counts_valid.append(sensor_count)
    
    if len(improvements) > 1:
        ax4.plot(sensor_counts_valid, improvements, 'o-', linewidth=2, markersize=8, 
                label=model_type.upper())

ax4.set_xlabel('Number of Sensors', fontsize=12)
ax4.set_ylabel('Performance Improvement (%)', fontsize=12)
ax4.set_title('Performance Improvement with More Sensors', fontsize=14)
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('comprehensive_model_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("Comprehensive analysis plots created and saved as 'comprehensive_model_analysis.png'")

## Field Reconstruction Visualization

Create comprehensive field reconstruction plots with enhanced color maps showing ground truth, predictions, and error distributions.

In [None]:
def create_flrnet_style_field_plot(gt_field, time_indices, Re_value, case_id=0, save_name=None):
    """Create field reconstruction visualization matching FLRNet training style"""
    n_times = len(time_indices)
    
    # Define model types to compare as specified
    model_types_to_compare = [
        'flrnet_fourier_percep',  # FLRNet with Fourier and Perceptual
        'flrnet_fourier',         # FLRNet with Fourier only  
        'flrnet_percep',          # FLRNet with Perceptual only
        'flrnet_standard',        # FLRNet standard (no fourier, no percep)
        'mlp',                    # MLP
        'pod'                     # POD
    ]
    
    model_labels = [
        'FLRNet-FP',
        'FLRNet-F',
        'FLRNet-P', 
        'FLRNet-Std',
        'MLP',
        'POD'
    ]
    
    # Use colorbar range from ground truth data for consistency
    global_vmin = np.min(gt_field[:, :, :, case_id])
    global_vmax = np.max(gt_field[:, :, :, case_id])
    print(f"Ground truth colorbar range: [{global_vmin:.6f}, {global_vmax:.6f}]")
    
    # Create subplots - 1 row for ground truth + 6 rows for models
    n_rows = 7
    fig, axes = plt.subplots(n_rows, n_times, figsize=(3*n_times, 1.7*n_rows))
    if n_times == 1:
        axes = axes.reshape(-1, 1)

    # Time step labels for 5 columns
    time_labels = [f't={round((idx+1)*0.0511, 2)} [s]' for idx in time_indices]  # Rounded to 2 decimal points    

    # Plot ground truth using FLRNet training style with global colorbar range
    for j, time_idx in enumerate(time_indices):
        im_gt = axes[0, j].imshow(gt_field[time_idx, :, :, case_id], 
                                cmap='RdBu_r', origin='lower', 
                                vmin=global_vmin, vmax=global_vmax)
        axes[0, j].set_title(time_labels[j], fontsize=14, fontweight='normal')
        axes[0, j].set_xticks([])
        axes[0, j].set_yticks([])
        # Remove all spines
        for spine in axes[0, j].spines.values():
            spine.set_visible(False)

    # Add ground truth label
    fig.text(-0.01, 0.88, 'Ground Truth', rotation=90, va='center', ha='center', 
            fontsize=14, fontweight='bold')

    # Add a single colorbar for all plots
    # Add a single colorbar for all plots
    cbar_ax = fig.add_axes([1.01, 0.25, 0.02, 0.5])  # [left, bottom, width, height]
    cbar = fig.colorbar(im_gt, cax=cbar_ax)
    cbar.set_label('Velocity (m/s)', fontsize=14, fontweight='bold')
    
    # Find and plot reconstructions for each model type - specifically for random layout and 32 sensors
    for i, (model_type, model_label) in enumerate(zip(model_types_to_compare, model_labels)):
        row = i + 1
        
        # Find the best matching model with random layout and 32 sensors
        recon_field = None
        selected_file = None
        
        if model_type in organized_data and organized_data[model_type]:
            for file_name, file_info in organized_data[model_type].items():
                info = file_info['info']
                # Check for random layout and 32 sensors
                if (info['sensor_layout'] == 'random' and info['num_sensors'] == 32):
                    recon_field = file_info['data']['predictions']
                    selected_file = file_name
                    print(f"Selected {model_type}: {file_name} (layout: {info['sensor_layout']}, sensors: {info['num_sensors']})")
                    break
            
            # If no random_32 found, try to find any random layout
            if recon_field is None:
                for file_name, file_info in organized_data[model_type].items():
                    info = file_info['info']
                    if info['sensor_layout'] == 'random':
                        recon_field = file_info['data']['predictions']
                        selected_file = file_name
                        print(f"Fallback {model_type}: {file_name} (layout: {info['sensor_layout']}, sensors: {info['num_sensors']})")
                        break
            
            # If still no random layout found, use first available
            if recon_field is None:
                first_file = list(organized_data[model_type].keys())[0]
                recon_field = organized_data[model_type][first_file]['data']['predictions']
                selected_file = first_file
                info = organized_data[model_type][first_file]['info']
                print(f"Default {model_type}: {first_file} (layout: {info['sensor_layout']}, sensors: {info['num_sensors']})")
        
        # Plot reconstruction using FLRNet training style with global colorbar range
        for j, time_idx in enumerate(time_indices):
            if recon_field is not None:
                # Reshape if needed
                if len(recon_field.shape) == 4 and recon_field.shape[3] > recon_field.shape[0]:
                    recon_field = np.transpose(recon_field, (3, 1, 2, 0))
                
                if time_idx < recon_field.shape[0]:
                    im_recon = axes[row, j].imshow(recon_field[time_idx, :, :, case_id], 
                                                  cmap='RdBu_r', origin='lower',
                                                  vmin=global_vmin, vmax=global_vmax)
                else:
                    # Show placeholder if time index out of range
                    axes[row, j].text(0.5, 0.5, 'N/A', ha='center', va='center', 
                                    transform=axes[row, j].transAxes, fontsize=10)
            else:
                # Show placeholder if no data
                axes[row, j].text(0.5, 0.5, 'No Data', ha='center', va='center', 
                                transform=axes[row, j].transAxes, fontsize=10)
            
            axes[row, j].set_xticks([])
            axes[row, j].set_yticks([])
            # Remove all spines
            for spine in axes[row, j].spines.values():
                spine.set_visible(False)
            
            # # Add colorbar for the last column
            # if j == n_times - 1 and recon_field is not None:
            #     cbar_recon = plt.colorbar(im_recon, ax=axes[row, j], fraction=0.046, pad=0.04)
            #     cbar_recon.set_label('Velocity (m/s)', fontsize=9)
        
        # Add model label with configuration info
        if recon_field is not None and selected_file:
            info = None
            for category_data in organized_data.values():
                if selected_file in category_data:
                    info = category_data[selected_file]['info']
                    break
            label_text = model_label
        else:
            label_text = f"{model_label}\n(No Data)"
            
        fig.text(-0.01, 0.88 - row * 0.135, label_text, rotation=90, va='center', ha='center', 
                 fontsize=14, fontweight='bold')
    
    # Main title
    fig.suptitle(f'Flow Field Reconstruction Results in The Case Re = {Re_value}', 
                 fontsize=18, fontweight='bold', y=1.0)
    
    plt.tight_layout()
    
    if save_name:
        plt.savefig(f"{save_name}.png", dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig(f"{save_name}.pdf", dpi=300, bbox_inches='tight', facecolor='white')
    
    plt.show()

def create_field_reconstruction_analysis(case_id=0, Re_value=None):
    """Create field reconstruction visualization using FLRNet training style colormap"""
    if gt_data is None:
        print("No ground truth data available for field visualization")
        return
    
    # Select 5 time indices: 1, 10, 20, 30, 39 (0-indexed: 0, 9, 19, 29, 38)
    max_time_steps = gt_data.shape[0]  # Total time steps = 39
    
    # Create 5 evenly distributed time indices with first=1 and last=39
    if max_time_steps >= 39:
        time_indices = [0, 9, 19, 29, 38] # t=1, t=10, t=20, t=30, t=39 (0-indexed)
    else:
        # Fallback if we have fewer time steps
        time_indices = [0, max_time_steps//4, max_time_steps//2, 3*max_time_steps//4, max_time_steps-1]
    if Re_value is not None:
        print("Creating field reconstruction visualization with FLRNet training style:")
        print(f"Total time steps available: {max_time_steps}")
        print(f"Selected time indices (0-indexed): {time_indices}")
        print(f"Corresponding time steps (1-indexed): {[t+1 for t in time_indices]}")
        print(f"Reynolds number: {Re_value}")
        print("Comparing models: FLRNet variants, MLP, and POD - Random Layout Configuration")
        
        # Print available configurations for verification
        print("\nAvailable model configurations:")
        for model_type, model_data in organized_data.items():
            if model_type != 'ground_truth' and model_data:
                print(f"\n{model_type}:")
                for file_name, file_info in model_data.items():
                    info = file_info['info']
                    print(f"  - {file_name}: {info['sensor_layout']} layout, {info['num_sensors']} sensors")
        
        # Create the field plot using FLRNet training style (no error maps)
        save_name = f'field_reconstruction_random_{Re_value}'
        create_flrnet_style_field_plot(
            gt_field=gt_data,
            time_indices=time_indices,
            Re_value=Re_value,
            case_id=case_id,
            save_name=save_name
        )
        
        print("FLRNet training style field reconstruction visualization completed")
    else:      
        print("Reynolds number not provided, skipping field reconstruction visualization")                

# Create field reconstruction analysis
create_field_reconstruction_analysis(1, Re_value=60)

create_field_reconstruction_analysis(3, Re_value=300)

create_field_reconstruction_analysis(5, Re_value=950)

create_field_reconstruction_analysis(6, Re_value=3000)

## Temporal Error Analysis

Analyze how reconstruction error evolves over time steps for different models.

In [None]:
def analyze_temporal_errors():
    """Analyze temporal evolution of reconstruction errors - ONLY random_32 data"""
    if gt_data is None:
        print("No ground truth data available for temporal analysis")
        return
    
    time_steps = np.arange(gt_data.shape[0])
    temporal_mae = {
        'flrnet_fourier_percep': [],
        'flrnet_fourier': [],
        'flrnet_percep': [],
        'flrnet_standard': [],
        'mlp': [],
        'pod': []
    }
    
    print("Analyzing temporal errors - ONLY random layout with 32 sensors:")
    print("=" * 60)
    
    # Filter and analyze only random_32 configurations
    temporal_mae_direct = {}
    
    for model_type, model_data in organized_data.items():
        if model_type == 'ground_truth' or not model_data:
            continue
            
        # Find files with random layout and 32 sensors ONLY
        random_32_files = []
        for file_name, file_info in model_data.items():
            info = file_info['info']
            if info['sensor_layout'] == 'random' and info['num_sensors'] == 32:
                random_32_files.append((file_name, file_info))
        
        if random_32_files:
            # Use the first random_32 file found for this model type
            file_name, file_info = random_32_files[0]
            
            if 'predictions' in file_info['data']:
                predictions = file_info['data']['predictions']
                targets = file_info['data']['targets']
                
                # Reshape if needed
                if len(predictions.shape) == 4 and predictions.shape[3] > predictions.shape[0]:
                    predictions = np.transpose(predictions, (3, 1, 2, 0))
                    targets = np.transpose(targets, (3, 1, 2, 0))
                
                # Calculate MAE for each time step
                mae_per_time = []
                for t in range(min(predictions.shape[0], targets.shape[0])):
                    mae_t = np.mean(np.abs(targets[t] - predictions[t]))
                    mae_per_time.append(mae_t)
                
                temporal_mae_direct[model_type] = mae_per_time
                print(f"  ✓ {model_type}: {file_name}")
                print(f"    Layout: {file_info['info']['sensor_layout']}, Sensors: {file_info['info']['num_sensors']}")
                print(f"    Time steps analyzed: {len(mae_per_time)}")
            else:
                print(f"  ✗ {model_type}: {file_name} missing predictions/targets")
        else:
            print(f"  ✗ {model_type}: No random_32 configuration found")
    
    # Use the filtered temporal data
    temporal_mae = temporal_mae_direct
    time_steps = np.arange(len(next(iter(temporal_mae.values()), []))) if temporal_mae else []
    
    return time_steps, temporal_mae

# Analyze temporal errors with random_32 filter
time_steps, temporal_mae = analyze_temporal_errors()

# Create enhanced temporal error plot for random_32 only
if any(len(values) > 0 for values in temporal_mae.values()):
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # Enhanced styling - updated colors for new model types with better contrast and visibility
    # colors = ['#2E86AB', '#E63946', '#F77F00', '#06D6A0', '#8338EC', '#A23B72']  # Professional color palette
    markers = ['o', 's', '^', 'D', 'v', 'p']  # More distinct markers
    linestyles = ['-', '-', '-', '-', '-', '-']  # All solid lines as requested

    # Alternative professional color scheme (uncomment to use):
    # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']  # Matplotlib default (professional)

    # Alternative vibrant color scheme (uncomment to use):
    colors = ['#0077BE', '#FF6B35', '#004E89', '#FF9F1C', '#7209B7', '#C73E1D']  # High contrast vibrant

    # Model labels for display
    model_labels = {
        'flrnet_fourier_percep': 'FLRNet-FP',
        'flrnet_fourier': 'FLRNet-F',
        'flrnet_percep': 'FLRNet-P',
        'flrnet_standard': 'FLRNet-Std',
        'mlp': 'MLP',
        'pod': 'POD'
    }
    
    # Plot temporal evolution
    for i, (model_type, mae_values) in enumerate(temporal_mae.items()):
        if len(mae_values) > 0:
            # Subsample for better visualization
            step = max(1, len(mae_values) // 20)  # Show ~20 points
            t_sub = time_steps[:len(mae_values):step]*0.0511  # Convert to seconds
            mae_sub = mae_values[::step]
            
            ax1.plot(t_sub, mae_sub,
                    label=model_labels.get(model_type, model_type),
                    color=colors[i],
                    linestyle=linestyles[i],
                    linewidth=2.5,
                    marker=markers[i],
                    markersize=6,
                    markerfacecolor='white',
                    markeredgewidth=2,
                    alpha=0.8)
    
        # Styling
    ax1.set_title('Temporal Evolution of Reconstruction Error\n(Random Layout, 32 Sensors)', 
                  fontsize=18, fontweight='bold', pad=20)
    ax1.set_xlabel('Time [s]', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Mean Absolute Error (m/s)', fontsize=14, fontweight='bold')
    ax1.legend(loc='upper left', fontsize=14, frameon=True, fancybox=True, shadow=True)
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.tick_params(axis='both', which='major', labelsize=14)
    ax1.set_xlim(0, 2)  # Set x-axis limit from 0 to 2
    # Add bounding box around the plot
    for spine in ax1.spines.values():
        spine.set_visible(True)
        # spine.set_linewidth(1.5)
        # spine.set_color('black')
    
    # Add secondary y-axis for percentage
    ax2 = ax1.twinx()
    for i, (model_type, mae_values) in enumerate(temporal_mae.items()):
        if len(mae_values) > 0:
            step = max(1, len(mae_values) // 20)
            t_sub = time_steps[:len(mae_values):step]*0.0511  # Convert to seconds
            mae_sub = np.array(mae_values[::step])
            mae_percentage = mae_sub / max_val * 100
            
            ax2.plot(t_sub, mae_percentage,
                    color=colors[i],
                    linestyle=linestyles[i],
                    linewidth=2.5,
                    marker=markers[i],
                    markersize=6,
                    markerfacecolor='white',
                    markeredgewidth=2,
                    alpha=0.6)
    
    ax2.set_ylabel('Mean Absolute Percentage Error (%)', fontsize=14, fontweight='bold')
    ax2.tick_params(axis='y', labelsize=14)
    
    # Ensure secondary axis also has visible spines
    for spine in ax2.spines.values():
        spine.set_visible(True)
        # spine.set_linewidth(1.5)
        # spine.set_color('black')
    
    plt.tight_layout()
    plt.savefig('temporal_error_analysis_random_32_only.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.savefig('temporal_error_analysis_random_32_only.pdf', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print("Temporal error analysis completed (random_32 only)")
    print(f"Models analyzed: {list(temporal_mae.keys())}")
    print(f"Files used: random layout, 32 sensors configuration only")
else:
    print("No random_32 data available for temporal analysis")

## Vertical Error Profile Analysis

Create vertical profile plots showing error distribution across the flow field domain at different positions.

In [None]:
def create_vertical_error_profiles():
    """Create vertical error profiles at different positions - PDF style with averaged errors across all time steps and cases"""
    if gt_data is None:
        print("No ground truth data available for vertical error profile analysis")
        return
    
    # Define positions along x-axis for vertical profiles
    positions = [1, 2, 3, 4, 5, 6]  # 6 positions as in original
    # Normalize vertical coordinate to range [0, 1] [m]
    y_coords = np.linspace(0, 1, gt_data.shape[1])
    
    print(f"Creating vertical error profiles averaged across all time steps and cases (channel 0 and last channel)")
    print(f"Using ONLY random_32 configurations")
    print(f"Positions: {positions}")
    print(f"Available time steps: {gt_data.shape[0]}")
    print(f"Available cases (channels): {gt_data.shape[3]}")
    
    # Create subplots for multiple positions
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for pos_idx, pos in enumerate(positions):
        if pos_idx >= len(axes):
            break
            
        ax = axes[pos_idx]
        x_pos = int(gt_data.shape[2] / 7 * pos)  # Map position to grid coordinates
        x_pos_real = x_pos * 0.0078125  # Real x-coordinate in meters
        
        # Enhanced styling for PDF with distinct colors and markers
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
        markers = ['o', 's', '^', 'D', 'v', 'p']  # Different markers for each model
        
        # Model profiles - calculate averaged error profiles across all time steps and cases
        model_types = ['flrnet_fourier_percep', 'flrnet_fourier', 'flrnet_percep', 'flrnet_standard', 'mlp', 'pod']
        model_labels = ['FLRNet-FP', 'FLRNet-F', 'FLRNet--P', 'FLRNet-Std', 'MLP', 'POD']
        
        for i, (model_type, model_label) in enumerate(zip(model_types, model_labels)):
            # Find the corresponding model data from organized_data - ONLY random_32
            recon_field = None
            selected_file = None
            
            if model_type in organized_data and organized_data[model_type]:
                # Look for random_32 configuration specifically
                for file_name, file_info in organized_data[model_type].items():
                    info = file_info['info']
                    if info['sensor_layout'] == 'random' and info['num_sensors'] == 32:
                        recon_field = file_info['data']['predictions']
                        selected_file = file_name
                        print(f"  Using {model_type}: {file_name} (random_32)")
                        break
                
                if recon_field is None:
                    print(f"  ✗ {model_type}: No random_32 configuration found")
                    continue
            
            if recon_field is not None:
                # Reshape if needed
                if len(recon_field.shape) == 4 and recon_field.shape[3] > recon_field.shape[0]:
                    recon_field = np.transpose(recon_field, (3, 1, 2, 0))
                
                # Get ground truth for the same position
                gt_profiles_all = gt_data[:, :, x_pos, :]  # Shape: (time, height, cases)
                recon_profiles_all = recon_field[:, :, x_pos, :]  # Shape: (time, height, cases)
                
                # Calculate error profiles for all time steps and cases
                all_error_profiles = []
                
                # Use channel 0 and last channel only
                channels_to_use = [0, gt_data.shape[3] - 1]
                
                for time_step in range(min(gt_data.shape[0], recon_field.shape[0])):
                    for channel in channels_to_use:
                        if channel < gt_data.shape[3] and channel < recon_field.shape[3]:
                            gt_profile = gt_data[time_step, :, x_pos, channel]
                            recon_profile = recon_field[time_step, :, x_pos, channel]
                            
                            # Calculate absolute error profile
                            error_profile = np.abs(gt_profile - recon_profile)
                            all_error_profiles.append(error_profile)
                
                if all_error_profiles:
                    # Average across all time steps and cases
                    averaged_error_profile = np.mean(all_error_profiles, axis=0)
                    
                    ax.plot(y_coords, averaged_error_profile, label=model_label, 
                           color=colors[i], linestyle='-', linewidth=1.5, 
                           marker=markers[i], markersize=3, markevery=5,
                           markerfacecolor='white', markeredgewidth=1, alpha=0.8)
                    
                    print(f"    ✓ Averaged {len(all_error_profiles)} error profiles (time steps × cases)")
                else:
                    print(f"    ✗ No valid error profiles calculated for {model_type}")
        
        # Styling for each subplot - PDF style
        ax.set_title(f'Horizontal coordinate \n {x_pos_real:.4f} [m]', fontsize=14, fontweight='bold')
        ax.set_xlabel('Vertical Coordinate [m]', fontsize=14, fontweight='bold')
        ax.set_ylabel('Averaged Absolute Error [m/s]', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
        ax.tick_params(axis='both', labelsize=12)
        ax.spines['top'].set_visible(True)
        ax.spines['right'].set_visible(True)
        
        # Set y-axis to start from 0 for error visualization
        ax.set_ylim(bottom=0)
        
        # Add legend only to the first subplot
        if pos_idx == 0:
            ax.legend(loc='upper right', fontsize=10, frameon=True)
    
    plt.suptitle('Average Vertical Absolute Error Profiles \nRandom Layout, 32 Sensors', 
                 fontsize=16, fontweight='bold', y=1.0)
    plt.tight_layout()
    plt.savefig('vertical_error_profiles_averaged_random_32.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.savefig('vertical_error_profiles_averaged_random_32.pdf', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print("Vertical error profile analysis completed (averaged across all time steps and cases, random_32 only)")

# Call the updated error profile function
create_vertical_error_profiles()

In [None]:
def create_domain_with_annotated_slices():
    """Create domain visualization with simplified drawing showing obstacle and slice positions"""
    if gt_data is None:
        print("No ground truth data available for domain visualization")
        return
    
    # Define domain parameters
    nx, ny = gt_data.shape[2], gt_data.shape[1]
    Lx, Ly = 2.0, 1.0  # Domain dimensions in meters
    dx, dy = Lx / nx, Ly / ny
    
    # Obstacle parameters
    obstacle_center_x = 0.5  # 0.5 m from left
    obstacle_center_y = 0.5  # 0.5 m from bottom
    obstacle_diameter = 0.25  # 0.25 m diameter
    obstacle_radius = obstacle_diameter / 2
    
    # Slice positions (same as error profile analysis)
    positions = [1, 2, 3, 4, 5, 6]
    slice_colors = ["#000000", "#000000", "#000000", "#000000", "#000000", '#000000']
    
    # Create the figure with proper aspect ratio
    # Calculate figure size to maintain proper aspect ratio (2:1 for domain)
    fig_width = 12  # Base width
    fig_height = fig_width * (Ly / Lx)  # Maintain domain aspect ratio
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    # Draw domain boundaries with light gray fill
    domain_rect = plt.Rectangle((0, 0), Lx, Ly, 
                               facecolor='lightgray', edgecolor='black', 
                               linewidth=3, alpha=0.3)
    ax.add_patch(domain_rect)
    
    # Draw the circular obstacle
    obstacle_circle = plt.Circle((obstacle_center_x, obstacle_center_y), obstacle_radius,
                                facecolor='darkgray', edgecolor='black', 
                                linewidth=2, alpha=0.8)
    ax.add_patch(obstacle_circle)
    
    # Mark and annotate slice positions
    for i, pos in enumerate(positions):
        # Calculate x-position
        x_pos_index = int(nx / 7 * pos)
        x_pos_real = x_pos_index * dx
        
        # Draw vertical slice line
        ax.axvline(x=x_pos_real, color=slice_colors[i], linewidth=4, 
                  linestyle='-', alpha=0.9, zorder=5)
        
        # Add position markers at top and bottom
        ax.plot(x_pos_real, Ly, marker='v', color=slice_colors[i], 
               markersize=20, markeredgecolor='white', markeredgewidth=2, zorder=6)
        ax.plot(x_pos_real, 0, marker='^', color=slice_colors[i], 
               markersize=20, markeredgecolor='white', markeredgewidth=2, zorder=6)
        
        # Add position labels
        ax.text(x_pos_real, Ly + 0.18, f'Slice {pos}\nx={x_pos_real:.3f}m', 
               ha='center', va='bottom', fontsize=15, fontweight='bold',
               color=slice_colors[i], alpha=0.9)
    
    # Set equal aspect ratio to maintain circle shape
    ax.set_aspect('equal')
    
    # Set limits to show the domain properly
    ax.set_xlim(-0.1, Lx + 0.1)
    ax.set_ylim(-0.1, Ly + 0.15)  # Extra space for labels
    
    # Remove axes but keep the proper scaling
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Remove spines
    for spine in ax.spines.values():
        spine.set_visible(False)
    
    plt.tight_layout()
    plt.savefig('domain_with_error_profile_slices.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.savefig('domain_with_error_profile_slices.pdf', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    # Print detailed slice information
    print("Domain and Slice Information:")
    print("=" * 60)
    print(f"Physical domain: {Lx:.1f} × {Ly:.1f} m")
    print(f"Grid resolution: {nx} × {ny} points")
    print(f"Grid spacing: Δx = {dx:.6f} m, Δy = {dy:.6f} m")
    print(f"Obstacle: Circular, diameter = {obstacle_diameter:.2f} m")
    print(f"Obstacle position: ({obstacle_center_x:.1f}, {obstacle_center_y:.1f}) m")
    print(f"Reynolds number: ~750 (estimated)")
    print()
    print("Vertical slice positions for error profile analysis:")
    for i, pos in enumerate(positions):
        x_pos_index = int(nx / 7 * pos)
        x_pos_real = x_pos_index * dx
        print(f"  Slice {pos}: Grid index {x_pos_index:3d} → x = {x_pos_real:.4f} m → Color: {slice_colors[i]}")
    print()
    print("These slices correspond to the vertical error profiles shown in previous analysis.")
    print("Error profiles are averaged across all time steps and specified cases.")

# Create the annotated domain visualization
create_domain_with_annotated_slices()

In [None]:
def analyze_spectral_bias_all_random_cases_all_timesteps():
    """Analyze spectral bias for all models using all random configurations across different Reynolds numbers and ALL time steps"""
    from spectral_bias import spectral_bias
    
    if gt_data is None:
        print("No ground truth data available for spectral bias analysis")
        return
    
    print("Computing Spectral Bias for All Models - All Random Layout Configurations - ALL TIME STEPS")
    print("=" * 80)
    
    # Reynolds numbers corresponding to each channel
    reynolds_numbers = [30, 60, 120, 300, 700, 950, 3000, 8000]
    total_channels = gt_data.shape[3]
    
    # Ensure we don't exceed available channels
    if len(reynolds_numbers) > total_channels:
        reynolds_numbers = reynolds_numbers[:total_channels]
    
    # Model types to analyze
    model_types = ['flrnet_fourier_percep', 'flrnet_fourier', 'flrnet_percep', 'flrnet_standard', 'mlp', 'pod']
    model_labels = ['FLRNet-FP', 'FLRNet-F', 'FLRNet-P', 'FLRNet-Std', 'MLP', 'POD']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    markers = ['o', 's', '^', 'D', 'v', 'p']
    
    # Store spectral bias data
    spectral_bias_data = {}
    
    print(f"Reynolds numbers to analyze: {reynolds_numbers}")
    print(f"Time steps used for analysis: ALL ({gt_data.shape[0]} time steps)")
    print()
    
    # Process each model
    for model_type, model_label in zip(model_types, model_labels):
        print(f"Processing {model_label}...")
        
        # Find ALL random configurations for this model type
        random_configs = []
        
        if model_type in organized_data and organized_data[model_type]:
            for file_name, file_info in organized_data[model_type].items():
                info = file_info['info']
                if info['sensor_layout'] == 'random':  # Any random layout, any sensor count
                    random_configs.append({
                        'file_name': file_name,
                        'data': file_info['data'],
                        'num_sensors': info['num_sensors']
                    })
                    print(f"  Found: {file_name} (sensors: {info['num_sensors']})")
            
            if not random_configs:
                print(f"  ✗ No random configurations found")
                continue
        else:
            print(f"  ✗ No data available")
            continue
        
        # Calculate spectral bias for each Reynolds number, averaging across all random configs and ALL time steps
        reynolds_sb_list = []
        
        for i, re_num in enumerate(reynolds_numbers):
            ch = i  # Channel index (0-indexed)
            config_sb_values = []
            
            # Process each random configuration
            for config in random_configs:
                recon_field = config['data']['predictions']
                
                # Reshape if needed
                if len(recon_field.shape) == 4 and recon_field.shape[3] > recon_field.shape[0]:
                    recon_field = np.transpose(recon_field, (3, 1, 2, 0))
                
                if ch < gt_data.shape[3] and ch < recon_field.shape[3]:
                    try:
                        # Calculate spectral bias for ALL time steps
                        timestep_sb_values = []
                        
                        for time_step in range(min(gt_data.shape[0], recon_field.shape[0])):
                            # Get ground truth and prediction for this Reynolds number at this time step
                            gt_field = gt_data[time_step, :, :, ch]  # Shape: (height, width)
                            pred_field = recon_field[time_step, :, :, ch]  # Shape: (height, width)
                            
                            # Compute spectral bias using the imported function
                            sb = spectral_bias(gt_field, pred_field, sigma=5)
                            timestep_sb_values.append(sb)
                        
                        # Average spectral bias across all time steps for this configuration
                        if timestep_sb_values:
                            avg_sb_timesteps = np.mean(timestep_sb_values)
                            config_sb_values.append(avg_sb_timesteps)
                        
                    except Exception as e:
                        print(f"    Re = {re_num}, Config {config['file_name']}: Error computing SB - {e}")
                        continue
            
            # Average spectral bias across all random configurations for this Reynolds number
            if config_sb_values:
                avg_sb = np.mean(config_sb_values)
                std_sb = np.std(config_sb_values)
                reynolds_sb_list.append(avg_sb)
                print(f"    Re = {re_num}: SB = {avg_sb:.6f} ± {std_sb:.6f} (n={len(config_sb_values)} configs, {gt_data.shape[0]} timesteps)")
            else:
                reynolds_sb_list.append(np.nan)
                print(f"    Re = {re_num}: No valid data")
        
        spectral_bias_data[model_type] = {
            'model_label': model_label,
            'reynolds_sb': reynolds_sb_list,
            'num_configs': len(random_configs)
        }
        print(f"  Total random configurations used: {len(random_configs)}")
        print()
    
    # Create spectral bias plot
    if spectral_bias_data:
        fig, ax = plt.subplots(figsize=(14, 8))
        
        # Plot lines for each model
        for i, (model_type, data) in enumerate(spectral_bias_data.items()):
            model_label = data['model_label']
            reynolds_sb = data['reynolds_sb']
            num_configs = data['num_configs']
            
            # Filter out NaN values
            valid_reynolds = []
            valid_sb = []
            for re_num, sb in zip(reynolds_numbers, reynolds_sb):
                if not np.isnan(sb):
                    valid_reynolds.append(re_num)
                    valid_sb.append(sb)
            
            if valid_sb:
                # Update label to show number of configurations averaged
                label_with_count = f"{model_label}"
                
                ax.plot(valid_reynolds, valid_sb, 
                       label=label_with_count,
                       color=colors[i],
                       marker=markers[i],
                       markersize=10,
                       linewidth=3,
                       markerfacecolor='white',
                       markeredgewidth=2,
                       markeredgecolor=colors[i],
                       alpha=0.9)
        
        # Add horizontal line at SB = 0 for reference
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.7, linewidth=2)
        ax.text(reynolds_numbers[0]*1.1, 0.05, 'SB = 0 (Equal bias)', fontsize=12, alpha=0.7)
        
        # Styling
        ax.set_xlabel('Reynolds Number', fontsize=16, fontweight='bold')
        ax.set_ylabel('Spectral Bias', fontsize=16, fontweight='bold')
        ax.set_title('Spectral Bias Across Different Reynolds Numbers', 
                     fontsize=18, fontweight='bold', pad=20)
        
        # Set x-axis to show all Reynolds numbers with log scale
        ax.set_xscale('log')
        ax.set_xticks(reynolds_numbers)
        ax.set_xticklabels([str(re) for re in reynolds_numbers])
        
        # Add grid and legend
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.legend(loc='best', fontsize=14, frameon=True, fancybox=True, shadow=True)
        ax.tick_params(axis='both', labelsize=14)
        
        # # Add explanation text box
        # explanation = ("Spectral Bias Interpretation:\n"
        #               "SB > 0: High-frequency bias (missing fine details)\n"
        #               "SB < 0: Low-frequency bias (over-smoothing)\n"
        #               "SB ≈ 0: Balanced frequency representation\n"
        #               "Values averaged across all random configs and all time steps")
        
        # ax.text(0.02, 0.98, explanation, transform=ax.transAxes, fontsize=11,
        #         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # Add bounding box
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.5)
            spine.set_color('black')
        
        plt.tight_layout()
        plt.savefig('spectral_bias_across_reynolds_all_random_all_timesteps.png', dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig('spectral_bias_across_reynolds_all_random_all_timesteps.pdf', dpi=300, bbox_inches='tight', facecolor='white')
        plt.show()
        
        # Print summary statistics
        print("Spectral Bias Analysis Summary (All Random Configurations, All Time Steps)")
        print("=" * 85)
        print(f"{'Model':<18} {'Configs':<8} {'Min SB':<10} {'Max SB':<10} {'Range':<10} {'Avg SB':<10}")
        print("-" * 85)
        
        for model_type, data in spectral_bias_data.items():
            model_label = data['model_label']
            num_configs = data['num_configs']
            reynolds_sb = [sb for sb in data['reynolds_sb'] if not np.isnan(sb)]
            
            if reynolds_sb:
                min_sb = np.min(reynolds_sb)
                max_sb = np.max(reynolds_sb)
                range_sb = max_sb - min_sb
                avg_sb = np.mean(reynolds_sb)
                
                print(f"{model_label:<18} {num_configs:<8} {min_sb:<10.6f} {max_sb:<10.6f} {range_sb:<10.6f} {avg_sb:<10.6f}")
        
        print("\nReynolds Number-wise Average Spectral Bias (across all models and configs):")
        print("-" * 75)
        for i, re_num in enumerate(reynolds_numbers):
            re_sb = [data['reynolds_sb'][i] for data in spectral_bias_data.values() 
                    if i < len(data['reynolds_sb']) and not np.isnan(data['reynolds_sb'][i])]
            if re_sb:
                avg_re_sb = np.mean(re_sb)
                std_re_sb = np.std(re_sb)
                bias_type = "High-freq bias" if avg_re_sb > 0 else "Low-freq bias" if avg_re_sb < 0 else "Balanced"
                print(f"Re = {re_num:4d}: {avg_re_sb:8.6f} ± {std_re_sb:6.6f} ({bias_type})")
        
        # Summary of configurations used
        print(f"\nConfiguration Summary:")
        print("-" * 60)
        total_configs = sum(data['num_configs'] for data in spectral_bias_data.values())
        total_samples = total_configs * gt_data.shape[0]  # configs × time steps
        print(f"Total random configurations analyzed: {total_configs}")
        print(f"Total time steps per configuration: {gt_data.shape[0]}")
        print(f"Total spectral bias samples: {total_samples}")
        for model_type, data in spectral_bias_data.items():
            if data['num_configs'] > 0:
                samples_per_model = data['num_configs'] * gt_data.shape[0]
                print(f"  {data['model_label']}: {data['num_configs']} configs × {gt_data.shape[0]} timesteps = {samples_per_model} samples")
        
        print("\nSpectral Bias Interpretation Guide:")
        print("- Positive SB: Model has high-frequency bias (captures fine details but may miss smooth variations)")
        print("- Negative SB: Model has low-frequency bias (captures smooth variations but may miss fine details)")
        print("- SB near 0: Model has balanced frequency representation")
        print("- Magnitude indicates severity of bias")
        print("- Values represent averages across all available random sensor configurations and ALL time steps")
        print("- This provides a more comprehensive temporal assessment of spectral bias")
        
        return spectral_bias_data
    else:
        print("No random configurations available for spectral bias analysis")
        return None

# Run the spectral bias analysis for all random cases using ALL time steps
spectral_bias_results_all_random_all_timesteps = analyze_spectral_bias_all_random_cases_all_timesteps()

## Summary and Conclusions

This notebook provides a comprehensive analysis of flow field reconstruction results with enhanced visualizations and improved color mapping for better interpretation of model performance across different conditions.

In [None]:
def print_comprehensive_summary():
    """Print comprehensive summary of all analyses"""
    print("=" * 80)
    print("COMPREHENSIVE FLOW FIELD RECONSTRUCTION ANALYSIS SUMMARY")
    print("=" * 80)
    
    print(f"\nDATA OVERVIEW:")
    print(f"  • Ground truth data shape: {gt_data.shape if gt_data is not None else 'Not available'}")
    print(f"  • Normalization range: [{min_val:.3f}, {max_val:.3f}]")
    print(f"  • Total inference files processed: {len(inference_data)}")
    
    print(f"\nMODEL RESULTS SUMMARY:")
    for model_type in ['flrnet', 'flrnet_fourier', 'mlp', 'pod']:
        count = len(model_results.get(model_type, []))
        print(f"  • {model_type.replace('_', ' ').title()}: {count} results")
        
        if count > 0 and model_type in error_metrics:
            mae_values = [metrics['mae'] for metrics in error_metrics[model_type]]
            if mae_values:
                avg_mae = np.mean(mae_values)
                std_mae = np.std(mae_values)
                print(f"    - Average MAE: {avg_mae:.6f} ± {std_mae:.6f} m/s")
                print(f"    - MAPE: {avg_mae/max_val*100:.3f}%")
    
    print(f"\nVISUALIZATIONS CREATED:")
    print(f"  • Enhanced sensor count effect analysis")
    print(f"  • Enhanced sensor layout comparison")
    print(f"  • Field reconstruction visualization with improved color maps")
    print(f"  • Temporal error evolution analysis")
    print(f"  • Vertical error profile analysis")
    
    print(f"\nCOLOR SCHEME ENHANCEMENTS:")
    print(f"  • Enhanced field visualization colormap with better contrast")
    print(f"  • Improved diverging colormap for error visualization")
    print(f"  • Consistent color coding across all plots")
    print(f"  • High-resolution output (300 DPI) for publication quality")
    
    print(f"\nKEY FINDINGS:")
    if any(len(model_results[model]) > 0 for model in model_results):
        print(f"  • Successfully analyzed multiple model architectures")
        print(f"  • Enhanced visualizations provide better insight into model performance")
        print(f"  • Improved color mapping reveals fine details in reconstruction quality")
    else:
        print(f"  • Limited data available for comprehensive analysis")
        print(f"  • Framework ready for analysis when inference results are available")
    
    print("=" * 80)

# Print comprehensive summary
print_comprehensive_summary()

# Additional utility function for custom analysis
def create_custom_analysis_plot():
    """Template function for custom analysis plots"""
    print("\nCustom Analysis Template:")
    print("This function can be modified to create specific analysis plots")
    print("based on your research requirements.")
    
    # Example: Create a performance heatmap
    if len(model_results['flrnet']) > 0:
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Placeholder for heatmap data
        # This would be filled with actual analysis results
        data = np.random.rand(4, 4)  # Example data
        
        im = ax.imshow(data, cmap='RdYlBu_r', aspect='auto')
        
        # Styling
        ax.set_title('Model Performance Heatmap (Template)', 
                     fontsize=16, fontweight='bold')
        ax.set_xlabel('Configuration', fontsize=14, fontweight='bold')
        ax.set_ylabel('Model Type', fontsize=14, fontweight='bold')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Performance Metric', fontweight='bold')
        
        plt.tight_layout()
        plt.savefig('custom_analysis_template.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print("Custom analysis template created!")

# Run custom analysis template
create_custom_analysis_plot()

print("\n" + "="*50)
print("INFERENCE RESULTS ANALYSIS COMPLETED!")
print("="*50)
print("All enhanced visualizations have been generated with improved color schemes.")
print("High-resolution PNG and PDF files have been saved for publication use.")

In [None]:
def plot_mae_across_channels():
    """Create line plot showing MAE changes across all Reynolds numbers for each model - random_32 only"""
    if gt_data is None:
        print("No ground truth data available for Reynolds number MAE analysis")
        return
    
    print("Creating line plot of MAE across Reynolds numbers")
    print("Using ONLY random_32 configurations")
    print("=" * 50)
    
    # Model types to analyze
    model_types = ['flrnet_fourier_percep', 'flrnet_fourier', 'flrnet_percep', 'flrnet_standard', 'mlp', 'pod']
    model_labels = ['FLRNet-FP', 'FLRNet-F', 'FLRNet-P', 'FLRNet-Std', 'MLP', 'POD']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    markers = ['o', 's', '^', 'D', 'v', 'p']
    
    # Reynolds numbers corresponding to each channel
    reynolds_numbers = [30, 60, 120, 300, 700, 950, 3000, 8000]
    total_channels = gt_data.shape[3]
    
    # Ensure we don't exceed available channels
    if len(reynolds_numbers) > total_channels:
        reynolds_numbers = reynolds_numbers[:total_channels]
        print(f"Warning: Only {total_channels} channels available, using first {len(reynolds_numbers)} Reynolds numbers")
    
    print(f"Reynolds numbers to analyze: {reynolds_numbers}")
    print(f"Total channels available: {total_channels}")
    print()
    
    # Store MAE data for each model across Reynolds numbers
    mae_data = {}
    
    for model_type, model_label in zip(model_types, model_labels):
        print(f"Processing {model_label}...")
        
        # Find random_32 configuration for this model type
        recon_field = None
        selected_file = None
        
        if model_type in organized_data and organized_data[model_type]:
            for file_name, file_info in organized_data[model_type].items():
                info = file_info['info']
                if info['sensor_layout'] == 'random' and info['num_sensors'] == 32:
                    recon_field = file_info['data']['predictions']
                    selected_file = file_name
                    print(f"  Using: {file_name}")
                    break
            
            if recon_field is None:
                print(f"  ✗ No random_32 configuration found")
                continue
        else:
            print(f"  ✗ No data available")
            continue
        
        # Reshape if needed
        if len(recon_field.shape) == 4 and recon_field.shape[3] > recon_field.shape[0]:
            recon_field = np.transpose(recon_field, (3, 1, 2, 0))
        
        # Calculate MAE for each Reynolds number (channel) across all time steps
        reynolds_maes = []
        for i, re_num in enumerate(reynolds_numbers):
            ch = i  # Channel index (0-indexed)
            if ch < gt_data.shape[3] and ch < recon_field.shape[3]:
                # Calculate MAE across all time steps for this Reynolds number
                channel_errors = []
                for t in range(min(gt_data.shape[0], recon_field.shape[0])):
                    gt_ch = gt_data[t, :, :, ch]
                    pred_ch = recon_field[t, :, :, ch]
                    mae_ch = np.mean(np.abs(gt_ch - pred_ch))
                    channel_errors.append(mae_ch)
                
                avg_mae_reynolds = np.mean(channel_errors)
                reynolds_maes.append(avg_mae_reynolds)
                print(f"    Re = {re_num}: MAE = {avg_mae_reynolds:.6f}")
            else:
                reynolds_maes.append(np.nan)
                print(f"    Re = {re_num}: No data available")
        
        mae_data[model_type] = {
            'model_label': model_label,
            'reynolds_maes': reynolds_maes
        }
        print()
    
    # Create the line plot
    if mae_data:
        fig, ax = plt.subplots(figsize=(14, 8))
        
        # Plot lines for each model
        for i, (model_type, data) in enumerate(mae_data.items()):
            model_label = data['model_label']
            reynolds_maes = data['reynolds_maes']
            
            # Filter out NaN values
            valid_reynolds = []
            valid_maes = []
            for re_num, mae in zip(reynolds_numbers, reynolds_maes):
                if not np.isnan(mae):
                    valid_reynolds.append(re_num)
                    valid_maes.append(mae)
            
            if valid_maes:
                ax.plot(valid_reynolds, valid_maes, 
                       label=model_label,
                       color=colors[i],
                       marker=markers[i],
                       markersize=8,
                       linewidth=2.5,
                       markerfacecolor='white',
                       markeredgewidth=2,
                       markeredgecolor=colors[i],
                       alpha=0.9)
        
        # Styling
        ax.set_xlabel('Reynolds Number', fontsize=16, fontweight='bold')
        ax.set_ylabel('Mean Absolute Error [m/s]', fontsize=16, fontweight='bold')
        ax.set_title('MAE Variation Across Reynolds Numbers for Different Models\n(Random Layout, 32 Sensors)', 
                     fontsize=18, fontweight='bold', pad=20)
        
        # Set x-axis to show all Reynolds numbers with log scale
        ax.set_xscale('log')
        ax.set_xticks(reynolds_numbers)
        ax.set_xticklabels([str(re) for re in reynolds_numbers])
        
        # Add grid and legend
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.legend(loc='best', fontsize=16, frameon=True, fancybox=True, shadow=True)
        ax.tick_params(axis='both', labelsize=16)
        
        # Add bounding box
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.5)
            spine.set_color('black')
        
        # Add statistical annotations
        # Find Reynolds number with highest and lowest average MAE across all models
        all_reynolds_maes = []
        for i in range(len(reynolds_numbers)):
            re_maes = []
            for data in mae_data.values():
                if i < len(data['reynolds_maes']) and not np.isnan(data['reynolds_maes'][i]):
                    re_maes.append(data['reynolds_maes'][i])
            if re_maes:
                all_reynolds_maes.append(np.mean(re_maes))
            else:
                all_reynolds_maes.append(np.nan)

        plt.tight_layout()
        plt.savefig('mae_across_reynolds_numbers_random_32.png', dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig('mae_across_reynolds_numbers_random_32.pdf', dpi=300, bbox_inches='tight', facecolor='white')
        plt.show()
        
        # Print summary statistics
        print("Summary Statistics: MAE Across Reynolds Numbers")
        print("=" * 70)
        print(f"{'Model':<15} {'Min MAE':<10} {'Max MAE':<10} {'Range':<10} {'Avg MAE':<10}")
        print("-" * 70)
        
        for model_type, data in mae_data.items():
            model_label = data['model_label']
            reynolds_maes = [mae for mae in data['reynolds_maes'] if not np.isnan(mae)]
            
            if reynolds_maes:
                min_mae = np.min(reynolds_maes)
                max_mae = np.max(reynolds_maes)
                range_mae = max_mae - min_mae
                avg_mae = np.mean(reynolds_maes)
                
                print(f"{model_label:<15} {min_mae:<10.6f} {max_mae:<10.6f} {range_mae:<10.6f} {avg_mae:<10.6f}")
        
        print("\nReynolds Number-wise Average MAE (across all models):")
        print("-" * 50)
        for i, re_num in enumerate(reynolds_numbers):
            re_maes = [data['reynolds_maes'][i] for data in mae_data.values() 
                      if i < len(data['reynolds_maes']) and not np.isnan(data['reynolds_maes'][i])]
            if re_maes:
                avg_re_mae = np.mean(re_maes)
                print(f"Re = {re_num:4d}: {avg_re_mae:.6f}")
        
        print("\nInterpretation:")
        print("- Higher MAE indicates more challenging reconstruction for that Reynolds number")
        print("- Generally higher Re flows are more complex and challenging to reconstruct")
        print("- Log scale on x-axis shows the wide range of Reynolds numbers tested")
        print("- Flat lines suggest consistent performance across Reynolds numbers")
        print("- Steep variations indicate Reynolds number-specific challenges")
        
    else:
        print("No random_32 data available for Reynolds number MAE analysis")

# Run the Reynolds number MAE line plot analysis
plot_mae_across_channels()

In [None]:
def analyze_noise_impact_on_error():
    """Analyze how reconstruction error changes with increasing noise levels"""
    import matplotlib.pyplot as plt
    import numpy as np
    import glob
    import os
    
    print("Analyzing Noise Impact on Reconstruction Error")
    print("=" * 60)
    
    # Define noise levels to analyze
    noise_levels = [0, 5, 10, 15, 20]  # 0 represents the normal file without noise
    
    # Model types and their filename patterns
    model_patterns = {
        'flrnet_fourier_percep': 'inference_random_32_standard',
        'flrnet_fourier': 'inference_random_32_fourier', 
        'flrnet_percep': 'inference_random_32_no_fourier',
        'flrnet_standard': 'inference_random_32_no_fourier_no_percep',
        'mlp': 'mlp_random_32_standard',
        'pod': 'pod_random_32_standard'
    }
    
    model_labels = {
        'flrnet_fourier_percep': 'FLRNet-FP',
        'flrnet_fourier': 'FLRNet-F',
        'flrnet_percep': 'FLRNet-P',
        'flrnet_standard': 'FLRNet-Std',
        'mlp': 'MLP',
        'pod': 'POD'
    }
    
    # Storage for results
    noise_error_data = {}
    
    # Search for files in inference results directory
    inference_dir = r"E:\Research\Physics-informed-machine-learning\flow_field_recon_parc\checkpoints\inference_results"
    
    print(f"Searching for noise files in: {inference_dir}")
    print()
    
    for model_type, base_pattern in model_patterns.items():
        print(f"Processing {model_labels[model_type]}...")
        
        model_errors = []
        found_files = []
        
        for noise_level in noise_levels:
            error_value = None
            file_found = False
            
            if noise_level == 0:
                # Look for file without noise suffix
                if model_type.startswith('flrnet'):
                    search_pattern = f"{base_pattern}.npz"
                else:
                    search_pattern = f"{base_pattern}.npz"
            else:
                # Look for file with noise suffix
                if model_type.startswith('flrnet'):
                    search_pattern = f"{base_pattern}_noise_{noise_level}pct.npz"
                else:
                    search_pattern = f"{base_pattern}_noise_{noise_level}pct.npz"
            
            # Search for the file
            full_pattern = os.path.join(inference_dir, search_pattern)
            matching_files = glob.glob(full_pattern)
            
            if matching_files:
                file_path = matching_files[0]
                file_found = True
                found_files.append(os.path.basename(file_path))
                
                try:
                    # Load the file and calculate error
                    data = np.load(file_path)
                    
                    if 'predictions' in data and 'targets' in data:
                        predictions = data['predictions']
                        targets = data['targets']
                        
                        # Ensure consistent shapes
                        if len(predictions.shape) == 4 and predictions.shape[3] > predictions.shape[0]:
                            predictions = np.transpose(predictions, (3, 1, 2, 0))
                            targets = np.transpose(targets, (3, 1, 2, 0))
                        
                        # Calculate Mean Absolute Error across all data
                        mae = np.mean(np.abs(targets - predictions))
                        error_value = mae
                        
                        print(f"  Noise {noise_level}%: MAE = {mae:.6f} (file: {os.path.basename(file_path)})")
                    else:
                        print(f"  Noise {noise_level}%: File missing predictions/targets")
                        
                except Exception as e:
                    print(f"  Noise {noise_level}%: Error loading file - {e}")
            else:
                print(f"  Noise {noise_level}%: File not found (pattern: {search_pattern})")
            
            model_errors.append(error_value)
        
        if any(error is not None for error in model_errors):
            noise_error_data[model_type] = {
                'label': model_labels[model_type],
                'errors': model_errors,
                'files': found_files
            }
            print(f"  ✓ {model_labels[model_type]}: {sum(1 for e in model_errors if e is not None)}/{len(noise_levels)} files found")
        else:
            print(f"  ✗ {model_labels[model_type]}: No valid files found")
        
        print()
    
    # Create the noise impact plot
    if noise_error_data:
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Enhanced styling
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
        markers = ['o', 's', '^', 'D', 'v', 'p']
        linestyles = ['-', '-', '-', '-', '-', '-']
        
        # Plot data for each model
        for i, (model_type, data) in enumerate(noise_error_data.items()):
            model_label = data['label']
            errors = data['errors']
            
            # Filter out None values and corresponding noise levels
            valid_noise_levels = []
            valid_errors = []
            
            for noise_level, error in zip(noise_levels, errors):
                if error is not None:
                    valid_noise_levels.append(noise_level)
                    valid_errors.append(error)
            
            if valid_errors:
                ax.plot(valid_noise_levels, valid_errors,
                       label=model_label,
                       color=colors[i % len(colors)],
                       marker=markers[i % len(markers)],
                       markersize=8,
                       linewidth=3,
                       markerfacecolor='white',
                       markeredgewidth=2,
                       markeredgecolor=colors[i % len(colors)],
                       linestyle=linestyles[i % len(linestyles)],
                       alpha=0.9)
        
        # Styling
        ax.set_xlabel('Noise Level (%)', fontsize=16, fontweight='bold')
        ax.set_ylabel('Mean Absolute Error [m/s]', fontsize=16, fontweight='bold')
        ax.set_title('Impact of Noise Perturbation in Sensor Measurement on Reconstruction Error', 
                     fontsize=18, fontweight='bold', pad=20)
        
        # Set x-axis to show all noise levels
        ax.set_xticks(noise_levels)
        ax.set_xticklabels([f'{level}%' for level in noise_levels])
        
        # Add grid and legend
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.legend(loc='best', fontsize=14, frameon=True, fancybox=True, shadow=True)
        ax.tick_params(axis='both', labelsize=14)
        
        # Add bounding box
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.5)
            spine.set_color('black')
        
        # Set y-axis to start from 0
        ax.set_ylim(bottom=0)

        
        plt.tight_layout()
        plt.savefig('noise_impact_on_reconstruction_error.png', dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig('noise_impact_on_reconstruction_error.pdf', dpi=300, bbox_inches='tight', facecolor='white')
        plt.show()
        
        # Print summary statistics
        print("Summary Statistics: Noise Impact Analysis")
        print("=" * 70)
        print(f"{'Model':<15} {'0% MAE':<12} {'20% MAE':<12} {'Increase':<12} {'% Change':<12}")
        print("-" * 70)
        
        for model_type, data in noise_error_data.items():
            model_label = data['label']
            errors = data['errors']
            
            # Get 0% and 20% noise errors if available
            mae_0 = errors[0] if errors[0] is not None else np.nan
            mae_20 = errors[4] if len(errors) > 4 and errors[4] is not None else np.nan
            
            if not np.isnan(mae_0) and not np.isnan(mae_20):
                increase = mae_20 - mae_0
                percent_change = (increase / mae_0) * 100
                
                print(f"{model_label:<15} {mae_0:<12.6f} {mae_20:<12.6f} {increase:<12.6f} {percent_change:<12.2f}")
        
        print(f"\nNoise Level Analysis:")
        print("-" * 50)
        for i, noise_level in enumerate(noise_levels):
            level_errors = []
            for data in noise_error_data.values():
                if i < len(data['errors']) and data['errors'][i] is not None:
                    level_errors.append(data['errors'][i])
            
            if level_errors:
                avg_error = np.mean(level_errors)
                std_error = np.std(level_errors)
                print(f"Noise {noise_level:2d}%: Avg MAE = {avg_error:.6f} ± {std_error:.6f} ({len(level_errors)} models)")
        
        print(f"\nInterpretation:")
        print("- Higher noise levels generally increase reconstruction error")
        print("- Some models may be more robust to noise than others")
        print("- The slope of each line indicates noise sensitivity")
        print("- Steeper slopes indicate higher noise sensitivity")
        print("- Models with flatter curves are more noise-resistant")
        
        return noise_error_data
    else:
        print("No valid noise data files found for analysis")
        return None

# Run the noise impact analysis
noise_analysis_results = analyze_noise_impact_on_error()