In [2]:
import xarray as xr
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import cartopy.feature as cfeature
import xskillscore as xs
import ipywidgets as widgets

In [3]:
import plotly.graph_objects as go

prc_m1 = xr.open_zarr(
    "gs://cmip6/CMIP6/ScenarioMIP/AS-RCEC/TaiESM1/ssp245/r1i1p1f1/Amon/prc/gn/v20201124",
    chunks={},
    consolidated=True,
)
prc_m2 = xr.open_zarr(
    "gs://cmip6/CMIP6/ScenarioMIP/IPSL/IPSL-CM6A-LR/ssp245/r1i1p1f1/Amon/prc/gr/v20190119",
    chunks={},
    consolidated=True,
)

# Convert calendars to standard
prc_m1 = prc_m1.convert_calendar("standard", use_cftime=False)
prc_m2 = prc_m2.convert_calendar("standard", use_cftime=False)

# Extract time ranges for both datasets
time1 = prc_m1.time
time2 = prc_m2.time

# Find overlapping time range
start_date = max(str(time1.min().values), str(time2.min().values))
end_date = min(str(time1.max().values), str(time2.max().values))

print(f"Overlapping time period: {start_date} to {end_date}")

# Select data within overlapping time period
prc1 = prc_m1["prc"].sel(time=slice(start_date, end_date))
prc2 = prc_m2["prc"].sel(time=slice(start_date, end_date))

# Find overlapping lat/lon range
lat_min = max(prc1.lat.min().item(), prc2.lat.min().item())
lat_max = min(prc1.lat.max().item(), prc2.lat.max().item())
lon_min = max(prc1.lon.min().item(), prc2.lon.min().item())
lon_max = min(prc1.lon.max().item(), prc2.lon.max().item())

print(f"Overlapping latitude range: {lat_min} to {lat_max}")
print(f"Overlapping longitude range: {lon_min} to {lon_max}")

# Subset data within overlapping spatial range
prc1 = prc1.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
prc2 = prc2.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))

# Interpolate prc2 to prc1 grid
prc2_interp = prc2.interp(lat=prc1.lat, lon=prc1.lon)

# Convert units from kg/m²/s to mm/day
prc1_mm_day = prc1 * 86400
prc2_mm_day = prc2_interp * 86400

# Compute squared error and RMSE per calendar month
diff_sq = (prc1_mm_day - prc2_mm_day) ** 2
mse_monthly = diff_sq.groupby("time.month").mean(dim=("time", "lat", "lon"))
rmse_monthly = np.sqrt(mse_monthly)


months = rmse_monthly["month"].values
rmse_values = rmse_monthly.values

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=months,
    y=rmse_values,
    mode='lines+markers',
    name='Monthly RMSE',
    line=dict(color='blue'),
    marker=dict(size=8)
))

fig.update_layout(
    title="Monthly RMSE of Precipitation (TaiESM1 vs IPSL-CM6A-LR)",
    xaxis_title="Month",
    yaxis_title="RMSE (mm/day)",
    xaxis=dict(tickmode='array', tickvals=list(range(1, 13)), ticktext=[str(m) for m in range(1, 13)]),
    template='plotly_white',
    width=800,
    height=500
)

fig.show()

Overlapping time period: 2015-01-16T12:00:00.000000000 to 2100-12-16T12:00:00.000000000
Overlapping latitude range: -90.0 to 90.0
Overlapping longitude range: 0.0 to 357.5


In [4]:
ssp245 = pd.read_csv("ssp245.csv")

In [5]:
ssp245_query = ssp245[
    (ssp245["grid_label"] == "gn") &
    (ssp245["member_id"] == "r1i1p1f1") &
    (ssp245["variable_id"] == "tos")
]

In [6]:
from itertools import combinations

ssp245 = pd.read_csv("ssp245.csv")
prc_df = ssp245[ssp245["variable_id"] == "prc"]

# Collect model names and their zarr URLs (prefer gn grid if available)
model_zarrs = {}
for _, row in prc_df.sort_values("grid_label").iterrows():
    if row["source_id"] not in model_zarrs:
        model_zarrs[row["source_id"]] = row["zstore"]

# Get all unique pairs and select only the first 5
pairs = list(combinations(model_zarrs.items(), 2))[:5]

fig = go.Figure()

for (name1, url1), (name2, url2) in pairs:
    try:
        ds1 = xr.open_zarr(url1, consolidated=True)
        ds2 = xr.open_zarr(url2, consolidated=True)
        ds1 = ds1.convert_calendar("standard", use_cftime=False)
        ds2 = ds2.convert_calendar("standard", use_cftime=False)

        # Overlapping time
        t1, t2 = ds1.time, ds2.time
        start = max(str(t1.min().values), str(t2.min().values))
        end = min(str(t1.max().values), str(t2.max().values))
        prc1 = ds1["prc"].sel(time=slice(start, end))
        prc2 = ds2["prc"].sel(time=slice(start, end))

        # Overlapping lat/lon
        lat_min = max(prc1.lat.min().item(), prc2.lat.min().item())
        lat_max = min(prc1.lat.max().item(), prc2.lat.max().item())
        lon_min = max(prc1.lon.min().item(), prc2.lon.min().item())
        lon_max = min(prc1.lon.max().item(), prc2.lon.max().item())
        prc1 = prc1.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
        prc2 = prc2.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
        prc2_interp = prc2.interp(lat=prc1.lat, lon=prc1.lon)

        # Convert units
        prc1_mm_day = prc1 * 86400
        prc2_mm_day = prc2_interp * 86400

        # RMSE per month
        diff_sq = (prc1_mm_day - prc2_mm_day) ** 2
        mse_monthly = diff_sq.groupby("time.month").mean(dim=("time", "lat", "lon"))
        rmse_monthly = np.sqrt(mse_monthly)
        months = rmse_monthly["month"].values
        rmse_values = rmse_monthly.values

        fig.add_trace(go.Scatter(
            x=months,
            y=rmse_values,
            mode='lines+markers',
            name=f"{name1} vs {name2}",
            marker=dict(size=6)
        ))
    except Exception as e:
        print(f"Skipping {name1} vs {name2}: {e}")

