# Unified MDTF/GFDL/NCAR Analysis Notebook Template

More details on the development process:
[MDTF Planning Document](https://docs.google.com/document/d/1P8HqL8O5304qwR3ik9RmgFDwSWwlkPgOjnp39PIkLfY/edit?usp=sharing)

In [1]:
# Development mode: constantly refreshes module code
%load_ext autoreload
%autoreload 2

## Framework Code and Diagnostic Setup

In [2]:
from esnb import NotebookDiagnostic, RequestedVariable, CaseGroup2
from esnb.sites.gfdl import call_dmget

Matplotlib is building the font cache; this may take a moment.


In [3]:
%%time

# Define a mode (leave "prod" for now)
mode = "prod"

# Verbosity
verbose = True

cosp_on = True
# Give your diagnostic a name and a short description
diag_name = "clouds"
diag_desc = "cloud diganostics using model outputs and cosp simulator"

# Define what variables you would like to analyze. The first entry is the
# variable name and the second entry is the realm (post-processing dir).
#   (By default, monthly timeseries data will be loaded. TODO: add documentation
#    on how to select different frequencies, multiple realms to search, etc.)
variables = [
    RequestedVariable("IWP", "atmos"),
    RequestedVariable("LWP", "atmos"),
    RequestedVariable("low_cld_amt", "atmos"),
    RequestedVariable("high_cld_amt", "atmos"),
    RequestedVariable("mid_cld_amt", "atmos"),
    RequestedVariable("olr", "atmos"),
    RequestedVariable("olr_clr", "atmos"),
    RequestedVariable("lwdn_sfc", "atmos"),
    RequestedVariable("lwdn_sfc_clr", "atmos"),
    RequestedVariable("swup_toa", "atmos"),
    RequestedVariable("swup_toa_clr", "atmos"),
    RequestedVariable("swdn_sfc", "atmos"),
    RequestedVariable("swdn_sfc_clr", "atmos"),
    RequestedVariable("swup_sfc", "atmos"),
    RequestedVariable("swup_sfc_clr", "atmos"),
    #RequestedVariable("reff_modis", "atmos"),
]

if cosp_on:
    variables.extend([
        #RequestedVariable("tauctpmodis_1", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_2", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_3", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_4", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_5", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_6", ["atmos_modis", "atmos_cosp"]),
        #RequestedVariable("tauctpmodis_7", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ctpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("lwpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("iwpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("locldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("hicldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("mdcldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("lremodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("iremodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ttaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ltaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("itaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("tlogtaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("llogtaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ilogtaumodis", ["atmos_modis", "atmos_cosp"]),
    ])


# Initialize the diagnostic with its name, description, vars, and options
diag = NotebookDiagnostic(diag_name, diag_desc, variables=variables)

# Define the groups of experiments to analyze. Provide a single dora id for one experiment
# or a list of IDs to aggregate multiple experiments into one; e.g. historical+future runs
# dora id: cm5-9
groups = [
    CaseGroup2("3302", name="am5d11_amip_cosp_sphere", date_range=("2001-01-01", "2020-12-31")),
    CaseGroup2("3296", name="am5d11_amip_cosp", date_range=("2001-01-01", "2020-12-31")),
    CaseGroup2("3101", name="am4mg2_cosp", date_range=("2001-01-01", "2020-12-31")),
    #CaseGroup2("2916",date_range=("0101-01-01", "0149-12-31")),
    #CaseGroup2("3031",date_range=("0101-01-01", "0149-12-31")),
    #CaseGroup2("2198",date_range=("0101-01-01", "0149-12-31")),
    #CaseGroup2("esm45-109",date_range=("0101-01-01", "0149-12-31")),
    #CaseGroup2("895",date_range=("0101-01-01",  "0149-12-31")),
]
shorttitle = ["AM5_D11_sphere","AM5_D11","AM4MG2"]#"OM4_D5", "B11_D6", "CM4X", "B11_ESM4.5", "CM4"]

ref_id = 2 # the experiment that used as the control run to be compared with

# Combine the experiments with the diag request and determine what files need to be loaded:
diag.resolve(groups)

CPU times: user 32.6 s, sys: 315 ms, total: 32.9 s
Wall time: 1min 12s


<i>(The files above are necessary to run the diagnostic.)</i>

In [4]:
# Check to see the dmget status before calling "open"
call_dmget(diag.files,status=True)
# Load the data as xarray datasets
diag.open()
print(groups)

dmget: All files are online
[CaseGroup am5d11_amip_cosp_sphere <>  resolved=True  loaded=True, CaseGroup am5d11_amip_cosp <>  resolved=True  loaded=True, CaseGroup am4mg2_cosp <>  resolved=True  loaded=True]


In [5]:
import matplotlib.pyplot as plt
import cartopy.crs 
import numpy as np
from numpy import ma

In [6]:
for var in variables:
    name = var.varname
    if name.endswith("_clr"): # process clear-sky variables
        base_name = name[:-4]  # strip "_clr"
        # only if the clear-sky partner has a corresponding all-sky variable
        if any(v.varname == base_name for v in variables):
            cld_name = f"{base_name}_cld"
            # compute difference and save into each group's dataset
            for group in groups:
                allsky = group.datasets[RequestedVariable(base_name, "atmos")][base_name]
                clearsky = group.datasets[var][var.varname]
                new_var = RequestedVariable(cld_name, "atmos")
                # assign
                group.datasets[RequestedVariable(base_name, "atmos")][cld_name] = allsky - clearsky
if cosp_on:
    variables_plot = [
    RequestedVariable("IWP", "atmos"),
    RequestedVariable("LWP", "atmos"),
    RequestedVariable("low_cld_amt", "atmos"),
    RequestedVariable("high_cld_amt", "atmos"),
    RequestedVariable("mid_cld_amt", "atmos"),
    RequestedVariable("olr", "atmos"),
    RequestedVariable("lwdn_sfc", "atmos"),
    RequestedVariable("swup_toa", "atmos"),
    RequestedVariable("swdn_sfc", "atmos"),
    RequestedVariable("ctpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("lwpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("iwpmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("locldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("hicldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("mdcldmodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("lremodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("iremodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ttaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ltaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("itaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("tlogtaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("llogtaumodis", ["atmos_modis", "atmos_cosp"]),
        RequestedVariable("ilogtaumodis", ["atmos_modis", "atmos_cosp"]),]
else:
    variables_plot = [
    RequestedVariable("IWP", "atmos"),
    RequestedVariable("LWP", "atmos"),
    RequestedVariable("low_cld_amt", "atmos"),
    RequestedVariable("high_cld_amt", "atmos"),
    RequestedVariable("mid_cld_amt", "atmos"),
    RequestedVariable("olr", "atmos"),
    RequestedVariable("lwdn_sfc", "atmos"),
    RequestedVariable("swup_toa", "atmos"),
    RequestedVariable("swdn_sfc", "atmos"),
    #RequestedVariable("swup_sfc", "atmos"),
]
var_rad = [RequestedVariable("olr", "atmos"),
    RequestedVariable("lwdn_sfc", "atmos"),
    RequestedVariable("swup_toa", "atmos"),
    RequestedVariable("swdn_sfc", "atmos"),
    #RequestedVariable("swup_sfc", "atmos"),
]


# Always included variables
vars = ["LWP", "IWP", "low_cld_amt", "high_cld_amt", "mid_cld_amt"]
longnames = [
    "Liquid Water Path",
    "Ice Water Path",
    "Low-Level Cloud Fraction",
    "High-Level Cloud Fraction",
    "Mid-Level Cloud Fraction",
]

# Add MODIS-related variables only if cosp_on
if cosp_on:
    vars.extend([
        "locldmodis", "hicldmodis", "mdcldmodis", "cldfrac", "ctpmodis", "reff_modis",
        "lremodis", "iremodis", "ttaumodis", "ltaumodis", "itaumodis",
        "tlogtaumodis", "llogtaumodis", "ilogtaumodis"
    ])
    
    longnames.extend([
        "COSPMODIS Low-Level Cloud Fraction",
        "COSPMODIS High-Level Cloud Fraction",
        "COSPMODIS Mid-Level Cloud Fraction",
        "COSPMODIS Total Cloud Fraction",
        "COSPMODIS Cloud Top Pressure",
        "MODIS Effective Radius",
        "COSPMODIS Liquid Effective Radius",
        "COSPMODIS Ice Effective Radius",
        "COSPMODIS Total Cloud Optical Depth",
        "COSPMODIS Liquid Cloud Optical Depth",
        "COSPMODIS Ice Cloud Optical Depth",
        "COSPMODIS Log of Total Cloud Optical Depth",
        "COSPMODIS Log of Liquid Cloud Optical Depth",
        "COSPMODIS Log of Ice Cloud Optical Depth"
    ])

In [7]:
import xarray as xr
import xarray as xr
import xesmf as xe

# Open dataset
ds = xr.open_dataset("/work/rjk/data/CERES/EBAF-All/CERES_EBAF_Ed4.2.1_Subset_200003-202412.nc")
ds_ceres = xr.Dataset({
        'lat': (['lat'], groups[0].datasets[diag.variables[0]]['lat'].values),
        'lon': (['lon'], groups[0].datasets[diag.variables[0]]['lon'].values),
    })
regridder = xe.Regridder(ds, ds_ceres, method='bilinear', periodic=True)
ds_ceres["olr_cld"]=-regridder(ds["toa_cre_lw_mon"])
ds_ceres["swup_toa_cld"]=-regridder(ds["toa_cre_sw_mon"])
ds_ceres["lwdn_sfc_cld"]=regridder(ds["sfc_cre_net_lw_mon"])
ds_ceres["swdn_sfc_cld"]=regridder(ds["sfc_sw_down_all_mon"]) - regridder(ds["sfc_sw_down_clr_t_mon"])


In [10]:
import math

# global-mean seasonal cycle
shorttitle = ["AM5_D11_sphere","AM5_D11","AM4MG2"]

panel_colors = ["orange", "green", "grey", "blue","black"]
from cosp_lib import cal_gbl_mean, monthly_mean

n_vars = len(variables_plot)
n_cols = math.ceil(math.sqrt(n_vars))
n_rows = math.ceil(n_vars / n_cols)

fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 12))
for i, var in enumerate(variables_plot):
    fig.subplots_adjust(top=0.85, bottom=0.15, left=0.1, right=0.95)
    ax=axs.flat[i]
    if var in var_rad:
        varname = f"{var}_cld"
    else:
        varname = var.varname
        
    for n,group in enumerate(groups):
        ds = group.datasets[var]
        monthly_clim = monthly_mean(ds, varname)
        ax.plot(np.arange(1, 13), monthly_clim.values, label=shorttitle[n],color=panel_colors[n])
        std_dev = monthly_clim.std(dim='month').values
        group.add_metric(var.varname, ("seasonal std", float(std_dev)))
        
        ax.set_title(f"{varname} [{ds[var.varname].units}]", fontsize=16)
    ax.set_xlabel("Month")
    ax.set_ylabel(f"{varname} [{ds[var.varname].units}]")
    ax.legend()
    ax.grid(True)
plt.show()
plt.tight_layout()
plt.savefig("seasonal.png")



  plt.show()


In [None]:
# zonal-mean

from cosp_lib import zonal_mean
fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 12))
for i, var in enumerate(variables_plot):
    fig.subplots_adjust(top=0.85, bottom=0.15, left=0.1, right=0.95)
    ax=axs.flat[i]
    if var in var_rad:
        varname = f"{var}_cld"
    else:
        varname = var.varname
        
    for n,group in enumerate(groups):
        ds = group.datasets[var]
        zonal_clim = zonal_mean(ds, varname)
        ax.plot(zonal_clim.values,ds['lat'], label=shorttitle[n],color=panel_colors[n])
        std_dev = zonal_clim.std(dim='lat').values
        group.add_metric(var.varname, ("zonal std", float(std_dev)))
        
        ax.set_title(f"{varname} [{ds[var.varname].units}]", fontsize=16)
    ax.set_ylabel("Latitude")
    ax.set_xlabel(f"{varname} [{ds[var.varname].units}]")
    ax.legend()
    ax.grid(True)
plt.show()
plt.tight_layout()
plt.savefig("zonal.pdf")



In [None]:
# global-mean time series

for var in variables_plot:
    if var in var_rad:
        varname = f"{var}_cld"
    else:
        varname = var.varname
    fig, ax = plt.subplots(figsize=(8, 5))
    fig.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.95)

    for group in groups:
        ds = group.datasets[var]
        gbl_mean_time = cal_gbl_mean(ds, varname)
        rolled = gbl_mean_time.rolling(time=12, center=True, min_periods=1).construct('window_dim').mean('window_dim')
        ax.plot(ds['time'], rolled.values, label=group)

    ax.set_title(f"Global Mean Time Series, {varname}, {ds[var.varname].units}", fontsize=16)
    ax.set_xlabel("Time")
    ax.set_ylabel(f"{varname} [{ds[var.varname].units}]")
    ax.legend()
    ax.grid(True)
    plt.show()
    plt.savefig(f"{varname}.ts.png")



In [None]:
import matplotlib.pyplot as plt
from matplotlib.path import Path
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import warnings

# spatial map differences to a control-run (global, arctic, antarctic)

warnings.filterwarnings("ignore", category=UserWarning, module="cartopy.io")

views = {
    "global": {
        "proj": ccrs.Robinson(central_longitude=-180),
        "extent": None
    },
    "arctic": {
        "proj": ccrs.NorthPolarStereo(),
        "extent": [-180, 180, 60, 90]
    },
    "antarctic": {
        "proj": ccrs.SouthPolarStereo(),
        "extent": [-180, 180, -90, -60]
    }
}

coastline = cfeature.NaturalEarthFeature(
    category='physical', name='coastline',
    scale='110m', facecolor='none')

for var in variables_plot:

    # Reference data and grid
    if var in var_rad:
        varname = f"{var}_cld"
    else:
        varname = var.varname
        
    ref_ds = groups[ref_id].datasets[var]
    ref_data = ref_ds[varname].mean(dim='time')
    lon_ref = ref_ds['lon']
    lat_ref = ref_ds['lat']

    # Precompute all differences on the reference grid once
    all_diffs = []
    for group in groups[:-1]:
        model_ds = group.datasets[var]
        model_data = model_ds[varname].mean(dim='time')
        try:
            # Regrid to reference grid
            model_interp = model_data.interp(lon=lon_ref, lat=lat_ref)
            diff = model_interp - ref_data
            all_diffs.append(diff)
        except Exception as e:
            print(f"[Error] Interpolation failed for {group.name}: {e}")
            all_diffs.append(None)

    # Prepare color scale based on all diffs
    valid_diffs = [d for d in all_diffs if d is not None and np.any(np.isfinite(d.values))]
    if not valid_diffs:
        print(f"[Warning] No valid data for {var.varname}. Skipping plots.")
        continue

    all_values = np.concatenate([d.values.flatten() for d in valid_diffs])
    if np.percentile(all_values,95) * np.percentile(all_values,5) <0:
        vmax = np.percentile(abs(all_values),95)
        vmin = -vmax
    else:
        vmax = max([0,np.percentile(all_values,95)])
        vmin = min([0,np.percentile(all_values,5)])
        
    cmap = 'RdBu_r'

    for view_name, view_cfg in views.items():
        fig, axs = plt.subplots(2, 2, figsize=(14, 8),
                                subplot_kw={'projection': view_cfg['proj']})
        fig.subplots_adjust(wspace=0.05, hspace=0.15, top=0.9, bottom=0.1)
    
        for n, diff in enumerate(all_diffs):
            row, col = divmod(n, 2)
            ax = axs[row, col]
    
            if diff is None or not np.any(np.isfinite(diff.values)):
                ax.set_visible(False)
                continue

            lat_weights = np.cos(np.deg2rad(model_ds['lat']))
            lat_weights.name = "weights"
            # Expand lat_weights to 2D (lat, lon) by broadcasting
            weights_2d = xr.broadcast(lat_weights, model_ds['lon'])[0]
            # Apply weights and compute mean

            mesh = ax.pcolormesh(lon_ref, lat_ref, diff,
                                 vmin=vmin, vmax=vmax,
                                 transform=ccrs.PlateCarree(),
                                 cmap=cmap, shading="nearest")
            ax.set_title(f"({chr(97 + n)}) {shorttitle[n]} – {shorttitle[ref_id]}", fontsize=11)
    
            ax.coastlines()
            if view_cfg['extent']:
                ax.set_extent(view_cfg['extent'], crs=ccrs.PlateCarree())
            else:
                ax.set_global()
    
            if isinstance(view_cfg['proj'], (ccrs.NorthPolarStereo, ccrs.SouthPolarStereo)):
                # Add circular boundary to polar plots
                theta = np.linspace(0, 2 * np.pi, 100)
                center = np.array([0.5, 0.5])
                radius = 0.5
                verts = np.vstack([np.sin(theta), np.cos(theta)]).T * radius + center
                circle_path = Path(verts)
                ax.set_boundary(circle_path, transform=ax.transAxes)
            if view_cfg['extent']:
                lon_min, lon_max, lat_min, lat_max = view_cfg['extent']
                # Use diff’s coordinates, not lon_ref/lat_ref
                lat_grid, lon_grid = xr.broadcast(diff['lat'], diff['lon'])
                spatial_mask = ((lat_grid >= lat_min) & (lat_grid <= lat_max) &
                                (lon_grid >= lon_min) & (lon_grid <= lon_max))
                diff_masked = diff.where(spatial_mask)
                weights_masked = weights_2d.where(spatial_mask)
                print("masked weighting")
            else:
                diff_masked = diff
                weights_masked = weights_2d

            bias = diff_masked.weighted(weights_masked).mean(dim=['lat', 'lon'], skipna=True)
            rmse = ((diff_masked**2).weighted(weights_masked).mean(dim=['lat', 'lon'], skipna=True))**0.5

            ax.text(0.95,0.05,f"Bias = {bias.values:.2f}\nRMSE={rmse.values:.2f}",transform=ax.transAxes,fontsize=9,va="bottom",ha="right",bbox=dict(boxstyle="round,pad=0.3",facecolor="white",alpha=0.7))
  
        # Shared colorbar at bottom
        cbar_ax = fig.add_axes([0.3, 0.06, 0.4, 0.02])
        cbar = fig.colorbar(mesh, cax=cbar_ax, orientation='horizontal')
        cbar.ax.tick_params(labelsize=9)

        fig.suptitle(f"{varname} [{model_ds[var.varname].units}]:Differences Relative to {shorttitle[ref_id]}, {view_name.capitalize()} View", fontsize=14)
        plt.show()
        plt.savefig(f"{varname}.{view_name}.png")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.path import Path
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import warnings

# spatial map (global, arctic, antarctic)

warnings.filterwarnings("ignore", category=UserWarning, module="cartopy.io")

views = {
    "global": {
        "proj": ccrs.Robinson(central_longitude=-180),
        "extent": None
    },
    "arctic": {
        "proj": ccrs.NorthPolarStereo(),
        "extent": [-180, 180, 60, 90]
    },
    "antarctic": {
        "proj": ccrs.SouthPolarStereo(),
        "extent": [-180, 180, -90, -60]
    }
}

coastline = cfeature.NaturalEarthFeature(
    category='physical', name='coastline',
    scale='110m', facecolor='none')


for var in variables_plot:

    # Reference data and grid
    if var in var_rad:
        varname = f"{var}_cld"
    else:
        continue
        
    ref_ds = groups[ref_id].datasets[var]
    ref_data = ds_ceres[varname].mean(dim='time')
    lon_ref = ref_ds['lon']
    lat_ref = ref_ds['lat']

    # Precompute all differences on the reference grid once
    all_diffs = []
    for group in groups:
        model_ds = group.datasets[var]
        model_data = model_ds[varname].mean(dim='time')
        try:
            # Regrid to reference grid
            model_interp = model_data.interp(lon=lon_ref, lat=lat_ref)
            diff = model_interp - ref_data
            all_diffs.append(diff)
        except Exception as e:
            print(f"[Error] Interpolation failed for {group.name}: {e}")
            all_diffs.append(None)

    # Prepare color scale based on all diffs
    valid_diffs = [d for d in all_diffs if d is not None and np.any(np.isfinite(d.values))]
    if not valid_diffs:
        print(f"[Warning] No valid data for {var.varname}. Skipping plots.")
        continue

    all_values = np.concatenate([d.values.flatten() for d in valid_diffs])
    if np.percentile(all_values,95) * np.percentile(all_values,5) <0:
        vmax = np.percentile(abs(all_values),95)
        vmin = -vmax
    else:
        vmax = max([0,np.percentile(all_values,95)])
        vmin = min([0,np.percentile(all_values,5)])
        
    cmap = 'RdBu_r'

    for view_name, view_cfg in views.items():
        fig, axs = plt.subplots(2,3, figsize=(14, 8),
                                subplot_kw={'projection': view_cfg['proj']})
        fig.subplots_adjust(wspace=0.05, hspace=0.15, top=0.9, bottom=0.1)
    
        for n, diff in enumerate(all_diffs):
            row, col = divmod(n, 3)
            ax = axs[row, col]
    
            if diff is None or not np.any(np.isfinite(diff.values)):
                ax.set_visible(False)
                continue

            lat_weights = np.cos(np.deg2rad(model_ds['lat']))
            lat_weights.name = "weights"
            # Expand lat_weights to 2D (lat, lon) by broadcasting
            weights_2d = xr.broadcast(lat_weights, model_ds['lon'])[0]
            # Apply weights and compute mean

            mesh = ax.pcolormesh(lon_ref, lat_ref, diff,
                                 vmin=vmin, vmax=vmax,
                                 transform=ccrs.PlateCarree(),
                                 cmap=cmap, shading="nearest")
            ax.set_title(f"({chr(97 + n)}) {shorttitle[n]} – CERES EBAF", fontsize=11)
    
            ax.coastlines()
            if view_cfg['extent']:
                ax.set_extent(view_cfg['extent'], crs=ccrs.PlateCarree())
            else:
                ax.set_global()
    
            if isinstance(view_cfg['proj'], (ccrs.NorthPolarStereo, ccrs.SouthPolarStereo)):
                # Add circular boundary to polar plots
                theta = np.linspace(0, 2 * np.pi, 100)
                center = np.array([0.5, 0.5])
                radius = 0.5
                verts = np.vstack([np.sin(theta), np.cos(theta)]).T * radius + center
                circle_path = Path(verts)
                ax.set_boundary(circle_path, transform=ax.transAxes)
            if view_cfg['extent']:
                lon_min, lon_max, lat_min, lat_max = view_cfg['extent']
                lat_grid, lon_grid = xr.broadcast(lat_ref, lon_ref)
                spatial_mask = ((lat_grid >= lat_min) & (lat_grid <= lat_max) &
                               (lon_grid >= lon_min) & (lon_grid <= lon_max))
                diff_masked = diff.where(spatial_mask)
                weights_masked = weights_2d.where(spatial_mask)
                weights_masked = weights_masked.fillna(0)
            else:
                diff_masked = diff
                weights_masked = weights_2d

            bias = diff_masked.weighted(weights_masked).mean(dim=['lat', 'lon'], skipna=True)
            rmse = ((diff_masked**2).weighted(weights_masked).mean(dim=['lat', 'lon'], skipna=True))**0.5

            ax.text(0.95,0.05,f"Bias = {bias.values:.2f}\nRMSE={rmse.values:.2f}",transform=ax.transAxes,fontsize=9,va="bottom",ha="right",bbox=dict(boxstyle="round,pad=0.3",facecolor="white",alpha=0.7))
  
        # Shared colorbar at bottom
        cbar_ax = fig.add_axes([0.3, 0.06, 0.4, 0.02])
        cbar = fig.colorbar(mesh, cax=cbar_ax, orientation='horizontal')
        cbar.ax.tick_params(labelsize=9)

        fig.suptitle(f"{varname} [{model_ds[var.varname].units}]:Differences Relative to CERES EBAF v4.2.1, {view_name.capitalize()} View", fontsize=14)
        plt.tight_layout()
        plt.show()
        plt.savefig(f"{varname}.{view_name}.obs.png")


for var in variables:
    fig, ax = plt.subplots(figsize=(8, 5))
    fig.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.95)

    for group in groups:
        ds = group.datasets[var]
        ax.plot(ds['lat'], ds[var.varname].mean(dim='time').mean(dim='lon'), label=ds.title)
        
    ax.set_title(f"Zonal Mean, {var.varname} [{ds[var.varname].units}]", fontsize=16)
    ax.set_xlabel("Latitude")
    ax.set_ylabel(f"{var} [{ds[var.varname].units}]")
    ax.legend()
    ax.grid(True)
    plt.show()



