In [None]:
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np

# Function to load, process, and calculate rolling mean for an experiment
def process_experiment(file_path, time_range):
    data = xr.open_dataset(file_path, decode_times=False).rx1day.isel(time=slice(None, len(time_range)))
    data = data.assign_coords(time=time_range).rename({'time': 'year'})
    data = data.rolling(year=30, center=True, min_periods=1).mean()
    return data

# Generate time ranges
MStimeA = pd.date_range("1750-01-01", freq="AS", periods=341)
MStimeD = pd.date_range("1750-01-01", freq="AS", periods=451)

# Process each experiment dataset with appropriate time ranges
pr_A1 = process_experiment('rx1day_cdr_A1p_01_03_0001-0341.nc', MStimeA)
pr_A2 = process_experiment('rx1day_cdr_A2p_01_03_0001-0341.nc', MStimeA)
pr_B1 = process_experiment('rx1day_cdr_B1p_01_03_0001-0341.nc', MStimeA)
pr_B2 = process_experiment('rx1day_cdr_B2p_01_03_0001-0341.nc', MStimeA)
pr_C1 = process_experiment('rx1day_cdr_C1p_01_03_0001-0341.nc', MStimeA)
pr_C2 = process_experiment('rx1day_cdr_C2p_01_03_0001-0341.nc', MStimeA)
pr_D1 = process_experiment('rx1day_cdr_D1p_01_03_0001-0451.nc', MStimeD)
pr_D2 = process_experiment('rx1day_cdr_D2p_01_03_0001-0451.nc', MStimeD)

# Regions with coordinates
regions = {
    "East Asia": {"lat_min": 30, "lat_max": 55, "lon_min": 110, "lon_max": 145},
    "North Africa": {"lat_min": 5, "lat_max": 22, "lon_min": 0, "lon_max": 40, "lon_min_alt": 345, "lon_max_alt": 360},
    "South Asia": {"lat_min": 5, "lat_max": 28, "lon_min": 62, "lon_max": 110},
    "Central America": {"lat_min": 4, "lat_max": 28, "lon_min": 250, "lon_max": 300},
    "South Africa": {"lat_min": -25, "lat_max": -4, "lon_min": 10, "lon_max": 40},
    "Australia": {"lat_min": -26, "lat_max": -8, "lon_min": 120, "lon_max": 150},
    "South America": {"lat_min": -30, "lat_max": -6, "lon_min": 280, "lon_max": 325}
}


# Precomputed seasonal data
datasets = {
    "prA1": pr_A1, "prB1": pr_B1, "prC1": pr_C1, "prD1": pr_D1,
    "prA2": pr_A2, "prB2": pr_B2, "prC2": pr_C2, "prD2": pr_D2
}

# Experiment slices
experiment_slices = {
    "prA1": [(0, 30), (58, 88)], "prB1": [(0, 30), (71, 102)],
    "prC1": [(0, 30), (110, 140)], "prD1": [(0, 30), (250, 280)],
    "prA2": [(0, 30), (58, 88)], "prB2": [(0, 30), (71, 102)],
    "prC2": [(0, 30), (110, 140)], "prD2": [(0, 30), (250, 280)]
}

# Function to apply slicing and calculate mean over lat/lon
def apply_slices(data, slices):
    return [data.isel(year=slice(start, end)).mean(dim=["lat", "lon"]) for start, end in slices]

# Function to calculate reversibility as (mean of last - mean of first) / mean of first * 100%
def calculate_reversibility(slice_data):
    first_slice_mean = slice_data[0].mean().item()
    last_slice_mean = slice_data[-1].mean().item()
    return ((last_slice_mean - first_slice_mean) / first_slice_mean * 100)

# Results dictionary to store reversibility
reversibility_results = {region: {} for region in regions}
for region_name, coords in regions.items():
    for exp, data in datasets.items():
        # Select region
        region_data = data.where(
            (data.lat >= coords["lat_min"]) & (data.lat <= coords["lat_max"]) &
            (data.lon >= coords.get("lon_min", 0)) & (data.lon <= coords.get("lon_max", 360))
        )
        slices = apply_slices(region_data, experiment_slices[exp])
        reversibility = calculate_reversibility(slices)
        reversibility_results[region_name][exp] = reversibility