fig.update_layout(
    title="Monthly RMSE of Precipitation (5 Model Pairs)",
    xaxis_title="Month",
    yaxis_title="RMSE (mm/day)",
    xaxis=dict(tickmode='array', tickvals=list(range(1, 13)), ticktext=[str(m) for m in range(1, 13)]),
    template='plotly_white',
    width=900,
    height=600
)

fig.show()

In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Widget for region and variable selection for RMSE plot

def get_region_slices(region, lat):
    if region == "Global":
        return slice(-90, 90)
    elif region == "Tropics":
        return slice(-23.5, 23.5)
    elif region == "Northern Hemisphere":
        return slice(0, 90)
    elif region == "Southern Hemisphere":
        return slice(-90, 0)
    else:
        return slice(-90, 90)

def available_variables(ds1, ds2):
    # Return intersection of variable names in both datasets
    return list(set(ds1.data_vars).intersection(ds2.data_vars))

def plot_rmse_widget(model_zarrs):

    model_names = list(model_zarrs.keys())
    region_options = ["Global", "Tropics", "Northern Hemisphere", "Southern Hemisphere"]

    model1_dd = widgets.Dropdown(options=model_names, value=model_names[0], description="Model 1")
    model2_dd = widgets.Dropdown(options=model_names, value=model_names[1], description="Model 2")
    region_dd = widgets.Dropdown(options=region_options, value="Global", description="Region")
    variable_dd = widgets.Dropdown(options=[], description="Variable")

    output = widgets.Output()

    def update_variables(*args):
        try:
            ds1 = xr.open_zarr(model_zarrs[model1_dd.value], consolidated=True)
            ds2 = xr.open_zarr(model_zarrs[model2_dd.value], consolidated=True)
            vars_avail = available_variables(ds1, ds2)
            variable_dd.options = vars_avail
            if vars_avail:
                variable_dd.value = vars_avail[0]
        except Exception:
            variable_dd.options = []
            variable_dd.value = None

    model1_dd.observe(update_variables, names='value')
    model2_dd.observe(update_variables, names='value')

    def plot_callback(*args):
        output.clear_output()
        with output:
            try:
                ds1 = xr.open_zarr(model_zarrs[model1_dd.value], consolidated=True)
                ds2 = xr.open_zarr(model_zarrs[model2_dd.value], consolidated=True)
                ds1 = ds1.convert_calendar("standard", use_cftime=False)
                ds2 = ds2.convert_calendar("standard", use_cftime=False)
                var = variable_dd.value

                # Overlapping time
                t1, t2 = ds1.time, ds2.time
                start = max(str(t1.min().values), str(t2.min().values))
                end = min(str(t1.max().values), str(t2.max().values))
                da1 = ds1[var].sel(time=slice(start, end))
                da2 = ds2[var].sel(time=slice(start, end))

                # Overlapping lat/lon
                lat_min = max(da1.lat.min().item(), da2.lat.min().item())
                lat_max = min(da1.lat.max().item(), da2.lat.max().item())
                lon_min = max(da1.lon.min().item(), da2.lon.min().item())
                lon_max = min(da1.lon.max().item(), da2.lon.max().item())
                da1 = da1.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
                da2 = da2.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
                da2_interp = da2.interp(lat=da1.lat, lon=da1.lon)

                # Region selection
                region_slice = get_region_slices(region_dd.value, da1.lat)
                da1_reg = da1.sel(lat=region_slice)
                da2_reg = da2_interp.sel(lat=region_slice)

                # Convert units for precipitation
                if var == "prc":
                    da1_reg = da1_reg * 86400
                    da2_reg = da2_reg * 86400

                # RMSE per month
                diff_sq = (da1_reg - da2_reg) ** 2
                mse_monthly = diff_sq.groupby("time.month").mean(dim=("time", "lat", "lon"))
                rmse_monthly = np.sqrt(mse_monthly)
                months = rmse_monthly["month"].values
                rmse_values = rmse_monthly.values

                fig = go.Figure()
                fig.add_trace(go.Scatter(
                    x=months,
                    y=rmse_values,
                    mode='lines+markers',
                    name=f"{model1_dd.value} vs {model2_dd.value}",
                    marker=dict(size=8)
                ))

                fig.update_layout(
                    title=f"Monthly RMSE of {var.upper()} ({model1_dd.value} vs {model2_dd.value})<br>Region: {region_dd.value}",
                    xaxis_title="Month",
                    yaxis_title=f"RMSE ({'mm/day' if var == 'prc' else ''})",
                    xaxis=dict(tickmode='array', tickvals=list(range(1, 13)), ticktext=[str(m) for m in range(1, 13)]),
                    template='plotly_white',
                    width=800,
                    height=500
                )
                fig.show()
            except Exception as e:
                print(f"Error: {e}")
                

    variable_dd.observe(plot_callback, names='value')
    region_dd.observe(plot_callback, names='value')
    model1_dd.observe(plot_callback, names='value')
    model2_dd.observe(plot_callback, names='value')

    update_variables()
    plot_callback()

    display(widgets.VBox([widgets.HBox([model1_dd, model2_dd]), variable_dd, region_dd, output]))

plot_rmse_widget(model_zarrs)

VBox(children=(HBox(children=(Dropdown(description='Model 1', options=('TaiESM1', 'MPI-ESM1-2-LR', 'MIROC6', '…