In [None]:
# Read Modis Data
root = '/archive/Huan.Guo/work/obs/CFMIP-OBS/MODIS/MCD06COSP_D3_MODIS/'
#fname = f'{root}/MCD06COSP_D3_MODIS_2008_Ann.nc'
fname = f'{root}/MCD06COSP_D3_MODIS_2002To2023_Ann.nc'

MODIS_grps = xr.open_dataset(fname)
for group in MODIS_grps:
    print(group)

In [None]:
import xarray as xr

ds = xr.open_dataset("/work/rjk/data/CERES/EBAF-All/CERES_EBAF_Ed4.2.1_Subset_200003-202412.nc")



In [None]:
# observation data regrid

import xarray as xr
from cosp_lib import cal_gbl_mean
import xesmf as xe
# load observations, to be added
cldfrac_obs_raw = MODIS_tau_grp['JHisto_vs_Cloud_Top_Pressure'] \
    .sum(dim=["jhisto_cloud_optical_thickness_total_7", "jhisto_cloud_top_pressure_7"]) \
    .squeeze()

for group in diag.groups:
    target_grid = xr.Dataset({
        'lat': (['lat'], group.datasets[diag.variables[0]]['lat'].values),
        'lon': (['lon'], group.datasets[diag.variables[0]]['lon'].values),
    })
    regridder = xe.Regridder(cldfrac_obs_raw, target_grid, method='bilinear', periodic=True)

    var_names = [v.name for v in group.datasets.keys()]
    print(var_names)
    
    cldfrac_sum = sum([group.datasets[v] for v in var_names if v.startswith('tauctpmodis_')])
    # Assign to 'cldfrac' and process
    group.datasets['cldfrac'] = cldfrac_sum.sum(dim='modistauindx').squeeze()

    group.datasets['cldfrac'].attrs['units'] = '%'
    group.datasets['cldfrac_obs'] = regridder(cldfrac_obs_raw)
    for var in vars:
        #group.group.datasets[0][var] = group.group.datasets[0][var] .where(group.group.datasets[0][var] != 0)
        # Unit conversion
        units = group.datasets[var].attrs.get("units", "")
        if units == "m":
            group.datasets[var].values = group.datasets[var].values * 1e6  # use 1e6, not 10^6
            group.datasets[var].attrs["units"] = "micron"
        elif units == "Pa":
            group.datasets[var].values = group.datasets[var].values / 100.0
            group.datasets[var].attrs["units"] = "hPa"