# Plotting code
regions_plot_order = ['East Asia', 'North Africa', 'South Asia', 'Central America', 'South Africa', 'Australia', 'South America']

experiments = list(datasets.keys())
colors = ['blue', 'green', 'orange', 'red']
subplot_labels = ['a', 'b', 'a', 'b', 'c', 'd', 'e', 'f', 'g']

bar_width = 0.5  # Reduced bar width
bar_spacing = 0.0  # Adjusted spacing between bars within a group

# Updated layout: 2 rows × 4 columns
fig, axes = plt.subplots(2, 4, figsize=(7.5, 4.5), dpi=300, sharex=False, sharey=False)  # Wider aspect for 8 panels
axes = axes.flatten()

# Define regions to plot (only 7 regions), leave last subplot blank
regions_plot_order = ['East Asia', 'North Africa', 'South Asia', 'Central America',
                      'South Africa', 'Australia', 'South America']
subplot_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g']

for plot_idx, ax in enumerate(axes):
    if plot_idx >= len(regions_plot_order):
        ax.axis('off')  # Hide the 8th (empty) subplot
        continue

    region = regions_plot_order[plot_idx]
    subplot_label = subplot_labels[plot_idx]

    values_1 = [reversibility_results[region][exp] for exp in experiments[:4]]
    values_2 = [reversibility_results[region][exp] for exp in experiments[4:]]

    x_positions_1 = np.arange(len(values_1)) * (bar_width + bar_spacing)
    x_positions_2 = x_positions_1 + len(values_1) + bar_spacing

    for j, value in enumerate(values_1):
        ax.bar(x_positions_1[j], value, bar_width, color=colors[j], edgecolor='black', linewidth=0.5)

    for j, value in enumerate(values_2):
        ax.bar(x_positions_2[j], value, bar_width, color=colors[j], edgecolor='black', linewidth=0.5)

    ax.axvline(x=(x_positions_1[-1] + x_positions_2[0]) / 2, color='black', linestyle='--', linewidth=0.4)

    ax.set_title(f'{region}', fontsize=7.5, fontfamily='serif')

    if region in ['East Asia', 'North Africa', 'South Asia', 'Central America']:
        ax.set_xticks(x_positions_1.tolist() + x_positions_2.tolist(), minor=True)
        ax.set_xticks([x_positions_1.mean(), x_positions_2.mean()])
        ax.set_xticklabels(['1%', '2%'], fontsize=5.5, fontfamily='serif')
        #ax.set_xticklabels([])
    else:
        group_positions = [x_positions_1.mean(), x_positions_2.mean()]
        ax.set_xticks(group_positions)
        ax.set_xticklabels(['1%', '2%'], fontsize=5.5, fontfamily='serif')
        ax.set_xticks(x_positions_1.tolist() + x_positions_2.tolist(), minor=True)

    if region in ['East Asia', 'South Africa']:
        ax.set_ylabel('Irreversibility(%)', fontfamily='serif', fontsize=6.5, )
        ax.tick_params(axis='x', which='minor', length=2, color='black')

    ax.tick_params(axis='both', which='major', labelsize=5)
    ax.minorticks_on()
    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontname('serif')
        label.set_fontsize(6.5)

    for spine in ax.spines.values():
        spine.set_linewidth(0.5)

    ax.text(0.1, 1.11, subplot_label, transform=ax.transAxes,
            fontsize=7.5, fontfamily='serif', fontweight='bold', va='top', ha='right')

# Tighter spacing between subplots
plt.subplots_adjust(hspace=0.35, wspace=0.3)

# Add legend
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[i]) for i in range(4)]
plt.figlegend(handles, ['A', 'B', 'C', 'D'], loc='upper right',
              bbox_to_anchor=(0.8, 0.45), ncol=1, fontsize=6.5)

# Show or save
plt.savefig("Figure_3.tif", dpi=300, bbox_inches='tight', facecolor='white')
plt.show()
