In [None]:
import pandas as pd
import xarray as xr
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt

# ----------------------------
# Rolling mean preprocessing
# ----------------------------
def process_experiment(file_path, time_range):
    data = xr.open_dataset(file_path, decode_times=False).rx1day.isel(time=slice(None, len(time_range)))
    data = data.assign_coords(time=time_range).rename({'time': 'year'})
    data = data.rolling(year=30, center=True, min_periods=1).mean()
    return data

# Generate time ranges
MStimeA = pd.date_range("1750-01-01", freq="AS", periods=341)
MStimeD = pd.date_range("1750-01-01", freq="AS", periods=451)

# ----------------------------
# Load and process datasets
# ----------------------------
datasets = {
    "prA": {"rx1day": process_experiment('rx1day_cdr_A1p_01_03_0001-0341.nc', MStimeA)},
    "prB": {"rx1day": process_experiment('rx1day_cdr_B1p_01_03_0001-0341.nc', MStimeA)},
    "prC": {"rx1day": process_experiment('rx1day_cdr_C1p_01_03_0001-0341.nc', MStimeA)},
    "prD": {"rx1day": process_experiment('rx1day_cdr_D1p_01_03_0001-0451.nc', MStimeD)}
}

# ----------------------------
# Region definitions
# ----------------------------
regions = {
    "East Asia": {"lat_min": 30, "lat_max": 55, "lon_min": 110, "lon_max": 145},
    "North Africa": {"lat_min": 5, "lat_max": 22, "lon_min": 0, "lon_max": 40, "lon_min_alt": 345, "lon_max_alt": 360},
    "South Asia": {"lat_min": 5, "lat_max": 28, "lon_min": 62, "lon_max": 110},
    "Central America": {"lat_min": 4, "lat_max": 28, "lon_min": 250, "lon_max": 300},
    "South Africa": {"lat_min": -25, "lat_max": -4, "lon_min": 10, "lon_max": 40},
    "Australia": {"lat_min": -26, "lat_max": -8, "lon_min": 120, "lon_max": 150},
    "South America": {"lat_min": -30, "lat_max": -6, "lon_min": 280, "lon_max": 325}
}

# ----------------------------
# Time slices for P1 and P3
# ----------------------------
experiment_slices = {
    "prA": [(0, 30), (58, 88)],
    "prB": [(0, 30), (71, 102)],
    "prC": [(0, 30), (110, 140)],
    "prD": [(0, 30), (250, 280)],
}

# ----------------------------
# Apply slicing function
# ----------------------------
def apply_slices(data, slices):
    return [data.isel(year=slice(start, end)) for start, end in slices]

# ----------------------------
# Slice and store regional rx1day data
# ----------------------------
sliced_data_results = {}
for exp, season_data in datasets.items():
    sliced_data_results[exp] = {}
    for region_name, coords in regions.items():
        sliced_data_results[exp][region_name] = {}
        for season_name, data in season_data.items():
            region_data = data.where(
                (data.lat >= coords["lat_min"]) & (data.lat <= coords["lat_max"]) &
                (data.lon >= coords.get("lon_min", 0)) & (data.lon <= coords.get("lon_max", 360))
            ).mean(dim=["lat", "lon"])
            slices = apply_slices(region_data, experiment_slices[exp])
            sliced_data_results[exp][region_name][season_name] = slices

# ----------------------------
# Calculate error bars
# ----------------------------
def calculate_error_bars(data1, data3):
    mean1 = np.mean(data1)
    mean3 = np.mean(data3)
    sd1 = np.std(data1, ddof=1)
    sd3 = np.std(data3, ddof=1)
    se1 = sd1 / np.sqrt(len(data1))
    se3 = sd3 / np.sqrt(len(data3))
    se_diff = np.sqrt(se1**2 + se3**2)
    mean_diff = mean3 - mean1
    z_score = stats.norm.ppf(0.975)  # 90% confidence interval
    ci_lower = mean_diff - z_score * se_diff
    ci_upper = mean_diff + z_score * se_diff
    return mean_diff, ci_lower, ci_upper

# ----------------------------
# Store results for plotting
# ----------------------------
plot_data = {'rx1day': {}}
for exp, region_data in sliced_data_results.items():
    for region, season_data in region_data.items():
        for season, (slice1, slice3) in season_data.items():
            data1 = slice1.values
            data3 = slice3.values
            mean_diff, ci_lower, ci_upper = calculate_error_bars(data1, data3)
            if region not in plot_data[season]:
                plot_data[season][region] = []
            plot_data[season][region].append((mean_diff, ci_lower, ci_upper))

# ----------------------------
# Plotting setup
# ----------------------------
fig, axes = plt.subplots(2, 4, figsize=(7.5, 4.5), dpi=300)
axes = axes.flatten()

regions_plot_order = ['East Asia', 'North Africa', 'South Asia', 'Central America',
                      'South Africa', 'Australia', 'South America']
subplot_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
colors = {'rx1day': 'red'}
scenarios = [440, 470, 570, 1145]  # CO₂ ppm values

# ----------------------------
# Plot each region with error bars
# ----------------------------
for plot_idx, ax in enumerate(axes):
    if plot_idx >= len(regions_plot_order):
        ax.axis('off')  # Hide unused 8th subplot
        continue

    region = regions_plot_order[plot_idx]
    subplot_label = subplot_labels[plot_idx]

    for i, scenario in enumerate(scenarios):
        for season, color in colors.items():
            mean_diff, ci_lower, ci_upper = plot_data[season][region][i]
            ax.errorbar(scenario, mean_diff,
                        yerr=[[mean_diff - ci_lower], [ci_upper - mean_diff]],
                        fmt='o', capsize=4, markersize=4, color=color,
                        label=season if region == 'East Asia' and i == 0 else "")
    
    ax.axhline(0, color="gray", linestyle='dotted', linewidth=0.6)
    ax.set_xticks([400, 600, 800, 1000, 1200])
    ax.minorticks_on()
    ax.set_title(region, fontsize=7.5, fontfamily='serif')
    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontname('serif')
        label.set_fontsize(6.5)

    
    if region in ['East Asia', 'South Africa']:
        ax.set_ylabel('Down-Up [mm/day]', fontfamily='serif', fontsize=6.5)
    if region in ['South Africa', 'Australia', 'South America']:
        ax.set_xlabel(r'CO$_2$ [ppm]', fontfamily='serif', fontsize=6.5)
    # else:
    #     ax.set_xticklabels([])

    ax.tick_params(axis='both', labelsize=6.5)
    ax.text(0.1, 1.11, subplot_label, transform=ax.transAxes,
            fontfamily='serif', fontsize=7.5, fontweight='bold', va='top', ha='right')

# ----------------------------
# Final layout and save
# ----------------------------
plt.subplots_adjust(hspace=0.35, wspace=0.33)
plt.savefig('Figure_4.tif', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()