In [None]:
# observation data regrid

import xarray as xr
from cosp_lib import cal_gbl_mean
import xesmf as xe
# load observations, to be added
cldfrac_obs_raw = MODIS_tau_grp['JHisto_vs_Cloud_Top_Pressure'] \
    .sum(dim=["jhisto_cloud_optical_thickness_total_7", "jhisto_cloud_top_pressure_7"]) \
    .squeeze()

for group in diag.groups:
    target_grid = xr.Dataset({
        'lat': (['lat'], group.datasets[diag.variables[0]]['lat'].values),
        'lon': (['lon'], group.datasets[diag.variables[0]]['lon'].values),
    })
    regridder = xe.Regridder(cldfrac_obs_raw, target_grid, method='bilinear', periodic=True)

    var_names = [v.name for v in group.datasets.keys()]
    print(var_names)
    
    cldfrac_sum = sum([group.datasets[v] for v in var_names if v.startswith('tauctpmodis_')])
    # Assign to 'cldfrac' and process
    group.datasets['cldfrac'] = cldfrac_sum.sum(dim='modistauindx').squeeze()

    group.datasets['cldfrac'].attrs['units'] = '%'
    group.datasets['cldfrac_obs'] = regridder(cldfrac_obs_raw)
    for var in vars:
        #group.group.datasets[0][var] = group.group.datasets[0][var] .where(group.group.datasets[0][var] != 0)
        # Unit conversion
        units = group.datasets[var].attrs.get("units", "")
        if units == "m":
            group.datasets[var].values = group.datasets[var].values * 1e6  # use 1e6, not 10^6
            group.datasets[var].attrs["units"] = "micron"
        elif units == "Pa":
            group.datasets[var].values = group.datasets[var].values / 100.0
            group.datasets[var].attrs["units"] = "hPa"

In [None]:
#To be added: comparisons with observations

In [None]:
diag.write_metrics()