facts-total:
- This is a minimal prototype of a total module for summing sealevel rise projections generated from different sources and modules. facts-total is a CLI tool that accepts a path to each netCDF file you would like summed as well as an output path where the summed result will be written. Each input netCDF file represents output from a FACTS sea level component module. It is the responsibility of the user to ensure that the desired and correct files are specified; check that file paths are correct and that each file specified belongs to the same scale ('global' or 'local').

- It is possible to run multiple FACTS sea-level components with different default values for common parameters such as pyear-start and pyear-end. If that happens, total will not cause a failure, but will show a message similar to the following:



In [1]:
import sys
# Add a path to the search list
sys.path.insert(0, '/discover/nobackup/projects/eis_freshwater/gtamkin/facts2.0')

In [2]:
import os

# Get and print the current working directory (optional, for verification)
cwd = os.getcwd()
print(f"Current working directory: {cwd}")

# Change the current working directory to a new path
new_directory_path = "/discover/nobackup/projects/eis_freshwater/gtamkin/facts2.0" # Example for Linux/macOS
# For Windows, you can use forward slashes or a raw string (see below)

try:
    os.chdir(new_directory_path)
    print(f"Directory successfully changed to: {os.getcwd()}")
except FileNotFoundError:
    print(f"Directory not found: {new_directory_path}")
except Exception as e:
    print(f"An error occurred: {e}")


Current working directory: /gpfsm/dnb06/projects/p151/gtamkin/facts2.0/notebooks
Directory successfully changed to: /gpfsm/dnb06/projects/p151/gtamkin/facts2.0


In [3]:
import asyncio
import logging
import time
import os
import shlex

from radical.asyncflow import WorkflowEngine
from radical.asyncflow import ConcurrentExecutionBackend

from concurrent.futures import ThreadPoolExecutor

from radical.asyncflow.logging import init_default_logger

logger = logging.getLogger(__name__)

