# Flow Field Reconstruction with 8 Edge Sensors

This notebook implements a physics-informed machine learning approach to reconstruct flow fields using data from 8 edge sensors. The implementation uses:
- Variational Autoencoder (VAE) for dimensionality reduction and flow field reconstruction
- Fourier feature embeddings for coordinate information
- FLRNet architecture to predict flow fields from sparse sensor measurements

## 1. Import Required Libraries

We'll import the necessary libraries for data manipulation, visualization, and deep learning.

In [12]:
# Standard libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import time
import datetime
import importlib
import sys

# TensorFlow and Keras
import tensorflow as tf
from tensorflow import keras

# Suppress TensorFlow warnings and info messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Force reload of modules
if 'models_improved' in sys.modules:
    importlib.reload(sys.modules['models_improved'])
if 'config_manager' in sys.modules:
    importlib.reload(sys.modules['config_manager'])
if 'data.flow_field_dataset' in sys.modules:
    importlib.reload(sys.modules['data.flow_field_dataset'])

# Now import everything
import models_improved
import config_manager
from data.flow_field_dataset import FlowFieldDatasetCreator

print(f"TensorFlow version: {tf.__version__}")

# Configure GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    print(f"Found {len(gpus)} GPU(s): {gpus}")
    for gpu in gpus:
        print(f"Using GPU: {gpu}")
        # Enable memory growth to avoid allocation issues
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    print("No GPUs found, using CPU")

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

