# Cloud Height Measurement System: Technical Overview

## System Architecture

This system combines aerial imagery with LiDAR measurements to create detailed cloud height fields using a deep learning approach. The process consists of four main stages:

1. Data Collection and Preprocessing
   - HD camera images from FEGS (Fly's Eye GLM Simulator)
   - LiDAR measurements from Cloud Physics LiDAR (CPL)
   - Aircraft flight metadata (altitude, speed, orientation)
   - Image synchronization and alignment

2. Motion Analysis and Height Estimation
   - Fine-tuned RAFT (Recurrent All-Pairs Field Transforms) model for optical flow
   - Parallax-based height calculation using motion fields
   - Integration with aircraft metadata for scale calibration
   - Confidence mapping for measurement reliability

3. Height Field Generation and Validation
   - Height field calculation from motion vectors
   - LiDAR measurement integration for calibration
   - Uncertainty estimation and quality control
   - Individual sequence height field generation

4. Height Field Stitching and Global Integration
   - Global coordinate system creation using GPS data
   - Weighted height field merging
   - Multi-sequence confidence integration
   - Large-scale cloud topology reconstruction

## Key Processing Steps

### 1. Data Input Preparation
- Temporal alignment of image sequences (5-second intervals)
- Image preprocessing and augmentation
- Metadata normalization and feature engineering
- Sequence packaging for batch processing

### 2. Motion Field Generation
- RAFT model application to image pairs
- Bidirectional flow calculation
- Motion field refinement and consistency checking
- Confidence map generation

### 3. Height Field Calculation
- Parallax principle application
- Aircraft motion compensation
- Scale factor calibration using LiDAR
- Height field uncertainty estimation

### 4. Global Field Stitching
- GPS-based sequence positioning
- Distance-weighted height field merging
- Confidence-based blending
- Banding artifact removal
- Seamless transition handling

## System Advantages

1. Wide Field Coverage
   - 90-degree field of view compared to single-point LiDAR
   - Continuous spatial coverage
   - Higher spatial resolution
   - Large-scale cloud field reconstruction through stitching

2. Cost Efficiency
   - Uses existing camera hardware
   - Reduces dependency on expensive LiDAR systems
   - Simplified instrument package
   - Maximizes value from existing flight data

3. Measurement Quality
   - LiDAR-calibrated accuracy
   - Uncertainty quantification
   - Real-time quality assessment
   - Multi-sequence validation
   - Robust to measurement gaps

4. Comprehensive Cloud Mapping
   - Seamless integration of multiple sequences
   - Preservation of local detail
   - Global context maintenance
   - Confidence-weighted merging
   - Artifact-free reconstruction

This system bridges the gap between limited single-point LiDAR measurements and the need for cloud height fields, enabling more accurate atmospheric observations and improved weather predictions. The stitching capability allows for reconstruction of large-scale cloud fields, providing coverage and detail in cloud height mapping from aerial platforms.


In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
!pip install azure-storage-blob azure-identity --quiet
!pip install transformers

# Imports

In [None]:
import base64
import cv2
import io
import json
import math
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import random
import seaborn as sns
import time
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.colors
import warnings
from collections import defaultdict
from datetime import datetime
from io import BytesIO
from IPython.display import clear_output, display, HTML
from PIL import Image, ImageEnhance, ImageOps
from azure.storage.blob import BlobServiceClient
from google.colab import userdata
from google.colab import runtime
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm, gridspec, animation
from mpl_toolkits.mplot3d import Axes3D
from plotly.subplots import make_subplots
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from timm import create_model
from torchvision import transforms
from torch.amp import autocast, GradScaler
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, OneCycleLR, ReduceLROnPlateau
from torch.serialization import add_safe_globals
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import PackedSequence, pack_sequence, pad_packed_sequence, pack_padded_sequence
from torchvision.models.optical_flow import raft_large
from torchvision.models.optical_flow import Raft_Large_Weights
from tqdm.notebook import tqdm_notebook as tqdm
from transformers import AutoImageProcessor, ConvNextModel, ConvNextV2Model, TimesformerModel, TimesformerConfig, SwinModel
from typing import Tuple, Dict, Optional

def show_warnings(func):
    def wrapper(*args, **kwargs):
        with warnings.catch_warnings(record=True) as caught_warnings:
            result = func(*args, **kwargs)
            for warning in caught_warnings:
                print(f"Warning: {warning.message}")
                print(f"  In file: {warning.filename}")
                print(f"  Line number: {warning.lineno}")
                print("--------------------")
        return result
    return wrapper

# Notebook parameters

In [None]:
# Memory
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Number of epochs for training
num_epochs = 2

# If true, train model from scratch. If false, load a previously trained model
train_model = True

# If true, use a previously generated dataset from the image files. If false, regenerate it.
# Note that regenerating the dataset takes a bit of time.
cache_cloud_dataset = False

# Available flight data dates. Modify the train_dates variable below as needed
# "20170418", "20170422", "20170508", "20170512”, "20170512"
train_dates = ["20170418"]

In [None]:
!nvidia-smi

# DebugControl Class

In [None]:
class DebugControl:
    """
    Controls debug output for different components of the cloud height measurement system.

    This class provides granular control over debug messages for different system components,
    allowing selective enabling/disabling of debug output for specific processing stages.

    Attributes:
        enabled (bool): Master switch for all debug output.
        components (dict): Dictionary of component-specific debug flags:
            - training: Training loop messages
            - optical_flow: Optical flow computation messages
            - loss: Loss computation messages
            - model: Model forward pass messages
            - shapes: Tensor shape debug messages
            - memory: Memory usage debug messages
            - heights: Height calculation debug messages
            - motion: Motion field debug messages
            - bias: Bias calculation debug messages
            - confidence: Confidence calculation messages
            - finetune: Fine-tuning process messages

    Methods:
        enable_all(): Enables debug output for all components.
        disable_all(): Disables debug output for all components.
        enable_only(*components): Enables debug output only for specified components.
        debug_print(component, *args, **kwargs): Prints debug message if component is enabled.

    Example:
        >>> debug = DebugControl()
        >>> debug.enable_only('memory', 'shapes')
        >>> debug.debug_print('memory', 'Current GPU memory usage: 2GB')
        Current GPU memory usage: 2GB
        >>> debug.debug_print('loss', 'Loss: 0.5')  # Won't print (component disabled)
    """
    def __init__(self):
        self.enabled = False
        self.components = {
            'training': False,      # Training loop messages
            'optical_flow': False,  # Optical flow computation messages
            'loss': False,         # Loss computation messages
            'model': False,        # Model forward pass messages
            'shapes': False,        # Tensor shape debug messages
            'memory': False,        # Memory usage debug messages
            'heights': False,       # height debug messages
            'motion': False,        # motion debug messages
            'bias': False,          # bias debug messages
            'confidence': False,     # confidence debug messages
            'finetune': False       # finetune debug messages
        }

    def enable_all(self):
        self.enabled = True
        for key in self.components:
            self.components[key] = True

    def disable_all(self):
        self.enabled = False
        for key in self.components:
            self.components[key] = False

    def enable_only(self, *components):
        self.enabled = True
        for key in self.components:
            self.components[key] = key in components

    def debug_print(self, component, *args, **kwargs):
        if self.enabled and self.components.get(component, False):
            print(*args, **kwargs)

# Create global debug controller
debug = DebugControl()

# Disable all debug output
debug.disable_all()

# Example usage

# Enable all debug output
# debug.enable_all()

# Enable only specific components
# debug.enable_only('memory', 'shapes')

In [None]:
def embed_matplotlib_jpeg(fig, quality=50, dpi=100, width="100%"):
    """
    Save a Matplotlib figure as a compressed JPEG and embed it in the notebook.
    Since this notebook is heavily image-based, this is to reduce the overall size.
    Parameters:
        fig (matplotlib.figure.Figure): The Matplotlib figure to embed.
        quality (int): Compression quality for the JPEG (1-100, higher is better quality).
        dpi (int): Dots per inch (resolution) for the figure.
        width (str): Width of the image in the notebook (e.g., "100%", "600px").
    Returns:
        HTML object: Embedded image in HTML format.
    """
    # Save the figure to a BytesIO buffer as a raw PNG
    buffer = io.BytesIO()
    fig.savefig(buffer, format='png', dpi=dpi, bbox_inches='tight')
    buffer.seek(0)

    # Open the PNG image with Pillow
    image = Image.open(buffer)

    # Convert the image to RGB if it has an alpha channel (RGBA)
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    # Save the image as a compressed JPEG in a new buffer
    jpeg_buffer = io.BytesIO()
    image.save(jpeg_buffer, format='JPEG', quality=quality)
    jpeg_buffer.seek(0)

    # Encode the compressed JPEG as base64
    encoded = base64.b64encode(jpeg_buffer.read()).decode('utf-8')
    jpeg_buffer.close()

    # Return an HTML <img> tag with a specified width
    return HTML(f'<img src="data:image/jpeg;base64,{encoded}" style="width:{width};" />')

# RAFT Architecture and Memory Usage

In this section we create a number of helper functions to monitor and report memory usage.

RAFT (Recurrent All-Pairs Field Transforms) is an optical flow model that works by building dense correlations between all pixels in a pair of images and iteratively refining flow predictions. It uses a surprising a significant amount of memory.

## Core Components

## Feature Extraction
The first step in RAFT is feature extraction, where each input image is processed through a convolutional encoder network. This encoder transforms each pixel into a rich feature representation containing hundreds of channels of information. Rather than producing a single feature map, RAFT creates a feature pyramid with multiple resolution levels. This pyramid structure allows the model to capture both fine details and broader image context, but it means storing multiple feature maps of varying sizes for each image.

## All-Pairs Correlation
The correlation computation is where RAFT's memory usage truly explodes. For every single pixel in the first image, RAFT computes correlation scores with every possible pixel in the second image. This creates an enormous 4D correlation volume of size H×W×H×W, where H and W are the image height and width. To put this in perspective, for our 384×384 pixel images, the correlation volume alone requires 384⁴ elements. Even using 4-byte floats, that's approximately 55GB of memory needed just to store a single pair of images' correlation volume! This quadratic scaling with image size is the primary reason for RAFT's intense memory requirements.

## Iterative Refinement
The final major component is RAFT's iterative refinement mechanism, which uses a Gated Recurrent Unit (GRU) to progressively update and refine flow estimates. As the GRU processes each pixel's motion, it looks up relevant correlation features based on the current estimated position. This process typically runs for 12-32 iterations, maintaining hidden states throughout. While not as memory-intensive as the correlation volume, these iterations require storing both the hidden states and intermediate results for backpropagation during training.

## Why Sequences Are Memory-Intensive

When processing sequences, memory usage multiplies because:
- Need correlation volumes for each consecutive pair of frames
- Intermediate activations are stored for backpropagation
- GRU hidden states maintained across sequence
- Each additional frame essentially adds another full RAFT computation

For a sequence of length T:
```
Memory ∝ (T-1) × (H²W²)
```

The quadratic scaling with image size (H²W²) makes this particularly challenging. Even with optimizations like limiting correlation range or using lower resolutions, the memory requirements grow rapidly with sequence length.

This is why RAFT, despite having a relatively modest number of parameters (~5-6M), can easily consume gigabytes of GPU memory during training, especially with longer sequences. Each additional frame in the sequence means another massive correlation volume must be computed and stored for backpropagation.

In [None]:
def print_gpu_memory(location=""):
    """Print GPU memory usage with location tag"""
    if torch.cuda.is_available():
        debug.debug_print('memory', f"\nMemory at {location}:")
        debug.debug_print('memory', f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
        debug.debug_print('memory', f"Cached: {torch.cuda.memory_reserved()/1e9:.2f} GB\n")

# Example
# >>>print_gpu_memory("Start of finetune_raft")

def format_memory_string(bytes_value):
    """Convert bytes to human readable string"""
    for unit in ['B', 'KB', 'MB', 'GB']:
        if bytes_value < 1024:
            return f"{bytes_value:.2f} {unit}"
        bytes_value /= 1024
    return f"{bytes_value:.2f} TB"

def get_detailed_memory_info():
    """Get detailed GPU memory information"""
    if not torch.cuda.is_available():
        return "GPU not available"

    info = []
    info.append(f"Total GPU memory: {format_memory_string(torch.cuda.get_device_properties(0).total_memory)}")
    info.append(f"Allocated memory: {format_memory_string(torch.cuda.memory_allocated())}")
    info.append(f"Reserved memory: {format_memory_string(torch.cuda.memory_reserved())}")
    info.append(f"Max allocated: {format_memory_string(torch.cuda.max_memory_allocated())}")
    return "\n".join(info)

def get_model_size(model):
    """Calculate total parameters and their memory footprint"""
    total_params = sum(p.numel() for p in model.parameters())
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())

    return {
        'total_params': total_params,
        'trainable_params': total_trainable_params,
        'param_memory': format_memory_string(param_size),
        'buffer_memory': format_memory_string(buffer_size),
        'total_memory': format_memory_string(param_size + buffer_size)
    }

def print_oom_report(batch_idx, images, sequence_lengths, model):
    """Print detailed OOM error report"""
    print("\n" + "="*50)
    print(f"OUT OF MEMORY ERROR - Batch {batch_idx}")
    print("="*50)

    print("\nBATCH INFORMATION:")
    print(f"Batch size: {images.shape[0]}")
    print(f"Max sequence length: {images.shape[1]}")
    print(f"Image dimensions: {images.shape[2:]} (C,H,W)")
    print(f"Actual sequence lengths: {sequence_lengths.tolist()}")
    print(f"Total elements in batch: {images.numel():,}")
    print(f"Batch memory: {format_memory_string(images.numel() * images.element_size())}")

    print("\nMEMORY STATE:")
    print(get_detailed_memory_info())

    print("\nMODEL INFORMATION:")
    model_info = get_model_size(model)
    print(f"Total parameters: {model_info['total_params']:,}")
    print(f"Trainable parameters: {model_info['trainable_params']:,}")
    print(f"Parameter memory: {model_info['param_memory']}")
    print(f"Buffer memory: {model_info['buffer_memory']}")
    print(f"Total model memory: {model_info['total_memory']}")

    print("\nRECOMMENDATIONS:")
    if images.shape[0] > 1:
        print("- Consider reducing batch size")
    if max(sequence_lengths) > 5:
        print("- Consider truncating sequence lengths")
    print("- Try enabling gradient checkpointing")
    print("- Consider using mixed precision training")
    print("="*50 + "\n")

In [None]:
# Initialize GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Authentication details
account_name = userdata.get('storage_account_name')
account_key = userdata.get('storage_account_key')
container_name = userdata.get('blob_container_name')

# Connection string to Azure Blob Storage
connection_string = f"DefaultEndpointsProtocol=https;AccountName={account_name};AccountKey={account_key};EndpointSuffix=core.windows.net"

# Setup to load file from blob
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_client = blob_service_client.get_container_client(container_name)

# Available Aircraft Metadata

## Temporal and Position Data
- **DateTime_UTC**: ISO-8601 formatted date and time in UTC (yyyy-mm-ddThh:mm:ss)
- **Lat**: Platform Latitude (degrees N, -90 to 90)
- **Lon**: Platform Longitude (degrees E, -180 to 179.9999)

## Altitude Measurements
- **GPS_MSL_Alt**: GPS Altitude above Mean Sea Level (meters)
- **WGS_84_Alt**: WGS 84 Geoid Altitude (meters)
- **Press_Alt**: Pressure Altitude (feet)

## Velocity Parameters
- **Grnd_Spd**: Ground Speed (m/s)
- **True_Airspeed**: True Airspeed (m/s)
- **Mach_Number**: Aircraft Mach Number (dimensionless)
- **Vert_Velocity**: Aircraft Vertical Velocity (m/s, negative=downward, positive=upward)

## Aircraft Orientation
- **True_Hdg**: True Heading (degrees true, 0 to 359.9999)
- **Track**: Track Angle (degrees true, 0 to 359.9999)
- **Drift**: Drift Angle (degrees)
- **Pitch**: Pitch (degrees, -90 to 90, negative=nose down, positive=nose up)
- **Roll**: Roll (degrees, -90 to 90, negative=left wing down, positive=right wing down)

## Environmental Conditions
- **Ambient_Temp**: Ambient Temperature (°C)
- **Total_Temp**: Total Temperature (°C)
- **Static_Press**: Static Pressure (millibars)
- **Dynamic_Press**: Dynamic Pressure - total minus static (millibars)
- **Cabin_Pressure**: Cabin Pressure/Altitude (millibars)

## Wind Data
- **Wind_Speed**: Wind Speed (m/s, ≥0)
- **Wind_Dir**: Wind Direction (degrees true, 0 to 359.9999)

## Solar Position
- **Solar_Zenith**: Solar Zenith Angle (degrees)
- **Sun_Elev_AC**: Sun Elevation from Aircraft (degrees)
- **Sun_Az_Grd**: Sun Azimuth from Ground (degrees true, 0 to 359.9999)
- **Sun_Az_AC**: Sun Azimuth from Aircraft (degrees true, 0 to 359.9999)



In [None]:
aircraft_metadata_params = [
    'DateTime_UTC', 'Lat', 'Lon', 'GPS_MSL_Alt', 'WGS_84_Alt', 'Press_Alt',
    'Grnd_Spd', 'True_Airspeed', 'Mach_Number', 'Vert_Velocity', 'True_Hdg',
    'Track', 'Drift', 'Pitch', 'Roll', 'Ambient_Temp', 'Total_Temp',
    'Static_Press', 'Dynamic_Press', 'Cabin_Pressure', 'Wind_Speed',
    'Wind_Dir', 'Solar_Zenith', 'Sun_Elev_AC', 'Sun_Az_Grd', 'Sun_Az_AC'
]

CTH_col = 'top_height'

# Aircraft Metadata
def load_metadata(blob_name):
    blob_client = container_client.get_blob_client(blob_name)
    streamdownloader = blob_client.download_blob()
    metadata_df = pd.read_csv(io.BytesIO(streamdownloader.readall()))
    return metadata_df

# LiDAR Validation Heights
def load_validation_heights(blob_name):
    blob_client = container_client.get_blob_client(blob_name)
    streamdownloader = blob_client.download_blob()
    validation_df = pd.read_csv(io.BytesIO(streamdownloader.readall()))
    return validation_df

# CloudDataset Class

In [None]:
# CloudDataset classes integrating all 3 data sources: FEGS Images, Aircraft Metadata and LiDAR Validation Heights with temporal alignment
class CloudDataset(Dataset):
    def __init__(self, date_folders, transform=None, drop_after_last_validation=True):
        self.date_folders = date_folders
        self.transform = transform
        self.drop_after_last_validation = drop_after_last_validation
        self.data_df = self._prepare_dataframe()

    def _prepare_dataframe(self):
        """
        Iterates over the date folders in azure blob storage and loads:
          1. .jpg Images from each sub-directory in the folder with '_crop_corrected_aligned' in the name.
          2. Aircraft Metadata with 1-1 time alignment with the images.
          3. LiDAR Validation Heights, mapped using timestamp, if not available filled with NaN.
        Creates a df with following columns:
            timestamp, image_path, [...aircraft_metadata_params...], validation_height
        """
        image_paths, timestamps, metadata_rows, validation_heights = [], [], [], []

        for folder in self.date_folders:
            print(f"Processing folder: {folder}")
            folder_image_paths, folder_timestamps, folder_metadata_rows, folder_validation_heights = [], [], [], []

            blob_list = container_client.list_blobs(name_starts_with=folder)
            metadata_path, validation_path = None, None
            for blob in blob_list:
                # extract image paths of all .jpg images in cropped folders
                if blob.name.endswith(".jpg") and "_crop_corrected_aligned" in blob.name:
                    folder_image_paths.append(blob.name)
                    folder_timestamps.append(self._extract_timestamp_from_filename(blob.name))
                # extract the aircraft metadata file path
                if blob.name.startswith(f"{folder}/IWG1.") and "processed" in blob.name:
                    metadata_path = blob.name
                # extract the LiDAR validation file path
                if blob.name.startswith(f"{folder}/goesrplt_CPL_layers_") and blob.name.endswith("_processed.txt"):
                    validation_path = blob.name

            # load aircraft metadata and LiDAR validation data
            if metadata_path:
                metadata_df = load_metadata(metadata_path)
            if validation_path:
                validation_df = load_validation_heights(validation_path)

            # prepare LiDAR validation data
            validation_df['datetime_combined'] = validation_df['date'] + ' ' + validation_df['timestamp']
            validation_df['datetime_combined'] = validation_df['datetime_combined'].str.split('.').str[0]
            validation_df['datetime_combined'] = pd.to_datetime(validation_df['datetime_combined'], format="%Y-%m-%d %H:%M:%S")

            # prepare aircraft metadata
            metadata_df = metadata_df[aircraft_metadata_params]
            metadata_df['DateTime_UTC'] = metadata_df['DateTime_UTC'].str.split('.').str[0]
            metadata_df = self._extract_time_features(metadata_df)  # Add hour_of_day and day_of_year

            metadata_timestamps = pd.to_datetime(metadata_df['DateTime_UTC'], format="%Y-%m-%d %H:%M:%S")
            metadata_df = metadata_df.set_index(metadata_timestamps)
            aligned_metadata = pd.DataFrame(index=pd.to_datetime(folder_timestamps, format="%H:%M:%S"))
            aligned_metadata = aligned_metadata.join(metadata_df, how='left')
            aligned_metadata = aligned_metadata[aircraft_metadata_params + ['hour_of_day', 'day_of_year']]

            folder_metadata_rows.extend(aligned_metadata.values.tolist())

            # extract LiDAR validation height exactly matching the timestamp where available, else NaN
            for ts in folder_timestamps:
                cth = self._map_timestamp_to_lidar(ts, validation_df)
                folder_validation_heights.append(cth)

            # Create a folder-level DataFrame
            folder_data = {
                'timestamp': folder_timestamps,
                'image_path': folder_image_paths,
                **{param: [row[i] for row in folder_metadata_rows] for i, param in enumerate(aircraft_metadata_params + ['hour_of_day', 'day_of_year'])},
                'validation_height': folder_validation_heights
            }
            folder_df = pd.DataFrame(folder_data)

            # Conditionally remove rows after the last valid validation_height in this folder
            if self.drop_after_last_validation:
                last_valid_index = folder_df['validation_height'].last_valid_index()
                if last_valid_index is not None:
                    folder_df_cleaned = folder_df.loc[:last_valid_index].copy()  # Use .copy() to ensure independence
                else:
                    folder_df_cleaned = folder_df.copy()  # In case there are no valid entries
            else:
                folder_df_cleaned = folder_df  # Keep all rows if dropping is disabled

            # Extend to the global lists
            image_paths.extend(folder_df_cleaned['image_path'].tolist())
            timestamps.extend(folder_df_cleaned['timestamp'].tolist())
            metadata_rows.extend(folder_df_cleaned[aircraft_metadata_params + ['hour_of_day', 'day_of_year']].values.tolist())
            validation_heights.extend(folder_df_cleaned['validation_height'].tolist())

            # Print the lengths for the current folder
            print(f"Folder {folder}:")
            print(f"  Number of images: {len(folder_df_cleaned['image_path'])}")
            print(f"  Number of timestamps: {len(folder_df_cleaned['timestamp'])}")
            print(f"  Number of metadata rows: {len(folder_df_cleaned)}")
            print(f"  Number of validation heights: {len(folder_df_cleaned['validation_height'])}")

        # Print the final lengths after processing all folders
        print("After processing all folders combined:")
        print(f"  Total number of images: {len(image_paths)}")
        print(f"  Total number of timestamps: {len(timestamps)}")
        print(f"  Total number of metadata rows: {len(metadata_rows)}")
        print(f"  Total number of validation heights: {len(validation_heights)}")

        # Check for any mismatches
        if not (len(image_paths) == len(timestamps) == len(metadata_rows) == len(validation_heights)):
            print("Error: Length mismatch detected!")
            print(f"  Images: {len(image_paths)}")
            print(f"  Timestamps: {len(timestamps)}")
            print(f"  Metadata rows: {len(metadata_rows)}")
            print(f"  Validation heights: {len(validation_heights)}")
            return None

        # combine all aligned data in a df
        data = {
            'timestamp': timestamps,
            'image_path': image_paths,
            **{param: [row[i] for row in metadata_rows] for i, param in enumerate(aircraft_metadata_params + ['hour_of_day', 'day_of_year'])},
            'validation_height': validation_heights
        }
        df = pd.DataFrame(data)
        df = df.drop(columns=['DateTime_UTC'])

        # Print columns with NaN values before interpolation
        self._print_columns_with_nan(df)

        # Interpolate missing values, excluding 'validation_height'
        df = self._interpolate_missing_values(df)

        # Add sequence length information for RNN
        self._add_sequence_length_column(df)
        return df

    def _extract_time_features(self, df):
        """
        Extracts hour of day and day of year from the DateTime_UTC column.
        Adds 'hour_of_day' (with fractional hour) and 'day_of_year' as new columns in the DataFrame.
        """
        df['DateTime_UTC'] = pd.to_datetime(df['DateTime_UTC'], format="%Y-%m-%d %H:%M:%S")
        df['hour_of_day'] = df['DateTime_UTC'].dt.hour + df['DateTime_UTC'].dt.minute / 60
        df['day_of_year'] = df['DateTime_UTC'].dt.dayofyear
        return df

    def _extract_timestamp_from_filename(self, filename):
        """
        Extracts the timestamp from the image filename on the blob.
        path/to/blob/YYYYMMDD_HHMMSS_frame_n_cropped.jpg -> %Y%m%d%H%M%S
        """
        filename = os.path.basename(filename)
        date_str = filename.split("_")[0]
        time_str = filename.split("_")[1]
        timestamp = datetime.strptime(date_str + time_str, "%Y%m%d%H%M%S")
        return timestamp


    def _map_timestamp_to_lidar(self, timestamp, validation_df):
        """
        extract LiDAR validation height exactly matching the timestamp where available, else NaN
        """
        validation_df['datetime_combined'] = pd.to_datetime(validation_df['datetime_combined'], format="%Y-%m-%d %H:%M:%S")
        timestamp_dt = pd.to_datetime(timestamp, format="%Y-%m-%d %H:%M:%S")
        exact_match = validation_df[validation_df['datetime_combined'] == timestamp_dt]
        return exact_match[CTH_col].values[0] if not exact_match.empty else np.nan


    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        """
        Retrieve a data record from the dataset for a given index.
        Returns loaded and transformed image, metadata and validation height associated with that image if available.
        """
        row = self.data_df.iloc[idx]

        image_path = row['image_path']
        blob_client = container_client.get_blob_client(image_path)
        streamdownloader = blob_client.download_blob()
        img_data = streamdownloader.readall()
        img = Image.open(io.BytesIO(img_data)).convert("RGB")
        if self.transform:
            img = self.transform(img)

        metadata = torch.tensor([row[param] for param in aircraft_metadata_params], dtype=torch.float32)
        validation_height = torch.tensor([row['validation_height']], dtype=torch.float32)

        return img, metadata, validation_height

    def _add_sequence_length_column(self, df):
        # Initialize a new column to NaN
        df['sequence_length'] = np.nan

        # Track the start of each sequence
        sequence_start = 0

        # Iterate through the dataframe to detect when a validation_height exists
        for i in range(len(df)):
            if not pd.isna(df.loc[i, 'validation_height']):
                # We found the end of a sequence, so mark the previous sequence images
                sequence_length = i - sequence_start
                df.loc[sequence_start:i, 'sequence_length'] = sequence_length + 1  # Using count (1-based index)
                sequence_start = i + 1  # Move the start to the next sequence

        # Ensure all sequence_length values are integers (if any were missed)
        df['sequence_length'] = df['sequence_length'].astype(int)

        return df

    def _print_columns_with_nan(self, df):
        """
        Prints the names of columns in the DataFrame that contain NaN values.
        """
        columns_with_nan = df.columns[df.isna().any()].tolist()
        if columns_with_nan:
            print("Columns with NaN values:")
            for col in columns_with_nan:
                print(col)
        else:
            print("No columns with NaN values.")

    def _interpolate_missing_values(self, df, exclude_columns=['validation_height', 'timestamp', 'image_path']):
        """
        Interpolates missing values in the DataFrame for all columns except specified ones.
        """
        # Create a copy to avoid modifying the original DataFrame
        df = df.copy()

        # Store excluded columns
        excluded_data = {col: df[col].copy() for col in exclude_columns if col in df.columns}

        # Convert all object-type columns in df to numeric where possible
        df = df.infer_objects()

        # Get columns for interpolation (excluding specified columns)
        columns_to_interpolate = [col for col in df.columns if col not in exclude_columns]

        # Convert columns to numeric where possible
        for col in columns_to_interpolate:
            try:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            except (ValueError, TypeError):
                continue

        # Perform interpolation only on numeric columns
        numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
        if not numeric_columns.empty:
            df[numeric_columns] = df[numeric_columns].interpolate(
                method='linear',
                axis=0,
                limit_direction='both'
            )

        # Restore excluded columns
        for col, data in excluded_data.items():
            df[col] = data

        return df

# Create the CloudDataset or retrieve it from cache

In [None]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
])

augmentations = transforms.Compose([
])

# Change the cache_cloud_dataset variable at the start to make this either
# generate the CloudDataset (takes some time) or use a cached version on Google Drive
if cache_cloud_dataset:
    full_dataframe = pd.read_csv('/content/drive/My Drive/datasets/train_dataset.csv')