In [4]:
async def main():
    init_default_logger(logging.DEBUG)

    # Create backend and workflow
    engine = await ConcurrentExecutionBackend(ThreadPoolExecutor())
    flow = await WorkflowEngine.create(engine)
    
    # Ensure output directories exist
    def setup_directories():
        os.makedirs('./data/output/fair', exist_ok=True)
        os.makedirs('./data/output/lws', exist_ok=True)
        os.makedirs('./data/output/sterodynamics', exist_ok=True)

    @flow.executable_task
    async def fair_task():
        """FAIR temperature model task"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/fair:/output',
            './containers/fair-temperature.sif',
            'fair-temperature',
            '--pipeline-id=1234',
            '--output-oceantemp-file=/output/oceantemp.nc',
            '--nsamps=20',
            '--output-ohc-file=/output/ohc.nc',
            '--output-gsat-file=/output/gsat.nc',
            '--output-climate-file=/output/climate.nc',
            '--rcmip-file=/input/rcmip/rcmip-emissions-annual-means-v5-1-0.csv',
            '--param-file=/input/parameters/fair_ar6_climate_params_v4.0.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def lws_task():
        """Land Water Storage task - can run independently of FAIR"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/lws:/output',
            './containers/ssp-landwaterstorage.sif',
            'ssp-landwaterstorage',
            '--pipeline-id=1234',
            '--nsamps=20',
            '--output-gslr-file=/output/gslr.nc',
            '--output-lslr-file=/output/lslr.nc',
            '--location-file=/input/location.lst',
            '--pophist-file=/input/UNWPP2012 population historical.csv',
            '--reservoir-file=/input/Chao2008 groundwater impoundment.csv',
            '--popscen-file=/input/ssp_iam_baseline_popscenarios2100.csv',
            '--gwd-file=/input/Konikow2011 GWD.csv',
            '--gwd-file=/input/Wada2012 GWD.csv',
            '--gwd-file=/input/Pokhrel2012 GWD.csv',
            '--fp-file=/input/REL_GROUNDWATER_NOMASK.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def sterodynamics_task(fair_task):
        """Sterodynamics task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/output/fair:/fair',
            '--bind', './data/input:/input',
            '--bind', './data/output/sterodynamics:/output',
            '--nv',
            './containers/tlm-sterodynamics.sif',
            'tlm-sterodynamics',
            '--pipeline-id=1234',
            '--scenario=ssp585',
            '--nsamps=20',
            '--model-dir=/input/cmip6/',
            '--location-file=/input/location.lst',
            '--output-lslr-file=/output/lslr.nc',
            '--output-gslr-file=/output/gslr.nc',
            '--expansion-coefficients-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n18_expcoefs.nc',
            '--gsat-rmses-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n17_gsat_rmse.nc',
            '--climate-data-file=/fair/climate.nc'
        ]
        return shlex.join(cmd)

    async def run_climate_workflow(pipeline_id):
        """Run the complete climate workflow"""
        logger.info(f'Starting climate workflow {pipeline_id} at {time.time()}')

        # Setup directories
        setup_directories()

        # Start FAIR and LWS tasks (they can run in parallel)
        fair_future = fair_task()
        lws_future = lws_task()

        # Wait for FAIR to complete (sterodynamics depends on it)
        fair_result = await fair_future
        logger.info(f'FAIR task completed for pipeline {pipeline_id}')

        # Start sterodynamics task (depends on FAIR output)
        sterodynamics_future = sterodynamics_task(fair_future)

        # Wait for all tasks to complete
        lws_result = await lws_future
        sterodynamics_result = await sterodynamics_future

        logger.info(f'Climate workflow {pipeline_id} finished at {time.time()}')

        return {
            'fair': fair_result,
            'lws': lws_result,
            'sterodynamics': sterodynamics_result
        }

    # Run workflow(s)
    results = await run_climate_workflow(1)
    logger.info("All workflows completed successfully")
    logger.info(results)
    await flow.shutdown()

# Just call it with await in Jupyter
await main()


[90m2026-02-02 10:37:44.480[0m │ [94mINFO[0m │ [38;5;165m[root][0m │ Logger configured successfully - Console: DEBUG, File: disabled (N/A), Structured: disabled, Style: modern
[90m2026-02-02 10:37:44.481[0m │ [94mINFO[0m │ [38;5;165m[execution.backend(concurrent)][0m │ ThreadPoolExecutor execution backend started successfully
[90m2026-02-02 10:37:44.481[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGHUP
[90m2026-02-02 10:37:44.482[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGTERM
[90m2026-02-02 10:37:44.482[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGINT
[90m2026-02-02 10:37:44.483[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Started run component
[90m2026-02-02 10:37:44.483[0m │ [94mINFO[0m │ [38;5;165m[main][0m │ Starting climate workflow 1 at 1770046664.4834452
[90m2026-02-02 10:37:44.488[0m │ [96mDEBUG[0m

In [None]:
async def main_emulandice():
    init_default_logger(logging.DEBUG)

    # Create backend and workflow
    engine = await ConcurrentExecutionBackend(ThreadPoolExecutor())
    flow = await WorkflowEngine.create(engine)
    
    # Ensure output directories exist
    def setup_directories():
        os.makedirs('./data/output/fair', exist_ok=True)
        os.makedirs('./data/output/lws', exist_ok=True)
        os.makedirs('./data/output/sterodynamics', exist_ok=True)
        os.makedirs('./data/output/emulandice', exist_ok=True)

    @flow.executable_task
    async def fair_task():
        """FAIR temperature model task"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/fair:/output',
            './containers/fair-temperature-sandbox',
            'fair-temperature',
            '--pipeline-id=1234',
            '--output-oceantemp-file=/output/oceantemp.nc',
            '--nsamps=20',
            '--output-ohc-file=/output/ohc.nc',
            '--output-gsat-file=/output/gsat.nc',
            '--output-climate-file=/output/climate.nc',
            '--rcmip-file=/input/rcmip/rcmip-emissions-annual-means-v5-1-0.csv',
            '--param-file=/input/parameters/fair_ar6_climate_params_v4.0.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def lws_task():
        """Land Water Storage task - can run independently of FAIR"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/lws:/output',
            './containers/ssp-landwaterstorage-sandbox',
            'ssp-landwaterstorage',
            '--pipeline-id=1234',
            '--nsamps=20',
            '--output-gslr-file=/output/gslr.nc',
            '--output-lslr-file=/output/lslr.nc',
            '--location-file=/input/location.lst',
            '--pophist-file=/input/UNWPP2012 population historical.csv',
            '--reservoir-file=/input/Chao2008 groundwater impoundment.csv',
            '--popscen-file=/input/ssp_iam_baseline_popscenarios2100.csv',
            '--gwd-file=/input/Konikow2011 GWD.csv',
            '--gwd-file=/input/Wada2012 GWD.csv',
            '--gwd-file=/input/Pokhrel2012 GWD.csv',
            '--fp-file=/input/REL_GROUNDWATER_NOMASK.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def sterodynamics_task(fair_task):
        """Sterodynamics task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/output/fair:/fair',
            '--bind', './data/input:/input',
            '--bind', './data/output/sterodynamics:/output',
            '--nv',
            './containers/tlm-sterodynamics-sandbox',
            'tlm-sterodynamics',
            '--pipeline-id=1234',
            '--scenario=ssp585',
            '--nsamps=20',
            '--model-dir=/input/cmip6/',
            '--location-file=/input/location.lst',
            '--output-lslr-file=/output/lslr.nc',
            '--output-gslr-file=/output/gslr.nc',
            '--expansion-coefficients-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n18_expcoefs.nc',
            '--gsat-rmses-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n17_gsat_rmse.nc',
            '--climate-data-file=/fair/climate.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def emulandice_ais_task(fair_task):
        """Emulandice AIS task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input/:ro',
            '--bind', './data/output:/output',
            '--nv',
            './containers/emulandice-sandbox',
            'ais',
            '--pipeline-id=1234',
            '--fprint-wais-file="/input/FPRINT/fprint_wais.nc',
            '--fprint-eais-file="/input/FPRINT/fprint_eais.nc',
            '--input-data-file="/output/fair/gsat.nc',
            '--location-file="/input/location.lst',
            '--output-gslr-file="/output/emulandice/gslr.nc',
            '--output-lslr-file="/output/emulandice/lslr.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def emulandice_gris_task(fair_task):
        """Emulandice GRIS task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input/:ro',
            '--bind', './data/output:/output',
            '--nv',
            './containers/emulandice-sandbox',
            'gris',
            '--pipeline-id=1234',
            '--fprint-gis-file="/input/FPRINT/fprint_gis.nc',
            '--input-data-file="/output/fair/gsat.nc',
            '--location-file="/input/location.lst',
            '--output-gslr-file="/output/emulandice/gslr.nc',
            '--output-lslr-file="/output/emulandice/lslr.nc'
        ]
        return shlex.join(cmd)

  # --fprint-glacier-dir="/input/FPRINT" \
  # --fprint-map-file="/input/fingerprint_region_map.csv" \
  # --input-data-file="/input/gsat.nc" \
  # --location-file="/input/location.lst" \
  # --output-gslr-file="/output/gslr.nc" \
  # --output-lslr-file="/output/lslr.nc" \
  # --output-glacier-dir="/output/glacier"

    @flow.executable_task
    async def emulandice_glacier_task(fair_task):
        """Emulandice glacier task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input/:ro',
            '--bind', './data/output:/output',
            '--nv',
            './containers/emulandice-sandbox',
            'glaciers',
            '--pipeline-id=1234',
            '--fprint-glacier-dir="/input/FPRINT',
            '--input-data-file="/output/fair/gsat.nc',
            '--location-file="/input/location.lst',
            '--output-gslr-file="/output/emulandice/gslr.nc',
            '--output-lslr-file="/output/emulandice/lslr.nc',
            '--output-glacier-dir="/output/glacier'
        ]
        print(cmd)
        return shlex.join(cmd)

    
    async def run_climate_workflow(pipeline_id):
        """Run the complete climate workflow"""
        logger.info(f'Starting climate workflow {pipeline_id} at {time.time()}')

        # Setup directories
        setup_directories()
        
        # emulandice_future = emulandice_ais_task()
        emulandice_future = emulandice_glacier_task()
        emulandice_result = await emulandice_future
        logger.info(f'EMULANDICE task completed for pipeline {pipeline_id}')
        
        # Start FAIR and LWS tasks (they can run in parallel)
        fair_future = fair_task()
        lws_future = lws_task()

        # Wait for FAIR to complete (sterodynamics depends on it)
        fair_result = await fair_future
        logger.info(f'FAIR task completed for pipeline {pipeline_id}')

        # Start sterodynamics task (depends on FAIR output)
        sterodynamics_future = sterodynamics_task(fair_future)

        # Wait for all tasks to complete
        lws_result = await lws_future
        sterodynamics_result = await sterodynamics_future

        logger.info(f'Climate workflow {pipeline_id} finished at {time.time()}')

        return {
            'fair': fair_result,
            'lws': lws_result,
            'sterodynamics': sterodynamics_result,
            'emulandice': emulandice_result
        }

    # Run workflow(s)
    results = await run_climate_workflow(1)
    logger.info("All workflows completed successfully")
    logger.info(results)
    await flow.shutdown()

# Just call it with await in Jupyter
await main_emulandice()


[90m2026-02-03 10:52:57.486[0m │ [94mINFO[0m │ [38;5;165m[root][0m │ Logger configured successfully - Console: DEBUG, File: disabled (N/A), Structured: disabled, Style: modern
[90m2026-02-03 10:52:57.487[0m │ [94mINFO[0m │ [38;5;165m[execution.backend(concurrent)][0m │ ThreadPoolExecutor execution backend started successfully
[90m2026-02-03 10:52:57.487[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGHUP
[90m2026-02-03 10:52:57.487[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGTERM
[90m2026-02-03 10:52:57.488[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Registered signal handler for SIGINT
[90m2026-02-03 10:52:57.488[0m │ [96mDEBUG[0m │ [38;5;165m[workflow_manager][0m │ Started run component
[90m2026-02-03 10:52:57.488[0m │ [94mINFO[0m │ [38;5;165m[main][0m │ Starting climate workflow 1 at 1770133977.4888594


Revised Example from ChatGSFC - 01/20/2026 - 10:16 AM:
- https://chat.gsfc.nasa.gov/c/25924100-fa4e-4543-b5dd-debc29cfe648

In [None]:
def validate_inputs():
    required_files = [
        './containers/fair-temperature.sif',
        './containers/ssp-landwaterstorage.sif',
        './containers/tlm-sterodynamics.sif',
        './data/input/rcmip/rcmip-emissions-annual-means-v5-1-0.csv',
        # ... add other critical files
    ]
    
    missing = [f for f in required_files if not os.path.exists(f)]
    if missing:
        raise FileNotFoundError(f"Missing required files: {missing}")

async def main():
    init_default_logger(logging.DEBUG)

    # Create backend and workflow
    engine = await ConcurrentExecutionBackend(ThreadPoolExecutor())
    flow = await WorkflowEngine.create(engine)
    
    # Ensure output directories exist
    def setup_directories():
        os.makedirs('./data/output/fair', exist_ok=True)
        os.makedirs('./data/output/lws', exist_ok=True)
        os.makedirs('./data/output/sterodynamics', exist_ok=True)

    @flow.executable_task
    async def fair_task():
        """FAIR temperature model task"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/fair:/output',
            './containers/fair-temperature.sif',
            'fair-temperature',
            '--pipeline-id=1234',
            '--output-oceantemp-file=/output/oceantemp.nc',
            '--nsamps=20',
            '--output-ohc-file=/output/ohc.nc',
            '--output-gsat-file=/output/gsat.nc',
            '--output-climate-file=/output/climate.nc',
            '--rcmip-file=/input/rcmip/rcmip-emissions-annual-means-v5-1-0.csv',
            '--param-file=/input/parameters/fair_ar6_climate_params_v4.0.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def lws_task():
        """Land Water Storage task - can run independently of FAIR"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/input:/input',
            '--bind', './data/output/lws:/output',
            './containers/ssp-landwaterstorage.sif',
            'ssp-landwaterstorage',
            '--pipeline-id=1234',
            '--nsamps=20',
            '--output-gslr-file=/output/gslr.nc',
            '--output-lslr-file=/output/lslr.nc',
            '--location-file=/input/location.lst',
            '--pophist-file=/input/UNWPP2012 population historical.csv',
            '--reservoir-file=/input/Chao2008 groundwater impoundment.csv',
            '--popscen-file=/input/ssp_iam_baseline_popscenarios2100.csv',
            '--gwd-file=/input/Konikow2011 GWD.csv',
            '--gwd-file=/input/Wada2012 GWD.csv',
            '--gwd-file=/input/Pokhrel2012 GWD.csv',
            '--fp-file=/input/REL_GROUNDWATER_NOMASK.nc'
        ]
        return shlex.join(cmd)

    @flow.executable_task
    async def sterodynamics_task(fair_task):
        """Sterodynamics task - depends on FAIR output"""
        cmd = [
            '/usr/local/other/singularity/4.0.3/bin/singularity', 'exec',
            '--bind', './data/output/fair:/fair',
            '--bind', './data/input:/input',
            '--bind', './data/output/sterodynamics:/output',
            '--nv',
            './containers/tlm-sterodynamics.sif',
            'tlm-sterodynamics',
            '--pipeline-id=1234',
            '--scenario=ssp585',
            '--nsamps=20',
            '--model-dir=/input/cmip6/',
            '--location-file=/input/location.lst',
            '--output-lslr-file=/output/lslr.nc',
            '--output-gslr-file=/output/gslr.nc',
            '--expansion-coefficients-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n18_expcoefs.nc',
            '--gsat-rmses-file=/input/scmpy2LM_RCMIP_CMIP6calpm_n17_gsat_rmse.nc',
            '--climate-data-file=/fair/climate.nc'
        ]
        return shlex.join(cmd)

    async def run_climate_workflow(pipeline_id):
        """Run the complete climate workflow"""
        logger.info(f'Starting climate workflow {pipeline_id} at {time.time()}')

        # Setup directories
        setup_directories()

        # Start FAIR and LWS tasks (they can run in parallel)
        fair_future = fair_task()
        lws_future = lws_task()

        # Wait for FAIR to complete (sterodynamics depends on it)
        fair_result = await fair_future
        logger.info(f'FAIR task completed for pipeline {pipeline_id}')

        # Start sterodynamics task (depends on FAIR output)
        sterodynamics_future = sterodynamics_task(fair_future)

        # Wait for all tasks to complete
        lws_result = await lws_future
        sterodynamics_result = await sterodynamics_future

        logger.info(f'Climate workflow {pipeline_id} finished at {time.time()}')

        return {
            'fair': fair_result,
            'lws': lws_result,
            'sterodynamics': sterodynamics_result
        }

    async def run_climate_workflow2(pipeline_id):
        """Run the complete climate workflow2"""
        logger.info(f'Starting climate workflow {pipeline_id} at {time.time()}')
        
        try:
            # Setup and validate
            setup_directories()
            validate_inputs()
            
            # Start FAIR and LWS tasks (parallel)
            fair_future = fair_task(pipeline_id)
            lws_future = lws_task(pipeline_id)
            
            # Wait for FAIR (sterodynamics depends on it)
            fair_result = await fair_future
            logger.info(f'FAIR task completed for pipeline {pipeline_id}')
            
            # Start sterodynamics (depends on FAIR)
            sterodynamics_future = sterodynamics_task(pipeline_id, fair_future)
            
            # Wait for remaining tasks
            lws_result, sterodynamics_result = await asyncio.gather(
                lws_future, 
                sterodynamics_future
            )
            
            logger.info(f'Climate workflow {pipeline_id} completed at {time.time()}')
            
            return {
                'pipeline_id': pipeline_id,
                'fair': fair_result,
                'lws': lws_result,
                'sterodynamics': sterodynamics_result,
                'status': 'success'
            }
        
        except Exception as e:
            logger.error(f'Workflow {pipeline_id} failed: {str(e)}', exc_info=True)
            return {
                'pipeline_id': pipeline_id,
                'status': 'failed',
                'error': str(e)
            }    # Run workflow(s)
    
    results = await run_climate_workflow(1)
    logger.info("Workflow#1 completed successfully")
    logger.info(results)

    results = await run_climate_workflow2(1)
    logger.info("Workflow#2 completed successfully")
    logger.info(results)

    await flow.shutdown()

# Just call it with await in Jupyter
await main()


Overview

Specific Prototype Goal:

Workflow (GI=Graphcast Initial Source file, OB=Obs_*_ges):  
1. As per conversation with Amal & Mark, 
2.	For each Pressure level in GI
    * Aggregate OB values between +/-10 of GI level 
    * For example, where GI=850, get all OB value between 840 and 860
    * Average these OB values
3.  Replace the corresponding GI value with this average
4.	Save results as new NetCDF GI file 
5.	Run GraphCast prediction with new GI file
6.	Compare prediction results with results from original GI file


Notes from Amal/Mark discussion: 
1. The 'Observation" column in the OB file contains the raw variable values
2. Kx index needs to be involved, but not for 1st prototype
3.	Use range of +/- 10 when mapping OBS pressure to GI.  So, if GI is 850, take OBS 840 to 860
4.	Conventional (u,v,q,t,s) variables are more important in the short-term than non-conventional
5.  Radiosondes have a range of data because they collect data in a moving column
6.	ERA5/GI is a subset of pressure levels in OBS (13 vs. > 13)
7.	Map OBS temperature to 1D Temperature in GI, not 2TM
8.	Override existing values in GI with average range, not all values between levels
9.	The landmask variable [1,0] indicates whether cell is over land or not (not used in prototype)
10.	Iterate with one variable replacement at a time staring with Pressure = 850
11.	Time [-3..3] delta from Xz (not used for prototype because it will be in 6hr batches like GI)

ToDo:  
1. Add support for multiple levels
2. Add support for multiple variables
3. Add support for multiple time steps
4. Enhance QA
5. Add visualization for scattered and gridded values
6. Containerize workflow functionality
7. Make inputs configurable (e.g., variable, pressure level(s) of interest, time, etc.)
8. Replace Data Dictionary below with configuration file

In [None]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import xarray as xr

from scipy.interpolate import griddata
import matplotlib.pyplot as plt # Optional: for visualization
import cartopy.crs as ccrs
import hvplot.xarray
from pathlib import Path
import subprocess,os

Specify input files:
     OB_file = Global observation file that contains the Temperature parameter at a specific timestep and pressure level (850 hPa)
     GI_file = Global ERA5 source input file for GraphCast at the same timestep (superset of parameters at 13 levels)
     GI_OBS_file  = GI_file containing Temperature values overridden by corresponding lat/lon cell averages from the 840-860 level range in OB_file

NOTES/ASSUMPTIONS: 
- Since the input files have the same lat/lon values, regridding is not necessary.  
- The OB_file usually contains multiple levels & parameter values per lat/lon cell
- We are collecting these values within a window of level +/- 10 degrees as specified by Amal.

In [None]:
##########################################

<!-- /discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_aura_t_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_ges.20220101_00z.bin
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_gps_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_prof_t_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_prof_uv_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_ps_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_q_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_tcp_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4
/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_uv_ges.20220101_00z.nc4 -->

Data Dictionary 
- Modify values in the next two cells to direct Notebook
- See comments for variable definitions

In [None]:

# Indices into dataset wih [example value]
index_short_name = 0 #['t']
index_long_name = 1 #['Temperature']
index_level_of_interest = 2 # level value in dataset [850] 
index_GI_level = 3 # level scalar index in dataset [10] 
index_lower_bound = 4 # upper bound of window range above 'index_level_of_interest' [10]
index_upper_bound = 5 # lower bound of window range below 'index_level_of_interest' [10]
index_OB_scale_factor = 6 # align level value with OBS units [100]
index_OB_file = 7 # index of OB file in OB dataset catalog where key = variable+level
index_GI_file = 8 # index of GI file in GI dataset catalog where key = variable+level
index_GI_OBS_file = 9 # index of GI OBS file in GI OBS dataset catalog where key = variable+level
index_GI_PRED_file = 10 # index of GI PRED file (original) in GI PRED dataset catalog where key = variable+level
index_GI_OBS_PRED_file = 11 # index of GI PRED OBS file (obs-modified GI) in GI PRED dataset catalog where key = variable+level

ds_obs = {} # dataset catalog for obs [OBS] files
ds_gis = {} # dataset catalog for original GraphCast input (GI)
ds_gi_mods = {} # dataset catalog for modified GraphCast input (GI)
cell_averages_obs = {} # collection of averaged observation values
cell_averages_gi = {} # collection of GI cells to be updated with cell_averages_obs
subdir = '20260108b' # directory name for output
container = '/discover/nobackup/projects/QEFM/qefm-core/../containers/qefm-core-debian-all-aifs-20250609-sandbox' # Graphcast runtime
override = True # # Flag indicating whether to overwrite outputs (set to False for faster execution)
show_scatter_plot = True # Flag for scatter plot of observation values
current_dd = None

output_dir = '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/' # root output directory
os.makedirs(output_dir, exist_ok=True) # create output path it non-existent

Define metadata for run 

In [None]:
dd_850 = [
  ['t', # short name
  'temperature', # long name
  850, # primary level of interest in OBS file (hPa)
  10, # level index in GI ERA5 (int)
  10, # lower bound (window around primary level (degrees)
  10, # upper bound (window around primary level (degrees)
  100, # OBS scale factor
  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_850.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/qefm/models/checkpoints/graphcast/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_850.nc',
 ],
]

dd_single = [
  ['t', # short name
  'temperature', # long name
  500, # primary level of interest in OBS file (hPa)
  7, # level index in GI ERA5 (int)
  10, # lower bound (window around primary level (degrees)
  10, # upper bound (window around primary level (degrees)
  100, # OBS scale factor
  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_500.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/qefm/models/checkpoints/graphcast/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_500.nc',
 ],
]

dd_multi = [
  ['t', # short name
  'temperature', # long name
  500, # primary level of interest in OBS file (hPa)
  7, # level index in GI ERA5 (int)
  10, # lower bound (window around primary level (degrees)
  10, # upper bound (window around primary level (degrees)
  100, # OBS scale factor
  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_500.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/qefm/models/checkpoints/graphcast/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_500.nc',
 ],
  ['t', # short name
  'temperature', # long name
  850, # primary level of interest in OBS file (hPa)
  10, # level index in GI ERA5 (int)
  10, # lower bound (window around primary level (degrees)
  10, # upper bound (window around primary level (degrees)
  100, # OBS scale factor
  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_850.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/qefm/models/checkpoints/graphcast/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_850.nc',
 ],
  ['t', # short name
  'temperature', # long name
  925, # primary level of interest in OBS file (hPa)
  11, # level index in GI ERA5 (int)
  10, # lower bound (window around primary level (degrees)
  10, # upper bound (window around primary level (degrees)
  100, # OBS scale factor
  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_925.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/qefm/models/checkpoints/graphcast/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',
  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/'+subdir+'/pred-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_925.nc',
 ],
 #  ['t', # short name
 #  'temperature', # long name
 #  1000, # primary level of interest in OBS file (hPa)
 #  12, # level index in GI ERA5 (int)
 #  10, # lower bound (window around primary level (degrees)
 #  10, # upper bound (window around primary level (degrees)
 #  100, # OBS scale factor
 #  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_t_ges.20220101_00z.nc4', # OB
 #  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc', # GI
 #  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_1000.nc',
 # ],
 # ['q', # short name
 #  'specific_humidity', # long name
 #  850, # primary level of interest in OBS file (hPa)
 #  10, # level index in GI ERA5 (int)
 #  10, # lower bound (window around primary level (degrees)
 #  10, # upper bound (window around primary level (degrees)
 #  100, # OBS scale factor
 #  '/discover/nobackup/projects/gmao/merra21c/TSE_staging/e5303_m21c_jan18/archive/obs/Y2022/M01/D01/H00/e5303_m21c_jan18.diag_conv_q_ges.20220101_00z.nc4', # OB
 #  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc',# GI
 #  '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_q_850.nc',
 # ]
]


Show scatter plot of OB values:

In [None]:
def show_OB_scatter_plot(obs_path, long_name):
    
    ds_ob = xr.open_dataset(obs_path)
    observation_values = ds_ob["Observation"].values
    lats = ds_ob["Latitude"].values
    lons = ds_ob["Longitude"].values
    times = ds_ob["Time"].values
    values = observation_values
    
    # 1: Define a single dimension for the unstructured data ---
    combined_data = xr.DataArray(
        data=observation_values,        # This is 1D data
        dims=['nobs'],             # Specify the single dimension name
        coords={                        # Attach 1D coordinates to that dimension
            'lat': (('nobs',), lats),
            'lon': (('nobs',), lons),
            'time': (('nobs',), times)
        },
        name='Observations'
    )
    
    # 2. Create the plot using matplotlib and cartopy for a map
    plt.figure(figsize=(10, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_global()
    #ax.set_extent([-130, -60, 20, 55])
    
    #ax.coastlines(resolution='50m')
    ax.gridlines(draw_labels=True)
    
    # 3. Use a SCATTER plot for 1D, unstructured data
    # The 'c' argument provides the color mapping based on the observation value
    scatter = ax.scatter(lons, lats, c=values, cmap='viridis', s=1, transform=ccrs.PlateCarree())
    
    # Add a color bar
    plt.colorbar(scatter, label=long_name, ax=ax)
    #plt.colorbar(scatter, label=f'{combined_data.name} ({combined_data.units})', ax=ax)
    
    plt.title(f'Observation Scatter Plot ({values.size} points)')
    plt.figtext(0.5, 0.01, f'Observation Plot \n [{obs_path}]', ha="center", fontsize=10, color="blue")
    #(f'Observation Plot \n [{OB_file}] ({values.size} points)')
    plt.show()

    return ds_ob

Plot the differences between OB (ERA+OBS) and GI (ERA):

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from pathlib import Path

def plot_era5_obs_differences(era5_file, era5_obs_file, output_dir=None, 
                             levels=None, parameters=None, 
                             time_index=0, batch_index=0,
                             vmin=-0.5, vmax=0.5, cmap='coolwarm',
                             figsize_per_plot=(4, 3), save_plots=True):
    """
    Create difference plots between ERA5 and ERA5+OBS data in rows by level and parameter.
    
    Parameters:
    -----------
    era5_file : str
        Path to original ERA5 file
    era5_obs_file : str  
        Path to ERA5+OBS modified file
    output_dir : str, optional
        Directory to save plots (default: same as era5_obs_file directory)
    levels : list, optional
        Pressure levels to plot (default: all available)
    parameters : list, optional
        Variables to plot (default: ['temperature', 'specific_humidity'])
    time_index : int
        Time step index to plot (default: 0)
    batch_index : int
        Batch index to plot (default: 0)
    vmin, vmax : float
        Color scale limits for difference plots
    cmap : str
        Colormap for difference plots
    figsize_per_plot : tuple
        Size of each subplot (width, height)
    save_plots : bool
        Whether to save plots to files
    
    Returns:
    --------
    dict : Dictionary containing difference datasets and statistics
    """
    
    # Load datasets
    print("Loading datasets...")
    # ds_era5 = xr.open_dataset(era5_file)
    # ds_era5_obs = xr.open_dataset(era5_obs_file)
    ds_era5 = ds_gis[f"{short_name}_{str(level_of_interest)}_ds"]
    ds_era5_obs = ds_gi_mods[f"{short_name}_{str(level_of_interest)}_ds"]
    
    # Set default parameters if not specified
    if parameters is None:
        # Find common 3D variables between datasets
        common_vars = []
        for var in ['temperature', 'specific_humidity', 'geopotential', 
                   'u_component_of_wind', 'v_component_of_wind']:
            if var in ds_era5.data_vars and var in ds_era5_obs.data_vars:
                if 'level' in ds_era5[var].dims:
                    common_vars.append(var)
        parameters = common_vars
    
    # Set default levels if not specified
    if levels is None:
        levels = ds_era5.level.values
    
    # Set output directory
    if output_dir is None:
        output_dir = Path(era5_obs_file).parent
    else:
        output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Storage for results
    results = {
        'differences': {},
        'statistics': {},
        'plots_saved': []
    }
    
    print(f"Creating plots for {len(parameters)} parameters and {len(levels)} levels...")
    
    # Create plots for each parameter
    for param_idx, param in enumerate(parameters):
        print(f"Processing parameter: {param}")
        
        # Check if parameter exists in both datasets
        if param not in ds_era5.data_vars or param not in ds_era5_obs.data_vars:
            print(f"Warning: {param} not found in both datasets, skipping...")
            continue
            
        # Check if parameter has level dimension
        if 'level' not in ds_era5[param].dims:
            print(f"Warning: {param} has no level dimension, skipping...")
            continue
        
        # Filter levels that exist for this parameter
        available_levels = ds_era5[param].level.values
        plot_levels = [lvl for lvl in levels if lvl in available_levels]
        
        if not plot_levels:
            print(f"Warning: No matching levels found for {param}, skipping...")
            continue
        
        # Calculate figure size
        n_levels = len(plot_levels)
        fig_width = figsize_per_plot[0] * n_levels
        fig_height = figsize_per_plot[1]
        
        # Create figure with subplots
        fig, axes = plt.subplots(1, n_levels, 
                                figsize=(fig_width, fig_height),
                                subplot_kw={'projection': ccrs.PlateCarree()})
        
        # Handle case where only one level (axes not a list)
        if n_levels == 1:
            axes = [axes]
        
        # Storage for this parameter's data
        param_diffs = {}
        param_stats = {}
        
        # Plot each level
        for level_idx, level in enumerate(plot_levels):
            ax = axes[level_idx]
            
            try:
                # Extract data for this level and time
                era5_data = ds_era5[param].sel(level=level).isel(time=time_index, batch=batch_index)
                era5_obs_data = ds_era5_obs[param].sel(level=level).isel(time=time_index, batch=batch_index)
                
                # Calculate difference (ERA5 - ERA5+OBS)
                diff = era5_data - era5_obs_data
                
                # Store difference data
                param_diffs[level] = diff
                
                # Calculate statistics
                stats = {
                    'mean': float(diff.mean().values),
                    'std': float(diff.std().values),
                    'min': float(diff.min().values),
                    'max': float(diff.max().values),
                    'rmse': float(np.sqrt((diff**2).mean()).values)
                }
                param_stats[level] = stats
                
                # Create plot
                im = diff.plot(ax=ax, 
                              transform=ccrs.PlateCarree(),
                              vmin=vmin, vmax=vmax, 
                              cmap=cmap,
                              add_colorbar=False)
                
                # Add map features
                ax.coastlines(resolution='50m', alpha=0.5)
                ax.gridlines(draw_labels=False, alpha=0.3)
                
                # Set title for subplot
                ax.set_title(f'{level} hPa\n'
                           f'Mean: {stats["mean"]:.3f}\n'
                           f'RMSE: {stats["rmse"]:.3f}', 
                           fontsize=8)
                
                print(f"  Level {level} hPa: Mean diff = {stats['mean']:.4f}, "
                      f"RMSE = {stats['rmse']:.4f}")
                
            except Exception as e:
                print(f"Error processing {param} at {level} hPa: {e}")
                ax.text(0.5, 0.5, f'Error\n{level} hPa', 
                       transform=ax.transAxes, ha='center', va='center')
        
        # Add overall title and colorbar
        param_clean = param.replace('_', ' ').title()
        fig.suptitle(f'{param_clean} Differences (ERA5 - ERA5+OBS)\n'
                    f'Time: {ds_era5.time.values[time_index]} | '
                    f'File: {Path(era5_obs_file).name}', 
                    fontsize=10, y=0.95)
        
        # Add colorbar
        plt.tight_layout()
        cbar = plt.colorbar(im, ax=axes, orientation='horizontal', 
                           pad=0.1, shrink=0.8, aspect=30)
        cbar.set_label(f'{param_clean} Difference', fontsize=9)
        
        # Save plot
        if save_plots:
            plot_filename = f'diff_{param}_levels_era5_vs_era5obs_{short_name}_{str(level_of_interest)}.png'
            plot_path = output_dir / plot_filename
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            results['plots_saved'].append(str(plot_path))
            print(f"Plot saved: {plot_path}")
        
        plt.show()
        
        # Store results
        results['differences'][param] = param_diffs
        results['statistics'][param] = param_stats
    
    # Create summary statistics table
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    for param in results['statistics']:
        print(f"\n{param.upper().replace('_', ' ')}")
        print("-" * 40)
        print(f"{'Level (hPa)':<12} {'Mean':<8} {'RMSE':<8} {'Min':<8} {'Max':<8}")
        print("-" * 40)
        
        for level in sorted(results['statistics'][param].keys(), reverse=True):
            stats = results['statistics'][param][level]
            print(f"{level:<12} {stats['mean']:<8.3f} {stats['rmse']:<8.3f} "
                  f"{stats['min']:<8.3f} {stats['max']:<8.3f}")
    
    # Close datasets
    ds_era5.close()
    ds_era5_obs.close()
    
    return results

def plot_single_parameter_all_levels(era5_file, era5_obs_file, parameter='temperature',
                                   levels=None, time_index=0, batch_index=0,
                                   figsize=(15, 3), save_plot=True, output_dir=None):
    """
    Create a single row of plots for one parameter across all levels.
    
    Parameters:
    -----------
    era5_file : str
        Path to original ERA5 file
    era5_obs_file : str
        Path to ERA5+OBS modified file
    parameter : str
        Variable name to plot
    levels : list, optional
        Pressure levels to plot (default: all available)
    time_index : int
        Time step index to plot
    batch_index : int
        Batch index to plot
    figsize : tuple
        Figure size (width, height)
    save_plot : bool
        Whether to save the plot
    output_dir : str, optional
        Directory to save plot
    """
    
    return plot_era5_obs_differences(
        era5_file=era5_file,
        era5_obs_file=era5_obs_file,
        output_dir=output_dir,
        levels=levels,
        parameters=[parameter],
        time_index=time_index,
        batch_index=batch_index,
        figsize_per_plot=(figsize[0]/len(levels) if levels else 3, figsize[1]),
        save_plots=save_plot
    )

def plot_single_parameter_at_level(era5_file, era5_obs_file, parameter='temperature',
                                   levels=None, time_index=0, batch_index=0,
                                   figsize=(15, 15), save_plot=True, output_dir=None):
#                                   figsize=(15, 3), save_plot=True, output_dir=None):
    """
    Create a single row of plots for one parameter across all levels.
    
    Parameters:
    -----------
    era5_file : str
        Path to original ERA5 file
    era5_obs_file : str
        Path to ERA5+OBS modified file
    parameter : str
        Variable name to plot
    levels : list, optional
        Pressure levels to plot (default: all available)
    time_index : int
        Time step index to plot
    batch_index : int
        Batch index to plot
    figsize : tuple
        Figure size (width, height)
    save_plot : bool
        Whether to save the plot
    output_dir : str, optional
        Directory to save plot
    """
    
    return plot_era5_obs_differences(
        era5_file=era5_file,
        era5_obs_file=era5_obs_file,
        output_dir=output_dir,
        levels=levels,
        parameters=[parameter],
        time_index=time_index,
        batch_index=batch_index,
        figsize_per_plot=(figsize[0]/len(levels) if levels else 3, figsize[1]),
        save_plots=save_plot
    )
# Example usage functions for your specific data
def analyze_temperature_modifications(era5_file, era5_obs_file, output_dir=None):
    """
    Analyze temperature modifications across multiple pressure levels.
    """
    temp_levels = [200, 300, 500, 700, 850, 925, 1000]  # Common atmospheric levels
    
    results = plot_era5_obs_differences(
        era5_file=era5_file,
        era5_obs_file=era5_obs_file,
        output_dir=output_dir,
        levels=temp_levels,
        parameters=['temperature'],
        vmin=-1.0, vmax=1.0,  # Adjust range for temperature
        figsize_per_plot=(3, 2.5)
    )
    
    return results

def analyze_all_modifications(era5_file, era5_obs_file, output_dir=None):
    """
    Analyze all available parameter modifications.
    """
    results = plot_era5_obs_differences(
        era5_file=era5_file,
        era5_obs_file=era5_obs_file,
        output_dir=output_dir,
        parameters=['temperature', 'specific_humidity', 'geopotential'],
        figsize_per_plot=(2.5, 2)
    )
    
    return results

# # Usage example for your data:
# if __name__ == "__main__":
#     # Your file paths
#     era5_original = '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/input/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc'
#     era5_modified = '/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/graphcast/source-era5_date-2022-01-01_res-1.0_levels-13_steps-01_obs_t_850.nc'
    
#     # Analyze temperature modifications
#     results = analyze_temperature_modifications(
#         era5_file=era5_original,
#         era5_obs_file=era5_modified,
#         output_dir='/discover/nobackup/projects/QEFM/qefm-core/data/NSE/output/plots'
#     )
    
#     # Print summary
#     print("\nAnalysis complete!")
#     print(f"Plots saved: {len(results['plots_saved'])}")
#     for plot_path in results['plots_saved']:
#         print(f"  - {plot_path}")

Calculate average value per cell for points that fall within range of the level window

In [None]:
def calculate_avg_per_cell(ds_ob, variable, level_of_interest, level_index, lower_bound, upper_bound, scale_factor):
    level_of_interest = level_of_interest
    level_index_GI = level_index

    unscaled_lower_bound = int(level_of_interest) - int(lower_bound)
    unscaled_upper_bound = int(level_of_interest) + int(upper_bound)
    lower_bound = unscaled_lower_bound * int(scale_factor)
    upper_bound = unscaled_upper_bound * int(scale_factor)
    
    print("\nFilter window: ", lower_bound, "-", upper_bound)
    filtered_df = ds_ob[['Latitude', 'Longitude', 'Pressure', 'Observation']].to_dataframe()
    
    #print(filtered_df.head())
    filtered_df = filtered_df[(filtered_df['Pressure'] >= lower_bound) & (filtered_df['Pressure'] <= upper_bound)].copy()
#    print(filtered_df.head())
    
    # Calculate the maximum value of the 'observation' column within this range
    max_obs = filtered_df['Observation'].max()
    min_obs = filtered_df['Observation'].min()
    print(f"The Min observation value for Pressure between {lower_bound} and {upper_bound} is: {min_obs}")
    print(f"The Max observation value for Pressure between {lower_bound} and {upper_bound} is: {max_obs}")
    
    # 1. Use .loc when adding new columns to the copied DataFrame
    # Round Latitude and Longitude to the nearest 1.0 degree to define grid cells
    filtered_df.loc[:, 'lat_grid'] = filtered_df['Latitude'].round()
    filtered_df.loc[:, 'lon_grid'] = (filtered_df['Longitude'] % 360).round()
    #TODO:  Consider applying weights here
    # 2. Group by the new grid coordinates and calculate the mean
    cell_averages = filtered_df.groupby(['lat_grid', 'lon_grid'])['Observation'].mean().reset_index()
    
    # 3. Optional: Rename columns for clarity
    avg_observation_label = "Avg_Observation_"+variable+"_"+str(level_of_interest)
    cell_averages.columns = ['Latitude', 'Longitude', avg_observation_label]
    
#    print('\n'+ str(cell_averages.head()))
    return cell_averages
    

Identify cells to be overriden with avg values

In [None]:
def filter_cells(ds_gi, variable, cell_averages, level_of_interest):
    # Initialize an empty list to store the GI values
    gi_background_values = []
    
    # Ensure the GI level dimension has the exact value you need
    # If 'level' is an int array like [850, ...], this works:
    if level_of_interest not in ds_gi.level.values:
         raise ValueError(f"Level {int(level_of_interest)} not found in GI dataset levels: {ds_gi.level.values}")
    
    # Loop through each averaged observation cell
    for index, row in cell_averages.iterrows():
        lat_val = row['Latitude']
        lon_val = row['Longitude']
    #    if (index == 0): 
    #        print("First row in GI:\n", row)
            # print(lat_val)
            # print(lon_val)
        
        # Use .sel() with method='nearest' if your rounded grid centers don't perfectly align 
        # with the exact GI coordinates (which are likely 90, 89, ..., -90)
        gi_val = ds_gi.sel(
            lat=lat_val, 
            lon=lon_val, 
            level=level_of_interest, # Select the specific level you filtered by
            time=ds_gi.time.values[0], # Select the first (or relevant) time step
            method='nearest' # Ensures it finds the closest GI grid cell value
        ).temperature.values
        
        gi_background_values.append(gi_val)
#        if (index == 0): 
#            print("First average in GI:\n", lat_val, lon_val, level_of_interest, gi_val)
        
    # print("GI Min/MaxOb: " + str(min(gi_background_values)) + " " + str(max(gi_background_values)) + " Delta= " + str(max(gi_background_values) - min(gi_background_values)))
    print("Number of Points to modify: " + str(len(gi_background_values)))
    
    # Add the GI values back to your DataFrame
    cell_averages[f"GI_Background_{variable}_{level_of_interest}"] = np.array(gi_background_values)
#    ds_gis[f"GI_Background_{short_name}"] = ds_gi
    
    #print("New cell_averages:\n", cell_averages['GI_Background_Temp'].head())
    ob_bias_label = 'Observation_Bias_'+variable+'_'+str(level_of_interest)
    avg_obs_label =  'Avg_Observation_'+variable+'_'+str(level_of_interest)
    gi_background_label =  'GI_Background_'+variable+'_'+str(level_of_interest)
    #print(ob_bias_label, avg_obs_label, avg_obs_label, gi_background_label)
    cell_averages[ob_bias_label] = cell_averages[avg_obs_label] - cell_averages[gi_background_label]
    
#    print(cell_averages.head())


Override corresponding values in GI with OBS avg values

In [None]:
def override_cells(ds_gi, variable, cell_averages, level_of_interest, output_path):
    # 1. Create a deep copy of GI to avoid modifying the original in-memory data
    ds_gi_modified = ds_gi.copy(deep=True)
    
    # 2. Assign the new values to the dataset
    # We assume 'level_of_interest' (e.g., 850) matches a coordinate in ds_gi.level
    #level_of_interest = 850 
    
    # Iterate through the averaged cells and update the 'temperature' variable
    for index, row in cell_averages.iterrows():
        # .loc allows assignment based on coordinate labels
        # if (index == 0): 
        #     print(row)
        ds_gi_modified.temperature.loc[{
            'batch': 0, 
            'time': ds_gi.time.values[0], 
            'level': level_of_interest,
            'lat': row['Latitude'], 
            'lon': row['Longitude']
        }] = row['Avg_Observation_'+variable+"_"+str(level_of_interest)]
    
    # 3. Write the modified dataset to a new NetCDF file
    ds_gi_modified.to_netcdf(output_path)
    
    print(f"File saved successfully to {output_path}")
    return ds_gi_modified


Calculate difference between datasets for a specifc variable and level

In [None]:
def calculate_diff (ds1, ds2, variable_name, level_of_interest, level_index):
    
    if variable_name in ds1 and variable_name in ds2:
        # Calculate the difference: file2 - file1
        diff = ds2[variable_name] - ds1[variable_name]
    
        var1 = ds1[variable_name].isel(time=0).sel(level=level_of_interest)
        var2 = ds2[variable_name].isel(time=0).sel(level=level_of_interest)
    
        diff = var1 - var2
    else:
        print(f"Variable '{variable_name}' not found in both files. Check variable names.")
    
    difference=diff

Run GraphCast to create new Prediction based on modified input state (GI OBS)

In [None]:
def run_graphcast (container, gi_obs_path, gi_obs_pred_path):

    if not Path(gi_obs_pred_path).exists() or override:
        print("Run GraphCast now to create: ", gi_obs_pred_path)
        # Construct the command
        # 'time' is a shell-builtin, so shell=True is required to use it directly like this
        command = (
            f"time /usr/local/other/singularity/4.0.3/bin/singularity exec --nv "
            f"-B /home/gtamkin,/discover/nobackup/projects/QEFM/qefm-core,/discover/nobackup/gtamkin "
            f"{container} python /discover/nobackup/projects/QEFM/qefm-core/qefm/models/src/FMGraphCast/fm_graphcast_nse.py "
            f"{gi_obs_path} {gi_obs_pred_path}"
        )
    
        # Run the command
        subprocess.run(command, shell=True)
    else:
        print("GraphCast prediction exists and override is False. Skipping: " + gi_obs_pred_path)

In [None]:
#################################################
#        MAIN WORKFLOW                          #
#################################################

Load observational Temperature parameter values at a specific timestep (20220101_00z) and pressure level window (lower_bound, upper_bound).  Then calculate the mean of these Temperature observational values per cell.

Load ERA5 Temperature parameter values at a specific timestep (20220101_00z) and pressure level window (lower_bound, upper_bound)

Loop through the cell averages and create a corresponding array to override the GI file wth.  Check the bias.

Loop through lat/lon cells of GI file and replace values with the OBS average per cell calculated above

Calculate difference statistics between original ERA5 source and obs-modified source

Run GraphCast to produce OBs-augmented prediction

In [None]:
current_dd = dd_multi
for entry in current_dd:
    short_name   = entry[index_short_name]
    long_name    = entry[index_long_name]
    obs_path     = entry[index_OB_file]
    gi_path      = entry[index_GI_file]
    level_of_interest          = entry[index_level_of_interest]
    level_index  = entry[index_GI_level]
    lower_bound  = entry[index_lower_bound]
    upper_bound  = entry[index_upper_bound] 
    scale_factor = entry[index_OB_scale_factor] 
    gi_obs_path  = entry[index_GI_OBS_file] 
    gi_pred_path = entry[index_GI_PRED_file] 
    gi_obs_pred_path    = entry[index_GI_OBS_PRED_file] 
    
    print(f"Processing {long_name} ({short_name}) at {level_of_interest}hPa...")
    print(f" - Observation File: {obs_path}")
    print(f" - ERA5 Input File:  {gi_path}")
    print("-" * 30)
    
    # Open OB dataset - don't forget to close if unneeded
    ds_ob = xr.open_dataset(obs_path)
    key_index = f"{short_name}_{str(level_of_interest)}"
    ds_obs[f"{key_index}_ds"] = ds_ob

    # Calculate average value per cell at prirmary lovel of interest windo
    cell_averages_ob = calculate_avg_per_cell(ds_ob, short_name, level_of_interest, level_index, lower_bound, upper_bound, scale_factor)
    cell_averages_obs[f"{key_index}_ds"] = cell_averages_ob
    
    # Open GI dataset - don't forget to close if unneeded
    ds_gi = xr.open_dataset(gi_path)
    ds_gis[f"{key_index}_ds"] = ds_gi
   
    gi_cell_averages = filter_cells(ds_gi, short_name, cell_averages_ob, level_of_interest)
    
    ds_gi_mods[f"{key_index}_ds"] = override_cells(ds_gi, short_name, cell_averages_ob, level_of_interest, gi_obs_path)

    calculate_diff(ds_gi, ds_gi_mods[f"{key_index}_ds"], long_name, level_of_interest, level_index)

    run_graphcast (container, gi_obs_path, gi_obs_pred_path)
    
    # Plot variable on current level:
    temp_results = plot_single_parameter_at_level(
        era5_file=short_name,
        era5_obs_file=short_name,
        levels=[level_of_interest],
        parameter=long_name,
        output_dir=output_dir
    )


In [None]:
##########################################

Quality Assurance plots:

In [None]:
if show_scatter_plot == True: 
    show_OB_scatter_plot(obs_path, long_name)
    show_scatter_plot = False


Compare Prediction results in output netcdf files.

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from pathlib import Path

def analyze_multi_level_impact_optimized(ds_original, ds_modified, levels=[500, 850, 925], 
                                       parameter='temperature', figsize=(15, 4)):
    """
    Optimized analysis of multi-level observation impacts with single-row visualization.
    
    Parameters:
    -----------
    ds_original : xarray.Dataset
        Original ERA5 dataset
    ds_modified : xarray.Dataset  
        Modified ERA5+OBS dataset
    levels : list
        Pressure levels to analyze
    parameter : str
        Variable to analyze
    figsize : tuple
        Figure size for plots
        
    Returns:
    --------
    dict : Analysis results with statistics and plots
    """
    
    print(f"Analyzing {parameter} impact across levels: {levels}")
    
    # Initialize results
    results = {'statistics': {}, 'plots': None}
    
    # Set up single-row plot
    n_levels = len(levels)
    fig, axes = plt.subplots(1, n_levels, figsize=figsize,
                            subplot_kw={'projection': ccrs.PlateCarree()})
    
    # Handle single level case
    if n_levels == 1:
        axes = [axes]
    
    # Process each level
    for i, level in enumerate(levels):
        
        # Extract data for this level (first time step, first batch)
        orig_data = ds_original[parameter].sel(level=level).isel(time=0, batch=0)
        mod_data = ds_modified[parameter].sel(level=level).isel(time=0, batch=0)
        
        # Calculate impact (difference)
        impact = mod_data - orig_data
        
        # Compute comprehensive statistics
        stats = {
            'mean_impact': float(impact.mean()),
            'std_impact': float(impact.std()), 
            'rmse_impact': float(np.sqrt((impact**2).mean())),
            'max_impact': float(impact.max()),
            'min_impact': float(impact.min()),
            'cells_modified': int((impact != 0).sum()),
            'total_cells': int(impact.size)
        }
        
        # Derived metrics
        stats['modification_fraction'] = stats['cells_modified'] / stats['total_cells']
        stats['impact_consistency'] = (abs(stats['mean_impact']) / stats['std_impact'] 
                                     if stats['std_impact'] != 0 else float('inf'))
        
        # Verify RMSE relationship
        expected_rmse = np.sqrt(stats['mean_impact']**2 + stats['std_impact']**2)
        stats['rmse_verification'] = abs(stats['rmse_impact'] - expected_rmse) < 1e-6
        
        # Classification
        if abs(stats['mean_impact']) > 2 * stats['std_impact']:
            stats['impact_type'] = 'Systematic bias correction'
        elif stats['std_impact'] > 2 * abs(stats['mean_impact']):
            stats['impact_type'] = 'Spatially variable corrections'
        else:
            stats['impact_type'] = 'Mixed systematic and variable'
            
        results['statistics'][level] = stats
        
        # Create plot for this level
        ax = axes[i]
        
        # Plot impact with appropriate color scale
        vmax = max(abs(stats['min_impact']), abs(stats['max_impact']))
        vmin = -vmax
        
        im = impact.plot(ax=ax, transform=ccrs.PlateCarree(),
                        cmap='coolwarm', vmin=vmin, vmax=vmax,
                        add_colorbar=False)
        
        # Add map features
        ax.coastlines(resolution='50m', alpha=0.7, linewidth=0.5)
        ax.gridlines(draw_labels=False, alpha=0.3)
        
        # Title with key statistics
        ax.set_title(f'{level} hPa\n'
                    f'Mean: {stats["mean_impact"]:.3f}K\n'
                    f'STD: {stats["std_impact"]:.3f}K\n',
                    # f'RMSE: {stats["rmse_impact"]:.3f}K',
                    fontsize=10)
        
        # Remove axis labels for cleaner look
        ax.set_xlabel('')
        ax.set_ylabel('')
    
    # Add overall title and colorbar
    param_title = parameter.replace('_', ' ').title()
    fig.suptitle(f'{param_title} Observation Impact Analysis\n'
                f'ERA5+OBS - ERA5 Differences', fontsize=12, y=0.98)
    
    # Add single colorbar for all subplots
    plt.tight_layout()
    cbar = plt.colorbar(im, ax=axes, orientation='horizontal', 
                       pad=0.1, shrink=0.8, aspect=40)
    cbar.set_label(f'{param_title} Change (K)', fontsize=10)
    
    results['plots'] = fig
    
    # Print summary statistics
    print("\n" + "="*80)
    print("MULTI-LEVEL IMPACT ANALYSIS SUMMARY")
    print("="*80)
    print(f"{'Level':<8} {'Mean':<8} {'STD':<8} {'RMSE':<8} {'Modified':<8} {'Type':<25}")
    print("-"*80)
    
    for level in levels:
        stats = results['statistics'][level]
        print(f"{level:<8} {stats['mean_impact']:<8.3f} {stats['std_impact']:<8.3f} "
              f"{stats['rmse_impact']:<8.3f} {stats['cells_modified']:<8} "
              f"{stats['impact_type']:<25}")
    
    print("\nKey Interpretations:")
    print("- Mean: Systematic bias correction magnitude")  
    print("- STD:  Spatial variability of corrections")
    print("- RMSE: Total impact magnitude (√(Mean² + STD²))")
    print("- Modified: Number of grid cells with non-zero changes")
    
    plt.show()
    
    return results

def compare_rmse_vs_std_components(results, levels=[500, 850, 925]):
    """
    Create focused comparison of RMSE decomposition
    """
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Extract statistics
    mean_values = [results['statistics'][level]['mean_impact'] for level in levels]
    std_values = [results['statistics'][level]['std_impact'] for level in levels]
    rmse_values = [results['statistics'][level]['rmse_impact'] for level in levels]
    
    # Plot 1: Components comparison
    x = np.arange(len(levels))
    width = 0.25
    
    ax1.bar(x - width, np.abs(mean_values), width, label='|Mean| (Systematic)', alpha=0.8)
    ax1.bar(x, std_values, width, label='STD (Variability)', alpha=0.8)
    ax1.bar(x + width, rmse_values, width, label='RMSE (Total)', alpha=0.8)
    
    ax1.set_xlabel('Pressure Level (hPa)')
    ax1.set_ylabel('Temperature Impact (K)')
    ax1.set_title('RMSE Decomposition by Level')
    ax1.set_xticks(x)
    ax1.set_xticklabels(levels)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: RMSE verification
    calculated_rmse = [np.sqrt(mean**2 + std**2) for mean, std in zip(mean_values, std_values)]
    
    ax2.scatter(rmse_values, calculated_rmse, s=100, alpha=0.7)
    ax2.plot([0, max(rmse_values)], [0, max(rmse_values)], 'r--', alpha=0.7)
    
    ax2.set_xlabel('Actual RMSE')
    ax2.set_ylabel('Calculated RMSE (√(Mean² + STD²))')
    ax2.set_title('RMSE Verification')
    ax2.grid(True, alpha=0.3)
    
    # Annotate points
    for i, level in enumerate(levels):
        ax2.annotate(f'{level} hPa', (rmse_values[i], calculated_rmse[i]), 
                    xytext=(5, 5), textcoords='offset points')
    
    plt.tight_layout()
    plt.show()

# Usage with your existing workflow data
def run_optimized_analysis():
    """
    Run optimized analysis using your existing datasets
    """
    
    # Using your existing data structures
    levels_to_analyze = [500, 850, 925]
    
    for level in levels_to_analyze:
        key_index = f"t_{level}"
        
        if f"{key_index}_ds" in ds_gis and f"{key_index}_ds" in ds_gi_mods:
            
            print(f"\n{'='*60}")
            print(f"ANALYZING LEVEL {level} hPa")
            print('='*60)
            
            # Run single-level analysis
            results = analyze_multi_level_impact_optimized(
                ds_original=ds_gis[f"{key_index}_ds"],
                ds_modified=ds_gi_mods[f"{key_index}_ds"], 
                levels=[level],
                parameter='temperature',
                figsize=(6, 4)
            )
    
    # Run multi-level comparison if you have aggregated dataset
    if 'aggregated_ds' in globals():
        print(f"\n{'='*60}")
        print("MULTI-LEVEL AGGREGATED ANALYSIS")
        print('='*60)
        
        multi_results = analyze_multi_level_impact_optimized(
            ds_original=ds_gis[f"t_{levels_to_analyze[0]}_ds"],  # Use any as template
            ds_modified=aggregated_ds,
            levels=levels_to_analyze,
            parameter='temperature',
            figsize=(15, 4)
        )
        
        # Show RMSE decomposition
        compare_rmse_vs_std_components(multi_results, levels_to_analyze)
        
        return multi_results
    
    return None

# Execute the optimized analysis
if __name__ == "__main__":
    results = run_optimized_analysis()

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
from pathlib import Path

def plot_prediction_comparison_optimized(gi_pred_path, gi_obs_pred_path, 
                                       level_of_interest, level_index, 
                                       variable_name='temperature',
                                       figsize=(18, 5)):
    """
    Optimized function to plot Original, Modified, and Difference predictions in one row.
    
    Parameters:
    -----------
    gi_pred_path : str
        Path to original GraphCast prediction
    gi_obs_pred_path : str
        Path to observation-modified GraphCast prediction
    level_of_interest : int
        Pressure level of interest (hPa)
    level_index : int
        Index of the pressure level in dataset
    variable_name : str
        Variable to plot (default: 'temperature')
    figsize : tuple
        Figure size (width, height)
    
    Returns:
    --------
    dict : Contains datasets and statistics
    """
    
    print(f"Loading and comparing predictions at {level_of_interest} hPa...")
    
    # Load datasets
    try:
        ds_original = xr.open_dataset(gi_pred_path)
        ds_modified = xr.open_dataset(gi_obs_pred_path)
        print("✓ Datasets loaded successfully")
    except FileNotFoundError as e:
        print(f"✗ Error loading files: {e}")
        return None
    
    # Extract temperature data at specified level
    if 'level' in ds_original[variable_name].dims:
        temp_original = ds_original[variable_name].isel(level=level_index, time=0, batch=0)
        temp_modified = ds_modified[variable_name].isel(level=level_index, time=0, batch=0)
    else:
        temp_original = ds_original[variable_name].isel(time=0, batch=0)
        temp_modified = ds_modified[variable_name].isel(time=0, batch=0)
    
    # Calculate difference
    temp_difference = temp_original - temp_modified
    
    # Calculate statistics
    stats = {
        'mean_diff': float(temp_difference.mean()),
        'std_diff': float(temp_difference.std()),
        'rmse_diff': float(np.sqrt((temp_difference**2).mean())),
        'max_diff': float(temp_difference.max()),
        'min_diff': float(temp_difference.min()),
        'original_mean': float(temp_original.mean()),
        'modified_mean': float(temp_modified.mean())
    }
    
    print(f"Statistics: Mean diff = {stats['mean_diff']:.4f}K, "
          f"STD diff = {stats['std_diff']:.4f}K")
          # f"STD diff = {stats['std_diff']:.4f}K, "
          # f"RMSE diff = {stats['rmse_diff']:.4f}K")
    
    # Create single-row subplot
    fig, axes = plt.subplots(1, 3, figsize=figsize,
                            subplot_kw={'projection': ccrs.PlateCarree()})
    
    # Common plot settings
    plot_configs = [
        {
            'data': temp_original,
            'title': f'Original GraphCast\n{level_of_interest} hPa\nMean: {stats["original_mean"]:.2f}K',
            'cmap': 'RdYlBu_r',
            'vmin': None, 'vmax': None
        },
        {
            'data': temp_modified, 
            'title': f'Modified GraphCast\n{level_of_interest} hPa\nMean: {stats["modified_mean"]:.2f}K',
            'cmap': 'RdYlBu_r',
            'vmin': None, 'vmax': None
        },
        {
            'data': temp_difference,
            'title': f'Difference (Orig - Mod)\n{level_of_interest} hPa\nMean: {stats["mean_diff"]:.4f}K',
            'cmap': 'coolwarm',
            'vmin': -0.05, 'vmax': 0.05
        }
    ]
    
    # Create plots
    images = []
    for i, (ax, config) in enumerate(zip(axes, plot_configs)):
        
        # Plot data
        im = config['data'].plot(
            ax=ax, 
            transform=ccrs.PlateCarree(),
            cmap=config['cmap'],
            vmin=config['vmin'],
            vmax=config['vmax'],
            add_colorbar=False
        )
        images.append(im)
        
        # Add map features
        ax.coastlines(resolution='50m', alpha=0.7, linewidth=0.5)
        ax.gridlines(draw_labels=(i==0), alpha=0.3, x_inline=False, y_inline=False)
        ax.set_global()
        
        # Set title
        ax.set_title(config['title'], fontsize=11, pad=10)
        
        # Remove axis labels for cleaner look
        ax.set_xlabel('')
        ax.set_ylabel('')
    
    # Add overall title
    fig.suptitle(f'GraphCast Prediction Comparison - {variable_name.title()}\n'
                f'Level: {level_of_interest} hPa | '
                f'Difference Stats: Mean={stats["mean_diff"]:.3f}K, STD={stats["std_diff"]:.3f}K',
                # f'Difference Stats: Mean={stats["mean_diff"]:.3f}K, STD={stats["std_diff"]:.3f}K, RMSE={stats["rmse_diff"]:.3f}K',
                fontsize=12, y=0.95)
    
    # Add individual colorbars
    plt.tight_layout()
    
    # Colorbar for temperature plots (first two)
    temp_cbar = plt.colorbar(images[0], ax=axes[:2], orientation='horizontal', 
                            pad=0.05, shrink=0.8, aspect=30)
    temp_cbar.set_label(f'{variable_name.title()} (K)', fontsize=10)
    
    # Colorbar for difference plot
    diff_cbar = plt.colorbar(images[2], ax=axes[2], orientation='horizontal',
                            pad=0.05, shrink=0.8, aspect=15)
    diff_cbar.set_label('Temperature Difference (K)', fontsize=10)
    
    # Add file paths as footnote
    plt.figtext(0.02, 0.02, 
                f'Original: {Path(gi_pred_path).name}\n'
                f'Modified: {Path(gi_obs_pred_path).name}',
                fontsize=8, color='gray', va='bottom')
    
    plt.show()
    
    # Cleanup
    ds_original.close()
    ds_modified.close()
    
    return {
        'statistics': stats,
        'figure': fig,
        'temp_original': temp_original,
        'temp_modified': temp_modified,
        'temp_difference': temp_difference
    }

def plot_multi_level_predictions(prediction_files, level_configs, 
                                variable_name='temperature', figsize=(20, 12)):
    """
    Plot prediction comparisons for multiple levels in a grid layout.
    
    Parameters:
    -----------
    prediction_files : dict
        Dictionary with 'original' and levels as keys, file paths as values
    level_configs : list
        List of tuples: [(level_hPa, level_index), ...]
    variable_name : str
        Variable to plot
    figsize : tuple
        Figure size
    """
    
    n_levels = len(level_configs)
    fig, axes = plt.subplots(n_levels, 3, figsize=figsize,
                            subplot_kw={'projection': ccrs.PlateCarree()})
    
    if n_levels == 1:
        axes = axes.reshape(1, -1)
    
    for row, (level_hPa, level_idx) in enumerate(level_configs):
        
        if level_hPa not in prediction_files:
            continue
            
        # Load data
        ds_orig = xr.open_dataset(prediction_files['original'])
        ds_mod = xr.open_dataset(prediction_files[level_hPa])
        
        # Extract and plot
        temp_orig = ds_orig[variable_name].isel(level=level_idx, time=0, batch=0)
        temp_mod = ds_mod[variable_name].isel(level=level_idx, time=0, batch=0) 
        temp_diff = temp_orig - temp_mod
        
        plot_data = [temp_orig, temp_mod, temp_diff]
        titles = [f'Original\n{level_hPa} hPa', 
                 f'Modified\n{level_hPa} hPa', 
                 f'Difference\n{level_hPa} hPa']
        cmaps = ['RdYlBu_r', 'RdYlBu_r', 'coolwarm']
        vlims = [(None, None), (None, None), (-0.05, 0.05)]
        
        for col, (data, title, cmap, vlim) in enumerate(zip(plot_data, titles, cmaps, vlims)):
            ax = axes[row, col]
            
            im = data.plot(ax=ax, transform=ccrs.PlateCarree(),
                          cmap=cmap, vmin=vlim[0], vmax=vlim[1],
                          add_colorbar=False)
            
            ax.coastlines(resolution='50m', alpha=0.7, linewidth=0.5)
            ax.gridlines(alpha=0.3)
            ax.set_global()
            ax.set_title(title, fontsize=10)
        
        ds_orig.close()
        ds_mod.close()
    
    plt.tight_layout()
    plt.suptitle(f'Multi-Level GraphCast Prediction Analysis', y=0.98, fontsize=14)
    plt.show()
    
    return fig

# Usage with your existing workflow
def run_optimized_prediction_plots():
    """
    Run optimized prediction plotting using your existing variables
    """
# ds_obs = {} # dataset catalog for obs [OBS] files
# ds_gis = {} # dataset catalog for original GraphCast input (GI)
# ds_gi_mods = {} # dataset catalog for modified GraphCast input (GI)
    
    for entry in current_dd:
        short_name   = entry[index_short_name]
        long_name    = entry[index_long_name]
        gi_pred_path = entry[index_GI_PRED_file] 
        gi_obs_pred_path    = entry[index_GI_OBS_PRED_file] 
        level_of_interest   = entry[index_level_of_interest]
        level_index  = entry[index_GI_level]

        # Single level comparison
        results = plot_prediction_comparison_optimized(
            gi_pred_path=gi_pred_path,
            gi_obs_pred_path=gi_obs_pred_path,
            level_of_interest=level_of_interest,
            level_index=level_index,
            variable_name='temperature',
            figsize=(18, 5)
        )
    
    return results
    
# Usage with your existing workflow
def _run_optimized_prediction_plots():
    """
    Run optimized prediction plotting using your existing variables
    """
    
    # Single level comparison
    results = plot_prediction_comparison_optimized(
        gi_pred_path=gi_pred_path,
        gi_obs_pred_path=gi_obs_pred_path,
        level_of_interest=level_of_interest,
        level_index=level_index,
        variable_name='temperature',
        figsize=(18, 5)
    )
    
    return results

# # Execute the optimized plotting
# if __name__ == "__main__":
#     results = run_optimized_prediction_plots()
    
#     if results:
#         print("\n" + "="*50)
#         print("PREDICTION COMPARISON COMPLETE")
#         print("="*50)
#         print("Key Statistics:")
#         for key, value in results['statistics'].items():
#             print(f"  {key}: {value:.4f}")

In [None]:
    results = run_optimized_prediction_plots()
    
    if results:
        print("\n" + "="*50)
        print("PREDICTION COMPARISON COMPLETE")
        print("="*50)
        print("Key Statistics:")
        for key, value in results['statistics'].items():
            print(f"  {key}: {value:.4f}")