TensorFlow version: 2.8.0
Found 1 GPU(s): [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Using GPU: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


## 2. Configuration Setup

Let's set up the configuration for our model training:

In [13]:
# Load configuration using ConfigManager
config_name = "random_8_fourier"
config_mgr = config_manager.ConfigManager()
hierarchical_config = config_mgr.load_config(config_name)

# Print configuration summary
print(config_mgr.create_config_summary(hierarchical_config))

# Flatten the config for training (with lowercase keys for Python style)
flattened_config = config_manager.flatten_config_for_training(hierarchical_config)

# Convert to lowercase keys for consistent Python style
config = {}
for key, value in flattened_config.items():
    config[key.lower()] = value

print("\n🔧 Final Configuration (lowercase keys):")
print(f"   Model name: {config['model_name']}")
print(f"   Use Fourier: {config['use_fourier']}")
print(f"   Use perceptual loss: {config['use_perceptual_loss']}")
print(f"   Input shape: {config['input_shape']}")
print(f"   Number of sensors: {config['n_sensors']}")
print(f"   Latent dims: {config['latent_dims']}")
print(f"   Base features: {config['n_base_features']}")
print(f"   Batch size: {config['batch_size']}")
print(f"   VAE epochs: {config['vae_epochs']}")
print(f"   FLRNet epochs: {config['flr_epochs']}")
print(f"   VAE learning rate: {config['vae_learning_rate']}")
print(f"   FLRNet learning rate: {config['flr_learning_rate']}")
print(f"   Dataset path: {config['dataset_path']}")
print(f"   Checkpoint dir: {config['checkpoint_dir']}")
print(f"   Logs dir: {config['logs_dir']}")

# Verify essential config keys exist
required_keys = ['model_name', 'use_fourier', 'use_perceptual_loss', 'input_shape', 
                 'n_sensors', 'latent_dims', 'n_base_features', 'batch_size', 
                 'vae_epochs', 'flr_epochs', 'vae_learning_rate', 'flr_learning_rate',
                 'dataset_path', 'checkpoint_dir', 'logs_dir']
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
    print(f"⚠️  Missing required config keys: {missing_keys}")
else:
    print("✅ All required configuration keys are present")


📋 Configuration Summary
🏗️  Model Architecture:
   - Fourier Enhancement: True
   - Perceptual Loss: True
   - Input Shape: [128, 256, 1]
   - Latent Dimensions: 8
   - Base Features: 64

📡 Sensor Configuration:
   - Layout: random
   - Number of Sensors: 8
   - Dataset: data/datasets\dataset_random_8.npz

🚀 Training Parameters:
   - VAE Epochs: 250
   - FLRNet Epochs: 150
   - VAE Learning Rate: 0.0001
   - FLRNet Learning Rate: 0.0001
   - Batch Size: 8
   - Test Split: 0.2

💾 Output Configuration:
   - Model Name: fourierTrue_percepTrue_random_8
   - Checkpoints: ./checkpoints\fourierTrue_percepTrue_random_8
   - Logs: ./logs\fourierTrue_percepTrue_random_8
   - Save Best Model: True
   - Save Last Model: True


🔧 Final Configuration (lowercase keys):
   Model name: fourierTrue_percepTrue_random_8
   Use Fourier: True
   Use perceptual loss: True
   Input shape: (128, 256, 1)
   Number of sensors: 8
   Latent dims: 8
   Base features: 64
   Batch size: 8
   VAE epochs: 250
   FLRNe

## 3. Load and Prepare Dataset

We'll load the flow field dataset and the sensor layout for 8 edge sensors:

In [14]:
# Load and prepare dataset created from data_creation_and_viz.ipynb
print(f"📂 Loading dataset from: {config['dataset_path']}")

# Parse the dataset filename to get layout and n_sensors
dataset_filename = Path(config['dataset_path']).name
# Expected format: dataset_edge_8.npz
parts = dataset_filename.split('_')
layout_type = parts[1]  # 'edge'
n_sensors = int(parts[2].split('.')[0])  # 8

print(f"📊 Dataset parameters:")
print(f"   Layout type: {layout_type}")
print(f"   Number of sensors: {n_sensors}")

# Load the dataset directly from the NPZ file
print(f"📁 Loading dataset file: {config['dataset_path']}")
data = np.load(config['dataset_path'])

# Check what keys are available in the dataset
print(f"📋 Available keys in dataset: {list(data.keys())}")

# Create dataset dictionary
dataset = {key: data[key] for key in data.keys()}

# Print dataset information
print(f"📊 Dataset loaded successfully:")
for key, value in dataset.items():
    if isinstance(value, np.ndarray):
        print(f"   {key}: {value.shape} ({value.dtype})")
    else:
        print(f"   {key}: {value}")

# Extract sensor positions
sensor_positions = dataset['sensor_positions']
print(f"📍 Sensor positions shape: {sensor_positions.shape}")

# Create dataset creator instance for TensorFlow dataset creation
creator = FlowFieldDatasetCreator(
    output_path="./data/",
    domain_shape=config['input_shape'][:2],  # (height, width)
    use_synthetic_data=False  # Don't create synthetic data, just use for TF dataset creation
)

# Create TensorFlow datasets using the creator's method
train_dataset, test_dataset = creator.create_tensorflow_dataset(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    test_split=config['test_split']
)

print(f"\n📊 TensorFlow datasets created:")
print(f"   Train dataset: {train_dataset}")
print(f"   Test dataset: {test_dataset}")

# Function to add coordinate grids to field data for Fourier-aware VAE training
def add_coordinate_grid(batch):
    """Add coordinate grid to field data for Fourier VAE training"""
    field_data = batch['field_data']
    
    # Get dimensions
    batch_size = tf.shape(field_data)[0]
    height = tf.shape(field_data)[1]
    width = tf.shape(field_data)[2]
    
    # Create normalized coordinate grids [0, 1] - match field dimensions
    x_coords = tf.linspace(0.0, 1.0, width)   # Width corresponds to x
    y_coords = tf.linspace(0.0, 1.0, height)  # Height corresponds to y
    
    # Create meshgrid to match image indexing: [height, width, 2]
    y_grid, x_grid = tf.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Stack to create coordinate grid (height, width, 2)
    coord_grid = tf.stack([x_grid, y_grid], axis=-1)
    
    # Expand to batch size (batch_size, height, width, 2)
    coord_batch = tf.tile(tf.expand_dims(coord_grid, 0), [batch_size, 1, 1, 1])
    
    # Update batch to include coordinates
    return {
        'field_data': field_data,
        'sensor_data': batch['sensor_data'],
        'coordinates': coord_batch
    }

# Add coordinate grids to datasets
coord_train_dataset = train_dataset.map(add_coordinate_grid)
coord_test_dataset = test_dataset.map(add_coordinate_grid)

print(f"\n📊 Coordinate-aware datasets created:")
print(f"   Train dataset with coordinates: {coord_train_dataset}")
print(f"   Test dataset with coordinates: {coord_test_dataset}")

# Create specialized datasets for VAE and FLRNet training
# VAE trains on field reconstruction WITH coordinates for Fourier features
if config['use_fourier']:
    print("\n🌊 Creating Fourier-aware VAE datasets...")
    # For Fourier VAE: input = (field, coordinates), output = field
    vae_train_dataset = coord_train_dataset.map(
        lambda batch: ((batch['field_data'], batch['coordinates']), batch['field_data'])
    )
    vae_test_dataset = coord_test_dataset.map(
        lambda batch: ((batch['field_data'], batch['coordinates']), batch['field_data'])
    )
    print("✅ Fourier-aware VAE datasets created")
else:
    print("\n🔄 Creating standard VAE datasets...")
    # Standard VAE: input = field, output = field
    vae_train_dataset = coord_train_dataset.map(
        lambda batch: (batch['field_data'], batch['field_data'])
    )
    vae_test_dataset = coord_test_dataset.map(
        lambda batch: (batch['field_data'], batch['field_data'])
    )
    print("✅ Standard VAE datasets created")

# FLRNet trains on sensor-to-field reconstruction 
if config['use_fourier']:
    print("\n🌊 Creating Fourier-aware FLRNet datasets...")
    # For Fourier FLRNet: input = (sensor, field, coordinates), output = field
    flrnet_train_dataset = coord_train_dataset.map(
        lambda batch: (batch['sensor_data'], batch['field_data'], batch['coordinates'])
    )
    flrnet_test_dataset = coord_test_dataset.map(
        lambda batch: (batch['sensor_data'], batch['field_data'], batch['coordinates'])
    )
    print("✅ Fourier-aware FLRNet datasets created")
else:
    print("\n🔄 Creating standard FLRNet datasets...")
    # Standard FLRNet: input = (sensor, field), output = field
    flrnet_train_dataset = coord_train_dataset.map(
        lambda batch: (batch['sensor_data'], batch['field_data'])
    )
    flrnet_test_dataset = coord_test_dataset.map(
        lambda batch: (batch['sensor_data'], batch['field_data'])
    )
    print("✅ Standard FLRNet datasets created")

print(f"\n📊 Specialized datasets:")
print(f"   VAE train dataset: {vae_train_dataset}")
print(f"   VAE test dataset: {vae_test_dataset}")
print(f"   FLRNet train dataset: {flrnet_train_dataset}")
print(f"   FLRNet test dataset: {flrnet_test_dataset}")

# Get a sample to verify data shapes
print(f"\n📊 Data shape verification:")
for batch in coord_train_dataset.take(1):
    print(f"   Sensor data shape: {batch['sensor_data'].shape}")
    print(f"   Field data shape: {batch['field_data'].shape}")
    print(f"   Coordinates shape: {batch['coordinates'].shape}")
    break

# Verify VAE dataset structure
print(f"\n📊 VAE dataset structure verification:")
for vae_batch in vae_train_dataset.take(1):
    if config['use_fourier']:
        inputs, targets = vae_batch
        if isinstance(inputs, tuple):
            field_input, coord_input = inputs
            print(f"   VAE field input shape: {field_input.shape}")
            print(f"   VAE coordinate input shape: {coord_input.shape}")
            print(f"   VAE target shape: {targets.shape}")
        else:
            print(f"   Unexpected VAE input format: {type(inputs)}")
    else:
        inputs, targets = vae_batch
        print(f"   VAE input shape: {inputs.shape}")
        print(f"   VAE target shape: {targets.shape}")
    break

📂 Loading dataset from: data/datasets\dataset_random_8.npz
📊 Dataset parameters:
   Layout type: random
   Number of sensors: 8
📁 Loading dataset file: data/datasets\dataset_random_8.npz
📋 Available keys in dataset: ['sensor_data', 'field_data', 'sensor_positions', 'reynolds_numbers', 'layout_type', 'n_sensors']
📊 Dataset loaded successfully:
   sensor_data: (28, 8, 39) (float64)
   field_data: (28, 128, 256, 39) (float64)
   sensor_positions: (8, 2) (float64)
   reynolds_numbers: (28,) (int32)
   layout_type: () (<U6)
   n_sensors: () (int32)
📍 Sensor positions shape: (8, 2)
📊 Dataset loaded successfully:
   sensor_data: (28, 8, 39) (float64)
   field_data: (28, 128, 256, 39) (float64)
   sensor_positions: (8, 2) (float64)
   reynolds_numbers: (28,) (int32)
   layout_type: () (<U6)
   n_sensors: () (int32)
📍 Sensor positions shape: (8, 2)
Dataset reshaped:
  Original sensor data: (28, 8, 39)
  Reshaped sensor data: (1092, 8)
  Original field data: (28, 128, 256, 39)
  Reshaped field d

## 5. Train VAE and FLRNet Models

Now we'll train our models using the FLRTrainer:
1. First, the Variational Autoencoder (VAE) to learn a compressed representation of flow fields
2. Then, the FLRNet to reconstruct flow fields from sparse sensor measurements

In [15]:
import os
import shutil
import time
import glob

# Clear TensorFlow session to fix any shape mismatches
print("🔄 Clearing TensorFlow session...")
tf.keras.backend.clear_session()

# Initialize FLRTrainer from models_improved.py
print("🚀 Initializing FLRTrainer...")

# Create the FLRTrainer instance
trainer = models_improved.FLRTrainer(
    input_shape=config['input_shape'],
    use_fourier=config['use_fourier'],
    checkpoint_dir=config['checkpoint_dir'],
    logs_dir=config['logs_dir'],
    model_name=config['model_name'],
    save_best_model=config['save_best_model'],
    save_last_model=config['save_last_model'],
    gradient_clip_norm=config['gradient_clip_norm']
)

print(f"✅ FLRTrainer initialized:")
print(f"   Input shape: {config['input_shape']}")
print(f"   Use Fourier: {config['use_fourier']}")
print(f"   Model name: {config['model_name']}")
print(f"   Gradient clipping: {config['gradient_clip_norm']}")
print(f"   Checkpoints: {config['checkpoint_dir']}")
print(f"   Logs: {config['logs_dir']}")
print(f"   Save best model: {config['save_best_model']}")
print(f"   Save last model: {config['save_last_model']}")

# Set training flags
train_vae_model = True  # Set to False if you want to skip VAE training
train_flrnet_model = True  # Set to True to test FLRNet with proper Fourier features

print(f"\n🔧 Training configuration:")
print(f"   Train VAE: {train_vae_model}")
print(f"   Train FLRNet: {train_flrnet_model}")

# Quick test to verify dataset shapes are correct
print(f"\n🔍 Verifying dataset shapes...")
for batch in coord_train_dataset.take(1):
    print(f"   Field data shape: {batch['field_data'].shape}")
    print(f"   Coordinates shape: {batch['coordinates'].shape}")
    break

# Test VAE dataset format
for vae_batch in vae_train_dataset.take(1):
    inputs, targets = vae_batch
    if isinstance(inputs, tuple):
        field_input, coord_input = inputs
        print(f"   VAE field input shape: {field_input.shape}")
        print(f"   VAE coordinate input shape: {coord_input.shape}")
        print(f"   VAE target shape: {targets.shape}")
        
        # Check if shapes match now
        if field_input.shape[1:3] == coord_input.shape[1:3]:
            print(f"   ✅ Coordinate shapes now match field shapes!")
        else:
            print(f"   ❌ Shape mismatch: field {field_input.shape[1:3]} vs coord {coord_input.shape[1:3]}")
    break



🔄 Clearing TensorFlow session...
🚀 Initializing FLRTrainer...
✅ FLRTrainer initialized:
   Input shape: (128, 256, 1)
   Use Fourier: True
   Model name: fourierTrue_percepTrue_random_8
   Gradient clipping: 2.0
   Checkpoints: ./checkpoints\fourierTrue_percepTrue_random_8
   Logs: ./logs\fourierTrue_percepTrue_random_8
   Save best model: True
   Save last model: True

🔧 Training configuration:
   Train VAE: True
   Train FLRNet: True

🔍 Verifying dataset shapes...
   Field data shape: (8, 128, 256, 1)
   Coordinates shape: (8, 128, 256, 2)
   Field data shape: (8, 128, 256, 1)
   Coordinates shape: (8, 128, 256, 2)
   VAE field input shape: (8, 128, 256, 1)
   VAE coordinate input shape: (8, 128, 256, 2)
   VAE target shape: (8, 128, 256, 1)
   ✅ Coordinate shapes now match field shapes!
   VAE field input shape: (8, 128, 256, 1)
   VAE coordinate input shape: (8, 128, 256, 2)
   VAE target shape: (8, 128, 256, 1)
   ✅ Coordinate shapes now match field shapes!


## Load Trained Models from Checkpoint

Now let's load the trained models from saved checkpoints using the FLRTrainer's built-in methods. The trainer provides several loading options:

- **`load_vae_from_checkpoint()`**: Load only the VAE model
- **`load_flrnet_from_checkpoint()`**: Load only the FLRNet model  
- **`load_models_from_checkpoint()`**: Load both models at once

### Features:
- **Automatic checkpoint detection**: Finds the best available checkpoint (best → last → final_weights)
- **Robust path handling**: Uses correct TensorFlow checkpoint format (no file extensions)
- **Error handling**: Graceful fallback and clear error messages
- **Architecture consistency**: Ensures loaded model matches trainer configuration
- **Smart dependencies**: FLRNet loading automatically handles VAE dependencies

In [16]:
import os
import glob
from pathlib import Path
import tensorflow as tf

# Load the VAE model using the trainer's built-in method
print("🚀 Loading trained VAE model from checkpoint using FLRTrainer...")
print("=" * 60)

# Fix checkpoint path - use the directory, not the specific file
checkpoint_directory = Path(config['checkpoint_dir'])
print(f"📂 Using checkpoint directory: {checkpoint_directory}")

# List available checkpoints for debugging
if checkpoint_directory.exists():
    print(f"📋 Available checkpoint files:")
    for file in sorted(checkpoint_directory.iterdir()):
        print(f"   - {file.name}")
else:
    print(f"❌ Checkpoint directory not found: {checkpoint_directory}")

# Use the trainer's built-in checkpoint loading method
vae_model = trainer.load_vae_from_checkpoint(
    checkpoint_dir=checkpoint_directory,
    latent_dims=config['latent_dims'],
    n_base_features=config['n_base_features'],
    use_perceptual_loss=config['use_perceptual_loss'],
    verbose=True
)

if vae_model is not None:
    print("\n🎉 VAE model loaded successfully!")
    print("📈 Model is ready for inference and visualization")
    
    # Verify the model works with a test prediction
    try:
        # Get a small batch for testing
        test_batch = next(iter(vae_test_dataset.take(1)))
        test_inputs, test_targets = test_batch
        
        # Handle different input formats (Fourier vs Standard VAE)
        if isinstance(test_inputs, (list, tuple)) and len(test_inputs) == 2:
            # Fourier VAE expects [field, coordinates]
            test_field, test_coord = test_inputs
            test_prediction = vae_model([test_field[:1], test_coord[:1]])
            print(f"✅ Fourier VAE verification successful!")
            print(f"   - Field input shape: {test_field[:1].shape}")
            print(f"   - Coordinate input shape: {test_coord[:1].shape}")
        else:
            # Standard VAE expects just field data
            test_prediction = vae_model(test_inputs[:1])
            print(f"✅ Standard VAE verification successful!")
            print(f"   - Input shape: {test_inputs[:1].shape}")
        
        print(f"   - Output shape: {test_prediction.shape}")
        print(f"   - Output range: [{test_prediction.numpy().min():.4f}, {test_prediction.numpy().max():.4f}]")
        
    except Exception as e:
        print(f"⚠️  Model verification failed: {str(e)}")
        print("Model loaded but may have compatibility issues")
        import traceback
        traceback.print_exc()
        
else:
    print("\n❌ Failed to load VAE model")
    print("🔧 Troubleshooting steps:")
    print("   1. Check if training completed successfully")
    print("   2. Verify checkpoint files exist in the directory")
    print("   3. Ensure model configuration matches training setup")
    print("   4. Try running the training cell again if checkpoints are missing")

🚀 Loading trained VAE model from checkpoint using FLRTrainer...
📂 Using checkpoint directory: checkpoints\fourierTrue_percepTrue_random_8
📋 Available checkpoint files:
   - checkpoint
   - checkpoint_fourierTrue_percepTrue_random_8_vae_best.data-00000-of-00001
   - checkpoint_fourierTrue_percepTrue_random_8_vae_best.index
🔍 Loading VAE model from checkpoint directory: checkpoints\fourierTrue_percepTrue_random_8
✅ Found vae_best checkpoint: checkpoints\fourierTrue_percepTrue_random_8\checkpoint_fourierTrue_percepTrue_random_8_vae_best
🌊 VAE model built for Fourier features
📋 VAE Model Architecture:
   - Input shape: (128, 256, 1)
   - Latent dims: 8
   - Base features: 64
   - Use Fourier: True
   - Perceptual loss: True
🌊 VAE model built for Fourier features
📋 VAE Model Architecture:
   - Input shape: (128, 256, 1)
   - Latent dims: 8
   - Base features: 64
   - Use Fourier: True
   - Perceptual loss: True
✅ Successfully loaded VAE model from vae_best checkpoint!

🎉 VAE model loaded su

In [17]:
# Train FLRNet model using FLRTrainer
if train_flrnet_model:
    print("\n🔄 Starting FLRNet Training using FLRTrainer...")
    print("="*60)
    
    # Train FLRNet using the trainer
    flr_model = trainer.train_flr_net(
        train_dataset=flrnet_train_dataset,
        val_dataset=flrnet_test_dataset,
        n_sensors=config['n_sensors'],
        epochs=config['flr_epochs'],
        learning_rate=config['flr_learning_rate'],
        pretrained_vae=vae_model,  # Use the VAE model we just trained
        latent_dims=config['latent_dims'],
        n_base_features=config['n_base_features'],
        use_perceptual_loss=config['use_perceptual_loss'],
        patience=config['patience'],
        reduce_lr_patience=config['reduce_lr_patience']
    )
    
    print(f"✅ FLRNet training completed successfully!")
    
else:
    print("⚠️ Skipping FLRNet training. Loading existing model...")
    # The trainer will handle loading existing FLRNet weights automatically
    flr_model = None


🔄 Starting FLRNet Training using FLRTrainer...
🚀 Training FLRNet Model...
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
🛡️ Checkpoint saving will be disabled for epochs 1-30, enabled from epoch 31
Epoch 1/150
Epoch 1/150
Epoch 1: val_reconstruction_loss improved from inf to 15008.85742 (saving disabled for epochs 1-30)
Epoch 1: val_reconstruction_loss improved from inf to 15008.85742 (saving disabled for epochs 1-30)
Epoch 1: val_reconstruction_loss improved from inf to 15008.85742 (saving disabled for epochs 1-30)
Epoch 2/150
Epoch 2/150
Epoch 2: val_reconstruction_loss improved from 15008.85742 to 14591.94922 (saving disabled for epochs 1-30)
Epoch 2: val_reconstruction_loss improved from 15008.85742 to 14591.94922 (saving disabled for epochs 1-30)
Epoch 2: val_reconstruction_loss improved from 15008.

KeyboardInterrupt: 

## FLRNet Validation and Visualization

After training the FLRNet model, we can evaluate its performance by:
1. Testing sensor-to-field reconstruction accuracy
2. Comparing predicted vs ground truth flow fields
3. Analyzing reconstruction quality metrics
4. Visualizing sensor positions overlaid on predictions

This section validates the main FLRNet model performance.

In [None]:
# 2. Load FLRNet Model
print("🚀 Loading trained FLRNet model from checkpoint using FLRTrainer...")
print("=" * 60)

# Use the same checkpoint directory
checkpoint_directory = Path(config['checkpoint_dir'])
print(f"📂 Using checkpoint directory: {checkpoint_directory}")

# List available checkpoints for debugging
if checkpoint_directory.exists():
    print(f"📋 Available FLRNet checkpoint files:")
    flrnet_files = [f for f in sorted(checkpoint_directory.iterdir()) if 'flrnet' in f.name.lower()]
    for file in flrnet_files:
        print(f"   - {file.name}")
    
    if not flrnet_files:
        print("   ⚠️  No FLRNet checkpoint files found")
else:
    print(f"❌ Checkpoint directory not found: {checkpoint_directory}")

# Use the trainer's built-in checkpoint loading method for FLRNet
flr_model = trainer.load_flrnet_from_checkpoint(
    checkpoint_dir=checkpoint_directory,
    n_sensors=config['n_sensors'],
    latent_dims=config['latent_dims'],
    n_base_features=config['n_base_features'],
    use_perceptual_loss=config['use_perceptual_loss'],
    pretrained_vae=vae_model,  # Use the VAE model we just loaded
    freeze_autoencoder=True,   # Keep autoencoder frozen for inference
    verbose=True
)

if flr_model is not None:
    print("\n🎉 FLRNet model loaded successfully!")
    print("📈 Model is ready for sensor-to-field reconstruction")
    
    # Verify the model works with a test prediction
    try:
        # Get a small batch for testing
        test_batch = next(iter(flrnet_test_dataset.take(1)))
        
        # Handle different input formats based on Fourier configuration
        if config['use_fourier']:
            # Fourier FLRNet expects (sensor_data, field_data, coordinates)
            sensor_data, field_data, coord_data = test_batch
            print(f"✅ Fourier FLRNet input format detected:")
            print(f"   - Sensor data shape: {sensor_data[:1].shape}")
            print(f"   - Field data shape: {field_data[:1].shape}")
            print(f"   - Coordinate data shape: {coord_data[:1].shape}")
            
            # Test prediction with all inputs (training format)
            test_prediction = flr_model([sensor_data[:1], field_data[:1], coord_data[:1]], training=False)
            print(f"✅ Fourier FLRNet verification successful!")
            
            # Test prediction with just sensor data (inference format)
            sensor_only_prediction = flr_model(sensor_data[:1], training=False)
            print(f"✅ Sensor-only prediction also works!")
            
        else:
            # Standard FLRNet expects (sensor_data, field_data)
            sensor_data, field_data = test_batch
            print(f"✅ Standard FLRNet input format detected:")
            print(f"   - Sensor data shape: {sensor_data[:1].shape}")
            print(f"   - Field data shape: {field_data[:1].shape}")
            
            # Test prediction
            test_prediction = flr_model([sensor_data[:1], field_data[:1]], training=False)
            print(f"✅ Standard FLRNet verification successful!")
            
            # Test prediction with just sensor data (inference format)
            sensor_only_prediction = flr_model(sensor_data[:1], training=False)
            print(f"✅ Sensor-only prediction also works!")
        
        print(f"   - Output shape: {test_prediction.shape}")
        print(f"   - Output range: [{test_prediction.numpy().min():.4f}, {test_prediction.numpy().max():.4f}]")
        
        # Check if model has the new perceptual loss features
        if hasattr(flr_model, 'use_perceptual_loss'):
            print(f"   - Perceptual loss enabled: {flr_model.use_perceptual_loss}")
            print(f"   - Available metrics: {[metric.name for metric in flr_model.metrics]}")
        else:
            print("   - Legacy model (no perceptual loss configuration)")
        
    except Exception as e:
        print(f"⚠️  Model verification failed: {str(e)}")
        print("Model loaded but may have compatibility issues")
        import traceback
        traceback.print_exc()
        
else:
    print("\n❌ Failed to load FLRNet model")
    print("🔧 Troubleshooting steps:")
    print("   1. Check if FLRNet training completed successfully")
    print("   2. Verify FLRNet checkpoint files exist in the directory")
    print("   3. Ensure model configuration matches training setup")
    print("   4. Make sure VAE model is loaded first (FLRNet depends on it)")
    print("   5. Try running the FLRNet training cell if checkpoints are missing")

In [18]:
# Continue FLRNet Training from Checkpoint
# Set this to True if you want to continue training from the loaded checkpoint
continue_flrnet_training = False  # Change to True to enable

# Toggle for perceptual loss during continuation (can be different from original training)
use_perceptual_loss_continuation = True  # Change to False to disable perceptual loss

if continue_flrnet_training and flr_model is not None:
    print("🔄 Continuing FLRNet training from loaded checkpoint...")
    print("=" * 60)
    
    # Configuration for continued training
    additional_epochs = 50
    learning_rate = 1e-5  # Lower learning rate for fine-tuning
    
    print(f"📋 Continuation Configuration:")
    print(f"   - Additional epochs: {additional_epochs}")
    print(f"   - Learning rate: {learning_rate} (reduced for fine-tuning)")
    print(f"   - Perceptual loss: {'✅ ENABLED' if use_perceptual_loss_continuation else '❌ DISABLED'}")
    print(f"   - Model weights: ✅ PRESERVED from checkpoint")
    print(f"   - VAE dependency: {'✅ LOADED' if vae_model is not None else '❌ MISSING'}")
    
    # Ensure VAE model is available (required for FLRNet)
    if vae_model is None:
        print("\n🔄 Loading VAE model first (required for FLRNet)...")
        vae_model = trainer.load_vae_from_checkpoint(
            checkpoint_dir=checkpoint_directory,
            latent_dims=config['latent_dims'],
            n_base_features=config['n_base_features'],
            use_perceptual_loss=config['use_perceptual_loss'],
            verbose=False
        )
    
    if vae_model is not None:
        print("\n🚀 Method 1: Using FLRTrainer (Recommended)")
        
        # Set the loaded models in the trainer
        trainer.flr_model = flr_model
        trainer.vae_model = vae_model
        
        # Continue training with the trainer's enhanced method
        continued_flrnet = trainer.continue_flrnet_training(
            train_dataset=flrnet_train_dataset,
            val_dataset=flrnet_test_dataset,
            epochs=additional_epochs,
            learning_rate=learning_rate,
            patience=config['patience'],
            reduce_lr_patience=config['reduce_lr_patience'],
            use_perceptual_loss=use_perceptual_loss_continuation  # New option for perceptual loss
        )
        
        print("✅ FLRNet training continuation completed!")
        
        # Update the model reference
        if continued_flrnet is not None:
            flr_model = continued_flrnet
            print("🔄 Model reference updated to continued version")
        
    else:
        print("❌ Could not load VAE model - required for FLRNet training")
        print("   FLRNet training requires a pre-trained VAE model")
        
elif continue_flrnet_training and flr_model is None:
    print("❌ Cannot continue FLRNet training: No model loaded from checkpoint")
    print("   First load a model using the FLRNet checkpoint loading cell above")
    
else:
    print("ℹ️  FLRNet training continuation is disabled")
    print("   Set continue_flrnet_training = True to enable")
    if not continue_flrnet_training:
        print(f"   Current perceptual loss setting: {'✅ ENABLED' if use_perceptual_loss_continuation else '❌ DISABLED'}")
        print("   (can be changed via use_perceptual_loss_continuation variable)")

ℹ️  FLRNet training continuation is disabled
   Set continue_flrnet_training = True to enable
   Current perceptual loss setting: ✅ ENABLED
   (can be changed via use_perceptual_loss_continuation variable)


In [None]:
# Enhanced FLRNet Test Visualization
import matplotlib.pyplot as plt
import numpy as np

if trainer.flr_model is not None or 'test_flr_model' in locals():
    print("🎨 === FLRNet Test Visualization ===")
    
    # Use the test model if available, otherwise use trainer model
    flr_model_to_use = test_flr_model if 'test_flr_model' in locals() else trainer.flr_model
    print(f"Model type: {'Fourier-aware' if config['use_fourier'] else 'Standard'} FLRNet")
    
    # Get test data batch directly from the properly formatted dataset
    test_batch = next(iter(flrnet_test_dataset.take(1)))
    print("Using FLRNet test dataset")
    
    # Handle different input formats based on Fourier configuration
    if config['use_fourier']:
        # Fourier FLRNet expects (sensor_data, field_data, coordinates)
        sensor_data, field_data, coord_data = test_batch
        print(f"✅ Fourier FLRNet inputs detected:")
        print(f"   Sensor data shape: {sensor_data.shape}")
        print(f"   Field data shape: {field_data.shape}")
        print(f"   Coordinate data shape: {coord_data.shape}")
        
        # Generate FLRNet predictions
        print("\n🔮 Generating FLRNet reconstructions...")
        try:
            # Limit to manageable batch size for visualization
            max_samples = min(4, sensor_data.shape[0])
            
            limited_sensor = sensor_data[:max_samples]
            limited_field = field_data[:max_samples]
            limited_coord = coord_data[:max_samples]
            
            predictions = flr_model_to_use([limited_sensor, limited_field, limited_coord], training=False)
            print(f"✅ Reconstruction successful! Shape: {predictions.shape}")
            
            # Use limited_field as targets for comparison
            targets = limited_field
            
        except Exception as e:
            print(f"❌ FLRNet prediction failed: {e}")
            import traceback
            traceback.print_exc()

    else:
        # Standard FLRNet expects (sensor_data, field_data)
        sensor_data, field_data = test_batch
        print(f"✅ Standard FLRNet inputs:")
        print(f"   Sensor data shape: {sensor_data.shape}")
        print(f"   Field data shape: {field_data.shape}")
        
        # Generate FLRNet predictions
        print("\n🔮 Generating FLRNet reconstructions...")
        try:
            # Limit to manageable batch size for visualization
            max_samples = min(4, sensor_data.shape[0])
            
            limited_sensor = sensor_data[:max_samples]
            limited_field = field_data[:max_samples]
            
            predictions = flr_model_to_use([limited_sensor, limited_field], training=False)
            print(f"✅ Reconstruction successful! Shape: {predictions.shape}")
            
            # Use limited_field as targets for comparison
            targets = limited_field
            
        except Exception as e:
            print(f"❌ FLRNet prediction failed: {e}")
            import traceback
            traceback.print_exc()
    
    # Calculate reconstruction metrics
    mse = np.mean((targets.numpy() - predictions) ** 2)
    mae = np.mean(np.abs(targets.numpy() - predictions))
    max_error = np.max(np.abs(targets.numpy() - predictions))
    
    print(f"\n📊 FLRNet Reconstruction Metrics:")
    print(f"   MSE: {mse:.6f} ({'Excellent' if mse < 0.01 else 'Good' if mse < 0.05 else 'Needs improvement'})")
    print(f"   MAE: {mae:.6f}")
    print(f"   Max Error: {max_error:.6f}")
    
    # Create comprehensive visualization
    n_samples = max_samples
    fig, axes = plt.subplots(4, n_samples, figsize=(5*n_samples, 16))
    
    if n_samples == 1:
        axes = axes.reshape(-1, 1)
    
    for i in range(n_samples):
        # Sensor readings visualization (create a sensor field overlay)
        ax1 = axes[0, i]
        # Create a field to visualize sensor positions and values
        sensor_field = np.zeros((config['input_shape'][0], config['input_shape'][1]))
        
        # If we have sensor positions from dataset, use them
        if 'sensor_positions' in locals():
            for j, (x_pos, y_pos) in enumerate(sensor_positions):
                if j < len(limited_sensor[i]):
                    # Convert normalized positions to pixel coordinates
                    x_idx = int(x_pos * (config['input_shape'][1] - 1))
                    y_idx = int(y_pos * (config['input_shape'][0] - 1))
                    if 0 <= x_idx < config['input_shape'][1] and 0 <= y_idx < config['input_shape'][0]:
                        sensor_field[y_idx, x_idx] = limited_sensor[i, j]
        
        im1 = ax1.imshow(sensor_field, cmap='RdBu_r', origin='lower')
        ax1.set_title(f'Sensor Readings {i+1}', fontweight='bold')
        ax1.set_xlabel('X Position')
        ax1.set_ylabel('Y Position')
        plt.colorbar(im1, ax=ax1, shrink=0.8)
        
        # Add sensor positions as scatter points if available
        if 'sensor_positions' in locals():
            sensor_x = sensor_positions[:, 0] * (config['input_shape'][1] - 1)
            sensor_y = sensor_positions[:, 1] * (config['input_shape'][0] - 1)
            ax1.scatter(sensor_x, sensor_y, c='black', s=50, marker='o', edgecolors='white', linewidth=1)
        
        # Ground truth field
        im2 = axes[1, i].imshow(targets[i, :, :, 0], cmap='RdBu_r', origin='lower')
        axes[1, i].set_title(f'Ground Truth Field {i+1}', fontweight='bold')
        axes[1, i].set_xlabel('X Position')
        axes[1, i].set_ylabel('Y Position')
        plt.colorbar(im2, ax=axes[1, i], shrink=0.8)
        
        # FLRNet prediction
        im3 = axes[2, i].imshow(predictions[i, :, :, 0], cmap='RdBu_r', origin='lower')
        axes[2, i].set_title(f'FLRNet Prediction {i+1}', fontweight='bold')
        axes[2, i].set_xlabel('X Position')
        axes[2, i].set_ylabel('Y Position')
        plt.colorbar(im3, ax=axes[2, i], shrink=0.8)
        
        # Error map
        error = np.abs(targets[i, :, :, 0] - predictions[i, :, :, 0])
        im4 = axes[3, i].imshow(error, cmap='hot', origin='lower')
        axes[3, i].set_title(f'Reconstruction Error {i+1}', fontweight='bold')
        axes[3, i].set_xlabel('X Position')
        axes[3, i].set_ylabel('Y Position')
        plt.colorbar(im4, ax=axes[3, i], shrink=0.8)
        
        # Add error statistics as text
        sample_mse = np.mean(error**2)
        sample_max = np.max(error)
        axes[3, i].text(0.02, 0.98, f'MSE: {sample_mse:.4f}\nMax: {sample_max:.4f}', 
                       transform=axes[3, i].transAxes, verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                       fontsize=9, fontweight='bold')
    
    plt.tight_layout()
    plt.suptitle(f'FLRNet Reconstruction Results - {config_name}', y=1.02, fontsize=16, fontweight='bold')
    plt.show()
    
    # Statistical analysis
    print(f"\n📈 Statistical Comparison:")
    orig_stats = {
        'mean': np.mean(targets.numpy()),
        'std': np.std(targets.numpy()),
        'min': np.min(targets.numpy()),
        'max': np.max(targets.numpy())
    }
    
    pred_stats = {
        'mean': np.mean(predictions.numpy()),
        'std': np.std(predictions.numpy()),
        'min': np.min(predictions.numpy()),
        'max': np.max(predictions.numpy())
    }
    
    print(f"   Original  - Mean: {orig_stats['mean']:.4f}, Std: {orig_stats['std']:.4f}, Range: [{orig_stats['min']:.4f}, {orig_stats['max']:.4f}]")
    print(f"   Predicted - Mean: {pred_stats['mean']:.4f}, Std: {pred_stats['std']:.4f}, Range: [{pred_stats['min']:.4f}, {pred_stats['max']:.4f}]")
    
    # Distribution comparison
    plt.figure(figsize=(15, 5))
    
    # Value distributions
    plt.subplot(1, 3, 1)
    plt.hist(targets.numpy().flatten(), bins=50, alpha=0.7, label='Ground Truth', density=True, color='blue')
    plt.hist(predictions.numpy().flatten(), bins=50, alpha=0.7, label='FLRNet Prediction', density=True, color='red')
    plt.xlabel('Field Value')
    plt.ylabel('Density')
    plt.title('Value Distribution Comparison', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Error distribution
    plt.subplot(1, 3, 2)
    errors = np.abs(targets.numpy() - predictions).flatten()
    plt.hist(errors, bins=50, alpha=0.7, color='red', edgecolor='black')
    plt.xlabel('Absolute Error')
    plt.ylabel('Frequency')
    plt.title('Reconstruction Error Distribution', fontweight='bold')
    plt.axvline(np.mean(errors), color='darkred', linestyle='--', linewidth=2, label=f'Mean: {np.mean(errors):.4f}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Relative error distribution
    plt.subplot(1, 3, 3)
    relative_errors = errors / (np.abs(targets.numpy().flatten()) + 1e-8)
    plt.hist(relative_errors, bins=50, alpha=0.7, color='orange', edgecolor='black')
    plt.xlabel('Relative Error')
    plt.ylabel('Frequency')
    plt.title('Relative Error Distribution', fontweight='bold')
    plt.axvline(np.mean(relative_errors), color='darkorange', linestyle='--', linewidth=2, label=f'Mean: {np.mean(relative_errors):.4f}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.suptitle('FLRNet Performance Analysis', y=1.02, fontsize=14, fontweight='bold')
    plt.show()
    
    # Sensor analysis
    print(f"\n📡 Sensor Analysis:")
    print(f"   Number of sensors: {limited_sensor.shape[1]}")
    print(f"   Sensor value range: [{limited_sensor.numpy().min():.4f}, {limited_sensor.numpy().max():.4f}]")
    print(f"   Sensor value std: {limited_sensor.numpy().std():.4f}")
    
    # Show sensor readings distribution
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.boxplot([limited_sensor[i, :].numpy() for i in range(n_samples)], 
                labels=[f'Sample {i+1}' for i in range(n_samples)])
    plt.ylabel('Sensor Reading Value')
    plt.title('Sensor Readings Distribution per Sample', fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    for i in range(n_samples):
        plt.plot(limited_sensor[i, :].numpy(), 'o-', label=f'Sample {i+1}', alpha=0.7)
    plt.xlabel('Sensor Index')
    plt.ylabel('Sensor Reading Value')
    plt.title('Sensor Readings by Position', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.suptitle('Sensor Input Analysis', y=1.02, fontsize=14, fontweight='bold')
    plt.show()
    
    # Reconstruction quality analysis per sample
    print(f"\n🔗 Reconstruction Quality Analysis:")
    sample_mses = []
    sensor_means = []
    sensor_stds = []
    
    for i in range(n_samples):
        sample_error = np.abs(targets[i, :, :, 0] - predictions[i, :, :, 0])
        sample_mse = np.mean(sample_error**2)
        sample_mses.append(sample_mse)
        sensor_means.append(np.mean(limited_sensor[i, :]))
        sensor_stds.append(np.std(limited_sensor[i, :]))
    
    print(f"   Sample MSEs: {[f'{mse:.6f}' for mse in sample_mses]}")
    print(f"   Best reconstruction: Sample {np.argmin(sample_mses) + 1} (MSE: {min(sample_mses):.6f})")
    print(f"   Worst reconstruction: Sample {np.argmax(sample_mses) + 1} (MSE: {max(sample_mses):.6f})")
    
    # Correlation analysis
    correlation_field = np.corrcoef(targets.numpy().flatten(), predictions.numpy().flatten())[0, 1]
    print(f"   Field correlation coefficient: {correlation_field:.4f}")
    
    print(f"\n🎉 FLRNet visualization completed successfully!")
    print(f"   📊 Reconstruction Quality: {'🏆 Excellent' if mse < 0.01 else '✅ Good' if mse < 0.05 else '⚠️ Needs improvement'}")
    print(f"   🌊 Fourier Features: {'✅ Working correctly' if config['use_fourier'] else '➖ Not used'}")
    print(f"   📡 Sensor Count: {limited_sensor.shape[1]} sensors")
    print(f"   🔗 Field Correlation: {correlation_field:.4f}")
    
else:
    print("❌ No FLRNet model available for visualization")
    print("Please run the FLRNet model creation cell above first")