In [None]:
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# Function to load, process, and calculate rolling mean for an experiment
def process_experiment(file_path, time_range):
    # Adjust time to sequential numbers starting at 15
    sequential_time = range(15, len(time_range) + 1)
    data = xr.open_dataset(file_path, decode_times=False).rx1day.isel(time=slice(14, len(time_range)))  # Skip first 14 years
    data = data.assign_coords(time=sequential_time).rename({'time': 'year'})
    data = data.rolling(year=30, center=True, min_periods=1).mean()
    return data

# Generate time ranges
MStimeA = pd.date_range("1750-01-01", freq="AS", periods=341)
MStimeD = pd.date_range("1750-01-01", freq="AS", periods=451)

# Process each experiment dataset with appropriate time ranges
pr_A1 = process_experiment('rx1day_cdr_A1p_01_03_0001-0341.nc', MStimeA)
pr_A2 = process_experiment('rx1day_cdr_A2p_01_03_0001-0341.nc', MStimeA)
pr_B1 = process_experiment('rx1day_cdr_B1p_01_03_0001-0341.nc', MStimeA)
pr_B2 = process_experiment('rx1day_cdr_B2p_01_03_0001-0341.nc', MStimeA)
pr_C1 = process_experiment('rx1day_cdr_C1p_01_03_0001-0341.nc', MStimeA)
pr_C2 = process_experiment('rx1day_cdr_C2p_01_03_0001-0341.nc', MStimeA)
pr_D1 = process_experiment('rx1day_cdr_D1p_01_03_0001-0451.nc', MStimeD)
pr_D2 = process_experiment('rx1day_cdr_D2p_01_03_0001-0451.nc', MStimeD)

# Dictionary of all regions with coordinates
regions = {    
    "Europe": {"lat_min": 37, "lat_max": 71, "lon_min": 0, "lon_max": 58, "lon_min_alt": 350, "lon_max_alt": 360},
    "North America": {"lat_min": 30, "lat_max": 60, "lon_min": 60, "lon_max": 115},
    "East Asia": {"lat_min": 30, "lat_max": 55, "lon_min": 110, "lon_max": 145},
    "North Africa": {"lat_min": 5, "lat_max": 22, "lon_min": 0, "lon_max": 40, "lon_min_alt": 345, "lon_max_alt": 360},
    "South Asia": {"lat_min": 5, "lat_max": 28, "lon_min": 62, "lon_max": 110},
    "Central America": {"lat_min": 4, "lat_max": 28, "lon_min": 250, "lon_max": 300},
    "South Africa": {"lat_min": -25, "lat_max": -4, "lon_min": 10, "lon_max": 40},
    "Australia": {"lat_min": -26, "lat_max": -8, "lon_min": 120, "lon_max": 150},
    "South America": {"lat_min": -30, "lat_max": -6, "lon_min": 280, "lon_max": 325}
    
}

# Assign datasets
datasets = {
    "prA1": pr_A1, "prB1": pr_B1, "prC1": pr_C1, "prD1": pr_D1,
    "prA2": pr_A2, "prB2": pr_B2, "prC2": pr_C2, "prD2": pr_D2
}

# Extract regional data
regions_data = {}
for region, coords in regions.items():
    data_list = []
    for label, data in datasets.items():
        region_data = data.sel(
            lat=slice(coords['lat_min'], coords['lat_max']),
            lon=slice(coords['lon_min'], coords['lon_max'])
        ).mean(['lat', 'lon'])
        data_list.append(region_data)
    regions_data[region] = data_list

# Define plot parameters
regions_plot_order = ['Europe', 'North America', 'East Asia', 'North Africa', 'South Asia', 'Central America', 'South Africa', 'Australia', 'South America']
line_styles = ['-', '-', '-', '-', ':', ':', ':', ':']
labels = ['A1%', 'B1%', 'C1%', 'D1%', 'A2%', 'B2%', 'C2%', 'D2%']
color_map = {'A': 'blue', 'B': 'green', 'C': 'orange', 'D': 'red'}
subplot_labels = ['a', 'b', 'b', 'c', 'd', 'e', 'f', 'g', 'h']

# Custom x-ticks for sequential numbers
custom_xticks = [1, 50, 100, 150, 200, 250, 300, 350, 400, 450]
custom_xticklabels = [1, 50, 100, 150, 200, 250, 300, 350, 400, 450]
vertical_lines = [15, 140, 265]
recession_ranges = [(1, 30), (125, 155), (250, 280)]
recession_colors = ['mistyrose', 'mistyrose', 'mistyrose']

# Create subplots
fig, axes = plt.subplots(3, 3, figsize=(9.90, 6.29), dpi=300, sharex=False, sharey=False)
axes = axes.flatten()

handles = []
for i, region in enumerate(regions_plot_order):
    if region in ['Europe', 'North America']:
        axes[i].axis('off')  # Hide the "Eur" figure by making it a blank white area
        continue

    ax = axes[i]

    # Calculate first value for the horizontal line (year=15 corresponds to isel(year=0))
    first_value = regions_data[region][0].isel(year=0).values

    # Plot horizontal line at the first value
    ax.axhline(first_value, color='grey', linestyle='--', linewidth=0.8, label='Start Value')

    for j, data in enumerate(regions_data[region]):
        label = labels[j]
        color = color_map[label[0]]
        style = line_styles[j]
        line = data.plot(ax=ax, label=label, add_legend=False, color=color, linestyle=style, linewidth=0.7)

        # Only add handles from the first region to avoid duplicates in the legend
        #if i == 0 and j == 0:
        handles.append(line[0])

        # Shaded recession areas
        for start, end in recession_ranges:
            ax.axvspan(start, end, alpha=0.3, color=recession_colors[j % len(recession_colors)])

    # Add vertical lines
    for vline in vertical_lines:
        ax.axvline(vline, color='black', linestyle='--', linewidth=0.8)

     
    # Enable minor grid lines
    ax.minorticks_on()
    ax.grid(which='minor', color='grey', linestyle='--', linewidth=0.01)
    #ax.grid(which='major', color='grey', linestyle='--', linewidth=0.1)

    

    ax.set_title(f"{region}", fontfamily='serif', fontsize=10.5)

    if i in [2, 3, 6]:
        ax.set_ylabel('[mm/day]', fontfamily='serif', fontsize=8.5)
    else:
        ax.set_ylabel('')

    if i in [6, 7, 8]:
        ax.set_xlabel('[Year]', fontfamily='serif', fontsize=8.5)
    else:
        ax.set_xlabel('')

    ax.tick_params(axis='both', which='major', labelsize=7.5)

    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontname('serif')
        label.set_fontsize(7.5)
    
    ax.text(0.1, 1.12, subplot_labels[i], transform=ax.transAxes,
            fontsize=10.5, fontfamily='serif', fontweight='bold', va='top', ha='right')

plt.figlegend(handles, labels, loc='upper right', bbox_to_anchor=(0.66, 0.88), ncol=2, fontsize=8.5)
plt.tight_layout()

plt.savefig('Figure_2b_h.tif', dpi=300, bbox_inches='tight', facecolor='white')  # Ensure the entire figure is saved

# Show the plot
plt.show()