else:
    # Gather Data Dynamically from Azure
    # Create the full CloudDataset without splitting initially
    full_dataset = CloudDataset(train_dates, transform=None)
    # Extract the full dataframe from CloudDataset
    full_dataframe = full_dataset.data_df
    # Ensure folder exists
    os.makedirs('/content/drive/My Drive/datasets', exist_ok=True)
    # Save the dataframe directly to Google Drive
    full_dataframe.to_csv('/content/drive/My Drive/datasets/train_dataset.csv', index=False, header=True)

In [None]:
full_dataframe.to_csv('train_dataset.csv', index=False, header=True)

In [None]:
full_dataframe.head(20)

In [None]:
full_dataframe.tail(20)

# Random Scaling and Cropping Strategy

## Motivation
The goal behind this data augmentation strategy is to help the model become more robust at estimating cloud heights when the LiDAR measurement point isn't always in the center of the image. This is important because:

1. **Original Data Limitation**
   - LiDAR measurements are always taken at the exact center of each image
   - This could cause the model to overfit to center-focused features
   - Real-world applications may need height estimates from different parts of the image

2. **Spatial Understanding**
   - Want the model to understand cloud height features regardless of their position
   - Need to learn relationships between cloud features across the entire image
   - Important for generalizing to different viewing angles and positions

## Implementation Strategy

### 1. Random Cropping (random_crop_sequence)
$$
\begin{align*}
& \textbf{Algorithm: Random Crop Sequence} \\
& \textbf{Input: } \text{image_sequence, min_crop_size = 384} \\
& \textbf{Output: } \text{cropped_sequence, new_center_coords} \\
& \hline \\
& \text{orig_height, orig_width} \gets \text{image_size} \\
& \text{crop_size} \gets \text{random_int(min_crop_size, min(orig_height, orig_width))} \\
& \text{top} \gets \text{random_int(0, orig_height - crop_size)} \\
& \text{left} \gets \text{random_int(0, orig_width - crop_size)} \\
& \text{orig_center} \gets \text{(orig_width}/2\text{, orig_height}/2\text{)} \\
& \text{new_center} \gets \text{(orig_center.x - left, orig_center.y - top)} \\
& \text{cropped_sequence} \gets \text{crop_images(sequence, top, left, crop_size)}
\end{align*}
$$

Key Features:
- Random crop size between minimum (384) and image dimension
- Random position for the crop window
- Maintains aspect ratio (square crop)
- Original center point is transformed to new coordinates

### 2. Resizing (resize_sequence_and_adjust_center)
$$
\begin{align*}
& \textbf{Algorithm: Resize and Adjust Center} \\
& \textbf{Input: } \text{cropped_sequence, center_coords, target_size = 384} \\
& \textbf{Output: } \text{resized_sequence, adjusted_center_coords} \\
& \hline \\
& \text{scale_factor} \gets \text{target_size}/\text{crop_size} \\
& \text{new_center.x} \gets \text{center_coords.x} \times \text{scale_factor} \\
& \text{new_center.y} \gets \text{center_coords.y} \times \text{scale_factor} \\
& \text{resized_sequence} \gets \text{resize_images(cropped_sequence, target_size)}
\end{align*}
$$

## Benefits of This Approach

1. **Data Augmentation**
   - Effectively multiplies training data
   - Creates variations in LiDAR measurement position
   - Introduces scale variation while preserving aspect ratios

2. **Model Robustness**
   - Forces model to look at cloud features everywhere in image
   - Prevents overfitting to center-specific patterns
   - Better generalization to different viewing conditions

3. **Training Stability**
   - Consistent final image size (384x384)
   - Maintains original aspect ratios
   - Preserves relative spatial relationships

4. **Validation Strategy**
   - During validation, no random cropping is applied
   - Center point remains fixed for validation
   - Allows fair comparison with ground truth

# Random Scaling and Cropping Example

Original Image:
- Size: 800 x 800
- LiDAR measurement point: (400, 400) [exact center]

Let's say our random crop generates:
- Crop size: 500 x 500 (randomly chosen between 384 and 800)
- Top-left position: (220, 150) [randomly chosen to fit the 500x500 crop]

Step 1: Crop Transformation
```
Original center: (400, 400)
Crop offset: (220, 150)
New center = Original center - Crop offset
New center = (400-220, 400-150) = (180, 250)
```
So in our 500x500 cropped image, the LiDAR point is now at (180, 250) - notably off-center

Step 2: Resize to 384x384
```
Scale factor = 384/500 = 0.768
Final center = (180 * 0.768, 250 * 0.768) = (138, 192)
```

Final Result:
- Image size: 384 x 384
- LiDAR point: (138, 192)
- Original center point has moved significantly off-center
  - Horizontally: shifted from center by 54 pixels (384/2 - 138 = 54)
  - Vertically: shifted from center by 0 pixels (384/2 - 192 = 0)

This asymmetric transformation better illustrates how the cropping strategy creates variations in the LiDAR measurement position, forcing the model to learn to estimate cloud heights from features across the entire image rather than just the center region.

Here's a visual example of the process:

In [1]:
%%html
<svg width="1000" height="400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1000 400">
    <!-- Original 800x800 Image -->
    <rect x="20" y="20" width="320" height="320" fill="none" stroke="#666666" stroke-width="2"/>
    <text x="20" y="15" font-family="Arial" font-size="14">Original Image (800 x 800)</text>
    <circle cx="180" cy="180" r="4" fill="red"/>
    <text x="140" y="150" font-family="Arial" font-size="12" fill="red">LiDAR Point (400, 400)</text>

    <!-- Crop Box -->
    <rect x="108" y="80" width="200" height="200" fill="rgba(0, 102, 204, 0.1)" stroke="#0066cc" stroke-width="2"/>
    <text x="108" y="75" font-family="Arial" font-size="14" fill="#0066cc">Crop Region (500 x 500)</text>
    <text x="108" y="300" font-family="Arial" font-size="12" fill="#0066cc">Starting at (220, 150)</text>

    <!-- Arrow -->
    <path d="M 370 180 L 430 180" stroke="black" stroke-width="2" marker-end="url(#arrowhead)"/>

    <!-- Cropped Image -->
    <rect x="460" y="20" width="200" height="200" fill="none" stroke="#0066cc" stroke-width="2"/>
    <text x="460" y="15" font-family="Arial" font-size="14">Cropped Image (500 x 500)</text>
    <circle cx="532" cy="130" r="4" fill="red"/>
    <text x="492" y="100" font-family="Arial" font-size="12" fill="red">LiDAR Point (180, 250)</text>

    <!-- Arrow -->
    <path d="M 690 180 L 750 180" stroke="black" stroke-width="2" marker-end="url(#arrowhead)"/>

    <!-- Final Resized Image -->
    <rect x="780" y="58" width="154" height="154" fill="none" stroke="#009933" stroke-width="2"/>
    <text x="780" y="53" font-family="Arial" font-size="14">Resized (384 x 384)</text>
    <circle cx="836" cy="147" r="4" fill="red"/>
    <text x="796" y="117" font-family="Arial" font-size="12" fill="red">LiDAR Point (138, 192)</text>

    <!-- Arrowhead definition -->
    <defs>
        <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
            <polygon points="0 0, 10 3.5, 0 7" fill="black"/>
        </marker>
    </defs>

    <!-- Grid overlay on middle image to show scale -->
    <g stroke="rgba(128,128,128,0.2)">
        <line x1="460" y1="70" x2="660" y2="70"/>
        <line x1="460" y1="120" x2="660" y2="120"/>
        <line x1="460" y1="170" x2="660" y2="170"/>

        <line x1="510" y1="20" x2="510" y2="220"/>
        <line x1="560" y1="20" x2="560" y2="220"/>
        <line x1="610" y1="20" x2="610" y2="220"/>
    </g>
</svg>

# Image loading and augmentation

In [None]:
def load_image_from_blob_cv(blob_img):
    """
    Loads the image from the Azure Blob Storage using OpenCV and returns it as a numpy array.
    Args:
        blob_img (str): name of the blob image in the container
    Returns:
        (numpy.ndarray): loaded image in greyscale
    """
    blob_client = container_client.get_blob_client(blob_img)
    streamdownloader = blob_client.download_blob()
    blob_data = streamdownloader.readall()
    image_array = np.asarray(bytearray(blob_data), dtype=np.uint8)
    img_bgr = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return img_rgb

def augment_color_image(img, contrast_factor=1.2, brightness_beta=10, kernel_size=(5, 5), blur=True):
    """
    Augments the color image to enhance cloud features. Includes:
      1. Contrast and Brightness Adjustment
      2. Gaussian Blur (optional)
      3. CLAHE (Contrast Limited Adaptive Histogram Equalization)
      4. Sharpening (optional)

    Args:
        img (numpy.ndarray): color image in RGB format
        contrast_factor (float): contrast factor
        brightness_beta (int): brightness factor
        kernel_size (tuple): kernel size for Gaussian Blur
        blur (bool): whether to apply Gaussian Blur

    Returns:
        (numpy.ndarray): augmented image
    """
    # Step 1: Adjust contrast and brightness
    img_enhanced = cv2.convertScaleAbs(img, alpha=contrast_factor, beta=brightness_beta)

    # Step 2: Optionally apply Gaussian Blur
    if blur:
        img_enhanced = cv2.GaussianBlur(img_enhanced, kernel_size, 0)

    # Step 3: Convert to LAB color space to apply CLAHE on the L channel
    img_lab = cv2.cvtColor(img_enhanced, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(img_lab)

    # Step 4: Apply CLAHE (Clip Limit Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(4, 4))
    cl = clahe.apply(l)

    # Step 5: Merge CLAHE-enhanced L channel back with A and B channels
    img_clahe = cv2.merge((cl, a, b))

    # Step 6: Convert back to RGB color space
    img_clahe_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

    # Step 7: Optionally apply sharpening
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    img_sharpened = cv2.filter2D(img_clahe_rgb, -1, kernel)

    return img_sharpened

# Modify augment_color_image to use the original CLAHE-only method
def augment_color_image_clahe_only(img, clip_limit=1.5, tile_grid_size=(4, 4)):
    """
    Augments the color image by applying CLAHE (Contrast Limited Adaptive Histogram Equalization) only.

    Args:
        img (numpy.ndarray): color image in RGB format
        clip_limit (float): clip limit for CLAHE (higher values give more contrast)
        tile_grid_size (tuple): size of the grid for applying CLAHE

    Returns:
        (numpy.ndarray): augmented image with CLAHE applied to the L channel in LAB color space
    """
    # Convert to LAB color space to apply CLAHE on the L channel (lightness)
    img_lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(img_lab)

    # Apply CLAHE to the L channel
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    cl = clahe.apply(l)

    # Merge CLAHE-enhanced L channel back with A and B channels
    img_clahe = cv2.merge((cl, a, b))

    # Convert back to RGB color space
    img_clahe_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

    return img_clahe_rgb

def augment_greyscale_image(img, contrast_factor=1.5, brightness_beta=30, kernel_size=(5, 5), blur=True):
    """
    Augments the Greyscale image to enhance the feature of the cloud. Includes:
      1. Contrast and Brightness
      2. Gaussian Blur
      3. Histogram Equalization
      4. Sharpening
    Args:
        img (numpy.ndarray): greyscale image
        contrast_factor (float): contrast factor
        brightness_beta (int): brightness factor
        kernel_size (tuple): kernel size for Gaussian Blur
        blur (bool): whether to apply Gaussian Blur
    Returns:
        (numpy.ndarray): augmented image
    """
    img_enhanced = cv2.convertScaleAbs(img, alpha=contrast_factor, beta=brightness_beta)
    if blur:
        img_enhanced = cv2.GaussianBlur(img_enhanced, kernel_size, 0)
    img_enhanced = cv2.equalizeHist(img_enhanced)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    img_enhanced = clahe.apply(img_enhanced)
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    return cv2.filter2D(img_enhanced, -1, kernel)

def visualize_augmentations_cv(image_list, enhance=True):
    """
    Visualizes original and augmented greyscale images using OpenCV.
    Args:
        image_list (list): List of image paths in Azure Blob Storage
        enhance (bool): Whether to apply image augmentation
    Returns:
        None
    """
    fig, axes = plt.subplots(len(image_list), 3, figsize=(15, 5 * len(image_list)))

    for idx, blob_img in enumerate(image_list):
        original_image = load_image_from_blob_cv(blob_img)
        greyscale_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY)
        augmented_image = greyscale_image.copy()

        if enhance:
            augmented_image = augment_greyscale_image(augmented_image, contrast_factor=1.5, brightness_beta=30, kernel_size=(5, 5))

        axes[idx, 0].imshow(original_image)
        axes[idx, 0].set_title("Original Image (RGB)")
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(greyscale_image, cmap='gray')
        axes[idx, 1].set_title("Greyscale Image")
        axes[idx, 1].axis('off')

        axes[idx, 2].imshow(augmented_image, cmap='gray')
        axes[idx, 2].set_title("Augmented Image")
        axes[idx, 2].axis('off')

    plt.tight_layout()
    plt.show()

def random_crop_sequence(image_sequence, crop_size_minimum=224):
    """
    Apply identical random crop to a sequence of images and ensure randomness in cropping location.
    The crop size is randomly selected between the minimum size (crop_size_minimum) and the full image size.

    Args:
        image_sequence: A list or tensor of images in the sequence (each image should be the same size).
        crop_size_minimum: The minimum crop size (default is 224).

    Returns:
        cropped_sequence: The cropped sequence of images.
        new_center_coords: The new coordinates of the center pixel after cropping.
    """
    # Get the height and width of the original images
    orig_height, orig_width = image_sequence[0].shape[-2:]

    # Randomly select the crop size between the minimum and full image size
    max_crop_size = min(orig_height, orig_width)
    crop_h = random.randint(crop_size_minimum, max_crop_size)
    crop_w = crop_h  # Keeping the crop square for simplicity

    # Ensure that the crop size is smaller than the original image
    assert crop_h <= orig_height and crop_w <= orig_width, "Crop size must be smaller than the original image size."

    # Original center coordinates of the image
    orig_center_x = orig_width // 2
    orig_center_y = orig_height // 2

    # Randomly select the top-left corner for the crop
    top = random.randint(0, orig_height - crop_h)
    left = random.randint(0, orig_width - crop_w)

    # Apply the same crop to each image in the sequence
    cropped_sequence = [TF.crop(img, top, left, crop_h, crop_w) for img in image_sequence]

    # Calculate the new center coordinates based on the original center relative to the cropped region
    new_center_x = orig_center_x - left  # Adjust the original center x based on the crop left offset
    new_center_y = orig_center_y - top   # Adjust the original center y based on the crop top offset

    # Ensure the new center coordinates are still within the cropped image bounds
    new_center_x = max(0, min(new_center_x, crop_w - 1))
    new_center_y = max(0, min(new_center_y, crop_h - 1))

    return cropped_sequence, (new_center_x, new_center_y)

def resize_sequence_and_adjust_center(cropped_sequence, new_center_coords, target_size=(224, 224)):
    """
    Resize the cropped sequence to a target size and adjust the center coordinates accordingly.

    Args:
        cropped_sequence: The list of cropped images.
        new_center_coords: The (x, y) coordinates of the center in the cropped image.
        target_size: The desired output size (height, width), default is (224, 224).

    Returns:
        resized_sequence: The resized sequence of images.
        resized_center_coords: The adjusted (x, y) coordinates of the center in the resized images.
    """
    crop_h, crop_w = cropped_sequence[0].shape[-2:]  # Get height and width of the cropped image
    target_h, target_w = target_size  # Desired target size (e.g., 224x224)

    # Calculate the scaling factors for height and width
    scale_x = target_w / crop_w
    scale_y = target_h / crop_h

    # Resize each image in the sequence to the target size
    resized_sequence = [TF.resize(img, size=target_size) for img in cropped_sequence]

    # Adjust the center coordinates according to the scaling factor
    new_center_x, new_center_y = new_center_coords
    resized_center_x = int(new_center_x * scale_x)
    resized_center_y = int(new_center_y * scale_y)

    return resized_sequence, (resized_center_x, resized_center_y)

In [None]:
images = [
    "20170418/170418_175706_183328_frames/20170418_175706_frame_0.jpg",
    "20170418/170418_175706_183328_frames/20170418_175707_frame_60.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175708_frame_120_cropped.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175709_frame_180_cropped.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175710_frame_240_cropped.jpg",
    "20170418/170418_175706_183328_frames_crop_corrected_aligned/20170418_175708_frame_120.jpg",
    "20170418/170418_175706_183328_frames_crop_corrected_aligned/20170418_175709_frame_180.jpg",
    "20170418/170418_175706_183328_frames_crop_corrected_aligned/20170418_175710_frame_240.jpg",
]

visualize_augmentations_cv(images, True)

# FlightMetadataManager Design

The FlightMetadataManager is designed to handle the complex array of sensor and flight data collected during NASA's ER-2 missions. Here's why different types of metadata are handled in specific ways:

## 1. Metadata Categorization

The metadata is categorized into distinct groups based on their physical meaning and units:

### Geographic Parameters
- `Lat`, `Lon`: Geographic coordinates
- `GPS_MSL_Alt`, `WGS_84_Alt`, `Press_Alt`: Different altitude measurements
- **Motivation**: These parameters require consistent handling as they're used for spatial positioning and often need to be used together for accurate location determination.

### Aircraft Motion Parameters
- `Grnd_Spd`, `True_Airspeed`, `Mach_Number`: Speed measurements
- `Vert_Velocity`: Vertical movement
- `True_Hdg`, `Track`, `Drift`, `Pitch`, `Roll`: Aircraft orientation
- **Motivation**: These parameters describe the aircraft's motion and orientation, which affect image capture conditions and need to be considered when analyzing cloud heights.

### Environmental Parameters
- `Ambient_Temp`, `Total_Temp`: Temperature readings
- `Static_Press`, `Dynamic_Press`, `Cabin_Pressure`: Pressure measurements
- `Wind_Speed`, `Wind_Dir`: Wind conditions
- **Motivation**: Environmental conditions can affect both the aircraft's performance and the cloud formation being studied.

### Solar Parameters
- `Solar_Zenith`, `Sun_Elev_AC`, `Sun_Az_Grd`, `Sun_Az_AC`: Sun position relative to aircraft
- **Motivation**: Helps in understanding lighting conditions which affect image quality and interpretation.

## 2. Special Handling Features

### Temporal Feature Engineering
The system employs cyclic encoding for temporal features through trigonometric transformations:

$$
\begin{align*}
& \text{hour}_{\sin} = \sin(2\pi \cdot \text{hour}/24) \\
& \text{hour}_{\cos} = \cos(2\pi \cdot \text{hour}/24) \\
& \text{day}_{\sin} = \sin(2\pi \cdot \text{day}/365) \\
& \text{day}_{\cos} = \cos(2\pi \cdot \text{day}/365)
\end{align*}
$$

**Motivation**:
- Preserves cyclic nature of time (e.g., hour 23 is close to hour 0)
- Natural range of [-1, 1] eliminates need for additional normalization
- Smooth transitions between time periods
- Better representation for neural network processing

### Normalization Strategy

The system applies different normalization strategies based on parameter type:

$$
\begin{align*}
& \textbf{For standard parameters:} \\\\
& x_{\text{normalized}} = \frac{x - x_{\min}}{x_{\max} - x_{\min}} \\\\
& \textbf{For angular parameters:} \\\\
& \theta_{\text{normalized}} = \begin{cases}
\sin(\theta) & \text{for sine component} \\\\
\cos(\theta) & \text{for cosine component}
\end{cases} \\\\
& \textbf{For pre-normalized parameters:} \\\\
& x_{\text{final}} = x_{\text{original}}
\end{align*}
$$

## 3. Parameter Organization

### Feature Type Grouping

The system organizes parameters into hierarchical categories:

1. Primary Features
   $$\mathbf{F}_{\text{primary}} = \{\mathbf{F}_{\text{flight}}, \mathbf{F}_{\text{temporal}}, \mathbf{F}_{\text{target}}\}$$

2. Flight Data Subset
   $$\mathbf{F}_{\text{flight}} = \{\mathbf{P}_{\text{geographic}}, \mathbf{P}_{\text{motion}}, \mathbf{P}_{\text{environmental}}, \mathbf{P}_{\text{solar}}\}$$

3. Parameter Vectors
   $$\mathbf{P}_{\text{geographic}} = [{\text{Lat}, \text{Lon}, \text{Alt}_{\text{GPS}}, \text{Alt}_{\text{WGS}}, \text{Alt}_{\text{Press}}}]$$
   $$\mathbf{P}_{\text{motion}} = [\text{Speed}_{\text{ground}}, \text{Speed}_{\text{air}}, \text{Mach}, \text{Velocity}_{\text{vert}}, ...]$$

### Index Management

The system maintains a strict ordering system:

1. Base Index Structure:
   $$I(p) : \mathbf{F}_{\text{all}} \rightarrow \mathbb{N}_0$$
   
2. Parameter Access:
   $$p_i = I^{-1}(i) \text{ where } i \in [0, |\mathbf{F}_{\text{all}}| - 1]$$

This organization ensures:
- Consistent parameter ordering across batches
- O(1) parameter lookup time
- Maintainable structure for feature additions
- Clear separation of concerns for different parameter types

## 4. Key Benefits

1. **Data Integrity**
   - Consistent handling of related parameters
   - Proper normalization based on parameter type
   - Preservation of meaningful relationships between parameters

2. **Model Performance**
   - Well-organized features for neural network input
   - Properly scaled values for better training
   - Efficient batch processing

3. **Maintenance and Extensibility**
   - Clear structure for adding new parameters
   - Easy modification of normalization strategies
   - Simple parameter grouping updates

4. **Error Prevention**
   - Type checking for parameters
   - Validation of normalization requirements
   - Consistent parameter ordering

In [None]:
class FlightMetadataManager:
    def __init__(self):
        # Define all metadata parameters with their properties
        self.metadata_config = {
            # Temporal parameters (non-normalized)
            'timestamp': {'normalize': False, 'type': 'temporal'},

            # Height validation (normalized)
            'validation_height': {'normalize': True, 'type': 'height'},

            # Geographic coordinates (normalized)
            'Lat': {'normalize': True, 'type': 'coordinate'},
            'Lon': {'normalize': True, 'type': 'coordinate'},
            'GPS_MSL_Alt': {'normalize': True, 'type': 'altitude'},
            'WGS_84_Alt': {'normalize': True, 'type': 'altitude'},
            'Press_Alt': {'normalize': True, 'type': 'altitude'},

            # Speed parameters (normalized)
            'Grnd_Spd': {'normalize': True, 'type': 'speed'},
            'True_Airspeed': {'normalize': True, 'type': 'speed'},
            'Mach_Number': {'normalize': True, 'type': 'dimensionless'},
            'Vert_Velocity': {'normalize': True, 'type': 'speed'},

            # Angular parameters (normalized)
            'True_Hdg': {'normalize': True, 'type': 'angle'},
            'Track': {'normalize': True, 'type': 'angle'},
            'Drift': {'normalize': True, 'type': 'angle'},
            'Pitch': {'normalize': True, 'type': 'angle'},
            'Roll': {'normalize': True, 'type': 'angle'},

            # Temperature parameters (normalized)
            'Ambient_Temp': {'normalize': True, 'type': 'temperature'},
            'Total_Temp': {'normalize': True, 'type': 'temperature'},

            # Pressure parameters (normalized)
            'Static_Press': {'normalize': True, 'type': 'pressure'},
            'Dynamic_Press': {'normalize': True, 'type': 'pressure'},
            'Cabin_Pressure': {'normalize': True, 'type': 'pressure'},

            # Wind parameters (normalized)
            'Wind_Speed': {'normalize': True, 'type': 'speed'},
            'Wind_Dir': {'normalize': True, 'type': 'angle'},

            # Solar parameters (normalized)
            'Solar_Zenith': {'normalize': True, 'type': 'angle'},
            'Sun_Elev_AC': {'normalize': True, 'type': 'angle'},
            'Sun_Az_Grd': {'normalize': True, 'type': 'angle'},
            'Sun_Az_AC': {'normalize': True, 'type': 'angle'},

            # Encoded temporal features (pre-normalized)
            'hour_sin': {'normalize': False, 'type': 'encoded_temporal'},
            'hour_cos': {'normalize': False, 'type': 'encoded_temporal'},
            'day_sin': {'normalize': False, 'type': 'encoded_temporal'},
            'day_cos': {'normalize': False, 'type': 'encoded_temporal'}
        }

        # Create ordered lists of parameters for different purposes
        self.flight_data = [param for param, config in self.metadata_config.items()
                                if config['type'] != 'temporal' and param != 'validation_height']

        self.temporal_features = [param for param, config in self.metadata_config.items()
                                if config['type'] == 'temporal']

        # All input features (flight data + temporal)
        self.all_features = self.flight_data + self.temporal_features

        # Create indices mapping for quick lookup
        self.param_indices = {param: idx for idx, param in enumerate(self.all_features)}

        # All columns that need normalization (including validation_height)
        self.columns_to_normalize = [param for param, config in self.metadata_config.items()
                                        if config['normalize']]

    def get_index(self, param_name):
        """Get the index of a parameter in the normalized feature list."""
        return self.param_indices[param_name]

    def get_feature_types(self):
        """Get flight data and temporal features as separate lists.

        Returns:
            tuple: (flight_data, temporal_features) where
                flight_data: List of normalized flight parameters
                temporal_features: List of temporal features
        """
        return self.flight_data, self.temporal_features

    def print_feature_indices(self):
      """Print all features and their corresponding indices in a formatted table."""
      print("\nFlight Data Features:")
      print("-" * 50)
      for param in self.flight_data:
          print(f"{self.param_indices[param]:2d}: {param}")

      print("\nTemporal Features:")
      print("-" * 50)
      for param in self.temporal_features:
          print(f"{self.param_indices[param]:2d}: {param}")

metadata_manager = FlightMetadataManager()
metadata_manager.print_feature_indices()

# Cloud2CloudDataset Design

## Purpose and Core Functionality

The Cloud2CloudDataset class implements a custom PyTorch Dataset that manages three synchronized data streams:
1. FEGS HD camera images
2. Aircraft metadata
3. LiDAR validation height measurements

## Dataset Initialization

### Input Parameters
- `date_folders`: List of target flight dates
- `normalization_params`: Optional pre-computed normalization parameters
- `transform`: Image transformations pipeline
- `augmentations`: Optional additional image augmentations
- `apply_normalization`: Boolean flag for data normalization
- `apply_crop_and_scale`: Boolean flag for random cropping
- `resize_size`: Target image size (default 384)

### Key Components

1. **Temporal Feature Engineering**
   
   Cyclic encoding of time features:
  $$
  \begin{align*}
  & \text{hour_sin} = \sin(2\pi \cdot \text{hour}/24) \\
  & \text{hour_cos} = \cos(2\pi \cdot \text{hour}/24) \\
  & \text{day_sin} = \sin(2\pi \cdot \text{day}/365) \\
  & \text{day_cos} = \cos(2\pi \cdot \text{day}/365)
  \end{align*}
  $$

2. **Data Normalization**

   For each numerical column:
  $$
  x_{\text{normalized}} = \frac{x - x_{\min}}{x_{\max} - x_{\min}}
  $$

   Special handling for validation heights:
  $$
  \begin{align*}
  \text{height}{\text{normalized}} = \begin{cases}
  \frac{\text{height} - \text{height}{\min}}{\text{height}{\max} - \text{height}{\min}} & \text{if height is valid} \\
  \text{NaN} & \text{if height is missing}
  \end{cases}
  \end{align*}
  $$

3. **Sequence Management**
   
   For a sequence $S$ of length $n$:
  $$
  \begin{align*}
  & S = {f_1, f_2, ..., f_n} \text{ where } f_i \text{ is a frame} \\
  & \text{length}(S) = \min(n, \text{max_sequence_length}) \\
  & \text{valid}(S) \iff \text{length}(S) \geq \text{min_sequence_length}
  \end{align*}
  $$

## Data Processing Pipeline

### 1. Image Processing
$$
\begin{align*}
& \textbf{Input Image } I_{800 \times 800} \\
& \downarrow \text{ Convert to Grayscale} \\
& I_{\text{gray}} = \text{RGB2GRAY}(I_{800 \times 800}) \\
& \downarrow \text{ Apply Augmentations} \\
& I_{\text{aug}} = \text{Augment}(I_{\text{gray}}) \text{ where:} \\
& \quad \bullet \text{ Contrast Enhancement} \\
& \quad \bullet \text{ Histogram Equalization} \\
& \quad \bullet \text{ CLAHE} \\
& \downarrow \text{ Random Crop and Resize} \\
& I_{\text{final}} = \text{Resize}(\text{RandomCrop}(I_{\text{aug}}), 384 \times 384)
\end{align*}
$$


### 2. Sequence Construction

For each valid sequence:

$$
\begin{align*}
& \textbf{Sequence } S_i = {(I_j, M_j, h_j)}_{j=1}^n \text{ where:} \\
& \quad I_j = \text{processed image} \\
& \quad M_j = \text{metadata vector} \\
& \quad h_j = \text{validation height (if available)} \\
& \textbf{Constraint: } \text{timestamp}(j+1) - \text{timestamp}(j) \approx 5 \text{ seconds}
\end{align*}
$$

### 3. Batch Formation
$$
\begin{align*}
& \textbf{Batch } B = {S_1, S_2, ..., S_b} \text{ where:} \\
& \quad b = \text{batch size} \\
& \quad \text{max_len} = \max(\text{length}(S_i)) \\
& \quad \text{Pad shorter sequences to max_len}
\end{align*}
$$

## Access Methods

