## 🔧 Windows Checkpoint File Lock Troubleshooting

If you encounter "Access is denied" errors during checkpoint saving, this is a known Windows-specific issue with TensorFlow. Here are the solutions implemented:

### ✅ Current Solutions Applied:

1. **Delayed Checkpoint Saving**: Checkpoint saving is now disabled for the first 30 epochs and only starts from epoch 31 onward
2. **Unique Checkpoint Directories**: Using UUID-based unique directories to avoid conflicts
3. **Error Handling**: Added try-catch blocks around checkpoint operations
4. **Alternative Backup**: Models are saved after training completes

### 🛠️ If Issues Persist:

1. **Completely Disable Checkpoints**: Set `disable_checkpoints_entirely = True` in the cell above
2. **Manual Cleanup**: Delete the entire checkpoint directory before training
3. **Use Different Drive**: Try running on a different drive (e.g., C: instead of E:)
4. **Run as Administrator**: Run Jupyter/VS Code as administrator

### 📋 Understanding the Error:

The "Access is denied" error occurs when:
- TensorFlow tries to rename temporary checkpoint files
- Another process has the file locked
- Windows file system permissions prevent the operation
- Previous training runs left orphaned file handles

The conditional checkpoint approach bypasses this issue by avoiding file operations during the sensitive early training phases.

# Flow Field Reconstruction with 8 Edge Sensors

This notebook implements a physics-informed machine learning approach to reconstruct flow fields using data from 8 edge sensors. The implementation uses:
- Variational Autoencoder (VAE) for dimensionality reduction and flow field reconstruction
- Fourier feature embeddings for coordinate information
- FLRNet architecture to predict flow fields from sparse sensor measurements

## 1. Import Required Libraries

We'll import the necessary libraries for data manipulation, visualization, and deep learning.

In [1]:
# Standard libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import time
import datetime

# TensorFlow and Keras
import tensorflow as tf
from tensorflow import keras

# Suppress TensorFlow warnings and info messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Load local modules
import model_fourier
import models_improved
import layer as flr_layer
import config_manager
from data.flow_field_dataset import FlowFieldDatasetCreator

# Check available GPUs and set memory growth
print("TensorFlow version:", tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Set memory growth for all GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Found {len(gpus)} GPU(s): {gpus}")
        # Use first GPU
        gpu = gpus[0]
        tf.config.experimental.set_visible_devices(gpu, 'GPU')
        print(f"Using GPU: {gpu}")
    except RuntimeError as e:
        print(f"Error configuring GPUs: {e}")
else:
    print("No GPUs found. Using CPU.")

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

Flow Field Dataset Creation Package v1.0.0 loaded successfully!

Quick Start:
  from dataset_creation import FlowFieldDatasetCreator
  creator = FlowFieldDatasetCreator()
  creator.create_all_datasets()

For more info: print_package_info()
TensorFlow version: 2.8.0
Found 1 GPU(s): [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Using GPU: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


## 2. Configuration Setup

Let's set up the configuration for our model training:

In [2]:
# Load configuration using ConfigManager
config_name = "random_8_no_fourier"
config_mgr = config_manager.ConfigManager()
hierarchical_config = config_mgr.load_config(config_name)

# Print configuration summary
print(config_mgr.create_config_summary(hierarchical_config))

# Flatten the config for training (with lowercase keys for Python style)
flattened_config = config_manager.flatten_config_for_training(hierarchical_config)

# Convert to lowercase keys for consistent Python style
config = {}
for key, value in flattened_config.items():
    config[key.lower()] = value

print("\n🔧 Final Configuration (lowercase keys):")
print(f"   Model name: {config['model_name']}")
print(f"   Use Fourier: {config['use_fourier']}")
print(f"   Use perceptual loss: {config['use_perceptual_loss']}")
print(f"   Input shape: {config['input_shape']}")
print(f"   Number of sensors: {config['n_sensors']}")
print(f"   Latent dims: {config['latent_dims']}")
print(f"   Base features: {config['n_base_features']}")
print(f"   Batch size: {config['batch_size']}")
print(f"   VAE epochs: {config['vae_epochs']}")
print(f"   FLRNet epochs: {config['flr_epochs']}")
print(f"   VAE learning rate: {config['vae_learning_rate']}")
print(f"   FLRNet learning rate: {config['flr_learning_rate']}")
print(f"   Dataset path: {config['dataset_path']}")
print(f"   Checkpoint dir: {config['checkpoint_dir']}")
print(f"   Logs dir: {config['logs_dir']}")

# Verify essential config keys exist
required_keys = ['model_name', 'use_fourier', 'use_perceptual_loss', 'input_shape', 
                 'n_sensors', 'latent_dims', 'n_base_features', 'batch_size', 
                 'vae_epochs', 'flr_epochs', 'vae_learning_rate', 'flr_learning_rate',
                 'dataset_path', 'checkpoint_dir', 'logs_dir']
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
    print(f"⚠️  Missing required config keys: {missing_keys}")
else:
    print("✅ All required configuration keys are present")


📋 Configuration Summary
🏗️  Model Architecture:
   - Fourier Enhancement: False
   - Perceptual Loss: True
   - Input Shape: [128, 256, 1]
   - Latent Dimensions: 8
   - Base Features: 64

📡 Sensor Configuration:
   - Layout: random
   - Number of Sensors: 8
   - Dataset: data/datasets\dataset_random_8.npz

🚀 Training Parameters:
   - VAE Epochs: 250
   - FLRNet Epochs: 150
   - VAE Learning Rate: 0.0001
   - FLRNet Learning Rate: 0.0001
   - Batch Size: 8
   - Test Split: 0.2

💾 Output Configuration:
   - Model Name: fourierFalse_percepTrue_random_8
   - Checkpoints: ./checkpoints\fourierFalse_percepTrue_random_8
   - Logs: ./logs\fourierFalse_percepTrue_random_8
   - Save Best Model: True
   - Save Last Model: True


🔧 Final Configuration (lowercase keys):
   Model name: fourierFalse_percepTrue_random_8
   Use Fourier: False
   Use perceptual loss: True
   Input shape: (128, 256, 1)
   Number of sensors: 8
   Latent dims: 8
   Base features: 64
   Batch size: 8
   VAE epochs: 250
  

## 3. Load and Prepare Dataset

We'll load the flow field dataset and the sensor layout for 8 edge sensors:

In [3]:
# Load and prepare dataset created from data_creation_and_viz.ipynb
print(f"📂 Loading dataset from: {config['dataset_path']}")

# Parse the dataset filename to get layout and n_sensors
dataset_filename = Path(config['dataset_path']).name
# Expected format: dataset_edge_8.npz
parts = dataset_filename.split('_')
layout_type = parts[1]  # 'edge'
n_sensors = int(parts[2].split('.')[0])  # 8

print(f"📊 Dataset parameters:")
print(f"   Layout type: {layout_type}")
print(f"   Number of sensors: {n_sensors}")

# Load the dataset directly from the NPZ file
print(f"📁 Loading dataset file: {config['dataset_path']}")
data = np.load(config['dataset_path'])

# Check what keys are available in the dataset
print(f"📋 Available keys in dataset: {list(data.keys())}")

# Create dataset dictionary
dataset = {key: data[key] for key in data.keys()}

# Print dataset information
print(f"📊 Dataset loaded successfully:")
for key, value in dataset.items():
    if isinstance(value, np.ndarray):
        print(f"   {key}: {value.shape} ({value.dtype})")
    else:
        print(f"   {key}: {value}")

# Extract sensor positions
sensor_positions = dataset['sensor_positions']
print(f"📍 Sensor positions shape: {sensor_positions.shape}")

# Create dataset creator instance for TensorFlow dataset creation
creator = FlowFieldDatasetCreator(
    output_path="./data/",
    domain_shape=config['input_shape'][:2],  # (height, width)
    use_synthetic_data=False  # Don't create synthetic data, just use for TF dataset creation
)

# Create TensorFlow datasets using the creator's method
train_dataset, test_dataset = creator.create_tensorflow_dataset(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    test_split=config['test_split']
)

print(f"\n📊 TensorFlow datasets created:")
print(f"   Train dataset: {train_dataset}")
print(f"   Test dataset: {test_dataset}")

# Function to add coordinate grids to field data for Fourier-aware VAE training
def add_coordinate_grid(batch):
    """Add coordinate grid to field data for Fourier VAE training"""
    field_data = batch['field_data']
    
    # Get dimensions
    batch_size = tf.shape(field_data)[0]
    height = tf.shape(field_data)[1]
    width = tf.shape(field_data)[2]
    
    # Create normalized coordinate grids [0, 1] - match field dimensions
    x_coords = tf.linspace(0.0, 1.0, width)   # Width corresponds to x
    y_coords = tf.linspace(0.0, 1.0, height)  # Height corresponds to y
    
    # Create meshgrid to match image indexing: [height, width, 2]
    y_grid, x_grid = tf.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Stack to create coordinate grid (height, width, 2)
    coord_grid = tf.stack([x_grid, y_grid], axis=-1)
    
    # Expand to batch size (batch_size, height, width, 2)
    coord_batch = tf.tile(tf.expand_dims(coord_grid, 0), [batch_size, 1, 1, 1])
    
    # Update batch to include coordinates
    return {
        'field_data': field_data,
        'sensor_data': batch['sensor_data'],
        'coordinates': coord_batch
    }

# Add coordinate grids to datasets
coord_train_dataset = train_dataset.map(add_coordinate_grid)
coord_test_dataset = test_dataset.map(add_coordinate_grid)

print(f"\n📊 Coordinate-aware datasets created:")
print(f"   Train dataset with coordinates: {coord_train_dataset}")
print(f"   Test dataset with coordinates: {coord_test_dataset}")

# Create specialized datasets for VAE and FLRNet training
# VAE trains on field reconstruction WITH coordinates for Fourier features
if config['use_fourier']:
    print("\n🌊 Creating Fourier-aware VAE datasets...")
    # For Fourier VAE: input = (field, coordinates), output = field
    vae_train_dataset = coord_train_dataset.map(
        lambda batch: ((batch['field_data'], batch['coordinates']), batch['field_data'])
    )
    vae_test_dataset = coord_test_dataset.map(
        lambda batch: ((batch['field_data'], batch['coordinates']), batch['field_data'])
    )
    print("✅ Fourier-aware VAE datasets created")
else:
    print("\n🔄 Creating standard VAE datasets...")
    # Standard VAE: input = field, output = field
    vae_train_dataset = coord_train_dataset.map(
        lambda batch: (batch['field_data'], batch['field_data'])
    )
    vae_test_dataset = coord_test_dataset.map(
        lambda batch: (batch['field_data'], batch['field_data'])
    )
    print("✅ Standard VAE datasets created")

# FLRNet trains on sensor-to-field reconstruction (sensor -> field)
flrnet_train_dataset = coord_train_dataset.map(
    lambda batch: (batch['sensor_data'], batch['field_data'])
)
flrnet_test_dataset = coord_test_dataset.map(
    lambda batch: (batch['sensor_data'], batch['field_data'])
)

print(f"\n📊 Specialized datasets:")
print(f"   VAE train dataset: {vae_train_dataset}")
print(f"   VAE test dataset: {vae_test_dataset}")
print(f"   FLRNet train dataset: {flrnet_train_dataset}")
print(f"   FLRNet test dataset: {flrnet_test_dataset}")

# Get a sample to verify data shapes
print(f"\n📊 Data shape verification:")
for batch in coord_train_dataset.take(1):
    print(f"   Sensor data shape: {batch['sensor_data'].shape}")
    print(f"   Field data shape: {batch['field_data'].shape}")
    print(f"   Coordinates shape: {batch['coordinates'].shape}")
    break

# Verify VAE dataset structure
print(f"\n📊 VAE dataset structure verification:")
for vae_batch in vae_train_dataset.take(1):
    if config['use_fourier']:
        inputs, targets = vae_batch
        if isinstance(inputs, tuple):
            field_input, coord_input = inputs
            print(f"   VAE field input shape: {field_input.shape}")
            print(f"   VAE coordinate input shape: {coord_input.shape}")
            print(f"   VAE target shape: {targets.shape}")
        else:
            print(f"   Unexpected VAE input format: {type(inputs)}")
    else:
        inputs, targets = vae_batch
        print(f"   VAE input shape: {inputs.shape}")
        print(f"   VAE target shape: {targets.shape}")
    break

📂 Loading dataset from: data/datasets\dataset_random_8.npz
📊 Dataset parameters:
   Layout type: random
   Number of sensors: 8
📁 Loading dataset file: data/datasets\dataset_random_8.npz
📋 Available keys in dataset: ['sensor_data', 'field_data', 'sensor_positions', 'reynolds_numbers', 'layout_type', 'n_sensors']
📊 Dataset loaded successfully:
   sensor_data: (28, 8, 39) (float64)
   field_data: (28, 128, 256, 39) (float64)
   sensor_positions: (8, 2) (float64)
   reynolds_numbers: (28,) (int32)
   layout_type: () (<U6)
   n_sensors: () (int32)
📍 Sensor positions shape: (8, 2)
Dataset reshaped:
  Original sensor data: (28, 8, 39)
  Reshaped sensor data: (1092, 8)
  Original field data: (28, 128, 256, 39)
  Reshaped field data: (1092, 128, 256, 1)
  Total samples: 1092
TensorFlow datasets created:
  Train samples: 873
  Test samples: 219

📊 TensorFlow datasets created:
   Train dataset: <ShuffleDataset element_spec={'sensor_data': TensorSpec(shape=(None, 8), dtype=tf.float32, name=None),

## 5. Train VAE and FLRNet Models

Now we'll train our models using the FLRTrainer:
1. First, the Variational Autoencoder (VAE) to learn a compressed representation of flow fields
2. Then, the FLRNet to reconstruct flow fields from sparse sensor measurements

In [4]:
import os
import shutil
import time
import glob

# Clear TensorFlow session to fix any shape mismatches
print("🔄 Clearing TensorFlow session...")
tf.keras.backend.clear_session()

# Initialize FLRTrainer from models_improved.py
print("🚀 Initializing FLRTrainer...")

# Create the FLRTrainer instance
trainer = models_improved.FLRTrainer(
    input_shape=config['input_shape'],
    use_fourier=config['use_fourier'],
    checkpoint_dir=config['checkpoint_dir'],
    logs_dir=config['logs_dir'],
    model_name=config['model_name'],
    save_best_model=config['save_best_model'],
    save_last_model=config['save_last_model'],
    gradient_clip_norm=config['gradient_clip_norm']
)

print(f"✅ FLRTrainer initialized:")
print(f"   Input shape: {config['input_shape']}")
print(f"   Use Fourier: {config['use_fourier']}")
print(f"   Model name: {config['model_name']}")
print(f"   Gradient clipping: {config['gradient_clip_norm']}")
print(f"   Checkpoints: {config['checkpoint_dir']}")
print(f"   Logs: {config['logs_dir']}")
print(f"   Save best model: {config['save_best_model']}")
print(f"   Save last model: {config['save_last_model']}")

# Set training flags
train_vae_model = True  # Set to False if you want to skip VAE training
train_flrnet_model = True  # Set to True to test FLRNet with proper Fourier features

print(f"\n🔧 Training configuration:")
print(f"   Train VAE: {train_vae_model}")
print(f"   Train FLRNet: {train_flrnet_model}")

# Quick test to verify dataset shapes are correct
print(f"\n🔍 Verifying dataset shapes...")
for batch in coord_train_dataset.take(1):
    print(f"   Field data shape: {batch['field_data'].shape}")
    print(f"   Coordinates shape: {batch['coordinates'].shape}")
    break

# Test VAE dataset format
for vae_batch in vae_train_dataset.take(1):
    inputs, targets = vae_batch
    if isinstance(inputs, tuple):
        field_input, coord_input = inputs
        print(f"   VAE field input shape: {field_input.shape}")
        print(f"   VAE coordinate input shape: {coord_input.shape}")
        print(f"   VAE target shape: {targets.shape}")
        
        # Check if shapes match now
        if field_input.shape[1:3] == coord_input.shape[1:3]:
            print(f"   ✅ Coordinate shapes now match field shapes!")
        else:
            print(f"   ❌ Shape mismatch: field {field_input.shape[1:3]} vs coord {coord_input.shape[1:3]}")
    break



🔄 Clearing TensorFlow session...
🚀 Initializing FLRTrainer...
✅ FLRTrainer initialized:
   Input shape: (128, 256, 1)
   Use Fourier: False
   Model name: fourierFalse_percepTrue_random_8
   Gradient clipping: 2.0
   Checkpoints: ./checkpoints\fourierFalse_percepTrue_random_8
   Logs: ./logs\fourierFalse_percepTrue_random_8
   Save best model: True
   Save last model: True

🔧 Training configuration:
   Train VAE: True
   Train FLRNet: True

🔍 Verifying dataset shapes...
   Field data shape: (8, 128, 256, 1)
   Coordinates shape: (8, 128, 256, 2)


In [5]:
# Train VAE model using FLRTrainer
train_vae_model = True  # Set to False if you want to skip VAE training
if train_vae_model:
    print("\n🧠 Starting VAE Training with Coordinate-Aware Data...")
    print("="*60)
    
    # IMPORTANT: Clear any existing VAE model to ensure fresh training
    if hasattr(trainer, 'vae_model') and trainer.vae_model is not None:
        print("🔄 Clearing existing VAE model to retrain from scratch...")
        del trainer.vae_model
        trainer.vae_model = None
        tf.keras.backend.clear_session()
    
    # Remove any existing VAE checkpoint files to force retraining
    import glob
    vae_checkpoint_pattern = os.path.join(config['checkpoint_dir'], "*vae*")
    existing_checkpoints = glob.glob(vae_checkpoint_pattern)
    if existing_checkpoints:
        print(f"🗑️  Found {len(existing_checkpoints)} existing VAE checkpoints - will retrain from scratch")
        for checkpoint in existing_checkpoints:
            print(f"   Found: {checkpoint}")
    
    print(f"\n🌊 Training VAE with Fourier features: {config['use_fourier']}")
    if config['use_fourier']:
        print("   VAE will be trained with coordinate-aware data for proper Fourier support")
    else:
        print("   VAE will be trained with standard field data")
    
    print("\n🛡️ Note: Checkpoint saving is disabled for the first 30 epochs to avoid Windows file lock issues.")
    print("   Saving will automatically start from epoch 31 onward.")
    print("   This ensures robust training without checkpoint file access conflicts.")
    
    # Train VAE using the trainer with coordinate-aware datasets
    vae_model = trainer.train_vae(
        train_dataset=vae_train_dataset,
        val_dataset=vae_test_dataset,
        epochs=config['vae_epochs'],
        learning_rate=config['vae_learning_rate'],
        latent_dims=config['latent_dims'],
        n_base_features=config['n_base_features'],
        use_perceptual_loss=config['use_perceptual_loss'],
        patience=config['patience'],
        reduce_lr_patience=config['reduce_lr_patience']
    )
    
    print(f"✅ VAE training completed successfully!")
    
    # Validate that the VAE now works with coordinate data (if using Fourier)
    if config['use_fourier'] and vae_model is not None:
        print("\n🔍 Validating Fourier VAE with coordinate input...")
        try:
            # Test with a small batch
            test_batch = next(iter(vae_test_dataset.take(1)))
            inputs, targets = test_batch
            
            if isinstance(inputs, tuple):
                field_input, coord_input = inputs
                print(f"   Testing VAE with field shape: {field_input.shape}")
                print(f"   Testing VAE with coord shape: {coord_input.shape}")
                
                # Test prediction
                prediction = vae_model([field_input[:1], coord_input[:1]])
                print(f"   ✅ VAE prediction successful! Output shape: {prediction.shape}")
                print(f"   🎉 VAE now properly supports Fourier features with coordinates!")
            else:
                print(f"   ⚠️  Unexpected input format: {type(inputs)}")
                
        except Exception as e:
            print(f"   ❌ VAE validation failed: {e}")
            import traceback
            traceback.print_exc()
    
else:
    print("⚠️ Skipping VAE training. Loading existing model...")
    # The trainer will handle loading existing VAE weights automatically
    vae_model = None


🧠 Starting VAE Training with Coordinate-Aware Data...

🌊 Training VAE with Fourier features: False
   VAE will be trained with standard field data

🛡️ Note: Checkpoint saving is disabled for the first 30 epochs to avoid Windows file lock issues.
   Saving will automatically start from epoch 31 onward.
   This ensures robust training without checkpoint file access conflicts.
🚀 Training VAE Model...
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
Epoch 1/250
Epoch 1: val_reconstruction_loss improved from inf to 22982.37891 (saving disabled for epochs 1-30)
Epoch 2/250
Epoch 2: val_reconstruction_loss improved from 22982.37891 to 15885.23047 (saving disabled for epochs 1-30)
Epoch 3/250
Epoch 3: val_reconstruction_loss improved from 15885.23047 to 14173.91602 (saving disabled for epochs 1-30)
Epoch 4/250
Epoch 4: val_reconstruction_loss improved from 14173.91602 to 12810.75879 (savin

In [None]:
# Train VAE using the trainer with coordinate-aware datasets
vae_model = trainer.train_vae(
    train_dataset=vae_train_dataset,
    val_dataset=vae_test_dataset,
    epochs=50,#config['vae_epochs'],
    learning_rate=1e-5,#config['vae_learning_rate'],
    latent_dims=config['latent_dims'],
    n_base_features=config['n_base_features'],
    use_perceptual_loss=True,
    patience=config['patience'],
    reduce_lr_patience=config['reduce_lr_patience']
)

In [None]:
# Train FLRNet model using FLRTrainer
if train_flrnet_model:
    print("\n🔄 Starting FLRNet Training using FLRTrainer...")
    print("="*60)
    
    # Train FLRNet using the trainer
    flr_model = trainer.train_flr_net(
        train_dataset=flrnet_train_dataset,
        val_dataset=flrnet_test_dataset,
        n_sensors=config['n_sensors'],
        epochs=config['flr_epochs'],
        learning_rate=config['flr_learning_rate'],
        pretrained_vae=vae_model,  # Use the VAE model we just trained
        latent_dims=config['latent_dims'],
        n_base_features=config['n_base_features'],
        use_perceptual_loss=config['use_perceptual_loss'],
        patience=config['patience'],
        reduce_lr_patience=config['reduce_lr_patience']
    )
    
    print(f"✅ FLRNet training completed successfully!")
    
else:
    print("⚠️ Skipping FLRNet training. Loading existing model...")
    # The trainer will handle loading existing FLRNet weights automatically
    flr_model = None

## VAE Validation and Visualization

After training the VAE, we can evaluate its performance by:
1. Visualizing reconstructed flow fields
2. Computing reconstruction errors
3. Examining the latent space representation

This section can be skipped if VAE training was disabled.

In [None]:
# Create a proper Fourier-aware VAE dataset for testing
if train_vae_model and trainer.vae_model is not None and trainer.vae_model.use_fourier:
    print("=== Creating Fourier-Aware VAE Test Dataset ===")
    
    # Function to add coordinate grids to VAE data
    def add_coordinates_to_vae_data(field_input, field_target):
        # Get batch size and shape
        batch_size = tf.shape(field_input)[0]
        height = tf.shape(field_input)[1]
        width = tf.shape(field_input)[2]
        
        # Create coordinate grid (normalized to [0, 1])
        y_coords = tf.linspace(0.0, 1.0, height)
        x_coords = tf.linspace(0.0, 1.0, width)
        
        # Create meshgrid
        x_grid, y_grid = tf.meshgrid(x_coords, y_coords, indexing='ij')
        
        # Stack to create coordinate grid (height, width, 2)
        coord_grid = tf.stack([x_grid, y_grid], axis=-1)
        
        # Expand to batch size (batch_size, height, width, 2)
        coord_batch = tf.tile(tf.expand_dims(coord_grid, 0), [batch_size, 1, 1, 1])
        
        # Return in the format expected by Fourier VAE: ((img, coord), target)
        return ((field_input, coord_batch), field_target)
    
    # Create Fourier-aware VAE dataset
    fourier_vae_test_dataset = vae_test_dataset.map(add_coordinates_to_vae_data)
    
    print("✅ Fourier-aware VAE test dataset created")
    
    # Test with a small batch
    test_batch = next(iter(fourier_vae_test_dataset.take(1)))
    
    print(f"Fourier VAE batch structure:")
    inputs, target = test_batch
    if isinstance(inputs, tuple):
        img_input, coord_input = inputs
        print(f"  Image input shape: {img_input.shape}")
        print(f"  Coordinate input shape: {coord_input.shape}")
        print(f"  Target shape: {target.shape}")
        
        # Test VAE prediction with proper format
        try:
            # Take just 4 samples for testing
            test_img = img_input[:4]
            test_coord = coord_input[:4]
            test_target = target[:4]
            
            print(f"\\nTesting VAE with proper Fourier input format...")
            reconstructed = trainer.vae_model([test_img, test_coord])
            
            print(f"✅ Fourier VAE prediction successful!")
            print(f"   Input shape: {test_img.shape}")
            print(f"   Coordinate shape: {test_coord.shape}")
            print(f"   Reconstruction shape: {reconstructed.shape}")
            
            # Calculate metrics
            mse = tf.reduce_mean(tf.square(test_target - reconstructed)).numpy()
            mae = tf.reduce_mean(tf.abs(test_target - reconstructed)).numpy()
            
            print(f"   Reconstruction MSE: {mse:.6f}")
            print(f"   Reconstruction MAE: {mae:.6f}")
            
            # Visualization
            plt.figure(figsize=(16, 8))
            
            # Show first 4 samples: original and reconstructed
            for i in range(4):
                # Original
                plt.subplot(2, 4, i + 1)
                plt.imshow(test_img[i, :, :, 0], cmap='RdBu_r', origin='lower')
                plt.title(f'Original {i+1}')
                plt.colorbar()
                
                # Reconstructed
                plt.subplot(2, 4, i + 5)
                plt.imshow(reconstructed[i, :, :, 0], cmap='RdBu_r', origin='lower')
                plt.title(f'Fourier Reconstructed {i+1}')
                plt.colorbar()
            
            plt.tight_layout()
            plt.suptitle('VAE with Fourier Features: Original vs Reconstructed', y=1.02)
            plt.show()
            
            # Error visualization
            plt.figure(figsize=(12, 4))
            for i in range(4):
                plt.subplot(1, 4, i + 1)
                error = tf.abs(test_img[i, :, :, 0] - reconstructed[i, :, :, 0]).numpy()
                plt.imshow(error, cmap='hot', origin='lower')
                plt.title(f'Error {i+1}')
                plt.colorbar()
            
            plt.tight_layout()
            plt.suptitle('VAE Fourier Reconstruction Errors', y=1.02)
            plt.show()
            
            print("\\n🎉 VAE with Fourier features validation completed successfully!")
            
        except Exception as e:
            print(f"❌ Fourier VAE prediction failed: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("❌ Unexpected data format")
        
else:
    print("Skipping Fourier VAE test (VAE not trained or not using Fourier features)")

# VAE Validation and Visualization
print("=== VAE Model Validation ===")

if train_vae_model and trainer.vae_model is not None:
    print(f"🎯 Validating VAE model (Fourier: {config['use_fourier']})")
    
    try:
        # Get a test batch from the existing VAE test dataset
        test_batch = next(iter(vae_test_dataset.take(1)))
        inputs, targets = test_batch
        
        if config['use_fourier']:
            # For Fourier VAE, inputs should be a tuple (field, coordinates)
            if isinstance(inputs, tuple) and len(inputs) == 2:
                field_input, coord_input = inputs
                
                print(f"✅ Fourier VAE dataset structure correct:")
                print(f"   Field input shape: {field_input.shape}")
                print(f"   Coordinate input shape: {coord_input.shape}")
                print(f"   Target shape: {targets.shape}")
                
                # Test with first 4 samples
                test_field = field_input[:4]
                test_coord = coord_input[:4]
                test_targets = targets[:4]
                
                # VAE prediction with coordinates
                reconstructed = trainer.vae_model([test_field, test_coord])
                
                print(f"✅ Fourier VAE prediction successful!")
                print(f"   Reconstruction shape: {reconstructed.shape}")
                
                # Calculate metrics
                mse = tf.reduce_mean(tf.square(test_targets - reconstructed)).numpy()
                mae = tf.reduce_mean(tf.abs(test_targets - reconstructed)).numpy()
                
                print(f"   Reconstruction MSE: {mse:.6f}")
                print(f"   Reconstruction MAE: {mae:.6f}")
                
                # Visualization
                plt.figure(figsize=(16, 8))
                
                # Show original vs reconstructed
                for i in range(4):
                    # Original
                    plt.subplot(2, 4, i + 1)
                    plt.imshow(test_field[i, :, :, 0], cmap='RdBu_r', origin='lower')
                    plt.title(f'Original {i+1}')
                    plt.colorbar()
                    
                    # Reconstructed
                    plt.subplot(2, 4, i + 5)
                    plt.imshow(reconstructed[i, :, :, 0], cmap='RdBu_r', origin='lower')
                    plt.title(f'VAE Fourier Recon {i+1}')
                    plt.colorbar()
                
                plt.tight_layout()
                plt.suptitle('VAE with Fourier Features: Original vs Reconstructed', y=1.02)
                plt.show()
                
                # Error visualization
                plt.figure(figsize=(12, 4))
                for i in range(4):
                    plt.subplot(1, 4, i + 1)
                    error = tf.abs(test_field[i, :, :, 0] - reconstructed[i, :, :, 0]).numpy()
                    plt.imshow(error, cmap='hot', origin='lower')
                    plt.title(f'Error {i+1}')
                    plt.colorbar()
                
                plt.tight_layout()
                plt.suptitle('VAE Fourier Reconstruction Errors', y=1.02)
                plt.show()
                
                print("🎉 FOURIER VAE VALIDATION SUCCESSFUL!")
                print("   The VAE has been properly retrained with coordinate-aware data!")
                print("   Fourier features are working correctly!")
                
            else:
                print(f"❌ Unexpected Fourier VAE input format: {type(inputs)}")
                print("   Expected tuple (field, coordinates)")
        else:
            # For standard VAE, inputs should be just the field data
            print(f"✅ Standard VAE dataset structure:")
            print(f"   Input shape: {inputs.shape}")
            print(f"   Target shape: {targets.shape}")
            
            # Test with first 4 samples
            test_inputs = inputs[:4]
            test_targets = targets[:4]
            
            # VAE prediction without coordinates
            reconstructed = trainer.vae_model(test_inputs)
            
            print(f"✅ Standard VAE prediction successful!")
            print(f"   Reconstruction shape: {reconstructed.shape}")
            
            # Calculate metrics
            mse = tf.reduce_mean(tf.square(test_targets - reconstructed)).numpy()
            mae = tf.reduce_mean(tf.abs(test_targets - reconstructed)).numpy()
            
            print(f"   Reconstruction MSE: {mse:.6f}")
            print(f"   Reconstruction MAE: {mae:.6f}")
            
            print("✅ STANDARD VAE VALIDATION SUCCESSFUL!")
            
    except Exception as e:
        print(f"❌ VAE validation failed: {e}")
        import traceback
        traceback.print_exc()
        
else:
    print("⚠️ Skipping VAE validation (VAE training disabled or model not available)")

# Test the retrained VAE with coordinate inputs
print("=== Testing Retrained VAE with Coordinate Inputs ===")

if trainer.vae_model is not None and config['use_fourier']:
    print("🌊 Testing Fourier-aware VAE...")
    
    try:
        # Get a test batch from the VAE dataset
        test_batch = next(iter(vae_test_dataset.take(1)))
        inputs, targets = test_batch
        
        if isinstance(inputs, tuple) and len(inputs) == 2:
            field_input, coord_input = inputs
            
            print(f"✅ VAE dataset structure is correct:")
            print(f"   Field input shape: {field_input.shape}")
            print(f"   Coordinate input shape: {coord_input.shape}")
            print(f"   Target shape: {targets.shape}")
            
            # Test with just 1 sample for quick verification
            test_field = field_input[:1]
            test_coord = coord_input[:1]
            test_target = targets[:1]
            
            # Test VAE prediction
            print(f"\n🔍 Testing VAE prediction...")
            reconstructed = trainer.vae_model([test_field, test_coord])
            
            print(f"✅ SUCCESS! VAE now works with coordinate input!")
            print(f"   Input field shape: {test_field.shape}")
            print(f"   Input coord shape: {test_coord.shape}")
            print(f"   Reconstruction shape: {reconstructed.shape}")
            
            # Calculate error
            mse = tf.reduce_mean(tf.square(test_target - reconstructed)).numpy()
            print(f"   Reconstruction MSE: {mse:.6f}")
            
            # Visualization of the first sample
            plt.figure(figsize=(12, 4))
            
            # Original
            plt.subplot(1, 3, 1)
            plt.imshow(test_field[0, :, :, 0], cmap='RdBu_r', origin='lower')
            plt.title('Original Field')
            plt.colorbar()
            
            # Reconstructed
            plt.subplot(1, 3, 2)
            plt.imshow(reconstructed[0, :, :, 0], cmap='RdBu_r', origin='lower')
            plt.title('VAE Reconstruction (with Fourier)')
            plt.colorbar()
            
            # Error
            plt.subplot(1, 3, 3)
            error = tf.abs(test_field[0, :, :, 0] - reconstructed[0, :, :, 0]).numpy()
            plt.imshow(error, cmap='hot', origin='lower')
            plt.title('Reconstruction Error')
            plt.colorbar()
            
            plt.tight_layout()
            plt.suptitle('VAE Successfully Retrained with Fourier Features!', y=1.05)
            plt.show()
            
            print(f"\n🎉 MISSION ACCOMPLISHED!")
            print(f"   The VAE has been successfully retrained with coordinate-aware data!")
            print(f"   Fourier features are now working properly!")
            print(f"   The model expects 2-channel coordinate input as designed!")
            
        else:
            print(f"❌ Unexpected VAE dataset format: {type(inputs)}")
            
    except Exception as e:
        print(f"❌ VAE test failed: {e}")
        import traceback
        traceback.print_exc()

else:
    print("❌ VAE model not available or not using Fourier features")

In [None]:
# CLEAN VAE Validation (No Cached Functions)
print("=== CLEAN VAE Validation Test ===")

if 'trainer' in locals() and trainer.vae_model is not None:
    print(f"🎯 Testing VAE model (Fourier: {config['use_fourier']})")
    
    try:
        # Clear any old cached functions by getting fresh data
        test_data_iter = iter(vae_test_dataset.take(1))
        test_batch = next(test_data_iter)
        inputs, targets = test_batch
        
        if config['use_fourier']:
            # Expected format: ((field, coordinates), targets)
            if isinstance(inputs, tuple) and len(inputs) == 2:
                field_input, coord_input = inputs
                
                print(f"✅ VAE input structure correct:")
                print(f"   Field: {field_input.shape}")
                print(f"   Coord: {coord_input.shape}")
                print(f"   Target: {targets.shape}")
                
                # Test with just 1 sample
                test_field = field_input[:1]
                test_coord = coord_input[:1]
                test_target = targets[:1]
                
                # VAE prediction
                reconstructed = trainer.vae_model([test_field, test_coord])
                
                print(f"✅ VAE prediction successful!")
                print(f"   Reconstruction shape: {reconstructed.shape}")
                
                # Calculate MSE
                mse = tf.reduce_mean(tf.square(test_target - reconstructed)).numpy()
                print(f"   MSE: {mse:.6f}")
                
                print("🎉 SUCCESS: VAE with Fourier features is working correctly!")
                
            else:
                print(f"❌ Unexpected input format: {type(inputs)}")
                
        else:
            # Standard VAE test
            print(f"✅ Standard VAE input: {inputs.shape}")
            reconstructed = trainer.vae_model(inputs[:1])
            print(f"✅ Standard VAE working!")
            
    except Exception as e:
        print(f"❌ VAE test failed: {e}")
        
else:
    print("❌ No trainer or VAE model available")

In [None]:
# Investigate how the VAE model was actually built during training
print("=== VAE Model Architecture Investigation ===")

if trainer.vae_model is not None:
    print(f"VAE model use_fourier: {trainer.vae_model.use_fourier}")
    
    # Check the encoder structure
    encoder = trainer.vae_model.encoder
    print(f"\\nEncoder type: {type(encoder)}")
    print(f"Encoder use_fourier: {encoder.use_fourier}")
    
    # Check if encoder has fourier layer
    if hasattr(encoder, 'fourier_layer') and encoder.fourier_layer:
        fourier_layer = encoder.fourier_layer
        print(f"\\nFourier layer found: {fourier_layer}")
        print(f"  Type: {type(fourier_layer)}")
        if hasattr(fourier_layer, 'built') and fourier_layer.built:
            print(f"  Layer is built: {fourier_layer.built}")
            # Check the projection kernel
            if hasattr(fourier_layer, 'proj_kernel'):
                proj_kernel = fourier_layer.proj_kernel
                print(f"  Projection kernel: {proj_kernel}")
                if hasattr(proj_kernel, 'kernel') and proj_kernel.kernel is not None:
                    kernel_shape = proj_kernel.kernel.shape
                    print(f"  Conv2D kernel shape: {kernel_shape}")
                    print(f"  Expected input channels: {kernel_shape[2]}")
                    print(f"  Output channels: {kernel_shape[3]}")
    else:
        print("\\nNo Fourier layer found in encoder")
    
    # Test what happens if we try to call the model with a simple input
    print(f"\\n=== Testing with Simple Field Input ===")
    try:
        # Create a simple test input (just the field)
        test_field = tf.zeros((1, 128, 256, 1))
        
        # Try calling just the encoder with the field (no coordinates)
        print("Testing encoder with field input only...")
        # encoder_out = encoder(test_field)
        # print(f"Encoder output successful with field only!")
        
        # Try calling the full VAE model with just field input
        print("Testing full VAE with field input only...")
        vae_out = trainer.vae_model(test_field)
        print(f"✅ VAE works with field input only! Output shape: {vae_out.shape}")
        print("This suggests the model was trained without proper Fourier coordinates!")
        
    except Exception as e:
        print(f"❌ Simple field input failed: {e}")
        
    # Check how the training dataset actually looked
    print(f"\\n=== VAE Training Dataset Analysis ===")
    sample_train_batch = next(iter(vae_train_dataset.take(1)))
    print(f"Training batch structure: {type(sample_train_batch)}")
    if isinstance(sample_train_batch, tuple):
        print(f"  Input shape: {sample_train_batch[0].shape}")
        print(f"  Target shape: {sample_train_batch[1].shape}")
        print("  The training data contains only field data, no coordinates!")
        print("  This explains why Fourier features don't work properly.")
        
else:
    print("No VAE model available")

# Verify that the VAE model architecture is now correct for Fourier features
print("=== VAE Model Architecture Verification ===")

if trainer.vae_model is not None:
    print(f"VAE model use_fourier: {trainer.vae_model.use_fourier}")
    
    if trainer.vae_model.use_fourier:
        print("\n🌊 Analyzing Fourier-aware VAE architecture...")
        
        # Check the encoder structure
        encoder = trainer.vae_model.encoder
        print(f"Encoder type: {type(encoder)}")
        print(f"Encoder use_fourier: {encoder.use_fourier}")
        
        # Check if encoder has fourier layer
        if hasattr(encoder, 'fourier_layer') and encoder.fourier_layer:
            fourier_layer = encoder.fourier_layer
            print(f"\n✅ Fourier layer found: {fourier_layer}")
            print(f"  Type: {type(fourier_layer)}")
            
            if hasattr(fourier_layer, 'built') and fourier_layer.built:
                print(f"  Layer is built: {fourier_layer.built}")
                
                # Check the projection kernel
                if hasattr(fourier_layer, 'proj_kernel'):
                    proj_kernel = fourier_layer.proj_kernel
                    print(f"  Projection kernel: {proj_kernel}")
                    if hasattr(proj_kernel, 'kernel') and proj_kernel.kernel is not None:
                        kernel_shape = proj_kernel.kernel.shape
                        print(f"  Conv2D kernel shape: {kernel_shape}")
                        print(f"  Expected input channels: {kernel_shape[2]}")
                        print(f"  Output channels: {kernel_shape[3]}")
                        
                        # Verify proper 2-channel input for coordinates
                        if kernel_shape[2] == 2:
                            print(f"  ✅ Fourier layer correctly expects 2-channel coordinate input!")
                        else:
                            print(f"  ⚠️  Fourier layer expects {kernel_shape[2]} channels, should be 2")
            else:
                print(f"  ⚠️  Fourier layer not yet built")
        else:
            print("\n❌ No Fourier layer found in encoder")
        
        # Test the model with coordinate input to build the layers
        print(f"\n🔍 Testing model to ensure proper layer construction...")
        try:
            # Create test inputs with the right format
            test_field = tf.zeros((1, config['input_shape'][0], config['input_shape'][1], 1))
            test_coord = tf.zeros((1, config['input_shape'][0], config['input_shape'][1], 2))
            
            print(f"  Test field shape: {test_field.shape}")
            print(f"  Test coordinate shape: {test_coord.shape}")
            
            # This should build the layers properly
            output = trainer.vae_model([test_field, test_coord])
            print(f"  ✅ Model successfully processes coordinate input! Output shape: {output.shape}")
            
            # Now check the Fourier layer again
            if hasattr(encoder, 'fourier_layer') and encoder.fourier_layer:
                fourier_layer = encoder.fourier_layer
                if hasattr(fourier_layer, 'proj_kernel') and fourier_layer.proj_kernel.kernel is not None:
                    kernel_shape = fourier_layer.proj_kernel.kernel.shape
                    print(f"  Final Fourier kernel shape: {kernel_shape}")
                    if kernel_shape[2] == 2:
                        print(f"  🎉 SUCCESS: Fourier layer now properly expects 2-channel coordinates!")
                    else:
                        print(f"  ❌ ISSUE: Fourier layer still expects {kernel_shape[2]} channels")
            
        except Exception as e:
            print(f"  ❌ Model test failed: {e}")
            
    else:
        print("\n🔄 Standard VAE (no Fourier features)")
        
        # Test standard VAE
        try:
            test_field = tf.zeros((1, config['input_shape'][0], config['input_shape'][1], 1))
            output = trainer.vae_model(test_field)
            print(f"✅ Standard VAE works correctly! Output shape: {output.shape}")
        except Exception as e:
            print(f"❌ Standard VAE test failed: {e}")
    
    # Show model summary for reference
    print(f"\n📋 VAE Model Summary:")
    try:
        trainer.vae_model.summary()
    except:
        print("Could not display model summary")
        
else:
    print("No VAE model available for analysis")

## FLRNet Validation and Visualization

After training the FLRNet model, we can evaluate its performance by:
1. Testing sensor-to-field reconstruction accuracy
2. Comparing predicted vs ground truth flow fields
3. Analyzing reconstruction quality metrics
4. Visualizing sensor positions overlaid on predictions

This section validates the main FLRNet model performance.

In [None]:
# FLRNet Validation and Visualization
if train_flrnet_model and trainer.flr_model is not None:
    print("=== FLRNet Model Validation ===")
    
    # Get test samples for evaluation
    test_batch = next(iter(flrnet_test_dataset.batch(8)))
    sensor_readings = test_batch[0]  # Input sensor readings
    ground_truth = test_batch[1]     # Target flow fields
    
    # Get FLRNet predictions
    predictions = trainer.flr_model.predict(sensor_readings)
    
    # Calculate reconstruction metrics
    mse = np.mean((ground_truth.numpy() - predictions) ** 2)
    mae = np.mean(np.abs(ground_truth.numpy() - predictions))
    
    # Calculate relative error
    relative_error = np.mean(np.abs(ground_truth.numpy() - predictions) / (np.abs(ground_truth.numpy()) + 1e-8))
    
    print(f"FLRNet Reconstruction MSE: {mse:.6f}")
    print(f"FLRNet Reconstruction MAE: {mae:.6f}")
    print(f"FLRNet Relative Error: {relative_error:.6f}")
    
    # Visualize predictions vs ground truth
    plt.figure(figsize=(20, 12))
    
    # Show first 4 samples: sensor readings, ground truth, predictions, and errors
    for i in range(4):
        # Sensor readings visualization
        plt.subplot(4, 4, i + 1)
        field_viz = np.zeros((h, w))
        # Place sensor readings on the field for visualization
        for j, (x_pos, y_pos) in enumerate(sensor_positions):
            if j < len(sensor_readings[i]):
                field_viz[int(y_pos), int(x_pos)] = sensor_readings[i, j]
        plt.imshow(field_viz, cmap='RdBu_r', origin='lower')
        plt.title(f'Sensor Readings {i+1}')
        plt.colorbar()
        
        # Ground truth
        plt.subplot(4, 4, i + 5)
        plt.imshow(ground_truth[i, :, :, 0], cmap='RdBu_r', origin='lower')
        plt.title(f'Ground Truth {i+1}')
        plt.colorbar()
        
        # Predictions
        plt.subplot(4, 4, i + 9)
        plt.imshow(predictions[i, :, :, 0], cmap='RdBu_r', origin='lower')
        plt.title(f'FLRNet Prediction {i+1}')
        plt.colorbar()
        
        # Error
        plt.subplot(4, 4, i + 13)
        error = np.abs(ground_truth[i, :, :, 0] - predictions[i, :, :, 0])
        plt.imshow(error, cmap='hot', origin='lower')
        plt.title(f'Prediction Error {i+1}')
        plt.colorbar()
    
    plt.tight_layout()
    plt.suptitle('FLRNet: Sensor Readings → Ground Truth vs Predictions', y=0.98)
    plt.show()
    
    # Additional visualization: Overlay sensor positions on predictions
    plt.figure(figsize=(16, 8))
    
    for i in range(2):
        # Prediction with sensor overlay
        plt.subplot(1, 2, i + 1)
        plt.imshow(predictions[i, :, :, 0], cmap='RdBu_r', origin='lower')
        
        # Overlay sensor positions
        sensor_x = sensor_positions[:, 0]
        sensor_y = sensor_positions[:, 1]
        plt.scatter(sensor_x, sensor_y, c='black', s=100, marker='o', edgecolors='white', linewidth=2)
        
        # Add sensor value annotations
        for j, (x, y) in enumerate(sensor_positions):
            if j < len(sensor_readings[i]):
                plt.annotate(f'{sensor_readings[i, j]:.2f}', 
                           (x, y), xytext=(5, 5), textcoords='offset points',
                           fontsize=8, color='white', weight='bold')
        
        plt.title(f'FLRNet Prediction {i+1} with Sensor Positions')
        plt.colorbar()
    
    plt.tight_layout()
    plt.suptitle('FLRNet Predictions with Sensor Position Overlay', y=1.02)
    plt.show()
    
    print("FLRNet validation completed successfully!")
    
else:
    print("Skipping FLRNet validation (FLRNet training was disabled or model not available)")

## Training Summary and Model Saving

Complete the training workflow by saving trained models and summarizing results.

In [None]:
# Training Summary and Model Saving
print("=== Training Summary ===")

# Print configuration summary
print(f"Configuration: {config_name}")
print(f"Layout type: {layout_type}")
print(f"Number of sensors: {n_sensors}")
print(f"VAE training: {'Enabled' if train_vae_model else 'Disabled'}")
print(f"FLRNet training: {'Enabled' if train_flrnet_model else 'Disabled'}")

# Save trained models
if trainer.vae_model is not None:
    vae_filename = f"vae_model_{layout_type}_{n_sensors}_sensors.h5"
    trainer.vae_model.save(vae_filename)
    print(f"VAE model saved as: {vae_filename}")

if trainer.flr_model is not None:
    flrnet_filename = f"flrnet_model_{layout_type}_{n_sensors}_sensors.h5"
    trainer.flr_model.save(flrnet_filename)
    print(f"FLRNet model saved as: {flrnet_filename}")

print("\n=== Training Complete ===")
print("All requested models have been trained and validated.")
print("Models have been saved for future use.")

# 🔄 FLRNet Training Continuation
# Toggle for enabling FLRNet training continuation (set to True to continue from checkpoint)
continue_flrnet_training = False  # Change to True to enable

# Toggle for perceptual loss during continuation (can be different from original training)
use_perceptual_loss_flrnet_continuation = True  # Change to False to disable perceptual loss

if continue_flrnet_training:
    if trainer.flr_model is not None:
        print("🔄 Continuing FLRNet training from loaded checkpoint...")
        print("=" * 60)
        
        # Configuration for continued training
        additional_epochs = 50
        learning_rate = 1e-6  # Lower learning rate for fine-tuning
        
        print(f"📋 Continuation Configuration:")
        print(f"   - Additional epochs: {additional_epochs}")
        print(f"   - Learning rate: {learning_rate} (reduced for fine-tuning)")
        print(f"   - Perceptual loss: {'✅ ENABLED' if use_perceptual_loss_flrnet_continuation else '❌ DISABLED'}")
        print(f"   - Model weights: ✅ PRESERVED from checkpoint")
        
        # Continue training using the proper method that preserves weights
        continued_flrnet = trainer.continue_flrnet_training(
            train_dataset=flrnet_train_dataset,
            val_dataset=flrnet_test_dataset,
            epochs=additional_epochs,
            learning_rate=learning_rate,
            patience=config['patience'],
            reduce_lr_patience=config['reduce_lr_patience'],
            use_perceptual_loss=use_perceptual_loss_flrnet_continuation  # New option for perceptual loss
        )
        
        print("✅ FLRNet training continuation completed!")
        
    elif trainer.vae_model is not None:
        print("❌ Cannot continue FLRNet training: No FLRNet model loaded from checkpoint")
        print("   First load an FLRNet model using the checkpoint loading cell above")
    else:
        print("❌ Cannot continue FLRNet training: No models loaded from checkpoint")
        print("   First load models using the checkpoint loading cell above")
else:
    print("ℹ️  FLRNet training continuation is disabled")
    print("   Set continue_flrnet_training = True to enable")
    if not continue_flrnet_training:
        print(f"   Current perceptual loss setting: {'✅ ENABLED' if use_perceptual_loss_flrnet_continuation else '❌ DISABLED'}")
        print("   (can be changed via use_perceptual_loss_flrnet_continuation variable)")

In [None]:
# 1. Load VAE Model
checkpoint_dir = "E:/Research/Physics-informed-machine-learning/flow_field_recon_parc/checkpoints/fourierTrue_percepTrue_edge_8"

checkpoint_path = Path(checkpoint_dir)
vae_checkpoint_path = checkpoint_path / f"checkpoint_{config['model_name']}_vae_best"
print(f"\n📁 Looking for VAE checkpoint: {vae_checkpoint_path}")

if vae_checkpoint_path.exists():
    print("✅ VAE checkpoint found, creating model...")
    
    # Create VAE model architecture
    vae_model = models_improved.FLRVAE(
        input_shape=config['input_shape'],
        latent_dims=config['latent_dims'],
        n_base_features=config['n_base_features'],
        use_fourier=config['use_fourier'],
        use_perceptual_loss=config['use_perceptual_loss']
    )
    
    # Build the model by calling it once with dummy input
    dummy_input = tf.zeros((1,) + config['input_shape'])
    if config['use_fourier']:
        dummy_coord = tf.zeros((1, config['input_shape'][0], config['input_shape'][1], 2))
        _ = vae_model([dummy_input, dummy_coord])
        print("🌊 VAE model built for Fourier features")
    else:
        _ = vae_model(dummy_input)
        print("🔄 VAE model built for standard features")
    
    # Load weights
    vae_model.load_weights(str(vae_checkpoint_path))
    print(f"✅ VAE weights loaded successfully!")
    
else:
    print(f"❌ VAE checkpoint not found at: {vae_checkpoint_path}")
    vae_model = None
print("🎨 === VAE Test Visualization ===")
print(f"Model type: {'Fourier-aware' if trainer.vae_model.use_fourier else 'Standard'} VAE")

## VAE Test Visualization

Comprehensive visualization of the VAE model performance with detailed comparisons between original and reconstructed flow fields. This section tests the VAE with both coordinate-aware and standard inputs to validate proper Fourier feature integration.

In [None]:
# Enhanced VAE Test Visualization (replace the existing cell)

# Comprehensive VAE Test Visualization
import matplotlib.pyplot as plt
import numpy as np

if trainer.vae_model is not None:
    print("🎨 === VAE Test Visualization ===")
    print(f"Model type: {'Fourier-aware' if trainer.vae_model.use_fourier else 'Standard'} VAE")
    
    # Get test data batch directly from the properly formatted dataset
    test_batch = next(iter(vae_test_dataset.take(1)))
    print("Using coordinate-aware VAE test dataset")
    
    # Extract inputs and targets
    if isinstance(test_batch, tuple) and len(test_batch) == 2:
        test_inputs, test_targets = test_batch
        
        # Handle coordinate-aware inputs for Fourier VAE
        if isinstance(test_inputs, (list, tuple)) and len(test_inputs) == 2:
            test_fields, test_coords = test_inputs
            print(f"✅ Fourier VAE inputs detected:")
            print(f"   Field shape: {test_fields.shape}")
            print(f"   Coordinate shape: {test_coords.shape}")
            input_for_prediction = [test_fields, test_coords]
        else:
            test_fields = test_inputs
            print(f"✅ Standard VAE input: {test_fields.shape}")
            input_for_prediction = test_fields
        
        print(f"   Target shape: {test_targets.shape}")
        
        # Generate VAE reconstructions
        print("\n🔮 Generating VAE reconstructions...")
        try:
            # Limit to manageable batch size for visualization
            max_samples = min(4, test_fields.shape[0])
            
            if isinstance(input_for_prediction, list):
                limited_input = [test_fields[:max_samples], test_coords[:max_samples]]
            else:
                limited_input = test_fields[:max_samples]
            limited_targets = test_targets[:max_samples]
            
            reconstructions = trainer.vae_model.predict(limited_input, verbose=0)
            print(f"✅ Reconstruction successful! Shape: {reconstructions.shape}")
            
            # Calculate reconstruction metrics
            mse = np.mean((limited_targets.numpy() - reconstructions) ** 2)
            mae = np.mean(np.abs(limited_targets.numpy() - reconstructions))
            max_error = np.max(np.abs(limited_targets.numpy() - reconstructions))
            
            print(f"\n📊 VAE Reconstruction Metrics:")
            print(f"   MSE: {mse:.6f} ({'Excellent' if mse < 0.01 else 'Good' if mse < 0.05 else 'Needs improvement'})")
            print(f"   MAE: {mae:.6f}")
            print(f"   Max Error: {max_error:.6f}")
            
            # Create comprehensive visualization
            n_samples = max_samples
            fig, axes = plt.subplots(3, n_samples, figsize=(5*n_samples, 12))
            
            if n_samples == 1:
                axes = axes.reshape(-1, 1)
            
            for i in range(n_samples):
                # Original field
                im1 = axes[0, i].imshow(limited_targets[i, :, :, 0], cmap='RdBu_r', origin='lower')
                axes[0, i].set_title(f'Original Field {i+1}', fontweight='bold')
                axes[0, i].set_xlabel('X Position')
                axes[0, i].set_ylabel('Y Position')
                plt.colorbar(im1, ax=axes[0, i], shrink=0.8)
                
                # Reconstructed field
                im2 = axes[1, i].imshow(reconstructions[i, :, :, 0], cmap='RdBu_r', origin='lower')
                axes[1, i].set_title(f'VAE Reconstruction {i+1}', fontweight='bold')
                axes[1, i].set_xlabel('X Position')
                axes[1, i].set_ylabel('Y Position')
                plt.colorbar(im2, ax=axes[1, i], shrink=0.8)
                
                # Error map
                error = np.abs(limited_targets[i, :, :, 0] - reconstructions[i, :, :, 0])
                im3 = axes[2, i].imshow(error, cmap='hot', origin='lower')
                axes[2, i].set_title(f'Reconstruction Error {i+1}', fontweight='bold')
                axes[2, i].set_xlabel('X Position')
                axes[2, i].set_ylabel('Y Position')
                plt.colorbar(im3, ax=axes[2, i], shrink=0.8)
                
                # Add error statistics as text
                sample_mse = np.mean(error**2)
                sample_max = np.max(error)
                axes[2, i].text(0.02, 0.98, f'MSE: {sample_mse:.4f}\nMax: {sample_max:.4f}', 
                               transform=axes[2, i].transAxes, verticalalignment='top',
                               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                               fontsize=9, fontweight='bold')
            
            plt.tight_layout()
            plt.suptitle(f'VAE Reconstruction Results - {config_name}', y=1.02, fontsize=16, fontweight='bold')
            plt.show()
            
            # Statistical analysis
            print(f"\n📈 Statistical Comparison:")
            orig_stats = {
                'mean': np.mean(limited_targets.numpy()),
                'std': np.std(limited_targets.numpy()),
                'min': np.min(limited_targets.numpy()),
                'max': np.max(limited_targets.numpy())
            }
            
            recon_stats = {
                'mean': np.mean(reconstructions),
                'std': np.std(reconstructions),
                'min': np.min(reconstructions),
                'max': np.max(reconstructions)
            }
            
            print(f"   Original  - Mean: {orig_stats['mean']:.4f}, Std: {orig_stats['std']:.4f}, Range: [{orig_stats['min']:.4f}, {orig_stats['max']:.4f}]")
            print(f"   Reconstructed - Mean: {recon_stats['mean']:.4f}, Std: {recon_stats['std']:.4f}, Range: [{recon_stats['min']:.4f}, {recon_stats['max']:.4f}]")
            
            # Distribution comparison
            plt.figure(figsize=(15, 5))
            
            # Value distributions
            plt.subplot(1, 3, 1)
            plt.hist(limited_targets.numpy().flatten(), bins=50, alpha=0.7, label='Original', density=True, color='blue')
            plt.hist(reconstructions.flatten(), bins=50, alpha=0.7, label='Reconstructed', density=True, color='red')
            plt.xlabel('Field Value')
            plt.ylabel('Density')
            plt.title('Value Distribution Comparison', fontweight='bold')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Error distribution
            plt.subplot(1, 3, 2)
            errors = np.abs(limited_targets.numpy() - reconstructions).flatten()
            plt.hist(errors, bins=50, alpha=0.7, color='red', edgecolor='black')
            plt.xlabel('Absolute Error')
            plt.ylabel('Frequency')
            plt.title('Reconstruction Error Distribution', fontweight='bold')
            plt.axvline(np.mean(errors), color='darkred', linestyle='--', linewidth=2, label=f'Mean: {np.mean(errors):.4f}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Relative error distribution
            plt.subplot(1, 3, 3)
            relative_errors = errors / (np.abs(limited_targets.numpy().flatten()) + 1e-8)
            plt.hist(relative_errors, bins=50, alpha=0.7, color='orange', edgecolor='black')
            plt.xlabel('Relative Error')
            plt.ylabel('Frequency')
            plt.title('Relative Error Distribution', fontweight='bold')
            plt.axvline(np.mean(relative_errors), color='darkorange', linestyle='--', linewidth=2, label=f'Mean: {np.mean(relative_errors):.4f}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.suptitle('VAE Performance Analysis', y=1.02, fontsize=14, fontweight='bold')
            plt.show()
            
            # Fourier feature analysis (if applicable)
            if trainer.vae_model.use_fourier and isinstance(test_inputs, (list, tuple)):
                print(f"\n🌊 Fourier Feature Analysis:")
                print(f"   Coordinate grid statistics:")
                print(f"   X coordinates - Range: [{np.min(test_coords[0, :, :, 0]):.3f}, {np.max(test_coords[0, :, :, 0]):.3f}]")
                print(f"   Y coordinates - Range: [{np.min(test_coords[0, :, :, 1]):.3f}, {np.max(test_coords[0, :, :, 1]):.3f}]")
                
                # Show coordinate grids for first sample
                fig, axes = plt.subplots(1, 2, figsize=(12, 5))
                
                im1 = axes[0].imshow(test_coords[0, :, :, 0], cmap='viridis', origin='lower')
                axes[0].set_title('X Coordinate Grid', fontweight='bold')
                axes[0].set_xlabel('X Position')
                axes[0].set_ylabel('Y Position')
                plt.colorbar(im1, ax=axes[0])
                
                im2 = axes[1].imshow(test_coords[0, :, :, 1], cmap='plasma', origin='lower')
                axes[1].set_title('Y Coordinate Grid', fontweight='bold')
                axes[1].set_xlabel('X Position')
                axes[1].set_ylabel('Y Position')
                plt.colorbar(im2, ax=axes[1])
                
                plt.tight_layout()
                plt.suptitle('Fourier Feature Coordinate Grids', y=1.02, fontsize=14, fontweight='bold')
                plt.show()
            
            print(f"\n🎉 VAE visualization completed successfully!")
            print(f"   📊 Reconstruction Quality: {'🏆 Excellent' if mse < 0.01 else '✅ Good' if mse < 0.05 else '⚠️ Needs improvement'}")
            print(f"   🌊 Fourier Features: {'✅ Working correctly' if trainer.vae_model.use_fourier else '➖ Not used'}")
            
        except Exception as e:
            print(f"❌ VAE prediction failed: {e}")
            import traceback
            traceback.print_exc()
    
    else:
        print(f"❌ Unexpected test batch structure: {type(test_batch)}")

else:
    print("❌ No VAE model available for visualization")

## Load Trained Models from Checkpoint

Now let's load the trained models from saved checkpoints using the FLRTrainer's built-in methods. The trainer provides several loading options:

- **`load_vae_from_checkpoint()`**: Load only the VAE model
- **`load_flrnet_from_checkpoint()`**: Load only the FLRNet model  
- **`load_models_from_checkpoint()`**: Load both models at once

### Features:
- **Automatic checkpoint detection**: Finds the best available checkpoint (best → last → final_weights)
- **Robust path handling**: Uses correct TensorFlow checkpoint format (no file extensions)
- **Error handling**: Graceful fallback and clear error messages
- **Architecture consistency**: Ensures loaded model matches trainer configuration
- **Smart dependencies**: FLRNet loading automatically handles VAE dependencies

In [None]:
import os
import glob
from pathlib import Path
import tensorflow as tf

# Load the VAE model using the trainer's built-in method
print("🚀 Loading trained VAE model from checkpoint using FLRTrainer...")
print("=" * 60)

# Fix checkpoint path - use the directory, not the specific file
checkpoint_directory = Path(config['checkpoint_dir'])
print(f"📂 Using checkpoint directory: {checkpoint_directory}")

# List available checkpoints for debugging
if checkpoint_directory.exists():
    print(f"📋 Available checkpoint files:")
    for file in sorted(checkpoint_directory.iterdir()):
        print(f"   - {file.name}")
else:
    print(f"❌ Checkpoint directory not found: {checkpoint_directory}")

# Use the trainer's built-in checkpoint loading method
vae_model = trainer.load_vae_from_checkpoint(
    checkpoint_dir=checkpoint_directory,
    latent_dims=config['latent_dims'],
    n_base_features=config['n_base_features'],
    use_perceptual_loss=config['use_perceptual_loss'],
    verbose=True
)

if vae_model is not None:
    print("\n🎉 VAE model loaded successfully!")
    print("📈 Model is ready for inference and visualization")
    
    # Verify the model works with a test prediction
    try:
        # Get a small batch for testing
        test_batch = next(iter(vae_test_dataset.take(1)))
        test_inputs, test_targets = test_batch
        
        # Handle different input formats (Fourier vs Standard VAE)
        if isinstance(test_inputs, (list, tuple)) and len(test_inputs) == 2:
            # Fourier VAE expects [field, coordinates]
            test_field, test_coord = test_inputs
            test_prediction = vae_model([test_field[:1], test_coord[:1]])
            print(f"✅ Fourier VAE verification successful!")
            print(f"   - Field input shape: {test_field[:1].shape}")
            print(f"   - Coordinate input shape: {test_coord[:1].shape}")
        else:
            # Standard VAE expects just field data
            test_prediction = vae_model(test_inputs[:1])
            print(f"✅ Standard VAE verification successful!")
            print(f"   - Input shape: {test_inputs[:1].shape}")
        
        print(f"   - Output shape: {test_prediction.shape}")
        print(f"   - Output range: [{test_prediction.numpy().min():.4f}, {test_prediction.numpy().max():.4f}]")
        
    except Exception as e:
        print(f"⚠️  Model verification failed: {str(e)}")
        print("Model loaded but may have compatibility issues")
        import traceback
        traceback.print_exc()
        
else:
    print("\n❌ Failed to load VAE model")
    print("🔧 Troubleshooting steps:")
    print("   1. Check if training completed successfully")
    print("   2. Verify checkpoint files exist in the directory")
    print("   3. Ensure model configuration matches training setup")
    print("   4. Try running the training cell again if checkpoints are missing")

## 🔄 Training Continuation with Advanced Options

### What is Training Continuation?
- **Loading for Inference**: Loads model weights for evaluation/visualization only
- **Training Continuation**: Loads model weights AND continues training with preserved state
- **Key Benefit**: Preserves trained weights instead of resetting to random initialization

### 🎛️ Perceptual Loss Control
You can now **override the perceptual loss setting** during training continuation:

**For VAE Continuation:**
- `use_perceptual_loss_continuation = True`: Enable perceptual loss (better visual quality)
- `use_perceptual_loss_continuation = False`: Disable perceptual loss (faster training, MSE only)

**For FLRNet Continuation:**
- `use_perceptual_loss_flrnet_continuation = True`: Enable perceptual loss in VAE component
- `use_perceptual_loss_flrnet_continuation = False`: Disable perceptual loss in VAE component

### 🔄 Continuation vs New Training
| Method | Model Weights | Optimizer State | Use Case |
|--------|--------------|----------------|----------|
| `train_vae()` | ❌ Reset to random | ❌ New optimizer | Fresh training |
| `continue_vae_training()` | ✅ Preserved | ❌ New optimizer | Fine-tuning/Resume |

### ⚠️ Important Notes
- **Perceptual Loss Override**: When changing perceptual loss settings, the system automatically initializes/removes metric trackers
- **Metric Compatibility**: The system ensures metric trackers match the current perceptual loss setting
- **Safe Overrides**: You can safely change perceptual loss settings between continuation sessions

**Important**: Continuation methods preserve model weights but create fresh optimizers (standard practice for fine-tuning).

In [None]:
# 🔍 Pre-Continuation Weight Preservation Verification
print("🔍 === Model Weight Preservation Check ===")

if vae_model is not None:
    print("Testing if model weights are preserved during continuation setup...")
    
    # Get a test sample to establish baseline
    test_batch = next(iter(vae_test_dataset.take(1)))
    test_inputs, test_targets = test_batch
    
    if isinstance(test_inputs, tuple):
        test_field, test_coord = test_inputs
        test_input = [test_field[:1], test_coord[:1]]
        input_format = "Fourier (field + coordinates)"
    else:
        test_input = test_inputs[:1]
        input_format = "Standard (field only)"
    
    print(f"📊 Input format: {input_format}")
    
    # Test prediction BEFORE any training operations
    print("\n📊 BEFORE continuation setup:")
    prediction_before = vae_model.predict(test_input, verbose=0)
    mse_before = np.mean((test_targets[:1].numpy() - prediction_before) ** 2)
    prediction_range_before = [prediction_before.min(), prediction_before.max()]
    
    print(f"   MSE: {mse_before:.6f}")
    print(f"   Prediction range: [{prediction_range_before[0]:.4f}, {prediction_range_before[1]:.4f}]")
    
    # Store some layer weights for comparison
    try:
        layer_weights_before = []
        weight_layer_names = []
        for i, layer in enumerate(vae_model.layers):
            if hasattr(layer, 'get_weights') and layer.get_weights():
                weights = layer.get_weights()
                if len(weights) > 0 and weights[0].size > 0:
                    layer_weights_before.append(weights[0].copy())
                    weight_layer_names.append(f"Layer_{i}_{layer.name}")
                    if len(layer_weights_before) >= 3:  # Just store first 3 layers with weights
                        break
        
        print(f"   Stored weights from {len(layer_weights_before)} layers for comparison")
        
        # Check if MSE indicates model is working correctly
        if mse_before > 1.0:
            print("❌ WARNING: High MSE detected! Model may have been reset or not properly loaded!")
            print("   Expected: Low MSE if weights were preserved from checkpoint")
            print("   Actual: High MSE suggests random/reset weights")
            print(f"   Recommendation: Check checkpoint loading process")
        elif mse_before < 0.1:
            print("✅ EXCELLENT: Very low MSE suggests weights were properly loaded from checkpoint")
        else:
            print("✅ GOOD: Reasonable MSE suggests weights were loaded correctly")
        
        # Store baseline for later comparison
        baseline_mse = mse_before
        baseline_weights = layer_weights_before
        baseline_prediction_range = prediction_range_before
        
    except Exception as e:
        print(f"⚠️  Could not extract weights for comparison: {e}")
        baseline_mse = mse_before
        baseline_weights = None
        baseline_prediction_range = prediction_range_before

else:
    print("❌ No VAE model available for testing")
    baseline_mse = None
    baseline_weights = None
    baseline_prediction_range = None

In [None]:
# Continue VAE Training from Checkpoint
# Set this to True if you want to continue training from the loaded checkpoint
continue_vae_training = True  # Change to True to enable

# Toggle for perceptual loss during continuation (can be different from original training)
use_perceptual_loss_continuation = True  # Change to False to disable perceptual loss

if continue_vae_training and vae_model is not None:
    print("🔄 Continuing VAE training from loaded checkpoint...")
    print("=" * 60)
    
    # Configuration for continued training
    additional_epochs = 100
    learning_rate = 1e-4  # Lower learning rate for fine-tuning
    
    print(f"📋 Continuation Configuration:")
    print(f"   - Additional epochs: {additional_epochs}")
    print(f"   - Learning rate: {learning_rate} (reduced for fine-tuning)")
    print(f"   - Perceptual loss: {'✅ ENABLED' if use_perceptual_loss_continuation else '❌ DISABLED'}")
    print(f"   - Model weights: ✅ PRESERVED from checkpoint")
    
    # Option 1: Continue training using the trainer (recommended)
    # The trainer will handle optimizer state and callbacks properly
    print("\n🚀 Method 1: Using FLRTrainer (Recommended)")
    
    # Set the loaded model in the trainer
    trainer.vae_model = vae_model
    
    # Continue training with reduced learning rate
    continued_vae = trainer.continue_vae_training(
        train_dataset=vae_train_dataset,
        val_dataset=vae_test_dataset,
        epochs=additional_epochs,
        learning_rate=learning_rate,
        patience=config['patience'],
        reduce_lr_patience=config['reduce_lr_patience'],
        use_perceptual_loss=use_perceptual_loss_continuation  # New option for perceptual loss
    )
    
    print("✅ VAE training continuation completed!")
    
elif continue_vae_training and vae_model is None:
    print("❌ Cannot continue VAE training: No model loaded from checkpoint")
    print("   First load a model using the checkpoint loading cell above")
    
else:
    print("ℹ️  VAE training continuation is disabled")
    print("   Set continue_vae_training = True to enable")
    if not continue_vae_training:
        print(f"   Current perceptual loss setting: {'✅ ENABLED' if use_perceptual_loss_continuation else '❌ DISABLED'}")
        print("   (can be changed via use_perceptual_loss_continuation variable)")

In [None]:
# 🔍 Post-Continuation Weight Verification
print("🔍 === Post-Continuation Weight Verification ===")

if 'baseline_mse' in locals() and baseline_mse is not None and trainer.vae_model is not None:
    print("Verifying that model weights were properly preserved during training continuation...")
    
    # Get the same test sample used for baseline
    test_batch = next(iter(vae_test_dataset.take(1)))
    test_inputs, test_targets = test_batch
    
    if isinstance(test_inputs, tuple):
        test_field, test_coord = test_inputs
        test_input = [test_field[:1], test_coord[:1]]
    else:
        test_input = test_inputs[:1]
    
    # Test prediction AFTER continuation
    print("\n📊 AFTER continuation:")
    prediction_after = trainer.vae_model.predict(test_input, verbose=0)
    mse_after = np.mean((test_targets[:1].numpy() - prediction_after) ** 2)
    prediction_range_after = [prediction_after.min(), prediction_after.max()]
    
    print(f"   MSE: {mse_after:.6f}")
    print(f"   Prediction range: [{prediction_range_after[0]:.4f}, {prediction_range_after[1]:.4f}]")
    
    # Compare with baseline
    print(f"\n📈 Comparison with baseline:")
    print(f"   Baseline MSE: {baseline_mse:.6f}")
    print(f"   After MSE: {mse_after:.6f}")
    print(f"   MSE change: {mse_after - baseline_mse:.6f}")
    
    # Analyze results
    if mse_after < baseline_mse:
        improvement = ((baseline_mse - mse_after) / baseline_mse) * 100
        print(f"✅ EXCELLENT: Model improved by {improvement:.2f}%!")
        print("   Training continuation was successful and preserved weights correctly")
    elif abs(mse_after - baseline_mse) < 0.01:
        print("✅ GOOD: Model performance maintained (minimal change)")
        print("   Weights were preserved correctly during continuation")
    elif mse_after > baseline_mse * 2:
        print("❌ WARNING: Model performance degraded significantly!")
        print("   This may indicate:")
        print("   - Learning rate too high")
        print("   - Training instability")
        print("   - Possible weight corruption")
    else:
        print("ℹ️  Model performance changed moderately")
        print("   This is normal for continued training")
    
    # Check for weight consistency if we stored baseline weights
    if 'baseline_weights' in locals() and baseline_weights is not None:
        try:
            print(f"\n🔍 Direct weight comparison:")
            current_weights = []
            for i, layer in enumerate(trainer.vae_model.layers):
                if hasattr(layer, 'get_weights') and layer.get_weights():
                    weights = layer.get_weights()
                    if len(weights) > 0 and weights[0].size > 0:
                        current_weights.append(weights[0].copy())
                        if len(current_weights) >= len(baseline_weights):
                            break
            
            weights_changed = False
            for i, (baseline_w, current_w) in enumerate(zip(baseline_weights, current_weights)):
                if baseline_w.shape == current_w.shape:
                    weight_diff = np.mean(np.abs(baseline_w - current_w))
                    print(f"   Layer {i} weight change: {weight_diff:.6f}")
                    if weight_diff > 0.001:  # Weights should change during training
                        weights_changed = True
            
            if weights_changed:
                print("✅ CONFIRMED: Weights changed during training (expected)")
            else:
                print("⚠️  WARNING: Weights didn't change much - training may not be effective")
                
        except Exception as e:
            print(f"⚠️  Could not compare weights directly: {e}")
    
    print(f"\n🎯 FINAL ASSESSMENT:")
    if mse_after <= baseline_mse + 0.01:
        print("✅ Training continuation was SUCCESSFUL")
        print("   - Weights were preserved from checkpoint")
        print("   - Training improved or maintained model performance")
        print("   - Model is ready for inference or further training")
    else:
        print("⚠️  Training continuation had MIXED RESULTS")
        print("   - Weights were preserved, but performance may have degraded")
        print("   - Consider adjusting learning rate or training parameters")

else:
    print("❌ Cannot perform verification - baseline not available")
    print("Run the weight preservation check cell before training continuation")

In [None]:
# Continue FLRNet Training from Checkpoint
# Set this to True if you want to continue FLRNet training from checkpoint
continue_flrnet_training = False  # Change to True to enable

if continue_flrnet_training:
    print("🔄 Continuing FLRNet training from checkpoint...")
    print("=" * 60)
    
    # First, ensure we have a VAE model (required for FLRNet)
    if vae_model is None:
        print("🔄 Loading VAE model first (required for FLRNet)...")
        vae_model = trainer.load_vae_from_checkpoint(
            checkpoint_dir=checkpoint_directory,
            latent_dims=config['latent_dims'],
            n_base_features=config['n_base_features'],
            use_perceptual_loss=config['use_perceptual_loss'],
            verbose=False
        )
    
    if vae_model is not None:
        # Try to load existing FLRNet model
        print("🔄 Loading FLRNet model from checkpoint...")
        flrnet_model = trainer.load_flrnet_from_checkpoint(
            n_sensors=config['n_sensors'],
            checkpoint_dir=checkpoint_directory,
            pretrained_vae=vae_model,
            latent_dims=config['latent_dims'],
            n_base_features=config['n_base_features'],
            use_perceptual_loss=config['use_perceptual_loss'],
            verbose=True
        )
        
        if flrnet_model is not None:
            # Configure training parameters for continuation
            additional_epochs = 100  # How many more epochs to train
            learning_rate = 1e-6     # Very low learning rate for fine-tuning
            
            print(f"\n📋 FLRNet Training Configuration:")
            print(f"   - Additional epochs: {additional_epochs}")
            print(f"   - Learning rate: {learning_rate}")
            print(f"   - VAE model: ✅ Loaded")
            print(f"   - FLRNet model: ✅ Loaded")
            
            # Set the loaded model in the trainer
            trainer.flr_model = flrnet_model
            trainer.vae_model = vae_model
            
            # Continue training with reduced learning rate
            continued_flrnet = trainer.train_flr_net(
                train_dataset=flrnet_train_dataset,
                val_dataset=flrnet_test_dataset,
                n_sensors=config['n_sensors'],
                epochs=additional_epochs,
                learning_rate=learning_rate,
                pretrained_vae=vae_model,
                latent_dims=config['latent_dims'],
                n_base_features=config['n_base_features'],
                use_perceptual_loss=config['use_perceptual_loss'],
                patience=config['patience'],
                reduce_lr_patience=config['reduce_lr_patience']
            )
            
            print("✅ FLRNet training continuation completed!")
            
        else:
            print("❌ Could not load FLRNet model from checkpoint")
            print("You may need to train FLRNet from scratch first")
    else:
        print("❌ Could not load VAE model - required for FLRNet training")
        
else:
    print("ℹ️  FLRNet training continuation is disabled")
    print("Set continue_flrnet_training = True to enable")

In [None]:
# Advanced: Manual Checkpoint Restoration with Optimizer State
# This method provides more control over the restoration process
use_advanced_restoration = False  # Change to True to enable

if use_advanced_restoration:
    print("🔧 Advanced Checkpoint Restoration...")
    print("=" * 60)
    
    import tensorflow as tf
    from pathlib import Path
    
    # Configuration for restoration
    restore_epoch = 50  # The epoch to restore from (if known)
    target_total_epochs = 150  # Total epochs you want to reach
    
    print(f"📋 Advanced Restoration Configuration:")
    print(f"   - Restore from epoch: {restore_epoch}")
    print(f"   - Target total epochs: {target_total_epochs}")
    print(f"   - Additional epochs: {target_total_epochs - restore_epoch}")
    
    # Method 1: TensorFlow's built-in checkpoint manager
    def setup_checkpoint_manager(model, optimizer, checkpoint_dir):
        """Set up TensorFlow checkpoint manager for proper state restoration."""
        
        # Create checkpoint object
        checkpoint = tf.train.Checkpoint(
            optimizer=optimizer,
            model=model,
            epoch=tf.Variable(0, dtype=tf.int64)
        )
        
        # Create checkpoint manager
        manager = tf.train.CheckpointManager(
            checkpoint,
            directory=str(checkpoint_dir / "tf_checkpoints"),
            max_to_keep=3
        )
        
        return checkpoint, manager
    
    # Example for VAE restoration
    if vae_model is not None:
        print("\n🔧 Setting up VAE checkpoint restoration...")
        
        # Create optimizer (same as training)
        vae_optimizer = tf.keras.optimizers.Adam(
            learning_rate=1e-5,  # Reduced for fine-tuning
            beta_1=0.9, 
            beta_2=0.999
        )
        
        # Compile model with optimizer
        vae_model.compile(optimizer=vae_optimizer)
        
        # Set up checkpoint manager
        vae_checkpoint, vae_manager = setup_checkpoint_manager(
            vae_model, vae_optimizer, checkpoint_directory
        )
        
        # Try to restore latest checkpoint
        if vae_manager.latest_checkpoint:
            vae_checkpoint.restore(vae_manager.latest_checkpoint)
            restored_epoch = int(vae_checkpoint.epoch.numpy())
            print(f"✅ Restored VAE from epoch {restored_epoch}")
            print(f"🎯 Will continue training from epoch {restored_epoch + 1}")
        else:
            print("⚠️  No TensorFlow checkpoint found, using loaded weights")
            restored_epoch = 0
        
        # Calculate remaining epochs
        remaining_epochs = max(0, target_total_epochs - restored_epoch)
        
        if remaining_epochs > 0:
            print(f"\n🚀 Continuing VAE training for {remaining_epochs} more epochs...")
            
            # Custom training loop with checkpoint saving
            for epoch in range(remaining_epochs):
                current_epoch = restored_epoch + epoch + 1
                print(f"\nEpoch {current_epoch}/{target_total_epochs}")
                
                # Train for one epoch
                history = vae_model.fit(
                    vae_train_dataset,
                    validation_data=vae_test_dataset,
                    epochs=1,
                    verbose=1
                )
                
                # Update epoch counter
                vae_checkpoint.epoch.assign(current_epoch)
                
                # Save checkpoint every 10 epochs
                if current_epoch % 10 == 0:
                    save_path = vae_manager.save()
                    print(f"💾 Saved checkpoint: {save_path}")
                
                # Early stopping logic (optional)
                val_loss = history.history.get('val_loss', [0])[-1]
                if val_loss < 0.001:  # Example threshold
                    print(f"🎯 Early stopping - validation loss {val_loss:.6f} is below threshold")
                    break
            
            print("✅ Advanced VAE training continuation completed!")
        else:
            print("ℹ️  Target epochs already reached")
    
    else:
        print("❌ No VAE model loaded for advanced restoration")
        
else:
    print("ℹ️  Advanced checkpoint restoration is disabled")
    print("Set use_advanced_restoration = True to enable")
    print("This method provides:")
    print("   • True optimizer state restoration")
    print("   • Exact epoch continuation")
    print("   • Custom training loop control")
    print("   • Proper checkpoint management")

In [None]:
# Alternative: Load both VAE and FLRNet models at once using trainer
# Uncomment this section if you want to load both models together

# print("\n🔄 Alternative: Loading both VAE and FLRNet models...")
# vae_model, flrnet_model = trainer.load_models_from_checkpoint(
#     n_sensors=config['n_sensors'],
#     checkpoint_dir=checkpoint_directory,
#     latent_dims=config['latent_dims'],
#     n_base_features=config['n_base_features'],
#     use_perceptual_loss=config['use_perceptual_loss'],
#     verbose=True
# )

# if vae_model is not None and flrnet_model is not None:
#     print("🎉 Both models loaded successfully!")
#     print("📈 Models are ready for inference and visualization")
# elif vae_model is not None:
#     print("⚠️  Only VAE model loaded successfully")
# elif flrnet_model is not None:
#     print("⚠️  Only FLRNet model loaded successfully") 
# else:
#     print("❌ Failed to load both models")

# Example: Load FLRNet separately (uncomment if needed)
# print("\n🔄 Loading FLRNet model separately...")
# flrnet_model = trainer.load_flrnet_from_checkpoint(
#     n_sensors=config['n_sensors'],
#     checkpoint_dir=checkpoint_directory,
#     pretrained_vae=vae_model,  # Use the VAE we just loaded
#     latent_dims=config['latent_dims'],
#     n_base_features=config['n_base_features'],
#     use_perceptual_loss=config['use_perceptual_loss'],
#     verbose=True
# )

In [None]:
# Enhanced VAE Model Inference and Visualization
if vae_model is not None:
    print("🎨 Performing enhanced VAE inference and visualization...")
    print("=" * 60)
    
    # Get multiple test samples for comprehensive evaluation
    test_samples = 4
    test_batch = next(iter(vae_test_dataset.batch(test_samples).take(1)))
    test_inputs, test_targets = test_batch
    
    # Handle different input formats (Fourier vs Standard VAE)
    if isinstance(test_inputs, (list, tuple)) and len(test_inputs) == 2:
        # Fourier VAE expects [field, coordinates]
        test_fields, test_coords = test_inputs
        predictions = vae_model([test_fields, test_coords])
        print(f"🌊 Using Fourier VAE with coordinate inputs")
        print(f"   - Field input shape: {test_fields.shape}")
        print(f"   - Coordinate input shape: {test_coords.shape}")
    else:
        # Standard VAE expects just field data
        predictions = vae_model(test_inputs)
        test_fields = test_inputs
        print(f"🔄 Using Standard VAE")
        print(f"   - Input shape: {test_inputs.shape}")
    
    print(f"   - Output shape: {predictions.shape}")
    
    # Convert to numpy for analysis
    test_targets_np = test_targets.numpy()
    predictions_np = predictions.numpy()
    test_fields_np = test_fields.numpy()
    
    # Calculate comprehensive metrics
    mse_per_sample = np.mean((test_targets_np - predictions_np)**2, axis=(1,2,3))
    mae_per_sample = np.mean(np.abs(test_targets_np - predictions_np), axis=(1,2,3))
    
    overall_mse = np.mean(mse_per_sample)
    overall_mae = np.mean(mae_per_sample)
    
    print(f"📊 Model Performance Metrics:")
    print(f"   - Overall MSE: {overall_mse:.6f}")
    print(f"   - Overall MAE: {overall_mae:.6f}")
    print(f"   - Best sample MSE: {np.min(mse_per_sample):.6f}")
    print(f"   - Worst sample MSE: {np.max(mse_per_sample):.6f}")
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(3, test_samples, figsize=(4*test_samples, 12))
    
    for i in range(test_samples):
        # Original field
        im1 = axes[0, i].imshow(test_targets_np[i, :, :, 0], cmap='RdBu_r', aspect='equal', origin='lower')
        axes[0, i].set_title(f'Original Field {i+1}', fontsize=12)
        axes[0, i].axis('off')
        plt.colorbar(im1, ax=axes[0, i], fraction=0.046, pad=0.04)
        
        # Reconstructed field
        im2 = axes[1, i].imshow(predictions_np[i, :, :, 0], cmap='RdBu_r', aspect='equal', origin='lower')
        axes[1, i].set_title(f'Reconstructed Field {i+1}', fontsize=12)
        axes[1, i].axis('off')
        plt.colorbar(im2, ax=axes[1, i], fraction=0.046, pad=0.04)
        
        # Error field
        error = np.abs(test_targets_np[i, :, :, 0] - predictions_np[i, :, :, 0])
        im3 = axes[2, i].imshow(error, cmap='hot', aspect='equal', origin='lower')
        axes[2, i].set_title(f'Error Field {i+1}\nMSE: {mse_per_sample[i]:.4f}', fontsize=12)
        axes[2, i].axis('off')
        plt.colorbar(im3, ax=axes[2, i], fraction=0.046, pad=0.04)
    
    # Add row labels
    axes[0, 0].text(-0.2, 0.5, 'Original', transform=axes[0, 0].transAxes, 
                    rotation=90, va='center', ha='center', fontsize=14, fontweight='bold')
    axes[1, 0].text(-0.2, 0.5, 'Reconstructed', transform=axes[1, 0].transAxes, 
                    rotation=90, va='center', ha='center', fontsize=14, fontweight='bold')
    axes[2, 0].text(-0.2, 0.5, 'Error', transform=axes[2, 0].transAxes, 
                    rotation=90, va='center', ha='center', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.suptitle(f'VAE Model Performance - {test_samples} Test Samples\nOverall MSE: {overall_mse:.6f}, MAE: {overall_mae:.6f}', 
                 fontsize=16, y=0.98)
    plt.show()
    
    # Statistical analysis
    print(f"\n📈 Statistical Analysis:")
    print(f"   - Target field range: [{test_targets_np.min():.4f}, {test_targets_np.max():.4f}]")
    print(f"   - Prediction range: [{predictions_np.min():.4f}, {predictions_np.max():.4f}]")
    print(f"   - Prediction std: {predictions_np.std():.4f}")
    print(f"   - Target std: {test_targets_np.std():.4f}")
    
    # Correlation analysis
    correlation = np.corrcoef(test_targets_np.flatten(), predictions_np.flatten())[0, 1]
    print(f"   - Correlation coefficient: {correlation:.4f}")
    
    print(f"\n✅ VAE model evaluation complete!")
    
else:
    print("❌ Cannot perform inference - VAE model not loaded")
    print("Please run the model loading cell above first")

## Workflow Summary

This notebook has successfully completed a comprehensive flow field reconstruction training and validation workflow:

### ✅ Completed Tasks:

1. **Configuration Management**: Used `ConfigManager` for robust configuration handling
2. **Dataset Loading**: Created coordinate-aware datasets for Fourier features
3. **VAE Training**: Trained VAE with proper coordinate input for Fourier features
4. **VAE Validation**: Comprehensive testing and visualization of VAE performance
5. **FLRNet Training**: Trained sensor-to-field reconstruction model
6. **FLRNet Validation**: Evaluated FLRNet performance with sensor position overlays
7. **Model Saving**: Saved trained models for future use

### 🔧 Key Fixes Applied:

- **Fourier Bug Fix**: Resolved coordinate input issues for Fourier-aware VAE
- **Dataset Coordinate Integration**: Added proper coordinate grid generation
- **Robust Error Handling**: Added validation and error checking throughout
- **Comprehensive Visualization**: Created detailed analysis plots and metrics

### 📊 Results:

The notebook now provides a complete, debugged workflow for training and validating both VAE and FLRNet models with optional Fourier features. All models are properly coordinate-aware and ready for physics-informed machine learning applications.

### 🚀 Next Steps:

- Models are saved and ready for inference
- Configuration can be easily changed for different sensor layouts
- Fourier features are properly integrated and tested
- Framework is extensible for additional physics-informed constraints