1. **Length Function**:
   $$|D| = \text{number of valid sequences}$$

2. **Item Access**:
  $$
  \begin{align*}
  & D[i] \rightarrow (I_i, M_i, h_i, c_i, p_i, t_i) \text{ where:} \\
  & \quad I_i = \text{image sequence tensor} \\
  & \quad M_i = \text{metadata tensor} \\
  & \quad h_i = \text{validation height} \\
  & \quad c_i = \text{center coordinates} \\
  & \quad p_i = \text{image paths} \\
  & \quad t_i = \text{timestamps}
  \end{align*}
  $$

In [None]:
class Cloud2CloudDataset(Dataset):
    def __init__(self,
                dataframe,
                normalization_params=None,
                transform=None, augmentations=None,
                apply_normalization=True,
                apply_crop_and_scale=True,
                resize_size=224):
        """
        Args:
            dataframe (pd.DataFrame): dataframe containing image paths, flight data, and ground truth.
            normalization_params (dict, optional): Dictionary containing min and max values for all the fields.
            transform (callable, optional): optional transform to apply to each image.
            augmentations (callable, optional): optional augmentations to apply to each image.
            apply_normalization (bool): whether to apply normalization to the dataframe.
            resize_size (int): The target size to resize the images (square dimensions).
        """
        self.dataframe = dataframe.copy()
        self.transform = transform
        self.augmentations = augmentations
        self.apply_normalization = apply_normalization
        self.apply_crop_and_scale = apply_crop_and_scale
        self.resize_size = resize_size

        # Initialize metadata manager
        self.metadata_manager = FlightMetadataManager()

        # Add temporal encoding columns
        hour = self.dataframe['hour_of_day']
        day = self.dataframe['day_of_year']
        self.dataframe['hour_sin'] = np.sin(2 * np.pi * hour / 24.0)
        self.dataframe['hour_cos'] = np.cos(2 * np.pi * hour / 24.0)
        self.dataframe['day_sin'] = np.sin(2 * np.pi * day / 365.0)
        self.dataframe['day_cos'] = np.cos(2 * np.pi * day / 365.0)

        if self.apply_normalization:
            # Calculate or use provided normalization parameters
            self.normalization_params = normalization_params or self._calculate_normalization_params(
                self.dataframe, self.metadata_manager.columns_to_normalize
            )

            for col in self.metadata_manager.columns_to_normalize:
                print(f"\nNormalizing {col}:")
                print(f"Min: {self.normalization_params[col]['min']}")
                print(f"Max: {self.normalization_params[col]['max']}")

                col_min = self.normalization_params[col]['min']
                col_max = self.normalization_params[col]['max']

                if col == 'validation_height':
                    # Only normalize non-NaN values for validation_height
                    # (we use NaNs to indicate frames with no LiDAR data)
                    mask = ~self.dataframe[col].isna()
                    self.dataframe.loc[mask, col] = (self.dataframe.loc[mask, col] - col_min) / (col_max - col_min)
                else:
                    # Normalize all values for other columns
                    self.dataframe[col] = (self.dataframe[col] - col_min) / (col_max - col_min)
                    if self.dataframe[col].isna().any():
                        print("WARNING: NaN values detected after normalization!")

        # Track the indices where sequences start
        self.sequence_indices = self._generate_sequence_indices()

    def _calculate_normalization_params(self, dataframe, columns_to_normalize):
        """
        Manually calculate min and max values for the columns and store them for consistency across datasets.
        """
        params = {}
        for col in columns_to_normalize:
            params[col] = {
                'min': dataframe[col].min(),
                'max': dataframe[col].max()
            }
        return params

    def denormalize_validation_height(self, normalized_height):
        """
        Denormalize the validation height using the stored min and max values for validation_height.

        Args:
            normalized_height (float or np.array): The normalized validation height(s) to denormalize.

        Returns:
            float or np.array: The denormalized validation height(s).
        """
        # Check if normalization_params exists; if not, return the input as-is
        if not hasattr(self, 'normalization_params') or 'validation_height' not in self.normalization_params:
            return normalized_height

        # Get min and max for validation_height
        col_min = self.normalization_params['validation_height']['min']
        col_max = self.normalization_params['validation_height']['max']

        print(f"  Max: {col_max} Min: {col_min}")

        # Denormalize using the stored min and max values
        original_height = normalized_height * (col_max - col_min) + col_min
        return original_height

    def denormalize_flight_data(self, normalized_flight_data):
        """
        Denormalize flight data using the stored min and max values for each column.

        Args:
            normalized_flight_data (np.array): Normalized flight data array to denormalize.

        Returns:
            np.array: Denormalized flight data.
        """
        # Check if normalization_params exists; if not, return the input as-is
        if not hasattr(self, 'normalization_params'):
            return normalized_flight_data

        denormalized_flight_data = []
        # Use flight_data which does not contain validation_height
        for i, col in enumerate(self.metadata_manager.flight_data):
            if col not in self.normalization_params:
                # If normalization parameters for a column are missing, return the normalized value as-is
                denormalized_flight_data.append(normalized_flight_data[i])
            else:
                col_min = self.normalization_params[col]['min']
                col_max = self.normalization_params[col]['max']
                denorm_value = normalized_flight_data[i] * (col_max - col_min) + col_min
                denormalized_flight_data.append(denorm_value)

        return np.array(denormalized_flight_data)

    def _generate_sequence_indices(self):
        """
        Generate a list of indices where each sequence starts.
        If a NaN is encountered in the sequence_length, the process stops.
        """
        sequence_indices = []
        idx = 0

        while idx < len(self.dataframe):
            sequence_length = self.dataframe.iloc[idx]['sequence_length']

            if pd.isna(sequence_length):
                # Stop if sequence_length is NaN (no more sequences)
                break

            # Convert sequence_length to an integer
            sequence_length = int(sequence_length)
            sequence_indices.append(idx)
            idx += sequence_length  # Move to the start of the next sequence

        return sequence_indices

    def __len__(self):
        return len(self.sequence_indices)

    def __getitem__(self, idx):
        # Get the starting index of the sequence
        start_idx = self.sequence_indices[idx]

        # Get the sequence length for this specific starting point
        sequence_length = int(self.dataframe.iloc[start_idx]['sequence_length'])

        # Get image paths for the sequence
        image_paths = self.dataframe.iloc[start_idx:start_idx + sequence_length]['image_path'].tolist()

        # Initialize sequences
        image_sequence = []
        flight_data_sequence = []
        timestamp_sequence = []  # New sequence for timestamps

        # Get feature types from metadata manager
        numeric_features, temporal_features = self.metadata_manager.get_feature_types()

        # Single loop to process both images and flight data
        for i in range(sequence_length):
            # Process image
            image_path = self.dataframe.iloc[start_idx + i]['image_path']

            # Load image using OpenCV
            img_rgb = load_image_from_blob_cv(image_path)

            # Get the height and width of the loaded image
            h, w = img_rgb.shape[:2]

            # # [Optional] Convert RGB to grayscale
            # img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

            # # Apply augmentations to grayscale image
            # img_gray = augment_greyscale_image(img_gray)

            # # Convert grayscale image to RGB format
            # img_rgb_converted = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)

            # Apply the color augmentation instead of grayscale
            img_rgb_converted = augment_color_image_clahe_only(img_rgb, clip_limit=3.0, tile_grid_size=(8, 8))

            # Apply additional augmentations to the RGB image if any
            if self.augmentations:
                img_rgb_converted = self.augmentations(img_rgb_converted)

            # Convert back to tensor and append to the sequence
            img_tensor = TF.to_tensor(img_rgb_converted)
            image_sequence.append(img_tensor)

            # Process flight data
            row = self.dataframe.iloc[start_idx + i]

            # Extract flight data and ensure numeric values
            flight_data = []
            for feature in numeric_features:  # Changed from motion_features to numeric_features
                value = row[feature]
                # Convert to float and handle any non-numeric values
                try:
                    flight_data.append(float(value))
                except (ValueError, TypeError):
                    print(f"Error converting feature {feature} with value {value}")
                    flight_data.append(0.0)  # or some other default value

            flight_data_sequence.append(flight_data)

            # Process timestamps separately
            for feature in temporal_features:
                timestamp_sequence.append(row[feature])

        # Apply random cropping and resizing only if `apply_crop_and_scale` is True
        if self.apply_crop_and_scale:
            # Apply random cropping to the entire sequence
            cropped_sequence, new_center_coords = random_crop_sequence(image_sequence, crop_size_minimum=self.resize_size)
            # Resize the cropped sequence and adjust the center coordinates
            resized_sequence, resized_center_coords = resize_sequence_and_adjust_center(cropped_sequence, new_center_coords, target_size=(self.resize_size, self.resize_size))
            # Stack images into a tensor
            image_sequence = torch.stack(resized_sequence)
        else:
            # Resize original images to resize_size x resize_size
            resized_sequence = [TF.resize(img, size=(self.resize_size, self.resize_size)) for img in image_sequence]
            # Stack the resized images
            image_sequence = torch.stack(resized_sequence)
            # Calculate center of the resized square image
            resized_center_coords = (self.resize_size // 2, self.resize_size // 2)

        # print(f"Image sequence shape before model: {image_sequence.shape}")
        if image_sequence.shape[-1] != 384 or image_sequence.shape[-2] != 384:
            raise ValueError(f"Incorrect image dimensions: {image_sequence.shape}. Expected 384x384")

        # Convert flight data to tensor
        flight_data = torch.tensor(flight_data_sequence, dtype=torch.float32)

        # Fetch the validation height (target)
        validation_height = self.dataframe.iloc[start_idx + sequence_length - 1]['validation_height']
        validation_height = torch.tensor(validation_height, dtype=torch.float32)

        # Return with timestamp_sequence added
        return image_sequence, flight_data, validation_height, resized_center_coords, image_paths, timestamp_sequence

# Train/Validation Split

# Training and Validation Split Design

## 1. Initial Dataset Creation

First, we create a full dataset without normalization:

$$D_{\text{full}} = \{S_1, S_2, ..., S_n\}$$

where each $S_i$ is a sequence of frames.

## 2. Sequence-Based Splitting

Rather than splitting individual frames randomly, we split entire sequences:

$$
\begin{align*}
& I_{\text{sequences}} = {1, 2, ..., n} \text{ where } n = |D_{\text{full}}| \\
& \text{Split } I_{\text{sequences}} \text{ into:} \\
& \quad I_{\text{train}} \text{ (80% of sequences)} \\
& \quad I_{\text{val}} \text{ (20% of sequences)}
\end{align*}
$$

### Motivation for Sequence-Based Splitting:
- Preserves temporal continuity within sequences
- Prevents data leakage between train/validation
- Maintains sequence integrity

## 3. Row Extraction Process

For each set (train and validation):

$$
\begin{align*}
& \text{For each sequence index } i: \\
& \quad L_i = \text{sequence_length}(i) \\
& \quad R_i = {i, i+1, ..., i+L_i-1} \\
& R_{\text{train}} = \bigcup_{i \in I_{\text{train}}} R_i \\
& R_{\text{val}} = \bigcup_{i \in I_{\text{val}}} R_i
\end{align*}
$$

## 4. Dataset Creation

### Training Dataset
$$
\begin{align*}
& D_{\text{train}} = \text{Cloud2CloudDataset}(R_{\text{train}}) \text{ with:} \\
& \quad \bullet \text{ apply_normalization = True} \\
& \quad \bullet \text{ apply_crop_and_scale = True} \\
& \quad \bullet \text{ Calculate normalization parameters:} \\
& \quad\quad \text{For each feature } f: \\
& \quad\quad\quad \mu_f = \text{mean}(f) \\
& \quad\quad\quad \sigma_f = \text{std}(f)
\end{align*}
$$

### Validation Dataset
$$
\begin{align*}
& D_{\text{val}} = \text{Cloud2CloudDataset}(R_{\text{val}}) \text{ with:} \\
& \quad \bullet \text{ apply_normalization = True} \\
& \quad \bullet \text{ apply_crop_and_scale = False} \\
& \quad \bullet \text{ Use training normalization parameters:} \\
& \quad\quad f_{\text{normalized}} = \frac{f - \mu_f}{\sigma_f}
\end{align*}
$$

## 5. Key Differences Between Train and Validation

### Training Set Features:
- Random cropping and scaling
- Data augmentation
- Shuffled sequences
- Calculates normalization parameters

### Validation Set Features:
- No random cropping (center crop only)
- Limited augmentation
- Sequential processing
- Uses training set's normalization parameters

### **Note: We use training-only normalization parameters**

This is an important machine learning practice for several reasons:

a) **Preventing Data Leakage**
   - If we calculated normalization parameters using all data (including validation), we'd be letting information from the validation set influence our training data
   - This would create a subtle form of data leakage where our model gets to "peek" at validation data statistics

b) **Real-world Scenario Simulation**
   - In production, we'll need to normalize new, unseen data
   - We won't be able to calculate new normalization parameters for this data
   - Using training-only parameters better simulates this real-world scenario

c) **Validation Integrity**
   - Validation should measure how well our model generalizes
   - Using training-derived normalization parameters tests if our normalization approach itself generalizes
   - If validation performance is good, it suggests our normalization is robust

Mathematically, for any feature $f$:
```
Training:
μ_train = mean(f_train)
σ_train = std(f_train)
f_train_normalized = (f_train - μ_train) / σ_train

Validation:
f_val_normalized = (f_val - μ_train) / σ_train  # Using training parameters
```

This ensures that our validation metrics genuinely reflect how well our model will perform on unseen data.

## 6. DataLoader Configuration

$$
\begin{align*}
& \text{Loader}{\text{train}} = \text{DataLoader}(D{\text{train}}) \text{ with:} \\
& \quad \bullet \text{ batch_size = 1} \\
& \quad \bullet \text{ shuffle = True} \\
& \quad \bullet \text{ custom collate function} \\
& \text{Loader}{\text{val}} = \text{DataLoader}(D{\text{val}}) \text{ with:} \\
& \quad \bullet \text{ batch_size = 1} \\
& \quad \bullet \text{ shuffle = False} \\
& \quad \bullet \text{ same collate function}
\end{align*}
$$

At this stage, the dataloader using only a batch size of 1 is useful for:

- Visualizing individual sequences
- Debugging data transformations
- Checking center point calculations
- Verifying normalization

In [None]:
# Create the full Cloud2CloudDataset (no normalization yet)
full_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=full_dataframe, transform=transform, augmentations=augmentations, apply_normalization=False, resize_size=384)

# Split sequence indices into training and validation sets
train_sequence_indices, val_sequence_indices = train_test_split(full_cloud2cloud_dataset.sequence_indices, test_size=0.2, random_state=42)

# Extract rows for training dataset based on sequence indices
train_rows = []
for seq_start in train_sequence_indices:
    sequence_length = int(full_cloud2cloud_dataset.dataframe.iloc[seq_start]['sequence_length'])
    train_rows.extend(range(seq_start, seq_start + sequence_length))

train_dataframe = full_dataframe.iloc[train_rows].reset_index(drop=True)

# Create the training dataset and calculate normalization parameters
train_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=train_dataframe, transform=transform, augmentations=augmentations, apply_normalization=False, apply_crop_and_scale=True, resize_size=384)

# Get normalization parameters from the training dataset if it exists
if hasattr(train_cloud2cloud_dataset, 'normalization_params'):
    normalization_params = train_cloud2cloud_dataset.normalization_params
else:
    normalization_params = None

# Extract rows for validation dataset based on sequence indices
val_rows = []
for seq_start in val_sequence_indices:
    sequence_length = int(full_cloud2cloud_dataset.dataframe.iloc[seq_start]['sequence_length'])
    val_rows.extend(range(seq_start, seq_start + sequence_length))

val_dataframe = full_dataframe.iloc[val_rows].reset_index(drop=True)

# Create the validation dataset using the same normalization parameters
val_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=val_dataframe, normalization_params=normalization_params, transform=transform, augmentations=augmentations, apply_normalization=False, apply_crop_and_scale=False, resize_size=384)

# Create DataLoaders for training and validation sets
train_dataloader_single_batch = DataLoader(train_cloud2cloud_dataset, batch_size=1, shuffle=True)
val_dataloader_single_batch = DataLoader(val_cloud2cloud_dataset, batch_size=1, shuffle=False)

# Modelling

## Helper functions

In [None]:
def print_num_parameters(model):
    """
    Prints the number of trainable parameters in a PyTorch model.
    """
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params}")

def get_denormalized_metadata(dataset, data, sample_idx=0):
    """Get denormalized metadata values for height calculation."""
    # Get normalized metadata for the last frame
    metadata_tensor = data['metadata'][sample_idx, -1]

    # Get the denormalized values using the dataset's parameters
    denorm_data = dataset.denormalize_flight_data(metadata_tensor)

    # Create metadata dictionary using metadata manager indices
    metadata = {
        'GPS_MSL_Alt': float(denorm_data[dataset.metadata_manager.get_index('GPS_MSL_Alt')]),
        'True_Airspeed': float(denorm_data[dataset.metadata_manager.get_index('True_Airspeed')]),
        'Track': float(denorm_data[dataset.metadata_manager.get_index('Track')]),
        'Pitch': float(denorm_data[dataset.metadata_manager.get_index('Pitch')]),
        'Roll': float(denorm_data[dataset.metadata_manager.get_index('Roll')]),
        'Wind_Speed': float(denorm_data[dataset.metadata_manager.get_index('Wind_Speed')]),
        'Wind_Dir': float(denorm_data[dataset.metadata_manager.get_index('Wind_Dir')])
    }

    return metadata

## CloudMotionModel Class

The `CloudMotionModel` is a dual-RAFT architecture designed for cloud motion estimation that combines a trainable RAFT model with a frozen reference model.

## Model Components

### 1. Dual RAFT Models
The architecture employs two RAFT instances:
- **Trainable RAFT**: Fine-tuned for cloud-specific motion
- **Frozen Reference RAFT**: Maintains baseline motion estimates

```python
self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT if use_pretrained else None)
self.reference_raft = raft_large(weights=Raft_Large_Weights.DEFAULT)  # Frozen
```

### 2. Bidirectional Anomaly Detection and Preservation

The preserve_bidirectional_anomalies method addresses a key challenge in cloud motion sequences: temporal inconsistency in feature detection. Cloud features, particularly at high altitudes, may be detected strongly in some frame pairs but weakly or not at all in others.

The model identifies motion anomalies using statistical thresholds:

$$
\begin{align*}
\text{strong_anomalies} &= \text{motion} > (\mu + \sigma \cdot \text{threshold}) \\
\text{weak_anomalies} &= \text{motion} < (\mu - \sigma \cdot \text{threshold})
\end{align*}
$$

where:
- $\mu$ is the mean motion across the sequence (calculated per pixel)
- $\sigma$ is the standard deviation of motion
- threshold is the anomaly detection parameter (default 2.0)

For each pixel location $(x,y)$, the final motion is determined by:

$$
\text{final_motion}(x,y) = \begin{cases}
\max(\text{motion}(x,y)) & \text{if strong anomaly detected at any time} \\
\min(\text{motion}(x,y)) & \text{if weak anomaly detected at any time} \\
\mu(\text{motion}(x,y)) & \text{otherwise}
\end{cases}
$$

This preservation mechanism is helpful because cloud motion features in a sequence $\{I_1, I_2, ..., I_T\}$ may only be distinctly visible in specific frame pairs. For example, a high-altitude cloud feature might show significant motion between frames $I_1$ and $I_2$, but become less distinct in subsequent pairs due to:
- Changes in viewing angle
- Varying illumination conditions
- Temporary occlusions by other cloud layers
- Varying contrast against the background

By preserving both strong and weak anomalies across the entire sequence, we ensure that important motion features are retained even if they're only detected in a subset of frame pairs. This is particularly important for high-altitude clouds where parallax effects might be subtle and inconsistent across the sequence.

This approach significantly improves the robustness of height estimation by:
- Preserving rare but significant motion features
- Reducing the impact of temporary feature occlusions
- Maintaining consistency in height estimates for high-altitude clouds
- Combining evidence of cloud motion across multiple temporal samples

### 3. Confidence Estimation

Confidence scores are computed using motion magnitude:

$$
\text{confidence} = \sigma(\|\text{refined_motion}\|_2)
$$

where $\sigma$ is the sigmoid function to normalize confidence to [0,1].

## Forward Pass Flow

1. **Sequence Processing**:
   - Input: Sequence of images $\{I_1, I_2, ..., I_T\}$
   - For each consecutive pair $(I_t, I_{t+1})$:
     - Scale images to [0, 255] range
     - Process through both RAFT models
     - Store motion fields

2. **Motion Stack Creation**:
   - Stack motion fields temporally
   - Handle packed sequences if present
   - Motion dimensions: $(B, T-1, 2, H, W)$
   where B=batch size, T=sequence length, H=height, W=width

3. **Motion Refinement**:
   - Apply bidirectional anomaly preservation
   - Generate confidence maps
   - Handle packed sequence conversions

4. **Output Generation**:
```python
return {
    'motion_fields': Original motion sequence
    'refined_motion': Anomaly-preserved motion
    'reference_motion_fields': Reference RAFT sequence
    'reference_motion': Reference refined motion
    'confidence': Confidence maps
}
```

In [1]:
class CloudMotionModel(nn.Module):
    """
    Neural network model that combines trainable and reference RAFT models for cloud motion estimation.

    This model uses two RAFT (Recurrent All-Pairs Field Transforms) instances:
    - A trainable model that can be fine-tuned for cloud-specific motion
    - A frozen reference model that maintains baseline motion estimates

    The model processes sequences of images to estimate motion fields, preserving motion anomalies
    that may indicate important cloud features that are only visible in specific frame pairs.

    Attributes:
        raft (RAFT): Trainable RAFT model for motion estimation
        reference_raft (RAFT): Frozen RAFT model providing baseline estimates
        anomaly_threshold (float): Threshold for detecting motion anomalies (default: 2.0)

    Example:
        >>> model = CloudMotionModel(use_pretrained=True)
        >>> image_sequence = torch.randn(1, 5, 3, 384, 384)  # Batch x Frames x Channels x Height x Width
        >>> outputs = model(image_sequence)
        >>> motion_fields = outputs['motion_fields']  # Sequence of motion fields
        >>> refined_motion = outputs['refined_motion']  # Final motion with preserved anomalies
    """
    def __init__(self, use_pretrained=True, anomaly_threshold=2.0):
        super().__init__()
        # Initialize trainable RAFT
        self.raft = raft_large(
            weights=Raft_Large_Weights.DEFAULT if use_pretrained else None
        )

        # Initialize frozen reference RAFT
        self.reference_raft = raft_large(weights=Raft_Large_Weights.DEFAULT)
        # Freeze all parameters of reference model
        for param in self.reference_raft.parameters():
            param.requires_grad = False

        self.anomaly_threshold = anomaly_threshold

    def preserve_bidirectional_anomalies(self, motion_stack):
        """
        Preserves significant motion features detected across a sequence of frame pairs.

        Important cloud features may only be visible in specific frame pairs due to viewing angles
        or illumination. This method identifies and preserves these features by detecting motion
        anomalies (both strong and weak) across the sequence.

        Args:
            motion_stack (torch.Tensor or PackedSequence): Stack of motion fields from sequence
                Shape: Batch x Sequence x 2 x Height x Width

        Returns:
            torch.Tensor: Final motion field with preserved anomalies
        """
        # If packed sequence, unpack first
        if isinstance(motion_stack, PackedSequence):
            motion_data, lengths = pad_packed_sequence(motion_stack, batch_first=True)
        else:
            motion_data = motion_stack

        mean_motion = torch.mean(motion_data, dim=1)
        std_motion = torch.std(motion_data, dim=1)

        # Identify both strong and weak anomalies
        strong_anomalies = motion_data > (mean_motion.unsqueeze(1) + self.anomaly_threshold * std_motion.unsqueeze(1))
        weak_anomalies = motion_data < (mean_motion.unsqueeze(1) - self.anomaly_threshold * std_motion.unsqueeze(1))

        # Combine anomalies
        anomalies = strong_anomalies | weak_anomalies

        # For non-anomalous pixels, use the mean motion
        final_motion = mean_motion.clone()

        # For anomalous pixels, use the max or min value depending on whether it's a strong or weak anomaly
        max_motion, _ = torch.max(motion_data, dim=1)
        min_motion, _ = torch.min(motion_data, dim=1)

        final_motion[torch.any(strong_anomalies, dim=1)] = max_motion[torch.any(strong_anomalies, dim=1)]
        final_motion[torch.any(weak_anomalies, dim=1)] = min_motion[torch.any(weak_anomalies, dim=1)]

        return final_motion

    def forward(self, image_sequence):
        """
        Processes an image sequence through both RAFT models to estimate motion.

        Args:
            image_sequence (torch.Tensor or PackedSequence): Sequence of images
                Shape: Batch x Sequence x 3 x Height x Width
                Values should be in [0,1] range

        Returns:
            dict: Dictionary containing:
                - motion_fields: Original motion sequence from trainable RAFT
                - refined_motion: Motion with preserved anomalies from trainable RAFT
                - reference_motion_fields: Original motion sequence from reference RAFT
                - reference_motion: Motion with preserved anomalies from reference RAFT
                - confidence: Confidence maps based on motion magnitude
        """
        # If packed sequence, unpack first
        if isinstance(image_sequence, PackedSequence):
            images_unpacked, lengths = pad_packed_sequence(image_sequence, batch_first=True)
        else:
            images_unpacked = image_sequence

        motion_fields = []
        reference_motion_fields = []

        # Process all frames in the batch together
        for t in range(images_unpacked.shape[1] - 1):
            img1 = images_unpacked[:, t] * 255.0
            img2 = images_unpacked[:, t + 1] * 255.0

            # Get predictions from trainable RAFT
            flow_predictions = self.raft(img1, img2)
            motion_fields.append(flow_predictions[-1])

            # Get predictions from reference RAFT (no gradients)
            with torch.no_grad():
                reference_predictions = self.reference_raft(img1, img2)
                reference_motion_fields.append(reference_predictions[-1])

        motion_stack = torch.stack(motion_fields, dim=1)
        reference_stack = torch.stack(reference_motion_fields, dim=1)

        # Pack sequences if input was packed
        if isinstance(image_sequence, PackedSequence):
            motion_stack = pack_padded_sequence(motion_stack, [l-1 for l in lengths], batch_first=True)
            reference_stack = pack_padded_sequence(reference_stack, [l-1 for l in lengths], batch_first=True)

        # Apply bidirectional anomaly preservation
        refined_motion = self.preserve_bidirectional_anomalies(motion_stack)
        reference_motion = self.preserve_bidirectional_anomalies(reference_stack)

        # Calculate confidence based on motion magnitude
        if isinstance(refined_motion, PackedSequence):
            refined_unpacked, _ = pad_packed_sequence(refined_motion, batch_first=True)
            confidence = torch.norm(refined_unpacked, dim=1, keepdim=True)
        else:
            confidence = torch.norm(refined_motion, dim=1, keepdim=True)
        confidence = torch.sigmoid(confidence)

        return {
            'motion_fields': motion_stack,
            'refined_motion': refined_motion,
            'reference_motion_fields': reference_stack,
            'reference_motion': reference_motion,
            'confidence': confidence,
        }

NameError: name 'nn' is not defined

## ParallaxHeightCalculator Class

The ParallaxHeightCalculator implements the parallax principle to convert motion fields into cloud heights using a combination of LiDAR calibration and statistical methods.

## Core Components

### 1. Scale Management

The calculator maintains three critical scales that transform raw pixel motion into meaningful physical cloud heights. Each scale solves a specific challenge in the height estimation process:

**Motion Scale**: Maps pixel motion to physical distance
- Raw optical flow gives motion in pixels
- Need to convert to real-world distances (meters)
- LiDAR provides ground truth for calibration
- Higher altitude clouds show less pixel motion
- Bias towards larger values helps detect high clouds

**Height Scale**: Converts relative motion to absolute height
- Parallax effect creates relative motion proportional to height ratio
- Converts $(h_{\text{cloud}}/h_{\text{aircraft}})$ ratio to absolute height
- Calibrated using LiDAR measurements
- Updates slowly to maintain stability

**Vertical Extent Scale**: Controls the vertical range of cloud features
- Raw height calculations can compress or stretch vertical structure
- Scales cloud thickness to physically reasonable ranges
- Limited to 30% of cloud-top height or 3000m
- Preserves realistic cloud vertical development

Each scale uses asymmetric smoothing with bias towards larger values:

$$
\text{scale}_{\text{new}} = (1-\alpha)\text{scale}_{\text{old}} + \alpha(\text{scale}_{\text{update}})
$$

where $\alpha$ is:
```
0.05 if scale_update > scale_old
0.01 if scale_update ≤ scale_old
```

This asymmetric smoothing is particularly important for detecting high-altitude clouds which might only be visible in certain frames or conditions while maintaining stability in the height calculations over time.

### 2. Scale Calculation

For each frame with LiDAR data:

1. **Height Ratio**:
   $$\text{height_ratio} = \frac{\text{lidar_height}}{\text{aircraft_altitude}}$$

2. **Expected Motion**:
   $$\text{expected_motion} = \text{aircraft_speed} \times \text{height_ratio}$$

3. **Motion Scale**:
   $$\text{motion_scale} = \frac{\text{expected_motion}}{\text{center_motion}}$$

4. **Vertical Extent Scale**:
   $$\text{vertical_extent} = \min(\frac{0.3 \times \text{lidar_height}}{{\text{current_range}}}, 3000)$$

### 3. Height Calculation

For each pixel $(x,y)$ with valid motion:

1. **Relative Motion**:
   $$\text{relative_motion}_{x,y} = \frac{\text{motion_magnitude}_{x,y} \times \text{motion_scale}}{\text{aircraft_speed}}$$

2. **Raw Height**:
   $$\text{height}_{x,y} = \text{aircraft_altitude} \times \text{relative_motion}_{x,y} \times \text{height_scale}$$

3. **Confidence-weighted Smoothing**:
   $$\text{height}_{x,y} = c \times \text{height}_{x,y} + (1-c) \times \text{smooth}(\text{height}_{x,y})$$
   where $c$ is the confidence value.

### 4. Height Field Adjustment

When LiDAR calibration is used:
1. Center height subtraction: $h' = h - h_{\text{center}}$
2. Vertical extent scaling: $h'' = h' \times \text{vertical_extent_scale}$
3. LiDAR height addition: $h_{\text{final}} = h'' + h_{\text{lidar}}$

### 5. Confidence Calculation

Combines motion magnitude and local consistency:
1. Base confidence: $c_1 = 1 - e^{-\text{motion}/\text{threshold}}$
2. Local consistency: $c_2 = 1 - |\text{motion} - \text{smooth}(\text{motion})| /\text{smooth}(\text{motion})$
3. Final confidence: $c = c_1 \times c_2$

### 6. Robust Scale Updates

Uses Median Absolute Deviation (MAD) for outlier rejection:
$$\text{MAD} = \text{median}(|\text{scales} - \text{median}(\text{scales})|)$$

Inliers are defined as:
$$|\text{scale} - \text{median}(\text{scales})| < 5.0 \times \text{MAD}$$

This ensures stable height estimation even with noisy motion fields and varying aircraft conditions, while the asymmetric smoothing preserves sensitivity to high-altitude cloud features.

In [None]:
class ParallaxHeightCalculator:
    """
    Calculates cloud heights from motion fields using the parallax principle.

    This calculator uses a combination of motion fields, aircraft metadata, and LiDAR
    measurements to estimate cloud heights. It maintains running scales that convert
    pixel motion into physical heights and updates these scales using statistical methods.

    Three key scales are maintained:
    - Motion scale: Converts pixel motion to physical distance
    - Height scale: Converts relative motion to absolute height
    - Vertical extent scale: Controls reasonable vertical range of clouds

    Attributes:
        smoothing_factor (float): Base factor for scale smoothing
        history_size (int): Maximum number of historical values to maintain
        use_lidar_to_calculate_heights (bool): Whether to use LiDAR for direct calibration
        running_motion_scale (float): Current motion scale value
        running_height_scale (float): Current height scale value
        running_vertical_extent_scale (float): Current vertical extent scale value
    """
    def __init__(self, smoothing_factor=0.1, history_size=2000, use_lidar_to_calculate_heights=False):
        self.smoothing_factor = smoothing_factor
        self.running_motion_scale = None
        self.running_height_scale = None
        self.running_vertical_extent_scale = 1.0  # Initialize with 1.0
        self.use_lidar_to_calculate_heights = use_lidar_to_calculate_heights

        self.motion_scale_history = []
        self.height_scale_history = []
        self.vertical_extent_scale_history = []  # New history list
        self.timestamp_history = []
        self.altitude_history = []
        self.aircraft_speed_history = []

    def calculate_scales(self, motion_magnitude, metadata, lidar_height):
        """
        Calculates scaling factors using LiDAR measurements for calibration.

        Args:
            motion_magnitude (np.ndarray): Magnitude of motion field
            metadata (dict): Aircraft metadata including altitude and speed
            lidar_height (float): LiDAR-measured cloud height for calibration

        Returns:
            tuple: (motion_scale, height_scale, vertical_extent_scale) or (None, None)
                if calculation fails
        """
        MIN_MOTION_THRESHOLD = 0.1

        aircraft_altitude = float(metadata['GPS_MSL_Alt'])
        aircraft_speed = float(metadata['True_Airspeed'])

        debug.debug_print('motion', "\nScale Calculation Debug:")
        debug.debug_print('motion', f"Aircraft altitude: {aircraft_altitude:.1f} m")
        debug.debug_print('motion', f"LiDAR height: {lidar_height:.1f} m")
        debug.debug_print('motion', f"Aircraft speed: {aircraft_speed:.1f} m/s")

        if aircraft_altitude <= 0 or aircraft_speed <= 0:
            debug.debug_print('motion', "Warning: Invalid altitude or speed")
            return None, None

        # Get center motion with local averaging
        center_y, center_x = motion_magnitude.shape[0] // 2, motion_magnitude.shape[1] // 2
        kernel_size = 5
        motion_smooth = cv2.GaussianBlur(motion_magnitude, (kernel_size, kernel_size), 0)
        center_motion = motion_smooth[center_y, center_x]
        debug.debug_print('motion', f"Raw center motion (pixels): {center_motion:.3f}")

        if center_motion < MIN_MOTION_THRESHOLD:
            debug.debug_print('motion', f"Warning: Center motion ({center_motion:.3f} px) below threshold ({MIN_MOTION_THRESHOLD} px)")
            return None, None

        # Calculate height ratio
        height_ratio = lidar_height / aircraft_altitude

        # Calculate expected center motion
        expected_center_motion = aircraft_speed * height_ratio

        debug.debug_print('motion', "\nHeight Ratio Calculation:")
        debug.debug_print('motion', f"LiDAR height: {lidar_height}")
        debug.debug_print('motion', f"Aircraft altitude: {aircraft_altitude}")
        debug.debug_print('motion', f"Resulting Height ratio: {height_ratio:.3f}")
        debug.debug_print('motion', "\nMotion Scale Calculation:")
        debug.debug_print('motion', f"Aircraft speed: {aircraft_speed}")
        debug.debug_print('motion', f"Height ratio: {height_ratio}")
        debug.debug_print('motion', f"Resulting Expected center motion (px/s): {expected_center_motion:.3f}")

        # Calculate expected vs actual relative motion
        actual_relative_motion = center_motion / aircraft_speed
        expected_relative_motion = lidar_height / aircraft_altitude

        debug.debug_print('motion', "\nRelative Motion Analysis:")
        debug.debug_print('motion', f"Expected relative motion at center: {expected_relative_motion:.3f}")
        debug.debug_print('motion', f"Actual relative motion at center: {actual_relative_motion:.3f}")
        debug.debug_print('motion', f"Relative motion difference: {(expected_relative_motion - actual_relative_motion):.3f}")
        debug.debug_print('motion', f"Relative motion ratio: {(expected_relative_motion / actual_relative_motion):.3f}")

        # Calculate scales
        motion_scale = expected_center_motion / (center_motion + 1e-6)
        height_scale = height_ratio

        # Calculate vertical extent scale
        current_range = np.ptp(motion_magnitude * motion_scale * aircraft_speed)
        target_range = min(3000, lidar_height * 0.3)  # 30% of LiDAR height, minimum of 3000m
        vertical_extent_scale = max(target_range / (current_range + 1e-6), 0.5)  # Never scale below 50%

        # Update running vertical extent scale
        if self.running_vertical_extent_scale is None:
            self.running_vertical_extent_scale = vertical_extent_scale
        else:
            self.running_vertical_extent_scale = (
                self.running_vertical_extent_scale * 0.9 + vertical_extent_scale * 0.1
            )

        debug.debug_print('motion', f"Calculated motion scale: {motion_scale:.3f}")
        debug.debug_print('motion', f"Calculated height scale: {height_scale:.3f}")
        debug.debug_print('motion', f"Calculated vertical extent scale: {self.running_vertical_extent_scale:.3f}")
        debug.debug_print('motion', f"Current running motion scale: {self.running_motion_scale}")
        debug.debug_print('motion', f"Current running height scale: {self.running_height_scale}")

        return motion_scale, height_scale, self.running_vertical_extent_scale

    def adjust_heights(self, heights, valid_mask, lidar_height, center_y, center_x, vertical_extent_scale):
        """
        Adjusts height field using LiDAR measurement as reference.

        The adjustment process:
        1. Centers the height field on the LiDAR measurement
        2. Scales the vertical extent to reasonable ranges
        3. Maintains the LiDAR height at the center point

        Args:
            heights (np.ndarray): Raw height field
            valid_mask (np.ndarray): Boolean mask of valid height values
            lidar_height (float): LiDAR-measured height for calibration
            center_y, center_x (int): Coordinates of center point
            vertical_extent_scale (float): Scale for vertical extent adjustment

        Returns:
            np.ndarray: Adjusted height field
        """
        calculated_center_height = heights[center_y, center_x]
        debug.debug_print('heights', f"Initial calculated center height: {calculated_center_height:.1f} m")

        if calculated_center_height > 0:
            # Step 1: Subtract center height
            heights[valid_mask] -= calculated_center_height

            # Step 2: Scale the cloud
            heights[valid_mask] *= vertical_extent_scale

            # Step 3: Add LiDAR height
            heights[valid_mask] += lidar_height

            debug.debug_print('heights', f"Heights adjusted. Vertical extent scale: {vertical_extent_scale:.3f}")
            debug.debug_print('heights', f"New height range: {np.ptp(heights[valid_mask]):.1f} m")

            # Verify the adjustment
            adjusted_center_height = heights[center_y, center_x]
            debug.debug_print('heights', f"Adjusted center height: {adjusted_center_height:.1f} m")

            if not np.isclose(adjusted_center_height, lidar_height, rtol=1e-5):
                debug.debug_print('heights', "Warning: Adjusted center height does not match LiDAR height!")
        else:
            debug.debug_print('heights', "Warning: Calculated center height is zero or negative. Skipping adjustment.")

        return heights

    def update_running_averages(self, new_motion_scale, new_height_scale, new_vertical_extent_scale, metadata):
        """
        Updates running averages of scales using robust statistical methods.

        Uses asymmetric smoothing to bias towards larger scales, which helps preserve
        detection of high-altitude clouds. Employs MAD (Median Absolute Deviation) for
        outlier rejection.

        Args:
            new_motion_scale (float): New calculated motion scale
            new_height_scale (float): New calculated height scale
            new_vertical_extent_scale (float): New calculated vertical extent scale
            metadata (dict): Aircraft metadata for tracking conditions
        """
        # Add to history
        self.motion_scale_history.append(new_motion_scale)
        self.height_scale_history.append(new_height_scale)
        self.vertical_extent_scale_history.append(new_vertical_extent_scale)
        self.timestamp_history.append(len(self.timestamp_history))
        self.altitude_history.append(metadata['GPS_MSL_Alt'])
        self.aircraft_speed_history.append(metadata['True_Airspeed'])

        # Use robust statistics for motion scale if we have enough history
        if len(self.motion_scale_history) >= 10:
            # Get recent history (increased from 50 to 100 for better statistics)
            recent_motion_scales = np.array(self.motion_scale_history[-100:])

            # Calculate median and MAD
            median_motion = np.median(recent_motion_scales)
            mad = np.median(np.abs(recent_motion_scales - median_motion))

            # More permissive threshold (increased from 3.0 to 5.0)
            threshold = 5.0
            inlier_mask = np.abs(recent_motion_scales - median_motion) < threshold * mad

            if np.sum(inlier_mask) > 0:
                good_scales = recent_motion_scales[inlier_mask]
                # Use 75th percentile instead of mean to bias towards higher values
                new_running_scale = np.percentile(good_scales, 75)

                debug.debug_print('motion', f"\nRobust Scale Update:")
                debug.debug_print('motion', f"Median motion scale: {median_motion:.3f}")
                debug.debug_print('motion', f"MAD: {mad:.3f}")
                debug.debug_print('motion', f"Inlier range: {np.min(good_scales):.3f} to {np.max(good_scales):.3f}")
                debug.debug_print('motion', f"Number of inliers: {np.sum(inlier_mask)} out of {len(recent_motion_scales)}")

                # Much slower smoothing (reduced from 0.1 to 0.02)
                if self.running_motion_scale is None:
                    self.running_motion_scale = new_running_scale
                else:
                    # Bias towards larger scales with asymmetric smoothing
                    alpha = 0.05 if new_running_scale > self.running_motion_scale else 0.01
                    self.running_motion_scale = (1 - alpha) * self.running_motion_scale + alpha * new_running_scale
        else:
            # Not enough history yet, use simple update with bias towards larger scales
            if self.running_motion_scale is None:
                self.running_motion_scale = new_motion_scale
            else:
                alpha = 0.05 if new_motion_scale > self.running_motion_scale else 0.01
                self.running_motion_scale = (1 - alpha) * self.running_motion_scale + alpha * new_motion_scale

        # Height scale updates more simply
        if self.running_height_scale is None:
            self.running_height_scale = new_height_scale
        else:
            # Very slow update for height scale
            alpha = 0.01
            self.running_height_scale = (1 - alpha) * self.running_height_scale + alpha * new_height_scale

        # Vertical extent scale updates
        if self.running_vertical_extent_scale is None:
            self.running_vertical_extent_scale = new_vertical_extent_scale
        else:
            # Use a slow update for vertical extent scale
            alpha = 0.01
            self.running_vertical_extent_scale = (1 - alpha) * self.running_vertical_extent_scale + alpha * new_vertical_extent_scale

        debug.debug_print('motion', f"Updated running motion scale: {self.running_motion_scale:.3f}")
        debug.debug_print('motion', f"Updated running height scale: {self.running_height_scale:.3f}")
        debug.debug_print('motion', f"Updated running vertical extent scale: {self.running_vertical_extent_scale:.3f}")


    def calculate_motion_confidence(self, motion_magnitude, min_motion_threshold=0.01):
        """
        Calculates confidence values for motion measurements.

        Combines motion magnitude and local consistency to estimate reliability
        of motion measurements.

        Args:
            motion_magnitude (np.ndarray): Motion field magnitude
            min_motion_threshold (float): Minimum motion for confidence calculation

        Returns:
            np.ndarray: Confidence values in range [0,1]
        """
        # Calculate base confidence from motion magnitude
        confidence = 1.0 - np.exp(-motion_magnitude / min_motion_threshold)

        # Look for local motion consistency
        kernel_size = 5
        motion_smooth = cv2.GaussianBlur(motion_magnitude, (kernel_size, kernel_size), 0)
        local_consistency = np.abs(motion_magnitude - motion_smooth) / (motion_smooth + 1e-6)

        # Combine confidences
        confidence *= (1.0 - local_consistency)

        # Normalize to [0,1]
        confidence = np.clip(confidence, 0, 1)

        return confidence

    def calculate_heights(self, motion_field, metadata, lidar_height=None):
        """
        Calculates cloud heights from motion field using parallax principle.

        Main height calculation pipeline:
        1. Calculate motion magnitude and confidence
        2. Update scaling factors if LiDAR data available
        3. Convert motion to relative heights
        4. Apply confidence-weighted smoothing
        5. Adjust heights using LiDAR if available

        Args:
            motion_field (np.ndarray): Motion field from optical flow
            metadata (dict): Aircraft metadata including altitude and speed
            lidar_height (float, optional): LiDAR measurement for calibration

        Returns:
            tuple: (heights, confidence)
                - heights: np.ndarray of calculated cloud heights
                - confidence: np.ndarray of confidence values
        """
        MIN_MOTION_THRESHOLD = 0.1  # pixels

        aircraft_altitude = float(metadata['GPS_MSL_Alt'])
        aircraft_speed = float(metadata['True_Airspeed'])

        # Calculate motion magnitude
        dx = motion_field[0]
        dy = motion_field[1]
        debug.debug_print('motion', "\nMotion Component Debug:")
        debug.debug_print('motion', f"dx range: {dx.min():.3f} to {dx.max():.3f}")
        debug.debug_print('motion', f"dy range: {dy.min():.3f} to {dy.max():.3f}")

        motion_magnitude = np.sqrt(dx**2 + dy**2 + 1e-6)
        debug.debug_print('motion', f"motion_magnitude range: {motion_magnitude.min():.3f} to {motion_magnitude.max():.3f}")

        # Apply minimum motion threshold
        valid_motion = motion_magnitude > MIN_MOTION_THRESHOLD

        # Calculate motion confidence
        confidence = self.calculate_motion_confidence(motion_magnitude)
        confidence[~valid_motion] *= 0.5

        # Initialize scales
        motion_scale = self.running_motion_scale
        height_scale = self.running_height_scale
        vertical_extent_scale = self.running_vertical_extent_scale

        # Try to calculate new scales if we have LiDAR data
        if lidar_height is not None:
            scales = self.calculate_scales(motion_magnitude, metadata, lidar_height)

            if scales[0] is not None:  # Only update if we got valid scales
                new_motion_scale, new_height_scale, new_vertical_extent_scale = scales
                debug.debug_print('motion', "\nRunning Average Update Debug:")
                debug.debug_print('motion', f"New motion scale: {new_motion_scale:.3f}")
                debug.debug_print('motion', f"New height scale: {new_height_scale:.3f}")
                debug.debug_print('motion', f"New vertical extent scale: {new_vertical_extent_scale:.3f}")

                # Update running averages only if center motion is valid
                center_y, center_x = motion_magnitude.shape[0] // 2, motion_magnitude.shape[1] // 2
                if valid_motion[center_y, center_x] and confidence[center_y, center_x] > 0.5:
                    debug.debug_print('motion', f"Center motion valid (>{MIN_MOTION_THRESHOLD}) and confidence high (>0.5)")
                    self.update_running_averages(new_motion_scale, new_height_scale, new_vertical_extent_scale, metadata)

                if self.use_lidar_to_calculate_heights:
                    motion_scale = new_motion_scale
                    height_scale = new_height_scale
                    vertical_extent_scale = new_vertical_extent_scale
                else:
                    motion_scale = self.running_motion_scale
                    height_scale = self.running_height_scale if self.running_height_scale is not None else new_height_scale
                    vertical_extent_scale = self.running_vertical_extent_scale

        # Initialize heights array with NaN
        heights = np.full_like(motion_magnitude, np.nan)

        if motion_scale is not None and height_scale is not None and np.any(valid_motion):
            debug.debug_print('heights', "\nHeight Calculation Debug:")
            debug.debug_print('heights', f"Current motion_scale: {motion_scale}")
            debug.debug_print('heights', f"Current height_scale: {height_scale}")
            debug.debug_print('heights', f"Current vertical_extent_scale: {vertical_extent_scale}")

            relative_motion = np.zeros_like(motion_magnitude)
            debug.debug_print('heights', f"aircraft_speed before division: {aircraft_speed}")

            relative_motion[valid_motion] = np.abs(motion_magnitude[valid_motion] * motion_scale) / aircraft_speed

            # Get center point values for relative motion analysis
            center_y, center_x = motion_magnitude.shape[0] // 2, motion_magnitude.shape[1] // 2
            center_relative_motion = relative_motion[center_y, center_x]
            expected_relative_motion = lidar_height / aircraft_altitude if lidar_height is not None else None

            debug.debug_print('heights', "\nRelative Motion Analysis in Height Calculation:")
            debug.debug_print('heights', f"Center relative motion (actual): {center_relative_motion:.3f}")
            if expected_relative_motion is not None:
                debug.debug_print('heights', f"Expected relative motion: {expected_relative_motion:.3f}")
                debug.debug_print('heights', f"Relative motion difference: {(expected_relative_motion - center_relative_motion):.3f}")
                debug.debug_print('heights', f"Relative motion ratio: {(expected_relative_motion / (center_relative_motion + 1e-6)):.3f}")

            debug.debug_print('heights', f"relative_motion range: {relative_motion.min():.3f} to {relative_motion.max():.3f}")
            debug.debug_print('heights', f"aircraft_speed: {aircraft_speed:.3f}")

            heights[valid_motion] = aircraft_altitude * relative_motion[valid_motion] * height_scale
            debug.debug_print('heights', f"heights range before smoothing: {np.nanmin(heights):.3f} to {np.nanmax(heights):.3f}")

            kernel_size = 5
            heights_smooth = cv2.GaussianBlur(heights, (kernel_size, kernel_size), 0)
            valid_mask = ~np.isnan(heights)
            heights[valid_mask] = confidence[valid_mask] * heights[valid_mask] + \
                                (1 - confidence[valid_mask]) * heights_smooth[valid_mask]

            if self.use_lidar_to_calculate_heights and lidar_height is not None:
                heights = self.adjust_heights(heights, valid_mask, lidar_height, center_y, center_x, vertical_extent_scale)

        # Flip height field and confidence horizontally to match RAFT coordinate system
        heights = heights[:, ::-1]
        confidence = confidence[:, ::-1]


        center_y, center_x = motion_magnitude.shape[0] // 2, motion_magnitude.shape[1] // 2
        center_height = heights[center_y, center_x]
        debug.debug_print('heights', f"Final center height: {center_height:.1f} m")

        center_confidence = confidence[center_y, center_x]

        debug.debug_print('heights', "\nFinal Height Calculation:")
        debug.debug_print('heights', f"Valid height calculations: {np.sum(~np.isnan(heights)) / heights.size * 100:.1f}%")
        if np.any(~np.isnan(heights)):
            debug.debug_print('heights', f"Height range: {np.nanmin(heights):.1f} to {np.nanmax(heights):.1f} m")
        debug.debug_print('heights', f"Center height: {center_height if not np.isnan(center_height) else 'NaN'}")
        debug.debug_print('heights', f"Center confidence: {center_confidence:.3f}")
        if lidar_height is not None:
            debug.debug_print('heights', f"LiDAR height: {lidar_height:.1f} m")
            if not np.isnan(center_height):
                debug.debug_print('heights', f"Center height error: {abs(center_height - lidar_height):.1f} m")

        return heights, confidence

    def get_parameters(self):
        """
        Returns current calculator parameters.

        Returns:
            dict: Current scale values and history size
        """
        return {
            'motion_scale': self.running_motion_scale if self.running_motion_scale is not None else 1.0,
            'height_scale': self.running_height_scale if self.running_height_scale is not None else 1.0,
            'vertical_extent_scale': self.running_vertical_extent_scale if self.running_vertical_extent_scale is not None else 1.0,
            'history_size': len(self.motion_scale_history)
        }

## plot_scale_histories

In [None]:
def plot_scale_histories(model):
    """
    Creates a comprehensive visualization of scale histories from both fine-tuned and reference RAFT height calculators.

    This function generates a 6-panel figure showing:
    1. Motion Scale History: Evolution of motion scaling factors over time
    2. Height Scale History: Changes in height scaling factors
    3. Vertical Extent Scale History: Adaptations in vertical range scaling
    4. Fine-tuned RAFT Scales vs Altitude: How scales vary with aircraft altitude
    5. Original RAFT Scales vs Altitude: Reference model's scale behavior with altitude
    6. Fine-tuned RAFT Scales vs Aircraft Speed: Scale relationships with aircraft velocity

    Each plot includes:
    - Raw scale values over time or against metadata
    - Running averages where applicable
    - Clear labeling and legends
    - Grid lines for easier reading

    Args:
        model (CombinedModel): Model containing both fine-tuned and reference
            height calculators with their scale histories

    Returns:
        IPython.display.HTML: JPEG-compressed visualization embedded in notebook

    Note:
        The function automatically compresses the output to JPEG format to maintain
        notebook performance with large history visualizations.
    """
    ft_calc = model.finetuned_height_calculator
    ref_calc = model.reference_height_calculator

    fig = plt.figure(figsize=(20, 20))  # Increased height to accommodate new subplot

    # Motion Scale History
    plt.subplot(321)
    plt.plot(ft_calc.motion_scale_history, label='Fine-tuned RAFT', alpha=0.7)
    plt.plot(ref_calc.motion_scale_history, label='Original RAFT', alpha=0.7)
    if ft_calc.running_motion_scale is not None:
        plt.axhline(y=ft_calc.running_motion_scale, color='r', linestyle='--',
                  label=f'Fine-tuned Running Avg: {ft_calc.running_motion_scale:.3f}')
    if ref_calc.running_motion_scale is not None:
        plt.axhline(y=ref_calc.running_motion_scale, color='g', linestyle='--',
                  label=f'Original Running Avg: {ref_calc.running_motion_scale:.3f}')
    plt.title('Motion Scale History')
    plt.xlabel('Sample')
    plt.ylabel('Scale Factor')
    plt.legend()
    plt.grid(True)

    # Height Scale History
    plt.subplot(322)
    plt.plot(ft_calc.height_scale_history, label='Fine-tuned RAFT', alpha=0.7)
    plt.plot(ref_calc.height_scale_history, label='Original RAFT', alpha=0.7)
    if ft_calc.running_height_scale is not None:
        plt.axhline(y=ft_calc.running_height_scale, color='r', linestyle='--',
                  label=f'Fine-tuned Running Avg: {ft_calc.running_height_scale:.3f}')
    if ref_calc.running_height_scale is not None:
        plt.axhline(y=ref_calc.running_height_scale, color='g', linestyle='--',
                  label=f'Original Running Avg: {ref_calc.running_height_scale:.3f}')
    plt.title('Height Scale History')
    plt.xlabel('Sample')
    plt.ylabel('Scale Factor')
    plt.legend()
    plt.grid(True)

    # Vertical Extent Scale History
    plt.subplot(323)
    plt.plot(ft_calc.vertical_extent_scale_history, label='Fine-tuned RAFT', alpha=0.7)
    plt.plot(ref_calc.vertical_extent_scale_history, label='Original RAFT', alpha=0.7)
    if ft_calc.running_vertical_extent_scale is not None:
        plt.axhline(y=ft_calc.running_vertical_extent_scale, color='r', linestyle='--',
                  label=f'Fine-tuned Running Avg: {ft_calc.running_vertical_extent_scale:.3f}')
    if ref_calc.running_vertical_extent_scale is not None:
        plt.axhline(y=ref_calc.running_vertical_extent_scale, color='g', linestyle='--',
                  label=f'Original Running Avg: {ref_calc.running_vertical_extent_scale:.3f}')
    plt.title('Vertical Extent Scale History')
    plt.xlabel('Sample')
    plt.ylabel('Scale Factor')
    plt.legend()
    plt.grid(True)

    # Fine-tuned RAFT Scales vs Altitude
    plt.subplot(324)
    plt.scatter(ft_calc.altitude_history, ft_calc.motion_scale_history,
              alpha=0.5, label='Motion Scale')
    plt.scatter(ft_calc.altitude_history, ft_calc.height_scale_history,
              alpha=0.5, label='Height Scale')
    plt.scatter(ft_calc.altitude_history, ft_calc.vertical_extent_scale_history,
              alpha=0.5, label='Vertical Extent Scale')
    plt.xlabel('Aircraft Altitude (m)')
    plt.ylabel('Scale Factor')
    plt.title('Fine-tuned RAFT Scale Factors vs Altitude')
    plt.legend()
    plt.grid(True)

    # Original RAFT Scales vs Altitude
    plt.subplot(325)
    plt.scatter(ref_calc.altitude_history, ref_calc.motion_scale_history,
              alpha=0.5, label='Motion Scale')
    plt.scatter(ref_calc.altitude_history, ref_calc.height_scale_history,
              alpha=0.5, label='Height Scale')
    plt.scatter(ref_calc.altitude_history, ref_calc.vertical_extent_scale_history,
              alpha=0.5, label='Vertical Extent Scale')
    plt.xlabel('Aircraft Altitude (m)')
    plt.ylabel('Scale Factor')
    plt.title('Original RAFT Scale Factors vs Altitude')
    plt.legend()
    plt.grid(True)

    # Fine-tuned RAFT Scales vs Aircraft Speed
    plt.subplot(326)
    plt.scatter(ft_calc.aircraft_speed_history, ft_calc.motion_scale_history,
              alpha=0.5, label='Motion Scale')
    plt.scatter(ft_calc.aircraft_speed_history, ft_calc.height_scale_history,
              alpha=0.5, label='Height Scale')
    plt.scatter(ft_calc.aircraft_speed_history, ft_calc.vertical_extent_scale_history,
              alpha=0.5, label='Vertical Extent Scale')
    plt.xlabel('Aircraft Speed (m/s)')
    plt.ylabel('Scale Factor')
    plt.title('Fine-tuned RAFT Scale Factors vs Aircraft Speed')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    # Create JPEG and close the figure
    jpeg_output = embed_matplotlib_jpeg(fig, dpi=72)
    plt.close(fig)  # Close the figure explicitly
    return jpeg_output

## CombinedModel Class

The CombinedModel integrates the `CloudMotionModel` with the `ParallaxHeightCalculator` to create a complete cloud-height measurement pipeline.

## Core Components

### 1. Motion Estimation
Employs dual RAFT instances from `CloudMotionModel`:
- Trainable RAFT for cloud-specific motion
- Frozen reference RAFT for baseline comparison
- Both produce motion fields $M \in \mathbb{R}^{B \times T \times 2 \times H \times W}$

### 2. Height Calculation
Maintains two parallel height calculators:
```python
self.finetuned_height_calculator = ParallaxHeightCalculator(
    smoothing_factor=0.1,
    history_size=1000,
    use_lidar_to_calculate_heights=True
)
self.reference_height_calculator = ParallaxHeightCalculator(...)
```
This allows direct comparison between fine-tuned and baseline performance.

### 3. Forward Pass Pipeline

1. **Motion Field Generation**:
   - Input sequence: $I \in \mathbb{R}^{B \times T \times 3 \times H \times W}$
   - Generate motion fields from both RAFT models
   - Preserve bidirectional anomalies
   - Calculate motion confidence

2. **Height Field Generation**:
   When metadata and LiDAR available:
   $$
   \begin{align*}
   h_{\text{ft}} &= \text{finetuned_calculator}(M_{\text{ft}}, \text{metadata}, h_{\text{lidar}}) \\
   h_{\text{ref}} &= \text{reference_calculator}(M_{\text{ref}}, \text{metadata}, h_{\text{lidar}})
   \end{align*}
   $$

### 4. Output Structure
```python
{
    'motion_fields': Sequence of motion fields
    'refined_motion': Final motion with preserved anomalies
    'reference_motion': Reference model motion
    'height_field_finetuned': Height estimates from fine-tuned model
    'height_uncertainty_finetuned': Uncertainty estimates
    'height_field_reference': Reference height estimates
    'height_uncertainty_reference': Reference uncertainty
    'calibration_params_finetuned': Current scale parameters
    'calibration_params_reference': Reference scale parameters
}
```

### 5. Scale Management

Both height calculators maintain independent running scales:
- Motion scale: $s_m$ for pixel-to-meter conversion
- Height scale: $s_h$ for relative-to-absolute height
- Vertical extent scale: $s_v$ for cloud thickness

These scales are updated and saved in checkpoints for model continuity.

### 6. State Management
```python
def get_height_calculator_states(self):
    """Returns current states of both calculators."""
    return {
        'finetuned': {
            'running_motion_scale',
            'running_height_scale',
            'running_vertical_extent_scale',
            'scale_histories'
        },
        'reference': {...}
    }
```

### 7. Model Hierarchy
```
CombinedModel
├── CloudMotionModel
│   ├── Trainable RAFT
│   └── Reference RAFT
├── Fine-tuned Height Calculator
│   └── Scale Management
└── Reference Height Calculator
    └── Scale Management
```

This architecture enables:
- Direct comparison of fine-tuned vs baseline performance
- Height estimation with uncertainty quantification
- Scale preservation across training sessions
- Comprehensive motion and height field visualization

In [None]:
class CombinedModel(nn.Module):
    """
    End-to-end model combining motion estimation and height calculation for cloud measurement.

    This model integrates:
    - Dual RAFT motion estimation (trainable and reference)
    - Parallel height calculators for direct comparison
    - Scale management and calibration
    - Uncertainty estimation

    The model processes image sequences to produce both motion fields and height estimates,
    maintaining separate fine-tuned and reference pipelines for comparison.

    Attributes:
        raft (CloudMotionModel): Motion estimation model with dual RAFT instances
        finetuned_height_calculator (ParallaxHeightCalculator): Calculator for fine-tuned model
        reference_height_calculator (ParallaxHeightCalculator): Calculator for reference model
        metadata_manager (FlightMetadataManager): Manages flight metadata processing

    Args:
        raft_model (CloudMotionModel): Initialized motion estimation model
    """
    def __init__(self, raft_model):
        super().__init__()
        self.raft = raft_model
        self.metadata_manager = FlightMetadataManager()

        # Initialize two separate height calculators
        self.finetuned_height_calculator = ParallaxHeightCalculator(
            smoothing_factor=0.1,
            history_size=1000,
            use_lidar_to_calculate_heights=True
        )
        self.reference_height_calculator = ParallaxHeightCalculator(
            smoothing_factor=0.1,
            history_size=1000,
            use_lidar_to_calculate_heights=True
        )

    def forward(self, images, metadata=None, validation_height=None, sequence_lengths=None):
        """
        Processes image sequences through the complete height estimation pipeline.

        Args:
            images (torch.Tensor or PackedSequence): Image sequence
                Shape: Batch x Time x 3 x Height x Width
            metadata (torch.Tensor or PackedSequence, optional): Flight metadata
            validation_height (torch.Tensor, optional): LiDAR measurements
            sequence_lengths (torch.Tensor, optional): Valid sequence lengths

        Returns:
            dict: Results including:
                - motion_fields: Raw motion sequences
                - refined_motion: Motion with preserved anomalies
                - reference_motion: Reference model motion
                - height_field_finetuned: Fine-tuned height estimates
                - height_uncertainty_finetuned: Fine-tuned uncertainty
                - height_field_reference: Reference height estimates
                - height_uncertainty_reference: Reference uncertainty
                - calibration_params_finetuned: Current scale parameters
                - calibration_params_reference: Reference parameters
        """
        # Get device from packed sequence data
        device = images.data.device

        # Get motion fields from RAFT
        raft_outputs = self.raft(images)

        outputs = {
            'motion_fields': raft_outputs['motion_fields'],
            'refined_motion': raft_outputs['refined_motion'],
            'reference_motion': raft_outputs['reference_motion'],
            'confidence': raft_outputs['confidence']
        }

        # If metadata provided, calculate heights
        if metadata is not None and sequence_lengths is not None:
            # Unpack sequences if needed
            if isinstance(metadata, PackedSequence):
                metadata_unpacked, _ = pad_packed_sequence(metadata, batch_first=True)
                batch_size = metadata_unpacked.shape[0]
            else:
                metadata_unpacked = metadata
                batch_size = metadata.shape[0]

            # Use sequence_lengths to get the last valid index for each sequence
            last_valid_indices = sequence_lengths - 1

            # Convert metadata and motion fields to numpy
            aircraft_altitude = metadata_unpacked[torch.arange(batch_size), last_valid_indices, 2].cpu().numpy()

            if isinstance(outputs['refined_motion'], PackedSequence):
                final_motion_unpacked, _ = pad_packed_sequence(outputs['refined_motion'], batch_first=True)
                reference_motion_unpacked, _ = pad_packed_sequence(outputs['reference_motion'], batch_first=True)
                final_motion = final_motion_unpacked.detach().cpu().numpy()
                reference_motion = reference_motion_unpacked.detach().cpu().numpy()
            else:
                final_motion = outputs['refined_motion'].detach().cpu().numpy()
                reference_motion = outputs['reference_motion'].detach().cpu().numpy()

            lidar_height = validation_height.cpu().numpy() if validation_height is not None else None

            # Calculate heights for each sample in batch
            batch_heights_ft = []
            batch_uncertainties_ft = []
            batch_heights_ref = []
            batch_uncertainties_ref = []

            for i in range(batch_size):
                last_valid_idx = last_valid_indices[i]
                # Create metadata dict using manager's indices
                metadata_dict = {
                    'GPS_MSL_Alt': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('GPS_MSL_Alt')].cpu().numpy()),
                    'True_Airspeed': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('True_Airspeed')].cpu().numpy()),
                    'Track': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('Track')].cpu().numpy()),
                    'Pitch': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('Pitch')].cpu().numpy()),
                    'Roll': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('Roll')].cpu().numpy()),
                    'Wind_Speed': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('Wind_Speed')].cpu().numpy()),
                    'Wind_Dir': float(metadata_unpacked[i, last_valid_idx,
                        self.metadata_manager.get_index('Wind_Dir')].cpu().numpy())
                }

                # Calculate heights for fine-tuned motion
                heights_ft, uncertainty_ft = self.finetuned_height_calculator.calculate_heights(
                    final_motion[i],
                    metadata_dict,
                    float(lidar_height[i]) if lidar_height is not None else None
                )
                batch_heights_ft.append(heights_ft)
                batch_uncertainties_ft.append(uncertainty_ft)

                # Calculate heights for reference motion
                heights_ref, uncertainty_ref = self.reference_height_calculator.calculate_heights(
                    reference_motion[i],
                    metadata_dict,
                    float(lidar_height[i]) if lidar_height is not None else None
                )
                batch_heights_ref.append(heights_ref)
                batch_uncertainties_ref.append(uncertainty_ref)

            # Convert to tensors and store outputs
            outputs['height_field_finetuned'] = torch.from_numpy(np.stack(batch_heights_ft)).to(device)
            outputs['height_uncertainty_finetuned'] = torch.from_numpy(np.stack(batch_uncertainties_ft)).to(device)
            outputs['height_field_reference'] = torch.from_numpy(np.stack(batch_heights_ref)).to(device)
            outputs['height_uncertainty_reference'] = torch.from_numpy(np.stack(batch_uncertainties_ref)).to(device)

            # Store calibration parameters and histories for both calculators
            outputs['calibration_params_finetuned'] = self.finetuned_height_calculator.get_parameters()
            outputs['calibration_params_reference'] = self.reference_height_calculator.get_parameters()

        return outputs

    def get_height_calculator_states(self):
        """
        Returns current states of both height calculators.

        Used for checkpoint saving and scale preservation across training sessions.

        Returns:
            dict: State dictionaries for both calculators including:
                - Running scales (motion, height, vertical extent)
                - Scale histories
                - Metadata histories (timestamps, altitudes, speeds)
        """
        return {
            'finetuned': {
                'running_motion_scale': self.finetuned_height_calculator.running_motion_scale,
                'running_height_scale': self.finetuned_height_calculator.running_height_scale,
                'running_vertical_extent_scale': self.finetuned_height_calculator.running_vertical_extent_scale,
                'motion_scale_history': self.finetuned_height_calculator.motion_scale_history,
                'height_scale_history': self.finetuned_height_calculator.height_scale_history,
                'vertical_extent_scale_history': self.finetuned_height_calculator.vertical_extent_scale_history,
                'timestamp_history': self.finetuned_height_calculator.timestamp_history,
                'altitude_history': self.finetuned_height_calculator.altitude_history,
                'aircraft_speed_history': self.finetuned_height_calculator.aircraft_speed_history
            },
            'reference': {
                'running_motion_scale': self.reference_height_calculator.running_motion_scale,
                'running_height_scale': self.reference_height_calculator.running_height_scale,
                'running_vertical_extent_scale': self.reference_height_calculator.running_vertical_extent_scale,
                'motion_scale_history': self.reference_height_calculator.motion_scale_history,
                'height_scale_history': self.reference_height_calculator.height_scale_history,
                'vertical_extent_scale_history': self.reference_height_calculator.vertical_extent_scale_history,
                'timestamp_history': self.reference_height_calculator.timestamp_history,
                'altitude_history': self.reference_height_calculator.altitude_history,
                'aircraft_speed_history': self.reference_height_calculator.aircraft_speed_history
            }
        }

    def load_height_calculator_states(self, states):
        """
        Loads saved states into both height calculators.

        Used when resuming training or inference to maintain scale calibration.

        Args:
            states (dict): State dictionaries for both calculators containing:
                - Running scales
                - Scale histories
                - Metadata histories
        """
        if 'finetuned' in states:
            self.finetuned_height_calculator.running_motion_scale = states['finetuned']['running_motion_scale']
            self.finetuned_height_calculator.running_height_scale = states['finetuned']['running_height_scale']
            self.finetuned_height_calculator.running_vertical_extent_scale = states['finetuned']['running_vertical_extent_scale']
            self.finetuned_height_calculator.motion_scale_history = states['finetuned']['motion_scale_history']
            self.finetuned_height_calculator.height_scale_history = states['finetuned']['height_scale_history']
            self.finetuned_height_calculator.vertical_extent_scale_history = states['finetuned']['vertical_extent_scale_history']
            self.finetuned_height_calculator.timestamp_history = states['finetuned']['timestamp_history']
            self.finetuned_height_calculator.altitude_history = states['finetuned']['altitude_history']
            self.finetuned_height_calculator.aircraft_speed_history = states['finetuned']['aircraft_speed_history']

        if 'reference' in states:
            self.reference_height_calculator.running_motion_scale = states['reference']['running_motion_scale']
            self.reference_height_calculator.running_height_scale = states['reference']['running_height_scale']
            self.reference_height_calculator.running_vertical_extent_scale = states['reference']['running_vertical_extent_scale']
            self.reference_height_calculator.motion_scale_history = states['reference']['motion_scale_history']
            self.reference_height_calculator.height_scale_history = states['reference']['height_scale_history']
            self.reference_height_calculator.vertical_extent_scale_history = states['reference']['vertical_extent_scale_history']
            self.reference_height_calculator.timestamp_history = states['reference']['timestamp_history']
            self.reference_height_calculator.altitude_history = states['reference']['altitude_history']
            self.reference_height_calculator.aircraft_speed_history = states['reference']['aircraft_speed_history']

## SequenceConsistencyLoss Class

`SequenceConsistencyLoss` enforces temporal coherence and motion consistency in cloud height estimation through three complementary loss components:

## Photometric Loss

Photometric loss ensures accurate frame reconstruction by comparing how well the predicted motion fields can reconstruct subsequent frames. This enforces temporal consistency in the motion estimates.

$$L_{photo} = \sum_{t=1}^{T-1} \|I_{t+1} - \text{warp}(I_t, F_t)\| \cdot V_t$$

where:
- $I_t$ is frame t in the sequence
- $F_t$ is the estimated motion field
- $V_t$ is a visibility mask
- `warp()` applies the motion field to warp frame t

**Why this is important for cloud motion:**
Clouds undergo continuous, gradual changes in appearance and shape. At high altitudes, cloud features must maintain temporal consistency to enable accurate tracking. The visibility mask component is particularly crucial as it allows the model to handle complex scenarios where clouds at different heights overlap or occlude each other. This loss helps distinguish between actual cloud movement and changes caused by varying illumination conditions or aircraft motion, which is essential for accurate height estimation.

## Smoothness Loss

Smoothness loss penalizes rapid spatial variations in the motion field, promoting locally coherent motion predictions. This helps avoid spurious motion estimates and encourages physically plausible cloud motion patterns.

$$L_{smooth} = \sum_{t=1}^{T-1} \|\nabla F_t\|$$

where $\nabla F_t$ represents spatial gradients of the motion field

**Why this is important for cloud motion:**
Cloud formations typically exhibit fluid-like behavior with coherent motion patterns. At high altitudes, where clouds tend to move more uniformly, this loss term ensures the estimated motion fields reflect the natural, continuous flow of cloud systems. This is especially important for distinguishing between different cloud layers and preventing physically impossible motion discontinuities that could lead to errors in height estimation.

## Flow Consistency Loss

Flow consistency loss ensures bidirectional consistency between forward and backward motion estimates. This helps identify and penalize invalid motion predictions that are not consistent when applied in both temporal directions.

$$L_{consist} = \sum_{t=1}^{T-1} \|F_t + \text{warp}(F'_t, F_t)\|$$

where $F'_t$ is the backward flow

**Why this is important for cloud motion:**
True cloud movement should be temporally reversible over short time scales. This loss helps differentiate between actual cloud motion and apparent motion caused by changes in cloud shape or illumination. When dealing with high-altitude clouds viewed from an aircraft, consistency checking becomes particularly important as clouds may appear to move differently depending on the viewing angle and aircraft position. This bidirectional verification helps ensure reliable height estimates across varying viewing conditions.

## Combined Loss

The final loss combines these components with configurable weights:

$$L_{total} = L_{photo} + \lambda_{smooth}L_{smooth} + \lambda_{consist}L_{consist}$$

where $\lambda_{smooth}$ and $\lambda_{consist}$ are weighting factors.

## Implementation Features

- Uses PyTorch's packed sequence functionality to efficiently handle variable-length sequences
- Implements visibility masking to exclude occluded regions from loss computation
- Provides bidirectional consistency checking through forward and backward flow comparison
- Applies spatial regularization through smoothness constraints
- Handles batch processing for efficient training

The combined loss terms guide the model to learn motion fields that:
1. Accurately reconstruct subsequent frames
2. Maintain spatial smoothness
3. Exhibit bidirectional consistency
4. Account for occlusions through visibility masking

In [None]:
class SequenceConsistencyLoss(nn.Module):
    """
    A multi-component loss function for training cloud motion estimation models.

    This loss combines photometric reconstruction, motion smoothness, and flow
    consistency terms to ensure physically valid cloud motion estimation. The class
    handles variable-length sequences using PyTorch's packed sequence functionality.

    The total loss is a weighted sum of three components:
        1. Photometric Loss: Ensures accurate frame reconstruction
        2. Smoothness Loss: Promotes coherent motion fields
        3. Flow Consistency Loss: Enforces bidirectional consistency

    Args:
        smoothness_weight (float, optional): Weight for smoothness loss term. Default: 0.1
        consistency_weight (float, optional): Weight for flow consistency term. Default: 0.5
    """
    def __init__(self, smoothness_weight=0.1, consistency_weight=0.5):
        super().__init__()
        self.smoothness_weight = smoothness_weight
        self.consistency_weight = consistency_weight

    def forward_warp(self, image, flow):
        """
        Warps an image according to the given flow field.

        Args:
            image (torch.Tensor or PackedSequence): Image or sequence of images to warp
                Shape: [B, T, C, H, W] if tensor
            flow (torch.Tensor or PackedSequence): Flow fields to apply
                Shape: [B, T, 2, H, W] if tensor

        Returns:
            torch.Tensor or PackedSequence: Warped images with same structure as input
        """
        if isinstance(image, PackedSequence):
            image_data, image_batch_sizes = pad_packed_sequence(image, batch_first=True)
        else:
            image_data = image

        if isinstance(flow, PackedSequence):
            flow_data, flow_batch_sizes = pad_packed_sequence(flow, batch_first=True)
        else:
            flow_data = flow

        B, T, C, H, W = image_data.shape
        _, T_flow, _, _, _ = flow_data.shape

        grid_y, grid_x = torch.meshgrid(
            torch.arange(H, device=image_data.device),
            torch.arange(W, device=image_data.device),
            indexing='ij'
        )

        grid = torch.stack((grid_x, grid_y)).to(image_data.device).float()
        grid = grid.unsqueeze(0).unsqueeze(0).expand(B, T_flow, -1, H, W)

        flow_grid = grid + flow_data
        flow_grid = flow_grid.permute(0, 1, 3, 4, 2)

        # Normalize flow grid to [-1, 1] range
        flow_grid[..., 0] = 2.0 * flow_grid[..., 0] / (W - 1) - 1.0
        flow_grid[..., 1] = 2.0 * flow_grid[..., 1] / (H - 1) - 1.0

        # Reshape for grid_sample
        flow_grid = flow_grid.view(B * T_flow, H, W, 2)
        image_to_warp = image_data[:, :T_flow].contiguous().view(B * T_flow, C, H, W)

        warped = F.grid_sample(
            image_to_warp,
            flow_grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )

        # Reshape back to original dimensions
        warped = warped.view(B, T_flow, C, H, W)

        # Repack if input was packed
        if isinstance(image, PackedSequence):
            warped = pack_padded_sequence(warped, flow_batch_sizes, batch_first=True)

        return warped

    def compute_visibility_mask(self, flow_forward, flow_backward):
        """
        Computes visibility masks to handle occlusions in the motion field.

        Args:
            flow_forward (torch.Tensor or PackedSequence): Forward flow fields
                Shape: [B, T, 2, H, W] if tensor
            flow_backward (torch.Tensor or PackedSequence): Backward flow fields
                Shape: [B, T, 2, H, W] if tensor

        Returns:
            torch.Tensor or PackedSequence: Binary visibility mask
                Shape: [B, T, 1, H, W] if tensor
        """
        forward_warped = self.forward_warp(flow_backward, flow_forward)
        backward_warped = self.forward_warp(flow_forward, flow_backward)

        # Unpack if needed
        if isinstance(forward_warped, PackedSequence):
            forward_warped, batch_sizes = pad_packed_sequence(forward_warped, batch_first=True)
            flow_backward, _ = pad_packed_sequence(flow_backward, batch_first=True)
            backward_warped, _ = pad_packed_sequence(backward_warped, batch_first=True)
            flow_forward, _ = pad_packed_sequence(flow_forward, batch_first=True)

        forward_diff = torch.norm(forward_warped - flow_backward, dim=2, keepdim=True)
        backward_diff = torch.norm(backward_warped - flow_forward, dim=2, keepdim=True)

        forward_mask = forward_diff < 1.0
        backward_mask = backward_diff < 1.0

        return forward_mask & backward_mask

    def compute_consistency_loss(self, flow_forward, flow_backward):
        """
        Computes bidirectional consistency loss between forward and backward flows.

        Args:
            flow_forward (torch.Tensor or PackedSequence): Forward flow fields
                Shape: [B, T, 2, H, W] if tensor
            flow_backward (torch.Tensor or PackedSequence): Backward flow fields
                Shape: [B, T, 2, H, W] if tensor

        Returns:
            tuple: (consistency_mask, consistency_error)
                - consistency_mask (torch.Tensor): Per-pixel consistency weights
                - consistency_error (torch.Tensor): Mean consistency error
        """
        # Assuming flow_forward and flow_backward have shape (B, T, 2, H, W)

        # Warp forward flow using backward flow
        warped_forward = self.forward_warp(flow_forward, flow_backward)

        # Warp backward flow using forward flow
        warped_backward = self.forward_warp(flow_backward, flow_forward)

        # Unpack if needed
        if isinstance(warped_forward, PackedSequence):
            warped_forward, batch_sizes = pad_packed_sequence(warped_forward, batch_first=True)
            warped_backward, _ = pad_packed_sequence(warped_backward, batch_first=True)
            flow_forward, _ = pad_packed_sequence(flow_forward, batch_first=True)
            flow_backward, _ = pad_packed_sequence(flow_backward, batch_first=True)

        # Compute consistency error
        forward_consistency_error = torch.norm(warped_forward + flow_backward, dim=2)
        backward_consistency_error = torch.norm(warped_backward + flow_forward, dim=2)

        # Combine errors
        consistency_error = (forward_consistency_error + backward_consistency_error) / 2

        # Create mask based on consistency
        consistency_mask = torch.exp(-consistency_error / 0.1)

        return consistency_mask, consistency_error.mean()


    def forward(self, flows, image_sequence, sequence_lengths):
        """
        Computes the combined loss for a batch of sequences.

        Args:
            flows (PackedSequence): Predicted motion fields
                Data shape: [B, T-1, 2, H, W] when unpacked
            image_sequence (PackedSequence): Input image sequences
                Data shape: [B, T, C, H, W] when unpacked
            sequence_lengths (torch.Tensor): Length of each sequence in batch
                Shape: [B]

        Returns:
            torch.Tensor: Scalar loss value combining photometric, smoothness,
                and consistency terms according to their weights

        Note:
            Sequences in the batch must be sorted by length in descending order
            for proper packed sequence handling.
        """
        # Unpack sequences
        flows_unpacked, flows_lengths = pad_packed_sequence(flows, batch_first=True)
        images_unpacked, _ = pad_packed_sequence(image_sequence, batch_first=True)

        total_loss = 0
        valid_count = 0

        # Process each sequence based on its actual length
        for i, length in enumerate(flows_lengths):
            # Get sequence data
            seq_flows = flows_unpacked[i:i+1, :length]
            seq_images = images_unpacked[i:i+1, :length+1]  # +1 because we need one more image than flows

            # Photometric loss
            warped_images = self.forward_warp(seq_images[:, :-1], seq_flows)
            visibility_mask = self.compute_visibility_mask(seq_flows, seq_flows)
            photo_loss = (torch.abs(warped_images - seq_images[:, 1:]) * visibility_mask).mean()

            # Smoothness loss
            smoothness_loss = (
                torch.abs(seq_flows[:, :, :, 1:, :] - seq_flows[:, :, :, :-1, :]).mean() +
                torch.abs(seq_flows[:, :, :, :, 1:] - seq_flows[:, :, :, :, :-1]).mean()
            )

            # Flow consistency loss
            consistency_mask, consistency_loss = self.compute_consistency_loss(seq_flows, -seq_flows)
            photo_loss = (photo_loss * consistency_mask).mean()

            # Combine losses
            sequence_loss = (photo_loss +
                           self.smoothness_weight * smoothness_loss +
                           self.consistency_weight * consistency_loss)

            if torch.isfinite(sequence_loss):
                total_loss += sequence_loss
                valid_count += 1

        return total_loss / max(valid_count, 1)

## CombinedLoss Architecture

The `CombinedLoss` combines motion consistency with LiDAR supervision to enable accurate height field estimation:

### Motion Loss Component
Uses `SequenceConsistencyLoss` to ensure temporal coherence in predicted motion fields:

$$L_{motion} = L_{sequence}(F_{pred}, I, l)$$

where:
- $F_{pred}$ are predicted motion fields
- $I$ are input image sequences
- $l$ are sequence lengths

### LiDAR Supervision Component

While the model can predict motion patterns, converting to absolute heights requires calibration. A single LiDAR measurement at the image center provides this reference point.

1. Calculate motion at image center:
   $$m_{center} = \|F_{pred}(x_c, y_c)\|$$

2. Convert to relative motion:
   $$r_{actual} = \frac{m_{center}}{v_{aircraft}}$$

3. Calculate expected relative motion:
   $$r_{expected} = \frac{h_{lidar}}{h_{aircraft}}$$

4. Determine scaling factor:
   $$s = \frac{r_{expected}}{r_{actual}}$$

5. Apply to motion field:
   $$F_{scaled} = F_{pred} \cdot s$$

The loss is:
$$L_{lidar} = L_{smooth\_l1}(F_{scaled}, F_{pred})$$

### Total Loss
$$L_{total} = L_{motion} + w_{lidar} \cdot L_{lidar} \cdot \alpha$$

where:
- $w_{lidar}$ is the LiDAR weight (0.01)
- $\alpha = \text{clamp}(\frac{L_{motion}}{L_{lidar}}, 0.01, 10)$ dynamically scales the LiDAR loss

### Height Error Monitoring

Tracks error between predicted and LiDAR heights:

$$E_{height} = \frac{1}{N}\sum_{i=1}^N |h_i - h_{ref,i}|$$

For valid predictions where:
- $h_i$ are predicted center heights
- $h_{ref,i}$ are LiDAR measurements
- $N$ is the number of valid samples

In [None]:
class CombinedLoss(nn.Module):
    """
    Loss function combining motion consistency with LiDAR-guided motion field scaling.

    Uses a single LiDAR height measurement to calculate expected motion at image center
    and scale the entire motion field accordingly. This enables accurate absolute height
    estimation from relative motion measurements.

    Args:
        smoothness_weight (float): Weight for motion smoothness term. Default: 0.01
        lidar_weight (float): Weight for LiDAR supervision term. Default: 0.01
    """
    def __init__(self, smoothness_weight=0.01, lidar_weight=0.01):
        super().__init__()
        self.smoothness_weight = smoothness_weight
        self.lidar_weight = lidar_weight
        self.consistency_loss = SequenceConsistencyLoss(smoothness_weight=smoothness_weight)
        self.metadata_manager = FlightMetadataManager()

    def calculate_expected_relative_motion(self, lidar_height, metadata):
        """
        Calculate the expected relative motion at LiDAR measurement point.

        For cloud height estimation, relative motion should be proportional to height:
        relative_motion = cloud_height / aircraft_height

        Args:
            lidar_height (float): Cloud height measured by LiDAR
            metadata (dict): Flight metadata containing aircraft altitude

        Returns:
            float: Expected relative motion based on height ratio
        """
        aircraft_altitude = float(metadata['GPS_MSL_Alt'])
        debug.debug_print('finetune', f"Debug: lidar_height = {lidar_height}, aircraft_altitude = {aircraft_altitude}")
        expected_relative_motion = lidar_height / (aircraft_altitude + 1e-8)
        debug.debug_print('finetune', f"Debug: expected_relative_motion = {expected_relative_motion}")
        return expected_relative_motion

    def forward(self, outputs, images, sequence_lengths, validation_height=None, metadata=None):
        """
        Compute combined loss incorporating motion consistency and LiDAR supervision.

        The loss has two main components:
        1. Motion consistency loss using SequenceConsistencyLoss
        2. LiDAR supervision loss that:
            - Calculates expected motion from LiDAR height
            - Measures actual motion at image center
            - Scales motion field to match expected values
            - Computes smooth L1 loss between scaled and original fields

        Args:
            outputs (dict): Model outputs containing motion fields
            images (PackedSequence): Input image sequences
            sequence_lengths (torch.Tensor): Length of each sequence
            validation_height (torch.Tensor, optional): LiDAR height measurements
            metadata (PackedSequence, optional): Flight metadata per frame

        Returns:
            tuple: (total_loss, loss_dict)
                - total_loss: Combined weighted loss
                - loss_dict: Individual loss components
        """
        device = images.data.device
        motion_loss = torch.tensor(0.0, device=device)
        lidar_loss = torch.tensor(0.0, device=device)
        height_error = torch.tensor(0.0, device=device)
        total_loss = torch.tensor(0.0, device=device)

        # Motion consistency loss
        motion_loss = self.consistency_loss(
            outputs['motion_fields'],
            images,
            sequence_lengths
        )
        debug.debug_print('finetune', f"Debug: motion_loss = {motion_loss}")

        total_loss = motion_loss
        loss_dict = {'motion_loss': motion_loss}

        # Add LiDAR supervision if available
        if validation_height is not None and metadata is not None:
            motion_fields_unpacked, _ = pad_packed_sequence(outputs['motion_fields'], batch_first=True)
            metadata_unpacked, _ = pad_packed_sequence(metadata, batch_first=True)

            lidar_loss = torch.tensor(0.0, device=device)
            valid_samples = 0

            for i, seq_len in enumerate(sequence_lengths):
                debug.debug_print('finetune', f"\nDebug: Processing batch item {i}")
                flow_len = seq_len - 1  # number of flow fields is one less than sequence length

                if flow_len < 1:
                    debug.debug_print('finetune', f"Warning: Sequence {i} is too short (length {seq_len})")
                    continue

                # Use the last valid flow field
                motion_field = motion_fields_unpacked[i, flow_len - 1]
                h, w = motion_field.shape[-2:]
                center_y, center_x = h // 2, w // 2

                # Calculate actual motion magnitude at center
                center_motion = torch.norm(motion_field[:, center_y, center_x])
                debug.debug_print('finetune', f"Debug: center_motion = {center_motion}")

                # Convert center motion to relative motion using aircraft speed
                aircraft_speed = metadata_unpacked[i, seq_len - 1, 7]  # True_Airspeed
                debug.debug_print('finetune', f"Debug: aircraft_speed = {aircraft_speed}")

                if torch.isnan(aircraft_speed) or aircraft_speed <= 0:
                    debug.debug_print('finetune', f"Warning: Invalid aircraft_speed for batch item {i}")
                    continue

                center_relative_motion = center_motion / (aircraft_speed + 1e-8)
                debug.debug_print('finetune', f"Debug: center_relative_motion = {center_relative_motion}")

                # Calculate expected relative motion using metadata manager
                expected_relative_motion = self.calculate_expected_relative_motion(
                    validation_height[i],
                    {'GPS_MSL_Alt': metadata_unpacked[i, seq_len - 1,
                        self.metadata_manager.get_index('GPS_MSL_Alt')].item()}
                )

                if torch.isnan(expected_relative_motion):
                    debug.debug_print('finetune', f"Warning: Invalid expected_relative_motion for batch item {i}")
                    continue

                # Convert to tensor and match device
                expected_relative_motion = torch.full_like(center_relative_motion, expected_relative_motion)
                debug.debug_print('finetune', f"Debug: expected_relative_motion tensor = {expected_relative_motion}")

                # Calculate scaling factor
                scale_factor = expected_relative_motion / (center_relative_motion + 1e-8)
                debug.debug_print('finetune', f"Debug: scale_factor = {scale_factor}")

                # Scale entire motion field and calculate loss
                scaled_motion = motion_field * scale_factor.view(-1, 1, 1)
                debug.debug_print('finetune', f"Debug: scaled_motion shape = {scaled_motion.shape}, min = {scaled_motion.min()}, max = {scaled_motion.max()}")
                debug.debug_print('finetune', f"Debug: motion_field shape = {motion_field.shape}, min = {motion_field.min()}, max = {motion_field.max()}")

                batch_lidar_loss = F.smooth_l1_loss(scaled_motion, motion_field)
                debug.debug_print('finetune', f"Debug: batch_lidar_loss = {batch_lidar_loss}")

                if torch.isfinite(batch_lidar_loss):
                    lidar_loss += batch_lidar_loss
                    valid_samples += 1

            if valid_samples > 0:
                lidar_loss = lidar_loss / valid_samples
                debug.debug_print('finetune', f"Debug: average lidar_loss = {lidar_loss}")
                scaled_lidar_loss = lidar_loss * torch.clamp(motion_loss.detach() / (lidar_loss + 1e-8), min=0.01, max=10)
                total_loss = total_loss + self.lidar_weight * scaled_lidar_loss
                loss_dict['lidar_loss'] = lidar_loss
                loss_dict['scaled_lidar_loss'] = scaled_lidar_loss
            else:
                debug.debug_print('finetune', "Warning: No valid samples for lidar loss calculation")

        # Height error monitoring
        if validation_height is not None and 'height_field_finetuned' in outputs:
            with torch.no_grad():
                height_error = self._calculate_height_error(
                    outputs['height_field_finetuned'],
                    validation_height
                )
                debug.debug_print('finetune', f"Debug: height_error = {height_error}")

        loss_dict['height_error'] = height_error
        loss_dict['total_loss'] = total_loss

        debug.debug_print('finetune', f"Debug: total_loss = {total_loss}")
        debug.debug_print('finetune', f"Debug: loss_dict = {loss_dict}")

        # Final check for NaNs
        for key, value in loss_dict.items():
            if torch.isnan(value) or torch.isinf(value):
                debug.debug_print('finetune', f"Warning: {key} is NaN or Inf")

        return total_loss, loss_dict

    def _calculate_height_error(self, height_field, validation_height):
        """
        Calculate mean absolute error between predicted and LiDAR heights at center point.

        Args:
            height_field (torch.Tensor): Predicted height field, shape [B, H, W]
            validation_height (torch.Tensor): LiDAR measurements, shape [B]

        Returns:
            torch.Tensor: Mean absolute height error for valid samples
        """
        B, H, W = height_field.shape
        center_y, center_x = H // 2, W // 2
        center_heights = height_field[:, center_y, center_x]

        if center_heights.shape != validation_height.shape:
            center_heights = center_heights.view(-1)
            validation_height = validation_height.view(-1)

        valid_mask = torch.isfinite(center_heights) & torch.isfinite(validation_height)
        if valid_mask.sum() > 0:
            height_errors = torch.abs(center_heights[valid_mask] - validation_height[valid_mask])
            # Print individual height errors
            debug.debug_print('finetune', "Individual height errors:")
            for j, error in enumerate(height_errors):
                debug.debug_print('finetune', f"  Sample {j}: {error.item():.2f} m")
            return height_errors.mean()

        debug.debug_print('finetune', "Warning: No valid samples for height error calculation")
        return torch.tensor(0.0, device=height_field.device)

In [None]:
# Initialize model and move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = CloudMotionModel(use_pretrained=True)

# Print out the size of the model for reference
print_num_parameters(model)

# Move model to device after creation
model = model.to(device)

# Create random test batch
batch = {
    'images': torch.randn(2, 5, 3, 384, 384).to(device)
}

# Forward pass
with torch.no_grad():
    outputs = model(batch['images'])

# Print shapes and devices
for key, value in outputs.items():
    print(f"{key}: shape={value.shape}, device={value.device}")

# Modeling Visualizations

## Helper functions

In [None]:
def safe_normalize(img):
    """
    Normalize image data handling NaN values and edge cases.

    Args:
        img: Image data as numpy array or other format
            For numpy arrays, normalizes to [0,1] range
            For other formats, returns unchanged

    Returns:
        Normalized image with same type as input:
        - For numpy arrays: (img - min) / (max - min) if max > min
        - For numpy arrays with max == min: Array of zeros
        - For non-numpy inputs: Original image unchanged
    """
    if isinstance(img, np.ndarray):
        min_val = np.nanmin(img)
        max_val = np.nanmax(img)
        if max_val > min_val:
            return (img - min_val) / (max_val - min_val)
        return np.zeros_like(img)
    return img

def create_diagnostic_text(heights, motion, metadata, calib_params, center_height, model_type=""):
    """
    Create formatted diagnostic text for visualization overlays.

    Args:
        heights (np.ndarray): Height field array
        motion (np.ndarray): Motion field array
        metadata (dict): Flight metadata containing:
            - GPS_MSL_Alt: Aircraft altitude in meters
            - True_Airspeed: Aircraft speed in m/s
        calib_params (dict): Calibration parameters with:
            - motion_scale: Motion field scaling factor
            - height_scale: Height field scaling factor
        center_height (float): LiDAR height measurement at center
        model_type (str, optional): Model identifier string

    Returns:
        str: HTML-formatted diagnostic text containing:
            - Aircraft altitude and speed
            - LiDAR and predicted center heights
            - Height field range and valid data percentage
            - Motion field range
            - Calibration scale factors
    """
    center_height_value = heights[heights.shape[0]//2, heights.shape[1]//2]

    return (
        f"{model_type} Diagnostics:<br>"
        f"Aircraft Alt: {metadata['GPS_MSL_Alt']:.1f}m<br>"
        f"Aircraft Speed: {metadata['True_Airspeed']:.1f} m/s<br>"
        f"LiDAR height: {center_height:.1f}m<br>"
        f"Center height: {center_height_value:.1f}m<br>"
        f"Height range: {np.nanmin(heights):.1f}m to {np.nanmax(heights):.1f}m<br>"
        f"Valid data: {np.sum(~np.isnan(heights)) / heights.size * 100:.1f}%<br>"
        f"Motion range: {np.min(motion):.3f} to {np.max(motion):.3f} pixels<br>"
        f"Motion scale: {calib_params['motion_scale']:.3f}<br>"
        f"Height scale: {calib_params['height_scale']:.3f}<br>"
    )

## Motion Field Visualizations

## visualize_motion_sequence

Creates a three-panel visualization comparing motion estimates across a sequence:

### Layout
\begin{align*}
\text{Panel 1} &: \text{Original image frames} \\
\text{Panel 2} &: \text{Frame-by-frame motion fields} \\
\text{Panel 3} &: \text{Refined vs Original RAFT comparison}
\end{align*}


Motion fields are represented using:
1. Background image in grayscale
2. Motion magnitude overlay (viridis colormap)
3. Subsampled motion vectors (white arrows)
4. Confidence overlay for refined motion (red-yellow-green)

## `plot_motion_field_with_image()`

Creates composite visualization of motion field over an image:

$$\text{magnitude} = \sqrt{dx^2 + dy^2}$$

Components:
- Background image (grayscale)
- Motion magnitude (viridis colormap, α = 0.7)
- Motion vectors on grid (subsample=16)
- Optional confidence overlay (RdYlGn colormap, α = 0.3)

## `visualize_detailed_motion()`

Four-panel detailed motion analysis:
1. Original image frame
2. Motion field overlay
3. Motion magnitude map:
   $$\text{magnitude} = \sqrt{dx^2 + dy^2}$$
4. Confidence map

Each visualization:
- Uses safe normalization for correct scaling
- Handles packed sequences for variable-length inputs
- Returns compressed JPEG to manage notebook size
- Includes calibrated colorbars for quantitative comparison

Key features:
- Supports both individual frames and sequences
- Visualizes fine-tuned vs original RAFT differences
- Displays motion confidence estimates
- Memory-efficient cleanup of matplotlib objects

In [None]:
@show_warnings
def visualize_motion_sequence(model_outputs, data, sample_idx=0, figsize=(25, 15)):
    """
    Create multi-panel visualization comparing motion fields across a sequence.

    Creates three visualization panels:
    1. Original image sequence
    2. Frame-by-frame motion fields with shared colorbar
    3. Comparison of refined vs original RAFT motion estimation

    Args:
        model_outputs (dict): Model predictions containing:
            - motion_fields: Per-frame motion estimates
            - refined_motion: Final refined motion field
            - reference_motion: Original RAFT motion field
            - confidence: Motion confidence scores
        data (dict or Tensor): Input data with:
            - images: Image sequence
            - sequence_lengths: Length of each sequence
            - image_paths: Optional paths for labeling
        sample_idx (int): Which sample in batch to visualize
        figsize (tuple): Figure dimensions (width, height)

    Returns:
        HTML: JPEG visualization embedded in notebook
    """
    with torch.no_grad():
        # Handle both dictionary and tensor/packed sequence inputs
        if isinstance(data, dict):
            input_images = data['images']
            sequence_lengths = data['sequence_lengths']
            image_paths = data.get('image_paths', None)

            # Unpack if needed
            if isinstance(input_images, PackedSequence):
                input_images, _ = pad_packed_sequence(input_images, batch_first=True)

            seq_length = int(sequence_lengths[sample_idx])
        else:
            input_images = data
            if isinstance(input_images, PackedSequence):
                input_images, lengths = pad_packed_sequence(input_images, batch_first=True)
                seq_length = int(lengths[sample_idx])
            else:
                seq_length = int(input_images.size(1))
            image_paths = None

        def unpack_if_needed(tensor_or_packed):
            if isinstance(tensor_or_packed, PackedSequence):
                unpacked, _ = pad_packed_sequence(tensor_or_packed, batch_first=True)
                return unpacked
            return tensor_or_packed

        # Move tensors to CPU and convert to numpy
        input_images = input_images.detach().cpu()

        # Get motion fields, handling both packed and unpacked cases
        motion_fields = unpack_if_needed(model_outputs['motion_fields'])
        refined_motion = unpack_if_needed(model_outputs['refined_motion'])
        reference_motion = unpack_if_needed(model_outputs['reference_motion'])

        # Convert to numpy
        motion_fields = motion_fields[sample_idx].detach().cpu().numpy()
        refined_motion = refined_motion[sample_idx].detach().cpu().numpy()
        reference_motion = reference_motion[sample_idx].detach().cpu().numpy()
        confidence = safe_normalize(model_outputs['confidence'][sample_idx, 0].detach().cpu().numpy())


        # Create figure and a single GridSpec for main content
        fig = plt.figure(figsize=figsize)
        gs = plt.GridSpec(3, seq_length, height_ratios=[1, 1, 1.2])

        # Plot original images in top row
        for i in range(seq_length):
            ax = plt.subplot(gs[0, i])
            img = input_images[sample_idx, i].numpy().transpose(1, 2, 0)
            img = safe_normalize(img)
            plt.imshow(img, cmap='gray')
            if image_paths is not None:
                try:
                    path = os.path.basename(str(image_paths[sample_idx][i]))
                    plt.title(f'Frame {i+1}\n{path}', fontsize=8)
                except (IndexError, TypeError):
                    plt.title(f'Frame {i+1}')
            else:
                plt.title(f'Frame {i+1}')
            plt.axis('off')

        # Plot individual motion fields in middle row with a shared colorbar
        motion_overlays = []
        for i in range(seq_length - 1):
            ax = plt.subplot(gs[1, i])
            overlay = plot_motion_field_with_image(
                motion_fields[i],
                safe_normalize(input_images[sample_idx, i+1].numpy().transpose(1, 2, 0)),
                title=f'Motion Field {i+1}',
                add_colorbar=False  # Don't add individual colorbars
            )
            motion_overlays.append(overlay)

        # Plot refined motion and reference motion side by side in bottom row
        ax = plt.subplot(gs[2, :seq_length//2])
        plot_motion_field_with_image(
            refined_motion,
            safe_normalize(input_images[sample_idx, -1].numpy().transpose(1, 2, 0)),
            title='Fine-tuned RAFT Motion',
            confidence=confidence
        )

        ax = plt.subplot(gs[2, seq_length//2:])
        plot_motion_field_with_image(
            reference_motion,
            safe_normalize(input_images[sample_idx, -1].numpy().transpose(1, 2, 0)),
            title='Original RAFT Motion'
        )

        # Apply tight_layout first
        plt.tight_layout()

        # Then add the colorbar after tight_layout has been applied
        # Get the position of the middle row plots to align the colorbar
        middle_row_pos = gs[1, :].get_position(fig)
        cax = fig.add_axes([0.92, middle_row_pos.y0, 0.02, middle_row_pos.height])
        fig.colorbar(motion_overlays[0], cax=cax, label='Motion Magnitude')

        # Create JPEG and close the figure
        jpeg_output = embed_matplotlib_jpeg(fig, dpi=50)
        plt.close(fig)  # Close the figure explicitly
        return jpeg_output

@show_warnings
def plot_motion_field_with_image(motion_field, background_image, title='Motion Field',
                              subsample=16, confidence=None, alpha=0.7, ax=None, add_colorbar=True):
    """
    Create composite visualization of motion field overlaid on image.

    Combines:
    - Background image in grayscale
    - Motion magnitude as colormap overlay
    - Subsampled motion vectors as arrows
    - Optional confidence scores as second overlay

    Args:
        motion_field (np.ndarray): Motion vectors [2, H, W]
        background_image (np.ndarray): Image to overlay on [H, W, C]
        title (str): Plot title
        subsample (int): Spacing between motion arrows
        confidence (np.ndarray, optional): Confidence scores [H, W]
        alpha (float): Overlay transparency
        ax (matplotlib.axes, optional): Axes to plot on
        add_colorbar (bool): Whether to add colorbar

    Returns:
        matplotlib.image.AxesImage: Motion magnitude overlay for colorbar
    """
    # Use provided axes or current axes
    if ax is None:
        ax = plt.gca()

    dx = motion_field[0]
    dy = motion_field[1]
    magnitude = np.sqrt(dx**2 + dy**2)

    background_image = safe_normalize(background_image)

    # Create composed visualization
    ax.imshow(background_image, cmap='gray')

    # Create a semi-transparent overlay of the motion magnitude
    motion_overlay = ax.imshow(magnitude, cmap='viridis', alpha=alpha)
    if add_colorbar:
        plt.colorbar(motion_overlay, label='Motion Magnitude', ax=ax)

    # Create grid for quiver plot
    h, w = magnitude.shape
    y, x = np.mgrid[0:h:subsample, 0:w:subsample]

    # Subsample motion field for arrows
    dx_sub = dx[::subsample, ::subsample]
    dy_sub = dy[::subsample, ::subsample]

    # Plot arrows
    ax.quiver(x, y, dx_sub, dy_sub, angles='xy',
             scale_units='xy', scale=0.5,
             color='white', alpha=0.8)

    if confidence is not None:
        confidence_overlay = ax.imshow(confidence, cmap='RdYlGn', alpha=0.3)
        if add_colorbar:
            plt.colorbar(confidence_overlay, label='Confidence', ax=ax)

    ax.set_title(title)
    ax.set_axis_off()

    return motion_overlay  # Return the motion overlay for creating shared colorbar

@show_warnings
def visualize_detailed_motion(model_outputs, image_sequence, sample_idx=0,
                           frame_idx=-1, figsize=(20, 5)):
    """
    Create detailed four-panel motion analysis visualization.

    Panels:
    1. Original image
    2. Motion field overlay
    3. Motion magnitude map
    4. Confidence map

    Args:
        model_outputs (dict): Model predictions with motion fields and confidence
        image_sequence (Tensor or PackedSequence): Input image sequence
        sample_idx (int): Which sample in batch to visualize
        frame_idx (int): Which frame to analyze (-1 for last)
        figsize (tuple): Figure dimensions

    Returns:
        HTML: JPEG visualization embedded in notebook
    """
    with torch.no_grad():
        # Unpack sequences if needed
        if isinstance(image_sequence, PackedSequence):
            images_unpacked, lengths = pad_packed_sequence(image_sequence, batch_first=True)
            seq_length = lengths[sample_idx]
            if frame_idx == -1:
                frame_idx = seq_length - 2  # Last valid frame for motion
        else:
            images_unpacked = image_sequence

        # Get motion fields
        if isinstance(model_outputs['refined_motion'], PackedSequence):
            refined_unpacked, _ = pad_packed_sequence(model_outputs['refined_motion'], batch_first=True)
            motion_fields_unpacked, _ = pad_packed_sequence(model_outputs['motion_fields'], batch_first=True)

            if frame_idx == -1:
                motion_field = refined_unpacked[sample_idx].detach().cpu().numpy()
            else:
                motion_field = motion_fields_unpacked[sample_idx, frame_idx].detach().cpu().numpy()
        else:
            if frame_idx == -1:
                motion_field = model_outputs['refined_motion'][sample_idx].detach().cpu().numpy()
            else:
                motion_field = model_outputs['motion_fields'][sample_idx, frame_idx].detach().cpu().numpy()

        # Get frame to display
        frame = images_unpacked[sample_idx, frame_idx + 1].detach().cpu().numpy()
        frame = safe_normalize(frame)  # Normalize the frame

        # Get confidence
        if isinstance(model_outputs['confidence'], PackedSequence):
            confidence_unpacked, _ = pad_packed_sequence(model_outputs['confidence'], batch_first=True)
            confidence = confidence_unpacked[sample_idx, 0].detach().cpu().numpy()
        else:
            confidence = model_outputs['confidence'][sample_idx, 0].detach().cpu().numpy()

        confidence = safe_normalize(confidence)  # Normalize the confidence

        # Create figure
        fig = plt.figure(figsize=figsize)

        # Original image
        plt.subplot(141)
        plt.imshow(np.transpose(frame, (1, 2, 0)))
        plt.title('Original Image')
        plt.axis('off')

        # Motion field overlay
        plt.subplot(142)
        plot_motion_field_with_image(
            motion_field,
            background_image=np.transpose(frame, (1, 2, 0)),
            title='Motion Field Overlay'
        )

        # Motion magnitude
        plt.subplot(143)
        magnitude = np.sqrt(motion_field[0]**2 + motion_field[1]**2)
        magnitude = safe_normalize(magnitude)  # Normalize the magnitude
        plt.imshow(magnitude, cmap='viridis')
        plt.colorbar(label='Magnitude')
        plt.title('Motion Magnitude')
        plt.axis('off')

        # Confidence map
        plt.subplot(144)
        plt.imshow(confidence, cmap='RdYlGn')
        plt.colorbar(label='Confidence')
        plt.title('Confidence Map')
        plt.axis('off')

        plt.tight_layout()

        # Convert figure to JPEG
        jpeg_output = embed_matplotlib_jpeg(fig, dpi=50)
        plt.close(fig)  # Explicitly close the figure

        return jpeg_output

## Height Field Visualizations

## create_height_field_plot

Creates a 3D surface plot of cloud height fields using Plotly's Surface plots.

### Components
1. Main height surface:
   $$(x,y) = \text{meshgrid}(\text{range}(-w/2, w/2), \text{range}(-h/2, h/2))$$

2. Uncertainty visualization:
  $$
  \text{opacity} =
  \begin{cases}
  0.3 \cdot \text{uncertainty}, & \text{if height valid} \\
  0, & \text{otherwise}
  \end{cases}
  $$

3. LiDAR reference point:
   - Position: $(0,0,h_{lidar})$
   - Marker: Red diamond

### Adaptive tick spacing:
  $$
  \text{spacing} = \begin{cases}
  2000\text{m}, & \text{if range} > 10\text{km} \\
  1000\text{m}, & \text{otherwise}
  \end{cases}
  $$

## visualize_calibrated_height_fields

Creates side-by-side comparison of fine-tuned vs reference RAFT height fields.

### Layout
1. Left subplot: Fine-tuned RAFT
   - Height field surface (Viridis)
   - Uncertainty overlay (Reds)
   - Calibration parameters
   - Validity percentage

2. Right subplot: Original RAFT
   - Identical components
   - Shared z-scale

### Camera Settings
\begin{align*}
\text{aspect ratio} &= [1:1:0.5] \\
\text{camera position} &= (1.5, 1.5, 1)
\end{align*}

Both functions include:
- NaN handling for invalid heights
- Percentage of valid data calculation
- Diagnostic statistics
- Consistent colorbar positioning
- Shared z-axis scaling

In [None]:
def create_height_field_plot(heights, uncertainty, calibration_params, center_height, title,
                          colorbar_x=1.02, show_lidar_legend=False):
    """
    Create a 3D surface plot of cloud height field with uncertainty overlay.

    Creates three visualization layers:
    1. Height field surface with viridis colormap
    2. Uncertainty overlay with red transparency
    3. LiDAR reference point marker

    Args:
        heights (np.ndarray): Cloud height field array [H, W]
        uncertainty (np.ndarray): Uncertainty values [H, W]
        calibration_params (dict): Height field calibration parameters
        center_height (float): LiDAR height measurement at center
        title (str): Plot title
        colorbar_x (float): X-position of colorbar (0-1 range)
        show_lidar_legend (bool): Whether to show LiDAR point in legend

    Returns:
        tuple: (traces, valid_percentage)
            - traces: List of Plotly graph objects for surface plots
            - valid_percentage: Percentage of valid height measurements
    """
    h, w = heights.shape
    x = np.arange(w) - (w // 2)
    y = np.arange(h) - (h // 2)
    X, Y = np.meshgrid(x, y)
    valid_mask = ~np.isnan(heights)
    valid_percentage = np.sum(valid_mask) / heights.size * 100
    plotted_heights = heights.copy()
    plotted_heights[~valid_mask] = np.nanmin(heights)

    # Calculate height range for tick spacing
    height_range = np.nanmax(heights) - np.nanmin(heights)
    tick_spacing = 1000  # 1km spacing
    if height_range > 10000:
        tick_spacing = 2000  # 2km spacing for large ranges

    traces = []

    # Main surface plot
    traces.append(go.Surface(
        x=X, y=Y, z=plotted_heights,
        colorscale='Viridis',
        colorbar=dict(
            title='Height (m)',
            x=colorbar_x,  # Position passed as parameter
            len=0.75,
            y=0.5,
            dtick=tick_spacing,
            thickness=20,
            tickmode='linear'
        ),
        opacity=0.8,
        showscale=True
    ))

    # Uncertainty surface if available
    if uncertainty is not None:
        uncertainty_colors = np.zeros((h, w, 4))
        uncertainty_colors[..., 0] = 1
        uncertainty_colors[..., 3] = uncertainty * 0.3
        uncertainty_colors[~valid_mask] = [0, 0, 0, 0]
        traces.append(go.Surface(
            x=X, y=Y, z=plotted_heights,
            surfacecolor=uncertainty_colors[..., 3],
            colorscale='Reds',
            showscale=False,
            opacity=0.3
        ))

    # LiDAR point
    traces.append(go.Scatter3d(
        x=[0], y=[0], z=[center_height],
        mode='markers',
        marker=dict(size=5, color='red', symbol='diamond'),
        name=f'LiDAR: {center_height:.0f}m',
        showlegend=show_lidar_legend  # Control whether this plot shows in legend
    ))

    return traces, valid_percentage

def visualize_calibrated_height_fields(refined_heights, refined_uncertainty,
                                      reference_heights, reference_uncertainty,
                                      ft_calib_params, ref_calib_params, center_height,
                                      refined_stats, reference_stats):
    """
    Create side-by-side comparison of fine-tuned vs original RAFT height fields.

    Generates a Plotly figure with:
    1. Left subplot: Fine-tuned RAFT height field
    2. Right subplot: Original RAFT height field
    Both with uncertainty overlays and diagnostic statistics

    Args:
        refined_heights (np.ndarray): Fine-tuned height field [H, W]
        refined_uncertainty (np.ndarray): Fine-tuned uncertainty [H, W]
        reference_heights (np.ndarray): Original RAFT height field [H, W]
        reference_uncertainty (np.ndarray): Original uncertainty [H, W]
        ft_calib_params (dict): Fine-tuned calibration parameters
        ref_calib_params (dict): Original calibration parameters
        center_height (float): LiDAR height measurement
        refined_stats (str): Diagnostic text for fine-tuned model
        reference_stats (str): Diagnostic text for original model

    Returns:
        plotly.graph_objects.Figure: Interactive 3D visualization with:
            - Synchronized height scales and camera angles
            - Uncertainty overlays
            - Diagnostic statistics
            - LiDAR reference points
    """
    # Create both plots with different colorbar positions
    refined_traces, refined_valid = create_height_field_plot(
        refined_heights, refined_uncertainty, ft_calib_params,
        center_height, "Height Field (Fine-tuned RAFT)",
        colorbar_x=0.45,
        show_lidar_legend=False  # Don't show legend for first plot
    )
    reference_traces, reference_valid = create_height_field_plot(
        reference_heights, reference_uncertainty, ref_calib_params,
        center_height, "Height Field (Original RAFT)",
        colorbar_x=1.02,
        show_lidar_legend=True  # Show legend only for second plot
    )

    # Print diagnostic information
    print("\n" + refined_stats.replace("<br>", "\n"))
    print("\n" + reference_stats.replace("<br>", "\n"))

    # Create subplot figure
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'surface'}, {'type': 'surface'}]],
        subplot_titles=None,
        horizontal_spacing=0.02  # Reduce space between subplots
    )

    # Add traces to subplots
    for trace in refined_traces:
        fig.add_trace(trace, row=1, col=1)
    for trace in reference_traces:
        fig.add_trace(trace, row=1, col=2)

    # Update layout with increased dimensions and adjusted margins
    fig.update_layout(
        height=800,
        width=1350,  # Increased from 1200
        title_text="Calibrated Height Fields Comparison",
        title_y=0.95,  # Move main title down slightly
        margin=dict(t=180, b=50, l=0, r=20),  # Adjusted all margins
        scene=dict(
            xaxis_title='X Distance (pixels)',
            yaxis_title='Y Distance (pixels)',
            zaxis_title='Height (m)',
            aspectratio=dict(x=1, y=1, z=0.5),
            camera=dict(eye=dict(x=1.5, y=1.5, z=1)),
            domain=dict(x=[0.0, 0.48])  # Adjust left plot position
        ),
        scene2=dict(
            xaxis_title='X Distance (pixels)',
            yaxis_title='Y Distance (pixels)',
            zaxis_title='Height (m)',
            aspectratio=dict(x=1, y=1, z=0.5),
            camera=dict(eye=dict(x=1.5, y=1.5, z=1)),
            domain=dict(x=[0.52, 1.0])
        ),
        annotations=[
            dict(
                text=f"Height Field (Fine-tuned RAFT)<br>Valid data: {refined_valid:.1f}%",
                x=0.42,  # Right align for left plot
                y=1.08,  # Adjust vertical position as needed
                xref="paper",
                yref="paper",
                showarrow=False,
                font=dict(size=14),
                xanchor="right"  # Right justify the text
            ),
            dict(
                text=f"Height Field (Original RAFT)<br>Valid data: {reference_valid:.1f}%",
                x=0.94,  # Right align for right plot
                y=1.08,  # Adjust vertical position as needed
                xref="paper",
                yref="paper",
                showarrow=False,
                font=dict(size=14),
                xanchor="right"  # Right justify the text
            )
        ]
    )

    # Add diagnostic text annotations with adjusted y-position
    fig.add_annotation(
        x=0.01, y=1.15,  # Moved up slightly
        xref="paper", yref="paper",
        text=refined_stats,
        showarrow=False,
        font=dict(size=10),
        bgcolor="white",
        opacity=0.8,
        align="left"
    )

    fig.add_annotation(
        x=0.70, y=1.15,
        xref="paper", yref="paper",
        text=reference_stats,
        showarrow=False,
        font=dict(size=10),
        bgcolor="white",
        opacity=0.8,
        align="left"
    )

    return fig

## Training Progress Visualization

This function creates visualizations during model training to monitor both optical flow estimation and height field generation.

## Components

### 1. Motion Field Visualization
\begin{align*}
& \text{Input}: F_{refined}, F_{reference} \in \mathbb{R}^{2 \times H \times W} \\
& \text{Output}: \text{3-panel comparison showing:} \\
& \quad - \text{Original image sequence} \\
& \quad - \text{Frame-by-frame motion fields} \\
& \quad - \text{Refined vs Original RAFT comparison}
\end{align*}


### 2. Height Field Generation
For valid LiDAR measurements:

\begin{align*}
h_{center} &= \text{denormalize}(h_{lidar}) \\
h_{refined} &= \text{calculate_heights}(F_{refined}, \text{metadata}, h_{center}) \\
h_{reference} &= \text{calculate_heights}(F_{reference}, \text{metadata}, h_{center})
\end{align*}

### 3. Calibration Analysis
\begin{align*}
\text{Parameters tracked:}& \\
s_{motion} &= \text{motion scale factor} \\
s_{height} &= \text{height scale factor} \\
s_{vertical} &= \text{vertical extent scale}
\end{align*}

### 4. Diagnostic Metrics
For each model (refined and reference):
- Height field range
- Valid data percentage
- Motion magnitude range
- Center height accuracy
- Calibration parameters

The function handles:
- PackedSequence unpacking
- Tensor normalization
- GPU memory management
- Exception handling

In [None]:
def visualize_training_progress(outputs, batch_data, dataset, model, sample_idx=0):
    """
    Create comprehensive visualization of model training progress.

    Generates three visualizations:
    1. Motion field comparison showing:
        - Original image sequence
        - Per-frame motion fields
        - Refined vs original RAFT comparison

    2. Height field reconstruction showing:
        - Fine-tuned model heights with uncertainty
        - Original model heights with uncertainty
        - LiDAR reference point
        - Diagnostic statistics

    3. Calibration parameter history showing:
        - Motion scale evolution
        - Height scale evolution
        - Vertical extent scale changes

    Args:
        outputs (dict): Model outputs containing:
            - refined_motion: Fine-tuned motion fields
            - reference_motion: Original RAFT motion fields
        batch_data (dict): Training batch with:
            - metadata: Flight parameters
            - validation_height: LiDAR measurements
        dataset: Dataset object for denormalization
        model: Model instance with height calculators
        sample_idx (int): Which sample in batch to visualize

    Note:
        Handles packed sequences and GPU memory management.
        Catches and reports visualization errors.
    """
    try:
        with torch.no_grad():
            # Motion sequence visualization
            if 'refined_motion' in outputs and 'reference_motion' in outputs:
                display(visualize_motion_sequence(outputs, batch_data, sample_idx=sample_idx))

            # Height field visualization if data available
            if ('metadata' in batch_data and
                'validation_height' in batch_data and
                not torch.isnan(batch_data['validation_height'][sample_idx])):

                # Unpack metadata if it's a packed sequence
                if isinstance(batch_data['metadata'], PackedSequence):
                    metadata_unpacked, _ = pad_packed_sequence(batch_data['metadata'], batch_first=True)
                    last_idx = batch_data['sequence_lengths'][sample_idx] - 1
                    temp_data = {'metadata': metadata_unpacked}
                else:
                    temp_data = batch_data

                metadata = get_denormalized_metadata(dataset, temp_data, sample_idx)

                center_height = float(dataset.denormalize_validation_height(
                    batch_data['validation_height'][sample_idx]))

                # Handle packed motion fields
                if isinstance(outputs['refined_motion'], PackedSequence):
                    refined_unpacked, _ = pad_packed_sequence(outputs['refined_motion'], batch_first=True)
                    reference_unpacked, _ = pad_packed_sequence(outputs['reference_motion'], batch_first=True)
                    refined_motion = refined_unpacked[sample_idx].detach().cpu().numpy()
                    reference_motion = reference_unpacked[sample_idx].detach().cpu().numpy()
                else:
                    refined_motion = outputs['refined_motion'][sample_idx].detach().cpu().numpy()
                    reference_motion = outputs['reference_motion'][sample_idx].detach().cpu().numpy()

                # Calculate heights
                refined_heights, refined_uncertainty = model.finetuned_height_calculator.calculate_heights(
                    refined_motion, metadata, center_height)
                reference_heights, reference_uncertainty = model.reference_height_calculator.calculate_heights(
                    reference_motion, metadata, center_height)

                # Get calibration parameters
                ft_calib_params = model.finetuned_height_calculator.get_parameters()
                ref_calib_params = model.reference_height_calculator.get_parameters()

                # Create diagnostic text
                refined_stats = create_diagnostic_text(
                    refined_heights, refined_motion, metadata,
                    ft_calib_params, center_height, "Fine-tuned RAFT"
                )
                reference_stats = create_diagnostic_text(
                    reference_heights, reference_motion, metadata,
                    ref_calib_params, center_height, "Original RAFT"
                )

                # Create visualization
                fig = visualize_calibrated_height_fields(
                    refined_heights, refined_uncertainty,
                    reference_heights, reference_uncertainty,
                    ft_calib_params, ref_calib_params, center_height,
                    refined_stats, reference_stats
                )

                # Display visualization
                fig.show()

                # Scale histories visualization (static matplotlib - use JPEG)
                display(plot_scale_histories(model))

    except Exception as e:
        print(f"Visualization error: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        torch.cuda.empty_cache()

## Test visualization with random data



In [None]:
# Get the correct metadata length
metadata_length = len(metadata_manager.flight_data)

# Create random test batch
batch = {
    'images': torch.randn(2, 5, 3, 384, 384).to(device),
    'sequence_lengths': torch.tensor([5, 5]),  # Both sequences have length 5
    'image_paths': [
        ['path1.jpg', 'path2.jpg', 'path3.jpg', 'path4.jpg', 'path5.jpg'],
        ['path6.jpg', 'path7.jpg', 'path8.jpg', 'path9.jpg', 'path10.jpg']
    ],
    'metadata': torch.randn(2, 5, metadata_length).to(device),  # Using correct metadata length
    'validation_height': torch.randn(2).to(device)
}

# Run model inference
with torch.no_grad():
    outputs = model(batch['images'])

# Print shapes and devices
for key, value in outputs.items():
    print(f"{key}: shape={value.shape}, device={value.device}")

# Visualize entire sequence with motion fields
print("Calling visualization function with:")
print(f"model_outputs keys: {outputs.keys()}")
print(f"batch type: {type(batch)}")
if isinstance(batch, dict):
    print(f"batch keys: {batch.keys()}")
display(visualize_motion_sequence(outputs, batch))

# Visualize detailed motion for a specific frame
display(visualize_detailed_motion(outputs, batch['images']))

# Visualize a specific sample and frame
sample_idx = 0  # First sample in batch
frame_idx = 1   # Second frame pair
display(visualize_detailed_motion(outputs, batch['images'],
                         sample_idx=sample_idx,
                         frame_idx=frame_idx))

# Training

## TrainingLogger

In [None]:
class TrainingLogger:
    """
    Logger for tracking and visualizing training metrics and saving model checkpoints.

    Args:
        save_dir (str): Directory path for saving checkpoints

    Attributes:
        metrics (defaultdict): Stores training metrics by name
        current_epoch (int): Current training epoch
    """
    def __init__(self, save_dir='/content/drive/MyDrive/motion_maps'):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.metrics = defaultdict(list)
        self.current_epoch = 0
        print(f"Saving checkpoints to: {save_dir}")

    def log_metrics(self, metrics_dict, step_type='batch'):
        """
        Log training metrics for batch or epoch.

        Args:
            metrics_dict (dict): Metrics to log {name: value}
            step_type (str): Either 'batch' or 'epoch'
        """
        for key, value in metrics_dict.items():
            metric_name = f"{step_type}/{key}"
            self.metrics[metric_name].append(value)
            print(f"Logged metric: {metric_name} = {value}")

    def save_checkpoint(self, model, optimizer, scheduler, model_type='motion',
                        additional_info=None):
        """
        Save model checkpoint with training state.

        Args:
            model: Model to save
            optimizer: Optimizer state
            scheduler: Learning rate scheduler state
            model_type (str): Type of model being saved
            additional_info (dict, optional): Extra info to save

        Saves:
            - Model weights
            - Optimizer state
            - Scheduler state
            - Current epoch
            - Training metrics
            - Additional info if provided
        """
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': dict(self.metrics)
        }

        if additional_info:
            checkpoint.update(additional_info)

        save_path = os.path.join(
            self.save_dir,
            f'motion_model_epoch_{self.current_epoch}.pt'
        )
        torch.save(checkpoint, save_path)
        print(f"Saved checkpoint to: {save_path}")

    def plot_training_progress(self):
        """
        Create visualization of training metrics.

        Generates 4-panel plot showing:
        - Height estimation error
        - Motion prediction loss
        - LiDAR supervision loss
        - Total combined loss

        Returns:
            JPEG visualization for notebook display
        """
        print("=" * 50)
        print("Available metrics:")
        for key, values in self.metrics.items():
            print(f"  {key}: {values}")
        print("=" * 50)

        fig = plt.figure(figsize=(20, 15))

        # Plot Height Error
        plt.subplot(221)
        if 'epoch/height_error' in self.metrics:
            plt.plot(self.metrics['epoch/height_error'], label='Height Error', color='red')
        plt.title('Height Error')
        plt.xlabel('Epoch')
        plt.ylabel('Error')
        plt.legend()
        plt.grid(True)

        # Plot Motion Loss
        plt.subplot(222)
        if 'epoch/motion_loss' in self.metrics:
            plt.plot(self.metrics['epoch/motion_loss'], label='Motion Loss', color='blue')
        plt.title('Motion Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Plot Lidar Loss
        plt.subplot(223)
        if 'epoch/lidar_loss' in self.metrics:
            plt.plot(self.metrics['epoch/lidar_loss'], label='Lidar Loss', color='green')
        plt.title('Lidar Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Plot Total Loss
        plt.subplot(224)
        if 'epoch/total_loss' in self.metrics:
            plt.plot(self.metrics['epoch/total_loss'], label='Total Loss', color='purple')
        plt.title('Total Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Adjust x-axis for all subplots
        for subplot in [221, 222, 223, 224]:
            plt.subplot(subplot)
            plt.xlim(0, len(self.metrics['epoch/total_loss']) - 1)  # Set x-axis from 0 to number of epochs - 1

        plt.tight_layout()

        # Convert to JPEG and display
        jpeg_output = embed_matplotlib_jpeg(fig)
        plt.close(fig)
        display(jpeg_output)

## Training Loop Architecture

The training loop operates in epochs, with each epoch processing the full dataset. Here's what happens in each epoch:

1. **Batch Processing**
   - Loads a batch of image sequences with metadata
   - Skips invalid batches
   - Moves data to GPU if available

2. **Model Forward Pass**
   - Processes images through RAFT model
   - Computes motion fields and height estimates
   - Uses automatic mixed precision for efficiency

3. **Loss Calculation**
   - Computes combined loss including:
     - Motion consistency
     - LiDAR supervision
     - Height estimation error
   - Checks for NaN/Inf values

4. **Optimization Step**
   - Computes gradients
   - Clips gradients at 0.5 norm
   - Updates model parameters
   - Cleans GPU memory

5. **Monitoring**
   - Tracks multiple loss components
   - Prints progress every 20 batches
   - Visualizes first batch results
   - Logs metrics and learning rate

6. **End of Epoch**
   - Computes average losses
   - Updates learning rate based on performance
   - Creates visualization plots
   - Saves checkpoint if loss improved

The loop includes error handling for out-of-memory issues and careful GPU memory management throughout due to RAFT's voracious memory appetite.

### Model Setup
\begin{align*}
\text{Model} &= \text{CombinedModel}(\text{RAFT}) \\
\text{Loss} &= \text{CombinedLoss}(\lambda_{smooth}) \\
\text{Optimizer} &= \text{AdamW}(lr=\alpha) \\
\text{Scheduler} &= \text{ReduceLROnPlateau}(\text{factor}=0.5, \text{patience}=5)
\end{align*}

### Per-Epoch Training
For each epoch $e$:

1. Forward Pass:
   \begin{align*}
   \text{outputs} &= \text{Model}(I, M, h_{lidar}, l) \\
   L, L_{dict} &= \text{Loss}(\text{outputs}, I, l, h_{lidar}, M)
   \end{align*}

   where:
   - $I$: image sequence
   - $M$: metadata
   - $h_{lidar}$: validation height
   - $l$: sequence lengths

2. Backward Pass:
   \begin{align*}
   \nabla L &= \text{backward}(L) \\
   \|\nabla\| &\leq 0.5 \text{ (gradient clipping)}
   \end{align*}

3. Loss Tracking:
   $$\bar{L}_e = \frac{1}{N}\sum_{i=1}^N L_i$$

4. Learning Rate Update:
   $$
   \alpha_{e+1} = \begin{cases}
   0.5\alpha_e & \text{if no improvement for 5 epochs} \\
   \alpha_e & \text{otherwise}
   \end{cases}
   $$

### Checkpointing
Saves model when:
$$L_e < \min_{i<e} L_i \text{ or } e \bmod f = 0$$
where $f$ is save frequency.

### Features
- Automatic mixed precision training
- GPU memory management
- Gradient anomaly detection
- Visualization every epoch
- Progress logging

In [None]:
def train_cloud_motion_model(
    model,
    train_loader,
    num_epochs=10,
    device=None,
    learning_rate=1e-6,
    smoothness_weight=0.01,
    save_frequency=1,
    max_sequence_length=5,
    gradient_checkpointing=False):
    """
    Train a cloud motion model for height estimation.

    Core components:
    1. Combined model wrapping RAFT
    2. Combined loss with motion and LiDAR supervision
    3. AdamW optimizer with LR scheduling
    4. Training visualizations and checkpointing

    Args:
        model: Base RAFT model to train
        train_loader: DataLoader with cloud image sequences
        num_epochs (int): Number of training epochs
        device: Torch device for training
        learning_rate (float): Initial learning rate
        smoothness_weight (float): Weight for motion smoothness loss
        save_frequency (int): Save checkpoint every N epochs
        max_sequence_length (int): Maximum sequence length
        gradient_checkpointing (bool): Enable gradient checkpointing

    Returns:
        tuple: (trained_model, logger)
            - trained_model: Trained CombinedModel instance
            - logger: TrainingLogger with metrics history

    Features:
        - Mixed precision training
        - GPU memory management
        - Gradient anomaly detection
        - Progress visualization
        - Checkpoint saving
        - Error handling for OOM
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Training on device: {device}")
    print_gpu_memory("Before moving model to device")

    combined_model = CombinedModel(model)

    if gradient_checkpointing and hasattr(combined_model.raft, 'gradient_checkpointing_enable'):
        combined_model.raft.gradient_checkpointing_enable()
        print("Enabled gradient checkpointing for RAFT")

    combined_model = combined_model.to(device)
    print_gpu_memory("After moving model to device")

    combined_loss = CombinedLoss(smoothness_weight=smoothness_weight).to(device)
    logger = TrainingLogger()

    optimizer = AdamW([
        {'params': combined_model.raft.parameters(), 'lr': learning_rate}
    ])

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True,
        min_lr=1e-8
    )

    print_gpu_memory("Before starting training loop")
    best_loss = float('inf')

    try:
        for epoch in range(num_epochs):
            logger.current_epoch = epoch
            combined_model.train()
            epoch_losses = defaultdict(float)
            avg_loss_this_epoch = 0.0
            num_batches = 0

            print(f"\nEpoch {epoch+1}/{num_epochs}")

            for batch_idx, batch_data in enumerate(train_loader):
                # Skip invalid batches
                if batch_data is None:
                    continue

                print_gpu_memory(f"Start of batch {batch_idx}")

                images = batch_data['images'].to(device)
                metadata = batch_data['metadata'].to(device)
                sequence_lengths = batch_data['sequence_lengths'].to(device)
                validation_height = batch_data['validation_height'].to(device)

                try:
                    with autocast(device.type):
                        outputs = combined_model(
                            images=images,
                            metadata=metadata,
                            validation_height=validation_height,
                            sequence_lengths=sequence_lengths
                        )
                        loss, loss_dict = combined_loss(
                            outputs,
                            images,
                            sequence_lengths,
                            validation_height=validation_height,
                            metadata=metadata
                        )

                    for loss_type, loss_value in loss_dict.items():
                        if isinstance(loss_value, torch.Tensor):
                            if torch.isnan(loss_value) or torch.isinf(loss_value):
                                print(f"Warning: {loss_type} is {loss_value}")

                    optimizer.zero_grad(set_to_none=True)
                    loss.backward()

                    for name, param in combined_model.named_parameters():
                        if param.grad is not None:
                            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                                print(f"Warning: NaN or Inf gradients in {name}")

                    torch.nn.utils.clip_grad_norm_(
                        combined_model.raft.parameters(),
                        max_norm=0.5
                    )

                    optimizer.step()

                    for k, v in loss_dict.items():
                        if isinstance(v, torch.Tensor):
                            epoch_losses[k] += v.item()
                        else:
                            epoch_losses[k] += float(v)

                    num_batches += 1
                    avg_loss_this_epoch = sum(epoch_losses.values()) / num_batches

                    current_lr = optimizer.param_groups[0]['lr']
                    avg_losses = {k: v / (batch_idx + 1) for k, v in epoch_losses.items()}
                    if batch_idx % 20 == 0:
                        print(f"Batch {batch_idx+1}/{len(train_loader)}, Epoch {epoch+1}/{num_epochs}, Loss: {avg_losses['total_loss']:.4f}, LR: {current_lr:.2e}")

                    if batch_idx == 0:
                        print(f"\nVisualizing batch {batch_idx+1}/{len(train_loader)} of epoch {epoch+1}/{num_epochs}")
                        with torch.no_grad():
                            visualize_training_progress(
                                outputs,
                                batch_data,
                                train_cloud2cloud_dataset,
                                model=combined_model,
                                sample_idx=0
                            )

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print_oom_report(batch_idx, images, sequence_lengths, combined_model)
                        torch.cuda.empty_cache()
                    else:
                        raise e
                finally:
                    if 'outputs' in locals(): del outputs
                    if 'images' in locals(): del images
                    if 'metadata' in locals(): del metadata
                    if 'loss' in locals(): del loss
                    if 'validation_height' in locals(): del validation_height
                    torch.cuda.empty_cache()

            print(f"End of epoch {epoch+1}")
            print_gpu_memory(f"End of epoch {epoch+1}")

            avg_epoch_loss = sum(epoch_losses.values()) / len(train_loader)
            scheduler.step(avg_epoch_loss)

            current_lr = optimizer.param_groups[0]['lr']
            logger.log_metrics({'epoch_loss': avg_epoch_loss}, step_type='epoch')
            logger.log_metrics({'learning_rate': current_lr}, step_type='epoch')
            for loss_type, loss_value in epoch_losses.items():
                logger.log_metrics({loss_type: loss_value / len(train_loader)}, step_type='epoch')

            print(f"Epoch {epoch+1} summary:")
            print(f"  Average loss: {avg_epoch_loss:.4f}")
            print(f"  Learning rate: {current_lr:.2e}")

            if epoch >= 1:
                logger.plot_training_progress()

            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                logger.save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    additional_info={
                        'epoch_loss': avg_epoch_loss,
                        'best_model': True,
                        'height_calculator_state': combined_model.get_height_calculator_states()
                    }
                )
            elif (epoch + 1) % save_frequency == 0:
                logger.save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    additional_info={
                        'epoch_loss': avg_epoch_loss,
                        'height_calculator_state': combined_model.get_height_calculator_states()
                    }
                )

    except Exception as e:
        print(f"Error during training: {str(e)}")
        print_gpu_memory("At error")
        raise e

    return combined_model, logger

## Collation and Packing

## Purpose of Collation
When batching sequences for training, we need to handle:
1. Variable-length sequences (2-5 frames)
2. Missing/invalid data
3. Multiple aligned data streams (images, flight data, timestamps)
4. Memory efficiency
5. RNN processing requirements

### 1. Initial Data Unpacking
```
Input Batch Structure:
- image_sequences: (B × T × C × H × W)
- flight_data_list: (B × T × D_flight)
- validation_heights: (B × 1)
- resized_center_coords: (B × 2)
- image_paths: (B × T)
- timestamp_sequences: (B × T)
```

### 2. Sequence Validation and Truncation
For each sequence:

$$
\begin{align*}
& \text{Valid Index Set} = {i : \text{GPS_MSL_Alt}_i \neq \text{NaN}} \\
& \text{last_valid_idx} = \max(\text{Valid Index Set}) \\
& \text{start_idx} = \max(0, \text{last_valid_idx} - \text{max_sequence_length} + 1) \\
& \text{sequence_length} = \min(\text{end_idx} - \text{start_idx}, \text{max_sequence_length})
\end{align*}
$$

### 3. Sequence Length Requirements
```
Constraints:
2 ≤ sequence_length ≤ 5
```

- Minimum 2 frames needed for temporal patterns
- Maximum 5 frames for computational efficiency
- Must have valid GPS altitude data

### 4. Sequence Sorting and Packing

#### Why Pack Sequences?
1. **Memory Efficiency**
   - No need to pad all sequences to maximum length
   - Only stores actual data points

2. **Computational Efficiency**
   - RNN computations only on valid timesteps
   - Avoids processing padding tokens

3. **Better Gradient Flow**
   - No gradients through padding
   - More stable training

#### Packing Process:
1. Sort by length (descending):
   $$\text{sequences} = \text{sort}(\text{sequences}, \text{key}=\text{len}, \text{reverse}=\text{True})$$

2. Pack sequences:
   ```
   PackedSequence Structure:
   - data: All sequence elements concatenated
   - batch_sizes: Number of sequences at each timestep
   - sorted_indices: Original sorting order
   - unsorted_indices: Restore original order
   ```

Example:
```
Original Sequences:
Seq1: [A1, A2, A3]
Seq2: [B1, B2]
Seq3: [C1, C2, C3, C4]

Sorted:
C1, C2, C3, C4
A1, A2, A3, --
B1, B2, --, --

Packed:
data: [C1, A1, B1, C2, A2, B2, C3, A3, C4]
batch_sizes: [3, 3, 2, 1]
```

### 5. Final Output Structure
```python
{
    'images': PackedSequence,      # Packed image sequences
    'metadata': PackedSequence,    # Packed flight data
    'validation_height': Tensor,   # Single height per sequence
    'center_coords': Tensor,       # Coordinate pairs
    'sequence_lengths': Tensor,    # Length of each sequence
    'image_paths': List,          # For debugging/visualization
    'timestamps': List            # Temporal alignment
}
```

This structure ensures:
1. Memory efficient storage
2. Aligned multimodal data
3. Efficient RNN processing
4. Easy access to metadata
5. Debugging capability

# Sequence Truncation Strategy

## Problem Statement
Two main issues with sequence lengths:
1. **Memory Issues**
   - Long sequences (5+ frames) cause Out-of-Memory (OOM) errors
   - Each image is 384x384x3, consuming significant GPU memory when batched

2. **Data Quality Issues**
   - Some sequences are abnormally long (1000+ frames)
   - Usually indicates data collection or synchronization problems
   - Can be caused by:
     * Missing LiDAR measurements
     * GPS data gaps
     * Timestamp misalignment

## Truncation Implementation

### 1. Valid Data Detection
```
For each sequence:
1. Check GPS_MSL_Alt for NaN values
2. Find last valid index where data exists
3. If no valid data found, skip sequence
```

### 2. Sequence Windowing
```
start_idx = max(0, last_valid_idx - max_sequence_length + 1)
end_idx = last_valid_idx + 1
```

Key points:
- Always keeps the most recent valid frames
- Works backward from last valid measurement
- Ensures LiDAR measurement aligns with final frame

### 3. Length Constraints
```
min_sequence_length = 2  # Minimum for temporal patterns
max_sequence_length = 5  # Maximum for memory management
```

Benefits:
- Prevents OOM errors
- Ensures consistent batch sizes
- Maintains temporal coherence
- Filters out problematic sequences

## Example Scenarios

1. **Normal Case**:
   ```
   Original: [F1, F2, F3, F4, F5, F6, F7, F8]  // 8 frames
   Last valid: F8
   Truncated: [F4, F5, F6, F7, F8]  // 5 frames
   ```

2. **Bad Data Case**:
   ```
   Original: [F1, F2, ..., F1000]  // 1000 frames
   Last valid: F1000
   Truncated: [F996, F997, F998, F999, F1000]  // 5 frames
   ```

3. **Short Sequence**:
   ```
   Original: [F1, F2, F3]  // 3 frames
   Last valid: F3
   Result: [F1, F2, F3]  // Kept as is
   ```

4. **Invalid Sequence**:
   ```
   Original: [F1(NaN), F2(NaN), F3(NaN)]  // All invalid
   Result: Sequence discarded
   ```

This truncation strategy effectively:
- Manages memory usage
- Maintains data quality
- Preserves temporal relationships
- Handles edge cases gracefully

In [None]:
def collate_fn(batch, max_sequence_length=5, min_sequence_length=2):
    """
    Custom collate function for cloud image sequences that handles variable lengths and maintains data alignment.

    Key operations:
    1. Truncates sequences to most recent valid frames
    2. Filters sequences based on length requirements
    3. Sorts by sequence length for packed sequence efficiency
    4. Packs sequences for RNN processing

    Args:
        batch: List of tuples containing:
            - image_sequences: List of image tensors [T, C, H, W]
            - flight_data: Flight metadata [T, D]
            - validation_heights: LiDAR measurements
            - resized_center_coords: Image center points
            - image_paths: Paths to source images
            - timestamp_sequences: Frame timestamps
        max_sequence_length (int): Maximum frames to keep (default: 5)
        min_sequence_length (int): Minimum frames required (default: 2)

    Returns:
        dict: Collated batch with:
            - images: PackedSequence of image tensors
            - metadata: PackedSequence of flight data
            - validation_height: LiDAR heights tensor
            - center_coords: Center coordinates tensor
            - sequence_lengths: Sequence length tensor
            - image_paths: Source image paths
            - timestamps: Frame timestamps

        Returns None if no valid sequences found.

    Note:
        Sequences are truncated from the end, keeping the most recent
        frames that have valid metadata (non-NaN GPS altitude).
    """
    metadata_manager = FlightMetadataManager()

    # Unpack with timestamps
    image_sequences, flight_data_list, validation_heights, resized_center_coords_list, image_paths_list, timestamp_sequences = zip(*batch)

    # Truncate sequences, keeping the most recent valid frames
    truncated_sequences = []  # Will store tuples of (images, flight_data, image_paths, timestamps)
    sequence_lengths = []
    valid_indices = []  # Track which sequences we're keeping

    for idx, (imgs, flight_data, image_paths, timestamps) in enumerate(zip(image_sequences, flight_data_list, image_paths_list, timestamp_sequences)):
        # Find last valid index (non-NaN in metadata)
        valid_metadata_indices = torch.where(~torch.isnan(flight_data[:, metadata_manager.get_index('GPS_MSL_Alt')]))[0]
        if len(valid_metadata_indices) == 0:
            last_valid_idx = -1
        else:
            last_valid_idx = valid_metadata_indices[-1].item()

        # Truncate to last max_sequence_length valid frames
        start_idx = max(0, last_valid_idx - max_sequence_length + 1)
        end_idx = last_valid_idx + 1

        # Only include sequences that meet the minimum length requirement
        sequence_length = min(end_idx - start_idx, max_sequence_length)
        if sequence_length >= min_sequence_length:
            # Store truncated sequences together to maintain alignment
            truncated_sequences.append((
                imgs[start_idx:end_idx],
                flight_data[start_idx:end_idx],
                image_paths[start_idx:end_idx],
                timestamps[start_idx:end_idx]  # Add timestamps
            ))
            sequence_lengths.append(sequence_length)
            valid_indices.append(idx)

    # If no valid sequences, return None or raise an error
    if not truncated_sequences:
        return None  # Or raise ValueError("No sequences meet the minimum length requirement")

    # Sort by length in descending order
    sorted_indices = np.argsort(sequence_lengths)[::-1]
    sequence_lengths = [sequence_lengths[i] for i in sorted_indices]
    truncated_sequences = [truncated_sequences[i] for i in sorted_indices]
    # Only include validation heights and center coords for valid sequences
    validation_heights = [validation_heights[valid_indices[i]] for i in sorted_indices]
    resized_center_coords = [resized_center_coords_list[valid_indices[i]] for i in sorted_indices]

    # Separate aligned sequences
    sorted_images, sorted_flight_data, sorted_image_paths, sorted_timestamps = zip(*truncated_sequences)

    # Pack sequences
    packed_images = pack_sequence([seq for seq in sorted_images])
    packed_metadata = pack_sequence([torch.as_tensor(data, dtype=torch.float32)
                                   for data in sorted_flight_data])

    return {
        'images': packed_images,
        'metadata': packed_metadata,
        'validation_height': torch.as_tensor(validation_heights, dtype=torch.float32),
        'center_coords': torch.as_tensor(resized_center_coords, dtype=torch.long),
        'sequence_lengths': torch.tensor(sequence_lengths),
        'image_paths': sorted_image_paths,
        'timestamps': sorted_timestamps  # Add timestamps to output
    }

# Wrapper method to add extra parameters to the collate function as well as filter out sequences that are too short
def filter_collate_fn(batch):
    collated = collate_fn(batch, max_sequence_length=5, min_sequence_length=2)
    if collated is None:
        print("Skipping batch with no valid sequences")
        return None
    return collated

train_loader = DataLoader(
    train_cloud2cloud_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=filter_collate_fn,
    num_workers=4,
    pin_memory=True
)

if train_model:
  # Train with memory tracking
  trained_model, logger = train_cloud_motion_model(
      model=model,
      train_loader=train_loader,
      num_epochs=num_epochs,
      device=device,
      learning_rate=1e-8,
      save_frequency=2
  )

## Load Model (If not training)

In [None]:
def load_model_checkpoint(checkpoint_path, device='cuda'):
    """
    Load a saved model checkpoint.

    Args:
        checkpoint_path: Path to the checkpoint file
        device: Device to load model to ('cuda' or 'cpu')

    Returns:
        combined_model: Loaded CombinedModel instance
        epoch: Epoch number when checkpoint was saved
    """

    # # Add numpy scalar to safe globals
    # add_safe_globals([
    #     np.core.multiarray.scalar,
    #     np.dtype,
    #     np.dtypes.Float64DType
    # ])

    checkpoint = torch.load(checkpoint_path, weights_only=False)

    # Print what's saved in the checkpoint
    print("Checkpoint contains:", checkpoint.keys())

    # Initialize models
    motion_model = CloudMotionModel()
    combined_model = CombinedModel(motion_model)
    combined_model = combined_model.to(device)

    # Load states
    combined_model.raft.load_state_dict(checkpoint['model_state_dict'])

    if 'height_calculator_state' in checkpoint:
        combined_model.load_height_calculator_states(
            checkpoint['height_calculator_state']
        )

    return combined_model, checkpoint['epoch']

if not train_model:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Loading model on device: {device}")
    trained_model, epoch = load_model_checkpoint(
        '/content/drive/MyDrive/motion_maps/motion_model_epoch_2.pt',
        device=device
    )

# Post training Visualization

## visualize_multiple_sequences

In [None]:
def visualize_multiple_sequences(model, dataset, num_sequences=5, randomize=True,
                               starting_sequence_index=0, num_sequences_to_skip=1):
    """
    Generate visualizations for multiple sequences of cloud height estimation results.

    Creates comprehensive visualizations showing:
    1. Motion field estimation and comparison
    2. Height field reconstruction with uncertainty
    3. Calibration parameter evolution

    Args:
        model: Trained CloudMotionModel
        dataset: CloudDataset instance
        num_sequences (int): Number of sequences to visualize
        randomize (bool): If True, randomly sample sequences
                          If False, use consecutive sequences
        starting_sequence_index (int): First sequence index if not random
        num_sequences_to_skip (int): Stride between sequences if not random

    Visualization components per sequence:
        - Original image sequence
        - Motion field comparison (refined vs original RAFT)
        - Height field reconstruction with uncertainty
        - Diagnostic statistics
        - Calibration parameters

    Note:
        Handles GPU memory management and matplotlib cleanup.
        Catches and reports visualization errors per sequence.
        Maintains model in eval mode during visualization.
    """
    model.eval()
    device = next(model.parameters()).device

    # Get sequence indices to visualize
    if randomize:
        sequence_indices = np.random.choice(
            len(dataset),
            size=min(num_sequences, len(dataset)),
            replace=False
        )
    else:
        end_idx = min(
            starting_sequence_index + (num_sequences * num_sequences_to_skip),
            len(dataset)
        )
        sequence_indices = range(
            starting_sequence_index,
            end_idx,
            num_sequences_to_skip
        )

    # Create a temporary DataLoader for single sequences
    temp_loader = DataLoader(
        dataset,
        batch_size=1,
        collate_fn=lambda b: collate_fn(b, max_sequence_length=5)
    )

    with torch.no_grad():
        for seq_idx, idx in enumerate(sequence_indices):
            try:
                # Get the sequence
                batch_data = temp_loader.dataset[idx]
                batch = collate_fn([batch_data])

                # Move data to device
                images = batch['images'].to(device)
                metadata = batch['metadata'].to(device)
                sequence_lengths = batch['sequence_lengths'].to(device)
                validation_height = batch['validation_height'].to(device)

                # Get model outputs
                outputs = model(
                    images=images,
                    metadata=metadata,
                    validation_height=validation_height,
                    sequence_lengths=sequence_lengths
                )

                print(f"\nSequence {seq_idx + 1}/{len(sequence_indices)} (Dataset index: {idx})")
                print(f"Sequence length: {sequence_lengths[0].item()}")
                if 'image_paths' in batch:
                    print("\nImage paths:")
                    for path in batch['image_paths'][0]:
                        print(f"  {path}")

                # Get and explicitly display visualizations
                viz_outputs = visualize_training_progress(
                    outputs=outputs,
                    batch_data=batch,
                    dataset=dataset,
                    model=model
                )

                # Clear any remaining matplotlib state
                plt.close('all')

            except Exception as e:
                print(f"Error visualizing sequence {idx}: {str(e)}")
                import traceback
                traceback.print_exc()
                continue
            finally:
                # Ensure cleanup
                plt.close('all')
                torch.cuda.empty_cache()

    model.train()

In [None]:
# Visualize random sequences
visualize_multiple_sequences(
    model=trained_model,
    dataset=full_cloud2cloud_dataset,
    num_sequences=1,
    randomize=True
)

In [None]:
# Visualize consecutive sequences
visualize_multiple_sequences(
    model=trained_model,
    dataset=full_cloud2cloud_dataset,
    num_sequences=1,
    randomize=False,
    starting_sequence_index=0,
    num_sequences_to_skip=1
)

# Stitching Process

## Data Collection and Preprocessing

`collect_height_fields` processes sequences to obtain:
   - Height fields
   - Confidence maps
   - Flight metadata

A constraint of `collect_height_fields` is that it requires temporal ordering; obviously to stitch height fields together, it's necessary that sequences maintain their chronological order.

However, due to varying sequence lengths, the default collation process sorts sequences by length for efficient packing (longer sequences first). While this optimizes memory usage, it disrupts temporal order.

By using batch_size=1:
1. Each sequence stays in chronological order
2. We can track motion evolution over time
3. Height fields align with their timestamps

The tradeoff:
- Pros: Simpler code, preserved temporal order
- Cons: Must process sequences one at a time, slower execution

This can be seen in the DataLoader `stitch_loader` internal to the function which enforces the batch size of 1.

```python
# Creates batch_size=1 loader to maintain temporal order
stitch_loader = DataLoader(
    dataset,
    batch_size=1,  # Needed for temporal ordering
    shuffle=False,
    collate_fn=filter_collate_fn,
    num_workers=1,
    pin_memory=True
)
```

## Global Coordinate System Creation

The `create_global_grid` function creates a unified coordinate system to stitch together multiple cloud height fields:

First, it uses the haversine formula to calculate true ground distances between measurement points. Taking the first measurement as a reference point, it calculates how far each subsequent measurement is in meters (north-south and east-west).

These real-world distances are converted to pixel coordinates using a scale factor (0.05 pixels per meter). This creates a grid where each height field's position corresponds to its actual geographic location relative to others.

The function then finds the total bounds needed to contain all height fields by looking at their positions and dimensions. It shifts everything so the minimum x and y coordinates start at zero, adjusting all positions and measurement points accordingly.

### Haversine Distance
The haversine formula calculates great-circle distances between latitude/longitude points on a sphere (Earth):

\begin{align*}
a &= \sin^2(\frac{\Delta\phi}{2}) + \cos(\phi_1)\cos(\phi_2)\sin^2(\frac{\Delta\lambda}{2}) \\
d &= 2R \arcsin(\sqrt{a})
\end{align*}

where:
- $\phi$ is latitude
- $\lambda$ is longitude
- $R$ is Earth's radius (6371km)

### Grid Creation Process

1. **Find Spatial Bounds**
   \begin{align*}
   \text{For each height field } H_i:& \\
   dx_i &= \text{haversine}(\text{ref}_{\text{lat}}, \text{ref}_{\text{lon}}, \text{ref}_{\text{lat}}, \text{lon}_i) \\
   dy_i &= \text{haversine}(\text{ref}_{\text{lat}}, \text{ref}_{\text{lon}}, \text{lat}_i, \text{ref}_{\text{lon}}) \\
   \text{pos}_i &= (dx_i \cdot \text{scale}, dy_i \cdot \text{scale})
   \end{align*}

2. **Calculate Grid Dimensions**
   \begin{align*}
   x_{\text{min}} &= \min_i(\text{pos}_i^x) \\
   x_{\text{max}} &= \max_i(\text{pos}_i^x + w_i) \\
   y_{\text{min}} &= \min_i(\text{pos}_i^y) \\
   y_{\text{max}} &= \max_i(\text{pos}_i^y + h_i)
   \end{align*}

   where $w_i, h_i$ are height field dimensions

3. **Adjust Measurement Points**
   $$
   \text{For each point } p:
   \begin{align*}
   p_{\text{adj}}^x &= p^x - x_{\text{min}} \\
   p_{\text{adj}}^y &= p^y - y_{\text{min}}
   \end{align*}
   $$

The `create_global_grid` function returns:
- Grid dimensions
- Adjusted positions of all height fields
- Measurement point locations
- Reference coordinates and scale
- Relative positions for stitching

This creates a pixel-space coordinate system where all height fields can be positioned relative to each other while maintaining their true geographical relationships.

## Field Stitching

`stitch_height_fields` combines fields using weighted averaging:

1. Create output arrays:

The process begins by creating arrays sized to encompass all height fields. Two main arrays track the sum of weighted heights and sum of weights separately. This dual-array approach allows proper handling of invalid regions marked by NaN values in the input fields.

   \begin{align*}
   H_{\text{sum}} &\in \mathbb{R}^{h \times w} \\
   W_{\text{sum}} &\in \mathbb{R}^{h \times w}
   \end{align*}

2. Generate confidence mask for each valid region:

The system generates confidence masks using a Gaussian falloff from measurement points. Starting with a base confidence of 0.2, confidence increases to 0.8 near measurement points. The σ parameter controls how quickly confidence decreases with distance from measured points.

   $$C(x,y) = 0.2 + 0.6 \exp(-\frac{d^2}{2\sigma^2})$$
   where $d$ is distance from measurement points

The values 0.2 and 0.6 in the confidence mask equation are chosen to create a bounded confidence range:
- Minimum confidence: 0.2 far from measurements
- Maximum confidence: 0.8 (0.2 + 0.6) at measurement points

These specific values allow the weighting to:
1. Never completely discard data (0.2 floor)
2. Never fully trust a single measurement (0.8 ceiling)
3. Provide enough dynamic range (0.6) for smooth blending

The values represent a design choice balancing between measurement trust and field continuity.

3. Add weighted contributions:

Each height field is positioned in the global grid according to its geographic coordinates. The system calculates weights based on distance to measurement points and combines them with local uncertainty values. These weighted contributions accumulate in the running sums, naturally handling overlapping regions.

   \begin{align*}
   H_{\text{sum}}(x,y) &+= h(x,y) \cdot w(x,y) \cdot c(x,y) \\
   W_{\text{sum}}(x,y) &+= w(x,y) \cdot c(x,y)
   \end{align*}

4. Calculate final heights:
   $$H_{\text{final}}(x,y) = \frac{H_{\text{sum}}(x,y)}{W_{\text{sum}}(x,y)}$$

## Post-processing

After combining fields, the system divides weighted sums to calculate final heights. It then identifies sharp discontinuities in the height field and applies selective smoothing in moderate confidence regions while preserving strong gradients where confidence is high.

1. Detect and smooth discontinuities using Gaussian filter
2. Preserve strong gradients in high-confidence regions
3. Create uncertainty visualization overlay

`stitch_height_fields` produces a stitched height field, comprehensive confidence map, and metadata package containing the sequences used, measurement points, geographic bounds, and reference coordinates.

## Visualization

Two complementary visualizations:

1. `visualize_stitching_steps`:
   - 2D overview
   - LiDAR measurement points
   - Data validity percentages
   - Height ranges

2. `visualize_stitched_height_field_3d`:
   - Interactive 3D surface
   - Uncertainty overlay
   - Synchronized views
   - Diagnostic statistics

## Helper functions

In [None]:
class AlignmentError(Exception):
    """Custom exception for alignment failures."""
    pass

def haversine_distance(lat1, lon1, lat2, lon2):
    """
    Calculate great-circle distance between coordinates.

    Uses haversine formula for spherical Earth approximation:
    d = 2R * arcsin(sqrt(sin²(Δφ/2) + cos(φ₁)cos(φ₂)sin²(Δλ/2)))

    Args:
        lat1, lon1: First point coordinates in degrees
        lat2, lon2: Second point coordinates in degrees

    Returns:
        float: Distance in meters
    """
    R = 6371000  # Earth's radius in meters

    # Convert to radians
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])

    dlat = lat2 - lat1
    dlon = lon2 - lon1

    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a))

    return R * c

def gaussian_weight(dist, sigma=100):
    """
    Calculate Gaussian weights based on distance.

    Weight = exp(-0.5 * (dist/sigma)²)

    Args:
        dist (float): Distance from measurement point
        sigma (float): Standard deviation controlling falloff rate

    Returns:
        float: Weight in [0,1], maximum at dist=0
    """
    return np.exp(-0.5 * (dist / sigma)**2)

def create_confidence_mask(height, width, measurement_points, max_dist=200):
    """
    Generate confidence mask based on distance from measurements.

    Creates mask with base confidence 0.2 increasing to 0.8 near
    measurement points with Gaussian falloff.

    Args:
        height, width: Output dimensions
        measurement_points: List of measurement locations
        max_dist: Maximum influence distance

    Returns:
        numpy.ndarray: Confidence values [0.2, 0.8]
    """
    confidence = np.zeros((height, width))
    y_grid, x_grid = np.mgrid[:height, :width]

    for point in measurement_points:
        # Calculate distance from this measurement point
        dist_sq = ((x_grid - point['x'])**2 + (y_grid - point['y'])**2) / (max_dist**2)

        # Add confidence contribution with gaussian falloff
        point_conf = np.exp(-0.5 * dist_sq)
        confidence = np.maximum(confidence, point_conf)

    # Normalize confidence to [0.2, 0.8] range
    confidence = 0.2 + 0.6 * confidence

    return confidence

## collect_height_fields

In [None]:
def collect_height_fields(trained_model, dataset, start_batch=0, max_batches=10):
    """
    Collect height fields and metadata while preserving temporal order.

    Uses batch_size=1 to maintain chronological sequence ordering, since
    packed sequence sorting would otherwise disrupt temporal relationships.

    Args:
        trained_model: Trained cloud motion model
        dataset: Source dataset
        start_batch: First batch to process
        max_batches: Maximum number of batches to process

    Returns:
        tuple: (height_fields, confidence_maps, metadata_list)
            height_fields: List of height field arrays
            confidence_maps: List of confidence value arrays
            metadata_list: List of per-sequence metadata
    """
    # Create stitching-specific loader with enforced batch_size=1
    stitch_loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=filter_collate_fn,
        num_workers=1,
        pin_memory=True
    )

    height_fields = []
    confidence_maps = []
    metadata_list = []
    metadata_manager = FlightMetadataManager()

    print("\nCollecting sequences:")
    sequence_count = 0

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(stitch_loader):
            if batch_idx < start_batch:
                continue
            if batch_idx >= start_batch + max_batches:
                break
            if batch_data is None:
                continue

            # Now we can safely assume batch_size=1
            timestamp = batch_data['timestamps'][0]
            print(f"\nSequence {sequence_count} (Batch {batch_idx}):")
            print(f"  Start: {timestamp[0]}")
            print(f"  End: {timestamp[-1]}")

            # Rest of processing remains the same
            outputs = trained_model(
                images=batch_data['images'].to(device),
                metadata=batch_data['metadata'].to(device),
                validation_height=batch_data['validation_height'].to(device),
                sequence_lengths=batch_data['sequence_lengths'].to(device)
            )

            # Extract results (safe to use index 0)
            height_fields.append(outputs['height_field_finetuned'][0].cpu().numpy())
            confidence_maps.append(outputs['height_uncertainty_finetuned'][0].cpu().numpy())

            # Build metadata for sequence
            metadata_unpacked, lengths = torch.nn.utils.rnn.pad_packed_sequence(batch_data['metadata'])
            seq_length = lengths[0]
            seq_metadata = metadata_unpacked[:seq_length, 0].cpu().numpy()
            center_x, center_y = batch_data['center_coords'][0]

            seq_metadata_list = []
            for frame_idx in range(seq_length):
                frame_dict = {
                    'timestamp': timestamp[frame_idx],
                    'Lat': float(seq_metadata[frame_idx][metadata_manager.get_index('Lat')]),
                    'Lon': float(seq_metadata[frame_idx][metadata_manager.get_index('Lon')]),
                    'True_Hdg': float(seq_metadata[frame_idx][metadata_manager.get_index('True_Hdg')]),
                    'True_Airspeed': float(seq_metadata[frame_idx][metadata_manager.get_index('True_Airspeed')]),
                    'GPS_MSL_Alt': float(seq_metadata[frame_idx][metadata_manager.get_index('GPS_MSL_Alt')])
                }

                if frame_idx == seq_length - 1:
                    frame_dict['center_x'] = int(center_x)
                    frame_dict['center_y'] = int(center_y)

                seq_metadata_list.append(frame_dict)

            metadata_list.append(seq_metadata_list)
            sequence_count += 1

            # Memory cleanup
            del outputs
            torch.cuda.empty_cache()

    print(f"\nCollected {sequence_count} sequences")
    return height_fields, confidence_maps, metadata_list

## create_global_grid

In [None]:
def create_global_grid(height_fields, metadata_list):
    """
    Create unified coordinate system for stitching height fields.

    Converts geographic coordinates to pixel space while preserving
    relative positions and distances between measurements.

    Args:
        height_fields: List of height field arrays
        metadata_list: List of metadata with lat/lon coordinates

    Returns:
        dict: Grid information containing:
            - min/max coordinates
            - field positions in pixel space
            - measurement point locations
            - reference coordinates and scale factor
    """
    # Calculate total bounds
    min_x, max_x = float('inf'), float('-inf')
    min_y, max_y = float('inf'), float('-inf')
    global_positions = []
    measurement_points = []

    # Get reference lat/lon (from first measurement)
    ref_lat = metadata_list[0][-1]['Lat']
    ref_lon = metadata_list[0][-1]['Lon']
    pixels_per_meter = 0.05  # Adjust this based on your image resolution

    # First pass: calculate positions using lat/lon
    for i, (height, meta) in enumerate(zip(height_fields, metadata_list)):
        h, w = height.shape

        # Convert lat/lon to meters from reference point
        dx = haversine_distance(ref_lat, ref_lon, ref_lat, meta[-1]['Lon']) * \
             (1 if meta[-1]['Lon'] > ref_lon else -1)
        dy = haversine_distance(ref_lat, ref_lon, meta[-1]['Lat'], ref_lon) * \
             (1 if meta[-1]['Lat'] > ref_lat else -1)

        # Convert to pixel coordinates
        pos_x = int(dx * pixels_per_meter)
        pos_y = int(dy * pixels_per_meter)

        # Store position
        global_positions.append((pos_x, pos_y))

        # Store measurement point
        measurement_points.append({
            'x': pos_x + w // 2,  # Center of image
            'y': pos_y + h // 2,
            'height': meta[-1].get('validation_height', 0),
            'lat': meta[-1]['Lat'],
            'lon': meta[-1]['Lon']
        })

        # Update bounds
        min_x = min(min_x, pos_x)
        max_x = max(max_x, pos_x + w)
        min_y = min(min_y, pos_y)
        max_y = max(max_y, pos_y + h)

    # Adjust positions relative to minimum bounds
    adjusted_positions = []
    adjusted_points = []

    for (pos_x, pos_y), point in zip(global_positions, measurement_points):
        adj_x = pos_x - min_x
        adj_y = pos_y - min_y

        adjusted_positions.append((adj_x, adj_y))
        adjusted_points.append({
            'x': point['x'] - min_x,
            'y': point['y'] - min_y,
            'height': point['height'],
            'lat': point['lat'],
            'lon': point['lon']
        })

    return {
        'min_x': 0,
        'max_x': max_x - min_x,
        'min_y': 0,
        'max_y': max_y - min_y,
        'positions': adjusted_positions,
        'measurement_points': adjusted_points,
        'reference': {
            'lat': ref_lat,
            'lon': ref_lon,
            'pixels_per_meter': pixels_per_meter
        }
    }

## stitch_height_fields

In [None]:
def stitch_height_fields(height_fields, confidence_maps, metadata_list,
                        max_size_pixels=(4096, 4096)):
    """
    Combine multiple height fields into single continuous surface.

    Uses weighted averaging with confidence-based blending and
    selective smoothing of discontinuities.

    Args:
        height_fields: List of height field arrays
        confidence_maps: List of confidence arrays
        metadata_list: List of metadata with positions
        max_size_pixels: Maximum output dimensions

    Returns:
        tuple: (merged_height, merged_conf, metadata)
            merged_height: Combined height field
            merged_conf: Merged confidence map
            metadata: Stitching information
    """
    # Get global coordinate system with measurement points
    grid = create_global_grid(height_fields, metadata_list)

    # Create output arrays
    width = int(grid['max_x'] - grid['min_x'])
    height = int(grid['max_y'] - grid['min_y'])

    if width > max_size_pixels[1] or height > max_size_pixels[0]:
        print(f"Warning: Size {width}x{height} exceeds limit {max_size_pixels}")
        return None, None, None

    # Initialize arrays for weighted averaging
    height_sum = np.zeros((height, width))
    weight_sum = np.zeros((height, width))
    merged_height = np.full((height, width), np.nan)

    print(f"\nCreating global height field of size: ({height}, {width})")

    # Create global confidence mask based on measurement points
    merged_conf = create_confidence_mask(
        height, width,
        grid['measurement_points']
    )

    # For each sequence
    for idx, ((pos_x, pos_y), height_field, conf_map) in enumerate(
        zip(grid['positions'], height_fields, confidence_maps)):

        h, w = height_field.shape

        # Get valid region
        valid_mask = ~np.isnan(height_field)

        if np.any(valid_mask):
            # Get measurement point for this sequence
            meas_point = grid['measurement_points'][idx]

            # Calculate weights based on distance from measurement point
            y_grid, x_grid = np.mgrid[pos_y:pos_y+h, pos_x:pos_x+w]
            dist_sq = ((x_grid - meas_point['x'])**2 +
                      (y_grid - meas_point['y'])**2)
            weights = gaussian_weight(np.sqrt(dist_sq))

            # Get target region in merged array
            region_slice = (
                slice(pos_y, pos_y + h),
                slice(pos_x, pos_x + w)
            )

            # Add weighted heights to sum
            height_sum[region_slice][valid_mask] += (
                height_field[valid_mask] *
                weights[valid_mask] *
                conf_map[valid_mask]
            )

            # Add weights to sum
            weight_sum[region_slice][valid_mask] += (
                weights[valid_mask] * conf_map[valid_mask]
            )

    # Calculate final height field
    valid_weights = weight_sum > 0
    if np.any(valid_weights):
        merged_height[valid_weights] = (
            height_sum[valid_weights] / weight_sum[valid_weights]
        )

        # Find areas with sharp height changes (potential banding)
        gradient_y = np.abs(np.diff(merged_height, axis=0))
        gradient_x = np.abs(np.diff(merged_height, axis=1))

        # Create mask for areas with sharp changes
        threshold = np.nanpercentile(gradient_y, 95)  # Adjust percentile as needed
        band_mask_y = np.zeros_like(merged_height, dtype=bool)
        band_mask_y[:-1, :] = gradient_y > threshold
        band_mask_y[1:, :] |= gradient_y > threshold

        threshold = np.nanpercentile(gradient_x, 95)
        band_mask_x = np.zeros_like(merged_height, dtype=bool)
        band_mask_x[:, :-1] = gradient_x > threshold
        band_mask_x[:, 1:] |= gradient_x > threshold

        band_mask = band_mask_x | band_mask_y

        # Only smooth the banding areas with moderate confidence
        smooth_mask = band_mask & (merged_conf > 0.3) & (merged_conf < 0.7)

        if np.any(smooth_mask):
            # Very gentle smoothing
            smoothed = gaussian_filter(
                np.nan_to_num(merged_height, nan=np.nanmean(merged_height)),
                sigma=0.5  # Reduced sigma
            )

            # Apply smoothing only to masked areas
            merged_height[smooth_mask] = (
                merged_height[smooth_mask] * 0.7 +  # Weighted blend
                smoothed[smooth_mask] * 0.3
            )

    return merged_height, merged_conf, {
        'measurement_points': grid['measurement_points'],
        'used_sequences': list(range(len(height_fields))),
        'bounds': {
            'min_x': int(grid['min_x']),
            'max_x': int(grid['max_x']),
            'min_y': int(grid['min_y']),
            'max_y': int(grid['max_y'])
        },
        'reference': grid['reference']
    }

# Collect

In [None]:
# Collect and visualize
height_fields, confidence_maps, metadata_list = collect_height_fields(
    trained_model,
    full_cloud2cloud_dataset,  # Pass dataset, not loader
    start_batch=200,
    max_batches=10
)

# Stitch

In [None]:
# First, run the stitching
final_height, final_confidence, stitching_info = stitch_height_fields(
    height_fields, confidence_maps, metadata_list,
    max_size_pixels=(4096, 4096)
)

# Visualize

In [None]:
def visualize_stitching_steps(final_height, final_confidence, stitching_info):
    """
    Create 2D visualization of stitched height field result.

    Shows height field, confidence overlay, measurement points,
    and diagnostic information.

    Args:
        final_height: Stitched height field array
        final_confidence: Confidence map array
        stitching_info: Metadata and measurement points

    Returns:
        matplotlib.figure.Figure: Visualization figure
    """
    print(f"\nGenerating visualization:")
    print(f"  Used sequences: {stitching_info['used_sequences']}")
    print(f"  Shape: {final_height.shape}")

    # Create figure with subplots
    fig = plt.figure(figsize=(15, 10))
    gs = gridspec.GridSpec(2, 2, height_ratios=[1, 0.05])

    # Height field subplot
    ax_height = plt.subplot(gs[0, 0])
    height_img = ax_height.imshow(final_height, cmap='viridis')
    ax_height.set_title(f'Merged Height Field\n{len(stitching_info["used_sequences"])} sequences')

    # Add measurement point markers
    legend_handles = []
    for i, point in enumerate(stitching_info['measurement_points']):
        marker = ax_height.plot(point['x'], point['y'], 'r+',
                              label=f'Measurement {i}')[0]
        legend_handles.append(marker)

    if legend_handles:
        ax_height.legend(handles=legend_handles, bbox_to_anchor=(1.05, 1),
                        loc='upper left')

    # Confidence map subplot
    ax_conf = plt.subplot(gs[0, 1])
    conf_img = ax_conf.imshow(final_confidence, cmap='RdYlGn')
    ax_conf.set_title('Confidence Map')

    # Colorbar
    ax_colorbar = plt.subplot(gs[1, :])
    plt.colorbar(height_img, cax=ax_colorbar, orientation='horizontal',
                label='Height (m)')

    plt.tight_layout()
    plt.show()
    plt.close()

In [None]:
# Create the 2D visualization
visualize_stitching_steps(
    final_height, final_confidence, stitching_info
)

In [None]:
def visualize_stitched_height_field_3d(height_field, confidence_map, measurement_points, title="Stitched Cloud Height Field"):
    """
    Create interactive 3D visualization of stitched result.

    Generates Plotly figure with height surface, confidence overlay,
    measurement markers, and diagnostic text.

    Args:
        height_field: Stitched height field array
        confidence_map: Confidence values array
        measurement_points: List of LiDAR measurements
        title: Plot title

    Returns:
        plotly.graph_objects.Figure: Interactive 3D visualization
    """
    h, w = height_field.shape
    x = np.arange(w)
    y = np.arange(h)
    X, Y = np.meshgrid(x, y)

    # Create figure
    fig = go.Figure()

    # Add height field surface
    fig.add_trace(go.Surface(
        x=X,
        y=Y,
        z=height_field,
        colorscale='viridis',
        colorbar=dict(
            title='Height (m)',
            len=0.75,
            y=0.5,
            thickness=20
        ),
        opacity=0.8,
        showscale=True,
        name='Height Field'
    ))

    # Add confidence overlay
    fig.add_trace(go.Surface(
        x=X,
        y=Y,
        z=height_field,
        surfacecolor=confidence_map,
        colorscale='RdYlGn',
        colorbar=dict(
            title='Confidence',
            len=0.75,
            y=0.5,
            x=1.1,
            thickness=20
        ),
        opacity=0.3,
        showscale=True,
        name='Confidence'
    ))

    # Add measurement points
    point_xs = [p['x'] for p in measurement_points]
    point_ys = [p['y'] for p in measurement_points]
    point_zs = [height_field[int(p['y']), int(p['x'])] for p in measurement_points]

    fig.add_trace(go.Scatter3d(
        x=point_xs,
        y=point_ys,
        z=point_zs,
        mode='markers',
        marker=dict(
            size=5,
            color='red',
            symbol='diamond'
        ),
        name='LiDAR Measurements'
    ))

    # Set camera based on which dimension is longer
    if h > w:  # Vertical orientation
        camera = dict(eye=dict(x=2, y=0, z=1.5))
    else:  # Horizontal orientation
        camera = dict(eye=dict(x=0, y=2, z=1.5))

    # Update layout
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X Distance (pixels)',
            yaxis_title='Y Distance (pixels)',
            zaxis_title='Height (m)',
            aspectratio=dict(x=w/max(h,w), y=h/max(h,w), z=0.3),
            camera=camera
        ),
        width=1200,
        height=800,
        margin=dict(t=50, b=0, l=0, r=0)
    )

    return fig

In [None]:
# Create the 3D visualization
fig = visualize_stitched_height_field_3d(
    final_height,
    final_confidence,
    stitching_info['measurement_points'],
    title="Stitched Cloud Height Field with Confidence Overlay"
)

# Display the 3D figure
fig.show()