# Extended Exploratory Data Analysis for Climate Emulation

This notebook expands on the basic EDA, following the plan outlined in `exploratory_data_analysis_plan.md`.

In [None]:
import zarr
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set some default plotting styles
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context('notebook')

In [None]:
# Load the data (update the path if necessary)
# The original notebook used a path relative to the 'notebooks' directory.
data_path = "../processed_data_cse151b_v2_corrupted_ssp245.zarr"
try:
    data = xr.open_zarr(data_path)
except FileNotFoundError:
    print(f"Data file not found at {data_path}. Please ensure the path is correct.")
    print("Attempting to load from project root as 'processed_data_cse151b_v2_corrupted_ssp245.zarr'")
    data_path = "processed_data_cse151b_v2_corrupted_ssp245.zarr"
    data = xr.open_zarr(data_path)
data

## 1. Understanding Variable Distributions & Characteristics

### Q1.1: What are the distributions (histograms, box plots) of each input forcing (`CO2`, `SO2`, `CH4`, `BC`, `rsdt`) and output variable (`tas`, `pr`) across all scenarios and time points?

In [None]:
input_forcings = ['CO2', 'SO2', 'CH4', 'BC', 'rsdt']
output_variables = ['tas', 'pr']
all_vars_to_plot = input_forcings + output_variables

for var_name in all_vars_to_plot:
    plt.figure(figsize=(12, 5))
    if var_name in ['CO2', 'CH4']:
        flat_data = data[var_name].data.compute().flatten()
    elif var_name in ['SO2', 'BC']:
        flat_data = data[var_name].mean(dim=['latitude', 'longitude']).data.compute().flatten()
    elif var_name == 'rsdt':
        flat_data = data[var_name].mean(dim=['y', 'x']).data.compute().flatten()
    elif var_name in ['tas', 'pr']:
        flat_data = data[var_name].mean(dim=['x', 'y', 'member_id']).data.compute().flatten()
    else:
        print(f"Skipping {var_name}, unknown dimension structure for simple histogram.")
        continue
    
    plt.subplot(1, 2, 1)
    sns.histplot(flat_data, kde=True, bins=50)
    plt.title(f'Histogram of {var_name} (globally/spatially averaged)')
    plt.xlabel(var_name)
    plt.ylabel('Frequency')
    
    plt.subplot(1, 2, 2)
    sns.boxplot(y=flat_data)
    plt.title(f'Boxplot of {var_name} (globally/spatially averaged)')
    plt.ylabel(var_name)
    
    plt.tight_layout()
    plt.show()

#### Observations for Q1.1: General Distributions
*   (Comment on typical ranges, presence of outliers, and overall shape for each variable based on the plots above).

### Q1.2: How do these distributions change if you look at them per SSP scenario?

In [None]:
ssp_scenarios = data.ssp.values

for var_name in all_vars_to_plot:
    plt.figure(figsize=(15, 7))
    for i, ssp in enumerate(ssp_scenarios):
        if ssp == 'ssp245' and "corrupted" in data_path:
            print(f"Skipping potentially corrupted ssp245 for {var_name} plotting.")
            # Plot a placeholder for corrupted data
            ax = plt.subplot(2, len(ssp_scenarios), i + 1)
            plt.text(0.5, 0.5, f'{var_name} ({ssp})\nData Corrupted', ha='center', va='center', transform=ax.transAxes)
            plt.xticks([])
            plt.yticks([])
            ax_box = plt.subplot(2, len(ssp_scenarios), i + 1 + len(ssp_scenarios))
            plt.text(0.5, 0.5, f'{var_name} ({ssp})\nData Corrupted', ha='center', va='center', transform=ax_box.transAxes)
            plt.xticks([])
            plt.yticks([])
            continue
        try:
            ssp_data = data.sel(ssp=ssp)
            if var_name in ['CO2', 'CH4']:
                flat_ssp_data = ssp_data[var_name].data.compute().flatten()
            elif var_name in ['SO2', 'BC']:
                flat_ssp_data = ssp_data[var_name].mean(dim=['latitude', 'longitude']).data.compute().flatten()
            elif var_name == 'rsdt':
                flat_ssp_data = ssp_data[var_name].mean(dim=['y', 'x']).data.compute().flatten()
            elif var_name in ['tas', 'pr']:
                flat_ssp_data = ssp_data[var_name].mean(dim=['x', 'y', 'member_id']).data.compute().flatten()
            else:
                continue
            
            plt.subplot(2, len(ssp_scenarios), i + 1)
            sns.histplot(flat_ssp_data, kde=False, bins=30) 
            plt.title(f'{var_name} ({ssp})')
            plt.xlabel(var_name)
            if i == 0: plt.ylabel('Frequency')

            plt.subplot(2, len(ssp_scenarios), i + 1 + len(ssp_scenarios))
            sns.boxplot(y=flat_ssp_data)
            plt.title(f'{var_name} ({ssp})')
            plt.ylabel(var_name)

        except Exception as e:
            print(f"Could not plot {var_name} for {ssp}: {e}")
            ax = plt.gca() # Get current axes to clear/annotate
            if ax.lines or ax.patches: ax.cla() # Clear if something was partially plotted
            plt.text(0.5, 0.5, f'Error plotting\n{var_name} for {ssp}', ha='center', va='center', transform=ax.transAxes)
            plt.xticks([])
            plt.yticks([])
            # Also handle the boxplot subplot if error occurs before it's plotted
            if i + 1 + len(ssp_scenarios) <= 2 * len(ssp_scenarios):
                 ax_box_err = plt.subplot(2, len(ssp_scenarios), i + 1 + len(ssp_scenarios))
                 if ax_box_err.lines or ax_box_err.patches: ax_box_err.cla()
                 plt.text(0.5, 0.5, f'Error plotting\n{var_name} for {ssp}', ha='center', va='center', transform=ax_box_err.transAxes)
                 plt.xticks([])
                 plt.yticks([])
    plt.suptitle(f'Distributions of {var_name} per SSP Scenario (globally/spatially averaged)', fontsize=16, y=1.02)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

#### Observations for Q1.2: Distributions per SSP
*   (Comment on how distributions differ across SSPs. E.g., higher SSPs show higher CO2 values, SO2 might decrease in some SSPs over time due to emission controls, etc.).

### Q1.3: For `pr` (precipitation), which is noted as skewed, what does its distribution look like? Would a transformation (e.g., log transform) make it more suitable for modeling?

In [None]:
var_name = 'pr'
plt.figure(figsize=(18, 6))

pr_flat_all_ssp_avg = data[var_name].mean(dim=['x', 'y', 'member_id']).data.compute().flatten()

plt.subplot(1, 3, 1)
sns.histplot(pr_flat_all_ssp_avg, kde=True, bins=50)
plt.title('Histogram of Precipitation (pr) - Original')
plt.xlabel('Precipitation (mm/day)')
plt.ylabel('Frequency')

epsilon = 1e-6  # Small constant
pr_log_transformed = np.log(pr_flat_all_ssp_avg + epsilon)

plt.subplot(1, 3, 2)
sns.histplot(pr_log_transformed, kde=True, bins=50)
plt.title('Histogram of log(pr + epsilon)')
plt.xlabel('log(Precipitation + epsilon)')
plt.ylabel('Frequency')

from scipy import stats
pr_positive = pr_flat_all_ssp_avg[pr_flat_all_ssp_avg > 0] 
if len(pr_positive) > 0:
    pr_boxcox, lmbda = stats.boxcox(pr_positive)
    print(f"Optimal lambda for Box-Cox on pr (positive values only): {lmbda}")
    plt.subplot(1, 3, 3)
    sns.histplot(pr_boxcox, kde=True, bins=50)
    plt.title(f'Histogram of Box-Cox(pr) (lambda={lmbda:.2f})')
    plt.xlabel('Box-Cox Transformed Precipitation')
    plt.ylabel('Frequency')
else:
    print("No positive precipitation values found for Box-Cox transformation.")

plt.tight_layout()
plt.show()

for ssp in ssp_scenarios:
    if ssp == 'ssp245' and "corrupted" in data_path:
        print(f"Skipping potentially corrupted ssp245 for pr transformation plots.")
        continue
    
    plt.figure(figsize=(12, 4))
    try:
        ssp_pr_data = data.sel(ssp=ssp)[var_name].mean(dim=['x', 'y', 'member_id']).data.compute().flatten()
        
        plt.subplot(1, 2, 1)
        sns.histplot(ssp_pr_data, kde=True, bins=30)
        plt.title(f'pr Original ({ssp})')
        plt.xlabel('Precipitation (mm/day)')

        ssp_pr_log_transformed = np.log(ssp_pr_data + epsilon)
        plt.subplot(1, 2, 2)
        sns.histplot(ssp_pr_log_transformed, kde=True, bins=30)
        plt.title(f'log(pr + epsilon) ({ssp})')
        plt.xlabel('log(Precipitation + epsilon)')
        
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"Could not plot pr transformations for {ssp}: {e}")
        plt.close() # Close the figure if an error occurred

#### Observations for Q1.3: Precipitation Skewness & Transformation
*   (Confirm skewness of `pr`. Does log transform or Box-Cox make it more symmetrical? Which seems better?).

## 2. Exploring Temporal Patterns & Seasonality

### Q2.1: Beyond global means, how does the seasonal cycle (e.g., monthly averages over all years) look for `tas` and `pr` at different, representative grid points?

In [None]:
representative_points = {
    'Polar (Antarctica approx)': {'y': 5, 'x': 35}, 
    'Mid-Latitude (NH land approx)': {'y': 30, 'x': 15}, 
    'Tropical (Ocean approx)': {'y': 23, 'x': 50} 
}
output_vars_temporal = ['tas', 'pr']

if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
    all_ssp_scenarios_q2_1 = data.ssp.values
    print(f"Found SSP scenarios for Q2.1: {all_ssp_scenarios_q2_1}")
else:
    print("Error: 'ssp' coordinate not found in data. Cannot plot for multiple SSPs for Q2.1.")
    all_ssp_scenarios_q2_1 = [] 

for ssp_value_q2_1 in all_ssp_scenarios_q2_1:
    ssp_to_plot_seasonal_q2_1 = str(ssp_value_q2_1) 
    if ssp_to_plot_seasonal_q2_1 == 'ssp245' and "corrupted" in data_path:
        print(f"\n--- Skipping Seasonal Cycle for SSP: {ssp_to_plot_seasonal_q2_1} (Corrupted Data) ---")
        continue

    for var_name_q2_1 in output_vars_temporal:
        print(f"\n--- Seasonal Cycle for {var_name_q2_1} (SSP: {ssp_to_plot_seasonal_q2_1}) ---")
        plt.figure(figsize=(15, 5 * len(representative_points)))
        plot_num_q2_1 = 1
        for point_name_q2_1, coords_q2_1 in representative_points.items():
            try:
                y_coord_val_q2_1 = data.y[coords_q2_1['y']].item()
                x_coord_val_q2_1 = data.x[coords_q2_1['x']].item()
                
                data_filtered_ssp_q2_1 = data[var_name_q2_1].sel(ssp=ssp_to_plot_seasonal_q2_1)
                point_data_selected_q2_1 = data_filtered_ssp_q2_1.sel(
                    y=y_coord_val_q2_1,
                    x=x_coord_val_q2_1,
                    method='nearest'
                )
                point_data_q2_1 = point_data_selected_q2_1.mean(dim='member_id')
                seasonal_cycle_q2_1 = point_data_q2_1.groupby('time.month').mean().compute()
                
                numeric_values_q2_1 = seasonal_cycle_q2_1.data.astype(float)
                months_q2_1 = seasonal_cycle_q2_1.month.data

                plt.subplot(len(representative_points), 1, plot_num_q2_1)
                plt.plot(months_q2_1, numeric_values_q2_1, label=f'{var_name_q2_1} at {point_name_q2_1}')
                plt.title(f'Mean Seasonal Cycle of {var_name_q2_1} at {point_name_q2_1} (y={y_coord_val_q2_1:.2f}, x={x_coord_val_q2_1:.2f}) - SSP: {ssp_to_plot_seasonal_q2_1}')
                plt.xlabel('Month')
                plt.ylabel(f'{var_name_q2_1} ({data[var_name_q2_1].attrs.get("units", "unknown units")})')
                plt.xticks(np.arange(1, 13), ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
                plt.grid(True)
                plot_num_q2_1 += 1
            except Exception as e_q2_1:
                print(f"Could not plot seasonal cycle for {var_name_q2_1} at {point_name_q2_1} (SSP: {ssp_to_plot_seasonal_q2_1}): {e_q2_1}")
        plt.tight_layout()
        plt.show()

#### Observations for Q2.1: Seasonal Cycles at Representative Points
*   (Describe the seasonal cycles for `tas` and `pr` at each point. Are they as expected for these locations? E.g., strong temperature seasonality at mid/high latitudes, different precipitation patterns for tropical vs. polar.).

### Q2.2: How does the seasonal cycle of `tas` and `pr` vary across the different SSP scenarios for these selected grid points?

In [None]:
if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
    ssp_scenarios_valid_q2_2 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
    print(f"Valid SSP scenarios for Q2.2 comparison: {ssp_scenarios_valid_q2_2}")
else:
    print("Warning: 'ssp' coordinate not found or not iterable in data. Cannot proceed with Q2.2 SSP comparison.")
    ssp_scenarios_valid_q2_2 = []

for var_name_q2_2 in output_vars_temporal: # Defined in Q2.1
    print(f"\n--- Seasonal Cycle Comparison for {var_name_q2_2} across SSPs ---")
    for point_name_q2_2, coords_q2_2 in representative_points.items(): # Defined in Q2.1
        plt.figure(figsize=(12, 6))
        any_ssp_plotted_q2_2 = False
        try:
            y_coord_val_q2_2 = data.y[coords_q2_2['y']].item()
            x_coord_val_q2_2 = data.x[coords_q2_2['x']].item()
        except Exception as e_coord_q2_2:
            print(f"  ERROR extracting coordinates for {point_name_q2_2} in Q2.2: {e_coord_q2_2}")
            plt.close()
            continue

        for ssp_value_q2_2 in ssp_scenarios_valid_q2_2:
            current_ssp_str_q2_2 = str(ssp_value_q2_2)
            try:
                data_filtered_ssp_var_q2_2 = data[var_name_q2_2].sel(ssp=current_ssp_str_q2_2)
                point_data_ssp_selected_q2_2 = data_filtered_ssp_var_q2_2.sel(
                    y=y_coord_val_q2_2,
                    x=x_coord_val_q2_2,
                    method='nearest'
                )
                point_data_ssp_q2_2 = point_data_ssp_selected_q2_2.mean(dim='member_id')
                seasonal_cycle_ssp_q2_2 = point_data_ssp_q2_2.groupby('time.month').mean().compute()
                
                numeric_values_ssp_q2_2 = seasonal_cycle_ssp_q2_2.data.astype(float)
                months_ssp_q2_2 = seasonal_cycle_ssp_q2_2.month.data
                
                plt.plot(months_ssp_q2_2, numeric_values_ssp_q2_2, label=f'{current_ssp_str_q2_2}')
                any_ssp_plotted_q2_2 = True
            except Exception as e_ssp_q2_2:
                print(f"  Could not process/plot data for {var_name_q2_2} at {point_name_q2_2} (SSP: {current_ssp_str_q2_2}) in Q2.2: {e_ssp_q2_2}")
                continue
        
        if any_ssp_plotted_q2_2:
            plt.title(f'Mean Seasonal Cycle of {var_name_q2_2} at {point_name_q2_2} (y={y_coord_val_q2_2:.2f}, x={x_coord_val_q2_2:.2f})')
            plt.xlabel('Month')
            plt.ylabel(f'{var_name_q2_2} ({data[var_name_q2_2].attrs.get("units", "unknown units")})')
            plt.xticks(np.arange(1, 13), ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
            plt.legend(title='SSP Scenario')
            plt.grid(True)
            plt.show()
        else:
            print(f"No valid SSP data to plot for {var_name_q2_2} at {point_name_q2_2} in Q2.2.")
            plt.close()

#### Observations for Q2.2: Seasonal Cycles Across SSPs
*   (How do seasonal cycles change with SSP? E.g., warmer temperatures year-round, shifts in precipitation timing or intensity?).

### Q2.3: Are there any apparent long-term trends in the input forcings (e.g., `CO2`) and how do they correspond to trends in `tas` and `pr` for each SSP?

In [None]:
print("--- Executing Q2.3: Long-term trends in input forcings and output variables ---")
if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
    all_ssp_values_q2_3 = data.ssp.values
    print(f"Found SSP scenarios for trend analysis: {all_ssp_values_q2_3}")
else:
    print("Warning: 'ssp' coordinate not found in data. Cannot plot trends for Q2.3.")
    all_ssp_values_q2_3 = []

for ssp_val_trend in all_ssp_values_q2_3:
    ssp_id_str_trend = str(ssp_val_trend)
    if ssp_id_str_trend == 'ssp245' and "corrupted" in data_path:
        print(f"\n--- Plotting Long-term Trends for SSP: {ssp_id_str_trend} (Note: This SSP data may be corrupted) ---")
    else:
        print(f"\n--- Plotting Long-term Trends for SSP: {ssp_id_str_trend} ---")

    vars_for_trend_plot_q2_3 = input_forcings + output_variables 
    num_total_vars_q2_3 = len(vars_for_trend_plot_q2_3)
    if num_total_vars_q2_3 == 0:
        print(f"  No variables defined for trend plotting for SSP {ssp_id_str_trend}.")
        continue

    fig_q2_3, axes_q2_3 = plt.subplots(num_total_vars_q2_3, 1, figsize=(18, 5 * num_total_vars_q2_3), sharex=True)
    if num_total_vars_q2_3 == 1: axes_q2_3 = [axes_q2_3]

    for i, var_name_trend in enumerate(vars_for_trend_plot_q2_3):
        ax_curr = axes_q2_3[i]
        try:
            data_for_ssp_var_trend = data.sel(ssp=ssp_id_str_trend)[var_name_trend]
            avg_method_label = ""
            if var_name_trend in ['CO2', 'CH4']:
                averaged_data_trend = data_for_ssp_var_trend
            elif var_name_trend in ['SO2', 'BC']:
                averaged_data_trend = data_for_ssp_var_trend.mean(dim=['latitude', 'longitude'])
                avg_method_label = " (Spatially Averaged)"
            elif var_name_trend == 'rsdt':
                averaged_data_trend = data_for_ssp_var_trend.mean(dim=['y', 'x'])
                avg_method_label = " (Spatially Averaged)"
            elif var_name_trend in ['tas', 'pr']:
                averaged_data_trend = data_for_ssp_var_trend.mean(dim=['x', 'y', 'member_id'])
                avg_method_label = " (Spatially & Member Averaged)"
            else:
                print(f"    Skipping {var_name_trend} for SSP {ssp_id_str_trend}: Unknown structure.")
                ax_curr.text(0.5, 0.5, f'Skipped\n{var_name_trend}\nUnknown structure', ha='center', va='center', transform=ax_curr.transAxes)
                if i == num_total_vars_q2_3 - 1 : ax_curr.set_xlabel('Year')
                continue
            
            annual_means_trend = averaged_data_trend.groupby('time.year').mean().compute()
            annual_means_trend.plot(ax=ax_curr, label=f'{var_name_trend} trend')
            ax_curr.set_title(f'Annual Mean Trend of {var_name_trend}{avg_method_label} (SSP: {ssp_id_str_trend})')
            ax_curr.set_ylabel(f'{data_for_ssp_var_trend.attrs.get("units", "N/A")}')
            ax_curr.grid(True)
            if i == num_total_vars_q2_3 - 1 : ax_curr.set_xlabel('Year')
        except Exception as e_trend:
            print(f"    ERROR plotting trend for {var_name_trend} (SSP: {ssp_id_str_trend}): {str(e_trend)[:200]}")
            ax_curr.cla()
            ax_curr.text(0.5, 0.5, f'Error plotting {var_name_trend}\nSSP: {ssp_id_str_trend}\n{str(e_trend)[:100]}', ha='center', va='center', transform=ax_curr.transAxes, fontsize=9, wrap=True)
            if i == num_total_vars_q2_3 - 1 : ax_curr.set_xlabel('Year')

    fig_q2_3.suptitle(f'Long-term Climate Variable Trends for SSP: {ssp_id_str_trend}', fontsize=18, y=1.02)
    fig_q2_3.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()
print("--- Finished Q2.3 plotting ---")

#### Observations for Q2.3: Long-term Trends
*   **CO2, CH4:** (Describe observed trends across SSPs. E.g., generally increasing, rate of increase varies by SSP, ssp126 might show stabilization/decrease later).
*   **SO2, BC:** (Describe observed trends. These might be more complex, decreasing in some SSPs due to pollution controls).
*   **rsdt:** (Describe trend. Likely stable or minor variations unless specific SSPs model solar radiation management, which is not typical for these base SSPs).
*   **tas (Temperature Anomaly):** (Describe trends. Expected to increase in all SSPs, with magnitude depending on the SSP. Compare with forcing trends. E.g., higher CO2/CH4 correlates with higher tas).
*   **pr (Precipitation):** (Describe trends. More complex, might show increases in some regions/globally, decreases in others. Global average trend might be subtle. Relate to `tas` trends if possible - warmer atmosphere holds more moisture).
*   **Corrupted ssp245:** (If ssp245 was plotted, comment on whether its trends look anomalous or if the corruption is evident in these long-term aggregated plots).
*   **Correspondence:** (Overall comments on how input forcing trends appear to drive output variable trends. Are the relationships clear from these plots?)

## 3. Investigating Spatial Patterns & Regional Differences


### Q3.1: How do the spatial patterns of mean `tas` and `pr` differ across the SSPs for specific decades (e.g., early, mid, late century)?


In [None]:
print("--- Executing Q3.1: Spatial patterns of mean tas and pr for specific decades ---")
decades_q3_1 = {
    '2020-2029': (2020, 2029),
    '2050-2059': (2050, 2059),
    '2090-2099': (2090, 2099)
}
vars_for_spatial_q3_1 = output_variables
if 'ssp_scenarios_valid_q2_2' in globals(): # Prefer more specific var if available
    ssp_scenarios_valid_q3_1 = ssp_scenarios_valid_q2_2
elif 'ssp_scenarios_valid' in globals():
    ssp_scenarios_valid_q3_1 = ssp_scenarios_valid
else:
    if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
        ssp_scenarios_valid_q3_1 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
        print(f"Redefined ssp_scenarios_valid for Q3.1: {ssp_scenarios_valid_q3_1}")
    else:
        print("Error: Cannot define ssp_scenarios_valid for Q3.1.")
        ssp_scenarios_valid_q3_1 = []

for var_name_q3_1 in vars_for_spatial_q3_1:
    print(f"\n-- Plotting spatial patterns for {var_name_q3_1} --")
    all_decadal_means_q3_1 = []
    for ssp_str_q3_1 in ssp_scenarios_valid_q3_1:
        for decade_name_q3_1, (start_year_q3_1, end_year_q3_1) in decades_q3_1.items():
            try:
                decade_data_q3_1 = data[var_name_q3_1].sel(ssp=ssp_str_q3_1, time=slice(str(start_year_q3_1), str(end_year_q3_1)))
                if decade_data_q3_1.time.size > 0:
                    mean_for_decade_ssp_q3_1 = decade_data_q3_1.mean(dim=['time', 'member_id']).compute()
                    all_decadal_means_q3_1.append(mean_for_decade_ssp_q3_1)
            except Exception as e_fetch_clim:
                print(f"    Could not fetch data for {var_name_q3_1}, {ssp_str_q3_1}, {decade_name_q3_1} for clim: {e_fetch_clim}")
    
    if not all_decadal_means_q3_1:
        print(f"    No data to determine colorbar limits for {var_name_q3_1}. Skipping plots.")
        continue
        
    vmin_q3_1 = np.min([da.min().item() for da in all_decadal_means_q3_1 if da.size > 0])
    vmax_q3_1 = np.max([da.max().item() for da in all_decadal_means_q3_1 if da.size > 0])
    print(f"    Colorbar limits for {var_name_q3_1}: vmin={vmin_q3_1:.2f}, vmax={vmax_q3_1:.2f}")

    num_decades_q3_1 = len(decades_q3_1)
    num_ssps_q3_1 = len(ssp_scenarios_valid_q3_1)
    
    fig_q3_1, axes_q3_1 = plt.subplots(num_decades_q3_1, num_ssps_q3_1, figsize=(5 * num_ssps_q3_1, 4.5 * num_decades_q3_1), sharex=True, sharey=True, squeeze=False)
    fig_q3_1.suptitle(f'Spatial Patterns of Mean {var_name_q3_1} Across SSPs and Decades', fontsize=16, y=1.0)
    plot_obj_ref_q3_1 = None 

    for i_dec_q3_1, (decade_name_q3_1, (start_year_q3_1, end_year_q3_1)) in enumerate(decades_q3_1.items()):
        for j_ssp_q3_1, ssp_str_q3_1 in enumerate(ssp_scenarios_valid_q3_1):
            ax_q3_1 = axes_q3_1[i_dec_q3_1, j_ssp_q3_1]
            try:
                decade_data_q3_1 = data[var_name_q3_1].sel(ssp=ssp_str_q3_1, time=slice(str(start_year_q3_1), str(end_year_q3_1)))
                if decade_data_q3_1.time.size == 0:
                    ax_q3_1.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax_q3_1.transAxes)
                else:
                    mean_for_plot_q3_1 = decade_data_q3_1.mean(dim=['time', 'member_id']).compute()
                    im = mean_for_plot_q3_1.plot.imshow(ax=ax_q3_1, add_colorbar=False, vmin=vmin_q3_1, vmax=vmax_q3_1, cmap='viridis')
                    plot_obj_ref_q3_1 = im # Store for colorbar
                
                ax_q3_1.set_title(f'{ssp_str_q3_1}')
                if j_ssp_q3_1 == 0: ax_q3_1.set_ylabel(decade_name_q3_1)
                else: ax_q3_1.set_ylabel('')
                ax_q3_1.set_xlabel('')
            except Exception as e_plot_q3_1:
                print(f"    Could not plot {var_name_q3_1} for {ssp_str_q3_1}, {decade_name_q3_1}: {e_plot_q3_1}")
                ax_q3_1.text(0.5, 0.5, 'Error', ha='center', va='center', transform=ax_q3_1.transAxes)
            ax_q3_1.set_xticks([]); ax_q3_1.set_yticks([])

    if plot_obj_ref_q3_1:
        fig_q3_1.colorbar(plot_obj_ref_q3_1, ax=axes_q3_1, orientation='vertical', shrink=0.8, aspect=40, pad=0.02, label=f'{var_name_q3_1} ({data[var_name_q3_1].attrs.get("units", "")})')
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
print("--- Finished Q3.1 plotting ---")

#### Observations for Q3.1: Spatial Patterns of Mean `tas` and `pr` by Decade
*   **`tas` (Temperature Anomaly):**
    *   (Describe general patterns: e.g., greater warming over land vs. ocean, Arctic amplification).
    *   (How do these patterns evolve across decades? Do differences between SSPs become more pronounced in later decades?).
    *   (Are there specific regions that show particularly strong warming or interesting changes?).
*   **`pr` (Precipitation):**
    *   (Describe general patterns: e.g., increases in some regions, decreases in others, changes in wet/dry zones).
    *   (How do these patterns evolve across decades and differ between SSPs?).
    *   (Note any correspondence with temperature changes, e.g., "wet-get-wetter, dry-get-drier" patterns).
*   **Comparison across SSPs:**
    *   (For each decade, how do the spatial patterns differ between low-emission (e.g., ssp126) and high-emission (e.g., ssp585) scenarios?).

### Q3.2: What are the spatial patterns of *variance* or *standard deviation* for `tas` and `pr` over time for each SSP? Where is climate variability highest/lowest?\n

In [None]:
print("--- Executing Q3.2: Spatial patterns of temporal standard deviation for tas and pr ---")
vars_for_std_q3_2 = output_variables
if 'ssp_scenarios_valid_q3_1' in globals(): # Prefer more specific var if available
    ssp_scenarios_valid_q3_2 = ssp_scenarios_valid_q3_1
elif 'ssp_scenarios_valid' in globals():
    ssp_scenarios_valid_q3_2 = ssp_scenarios_valid
else:
    if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
        ssp_scenarios_valid_q3_2 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
        print(f"Redefined ssp_scenarios_valid for Q3.2: {ssp_scenarios_valid_q3_2}")
    else:
        print("Error: Cannot define ssp_scenarios_valid for Q3.2.")
        ssp_scenarios_valid_q3_2 = []

for var_name_q3_2 in vars_for_std_q3_2:
    print(f"\n-- Plotting temporal standard deviation for {var_name_q3_2} --")
    all_ssp_std_devs_q3_2 = []
    for ssp_str_q3_2 in ssp_scenarios_valid_q3_2:
        try:
            ssp_data_q3_2 = data[var_name_q3_2].sel(ssp=ssp_str_q3_2).mean(dim='member_id')
            temporal_std_dev_q3_2 = ssp_data_q3_2.std(dim='time').compute()
            all_ssp_std_devs_q3_2.append(temporal_std_dev_q3_2)
        except Exception as e_fetch_clim_std:
            print(f"    Could not fetch/compute std data for {var_name_q3_2}, {ssp_str_q3_2} for clim: {e_fetch_clim_std}")
            
    if not all_ssp_std_devs_q3_2:
        print(f"    No std dev data for {var_name_q3_2}. Skipping plots.")
        continue
        
    vmin_std_q3_2 = np.min([da.min().item() for da in all_ssp_std_devs_q3_2 if da.size > 0])
    vmax_std_q3_2 = np.max([da.max().item() for da in all_ssp_std_devs_q3_2 if da.size > 0])
    vmin_std_q3_2 = max(0, vmin_std_q3_2)
    if vmax_std_q3_2 <= vmin_std_q3_2 : vmax_std_q3_2 = vmin_std_q3_2 + 1e-6 
    print(f"    Colorbar limits for std dev of {var_name_q3_2}: vmin={vmin_std_q3_2:.2e}, vmax={vmax_std_q3_2:.2e}")

    num_ssps_q3_2 = len(ssp_scenarios_valid_q3_2)
    fig_q3_2, axes_q3_2 = plt.subplots(1, num_ssps_q3_2, figsize=(5 * num_ssps_q3_2, 4.5), sharex=True, sharey=True, squeeze=False)
    fig_q3_2.suptitle(f'Spatial Patterns of Temporal Standard Deviation for {var_name_q3_2}', fontsize=16, y=1.0)
    plot_obj_ref_q3_2 = None

    for j_ssp_q3_2, ssp_str_q3_2 in enumerate(ssp_scenarios_valid_q3_2):
        ax_q3_2 = axes_q3_2[0, j_ssp_q3_2]
        try:
            # Use the pre-calculated std_dev if available and aligned, otherwise recompute
            std_dev_to_plot_q3_2 = next((s for s in all_ssp_std_devs_q3_2 if ssp_str_q3_2 in str(s.ssp.values_at_time_of_compute) ), None) 
            if std_dev_to_plot_q3_2 is None: # Fallback to recompute if not found (should not happen ideally)
                 ssp_data_to_plot_q3_2 = data[var_name_q3_2].sel(ssp=ssp_str_q3_2).mean(dim='member_id')
                 std_dev_to_plot_q3_2 = ssp_data_to_plot_q3_2.std(dim='time').compute()

            im = std_dev_to_plot_q3_2.plot.imshow(ax=ax_q3_2, add_colorbar=False, vmin=vmin_std_q3_2, vmax=vmax_std_q3_2, cmap='magma')
            plot_obj_ref_q3_2 = im
            ax_q3_2.set_title(f'{ssp_str_q3_2}')
            ax_q3_2.set_xlabel(''); ax_q3_2.set_ylabel('')
        except Exception as e_plot_q3_2:
            print(f"    Could not plot std dev for {var_name_q3_2}, {ssp_str_q3_2}: {e_plot_q3_2}")
            ax_q3_2.text(0.5, 0.5, 'Error', ha='center', va='center', transform=ax_q3_2.transAxes)
        ax_q3_2.set_xticks([]); ax_q3_2.set_yticks([])

    if plot_obj_ref_q3_2:
        fig_q3_2.colorbar(plot_obj_ref_q3_2, ax=axes_q3_2, orientation='vertical', shrink=0.8, aspect=20, pad=0.02, label=f'Std Dev of {var_name_q3_2} ({data[var_name_q3_2].attrs.get("units", "")})')
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
print("--- Finished Q3.2 plotting ---")

#### Observations for Q3.2: Spatial Patterns of Temporal Standard Deviation
*   **`tas` (Temperature Anomaly Standard Deviation):**
    *   (Where is interannual temperature variability highest/lowest? E.g., higher over continents, specific ocean regions like ENSO areas, or polar regions?).
    *   (How does this pattern of variability change across SSPs? Does higher warming lead to higher/lower variability in certain areas?).
*   **`pr` (Precipitation Standard Deviation):**
    *   (Where is interannual precipitation variability highest/lowest? E.g., tropical convective regions, monsoon areas, storm tracks).
    *   (How do these patterns of variability differ across SSPs?).
*   **General Comments:**
    *   (Are there regions that consistently show high variability for both temperature and precipitation?).
    *   (How might these patterns of variability impact the predictability or modeling difficulty for these regions?).

### Q3.3: How do the spatial patterns of the input forcings (especially `SO2` and `BC` which can have more localized effects than `CO2` or `CH4`) look? Are they globally uniform or do they show regional concentrations?\n

In [None]:
print("--- Executing Q3.3: Spatial patterns of input forcings ---")
spatial_forcings_q3_3 = ['SO2', 'BC', 'rsdt']
if 'ssp_scenarios_valid_q3_2' in globals(): # Prefer more specific var if available
    ssp_scenarios_valid_q3_3 = ssp_scenarios_valid_q3_2
elif 'ssp_scenarios_valid' in globals():
    ssp_scenarios_valid_q3_3 = ssp_scenarios_valid
else:
    if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
        ssp_scenarios_valid_q3_3 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
        print(f"Redefined ssp_scenarios_valid for Q3.3: {ssp_scenarios_valid_q3_3}")
    else:
        print("Error: Cannot define ssp_scenarios_valid for Q3.3.")
        ssp_scenarios_valid_q3_3 = []

for forcing_name_q3_3 in spatial_forcings_q3_3:
    print(f"\n-- Plotting spatial patterns for input forcing: {forcing_name_q3_3} --")
    all_ssp_means_q3_3 = []
    for ssp_str_q3_3 in ssp_scenarios_valid_q3_3:
        try:
            ssp_forcing_data_q3_3 = data[forcing_name_q3_3].sel(ssp=ssp_str_q3_3).mean(dim='time').compute()
            all_ssp_means_q3_3.append(ssp_forcing_data_q3_3)
        except Exception as e_fetch_clim_forcing:
            print(f"    Could not fetch/compute mean data for {forcing_name_q3_3}, {ssp_str_q3_3} for clim: {e_fetch_clim_forcing}")

    if not all_ssp_means_q3_3:
        print(f"    No mean data for {forcing_name_q3_3}. Skipping plots.")
        continue
        
    vmin_forcing_q3_3 = np.min([da.min().item() for da in all_ssp_means_q3_3 if da.size > 0])
    vmax_forcing_q3_3 = np.max([da.max().item() for da in all_ssp_means_q3_3 if da.size > 0])
    if vmax_forcing_q3_3 <= vmin_forcing_q3_3 : vmax_forcing_q3_3 = vmin_forcing_q3_3 + 1e-9 
    print(f"    Colorbar limits for mean {forcing_name_q3_3}: vmin={vmin_forcing_q3_3:.2e}, vmax={vmax_forcing_q3_3:.2e}")

    num_ssps_q3_3 = len(ssp_scenarios_valid_q3_3)
    fig_q3_3, axes_q3_3 = plt.subplots(1, num_ssps_q3_3, figsize=(5 * num_ssps_q3_3, 4.5), sharex=True, sharey=True, squeeze=False)
    fig_q3_3.suptitle(f'Mean Spatial Patterns of Input Forcing: {forcing_name_q3_3}', fontsize=16, y=1.0)
    plot_obj_ref_q3_3 = None

    for j_ssp_q3_3, ssp_str_q3_3 in enumerate(ssp_scenarios_valid_q3_3):
        ax_q3_3 = axes_q3_3[0, j_ssp_q3_3]
        try:
            mean_forcing_to_plot_q3_3 = next((s for s in all_ssp_means_q3_3 if ssp_str_q3_3 in str(s.ssp.values_at_time_of_compute)), None)
            if mean_forcing_to_plot_q3_3 is None: # Fallback
                mean_forcing_to_plot_q3_3 = data[forcing_name_q3_3].sel(ssp=ssp_str_q3_3).mean(dim='time').compute()
            
            im = mean_forcing_to_plot_q3_3.plot.imshow(ax=ax_q3_3, add_colorbar=False, vmin=vmin_forcing_q3_3, vmax=vmax_forcing_q3_3, cmap='inferno')
            plot_obj_ref_q3_3 = im
            ax_q3_3.set_title(f'{ssp_str_q3_3}')
            ax_q3_3.set_xlabel(''); ax_q3_3.set_ylabel('')
        except Exception as e_plot_q3_3:
            print(f"    Could not plot mean for {forcing_name_q3_3}, {ssp_str_q3_3}: {e_plot_q3_3}")
            ax_q3_3.text(0.5, 0.5, f'Error plotting mean\n{forcing_name_q3_3}', ha='center', va='center', transform=ax_q3_3.transAxes)
        ax_q3_3.set_xticks([]); ax_q3_3.set_yticks([])

    if plot_obj_ref_q3_3:
        fig_q3_3.colorbar(plot_obj_ref_q3_3, ax=axes_q3_3, orientation='vertical', shrink=0.8, aspect=20, pad=0.02, label=f'Mean {forcing_name_q3_3} ({data[forcing_name_q3_3].attrs.get("units", "")})')
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
print("--- Finished Q3.3 plotting ---")

#### Observations for Q3.3: Spatial Patterns of Input Forcings
*   **`SO2` (Sulfur Dioxide):**
    *   (Describe spatial patterns. E.g., concentrations over industrial regions in the Northern Hemisphere, shipping lanes?).
    *   (How do these patterns differ across SSPs, reflecting different emission control policies or economic development pathways?).
*   **`BC` (Black Carbon):**
    *   (Describe spatial patterns. E.g., concentrations over regions with biomass burning, industrial activity, or heavy transportation).
    *   (How do these patterns vary by SSP?).
*   **`rsdt` (Incoming Shortwave Radiation):**
    *   (Describe spatial patterns. Primarily latitudinal gradient? Any longitudinal variations due to cloudiness if not a clear-sky variable? This dataset's `rsdt` is likely top-of-atmosphere, so it should be mostly zonal and follow solar insolation patterns).
    *   (Do these patterns change significantly across SSPs? Unlikely for `rsdt` unless specific geoengineering is implied, which is not standard for these SSPs).
*   **Global Uniformity vs. Regional Concentrations:**
    *   (Confirm that `SO2` and `BC` show strong regional concentrations, while `rsdt` is more zonally uniform. `CO2` and `CH4` are not plotted spatially here as they are globally mixed in the model inputs).

## 4. Analyzing Relationships Between Variables (Forcings & Outputs)


### Q4.1: For selected grid points or regions, can you plot time series of individual forcings (e.g., `CO2`, `SO2`) alongside `tas` and `pr` for different SSPs? Are there visible correlations or lagged effects?


In [None]:
print("--- Executing Q4.1: Time series of forcings and outputs at selected points ---")
if 'representative_points' not in globals():
    print("Error: `representative_points` not defined. Please run Q2.1 first.")
    representative_points = {
        'Polar (Antarctica approx)': {'y': 5, 'x': 35},
        'Mid-Latitude (NH land approx)': {'y': 30, 'x': 15},
        'Tropical (Ocean approx)': {'y': 23, 'x': 50}
    }

if 'ssp_scenarios_valid_q3_3' in globals(): # Prefer more specific var if available
    ssp_scenarios_valid_q4_1 = ssp_scenarios_valid_q3_3
elif 'ssp_scenarios_valid' in globals():
    ssp_scenarios_valid_q4_1 = ssp_scenarios_valid
else:
    if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
        ssp_scenarios_valid_q4_1 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
        print(f"Redefined ssp_scenarios_valid for Q4.1: {ssp_scenarios_valid_q4_1}")
    else:
        print("Error: Cannot define ssp_scenarios_valid for Q4.1.")
        ssp_scenarios_valid_q4_1 = [] 

forcings_to_plot_q4_1 = ['CO2', 'SO2', 'BC', 'rsdt'] 
outputs_to_plot_q4_1 = ['tas', 'pr'] 
variables_for_timeseries_q4_1 = forcings_to_plot_q4_1 + outputs_to_plot_q4_1

for point_name_q4_1, coords_q4_1 in representative_points.items():
    print(f"\n-- Plotting time series at {point_name_q4_1} --")
    num_vars_to_plot_q4_1 = len(variables_for_timeseries_q4_1)
    fig_q4_1, axes_q4_1 = plt.subplots(num_vars_to_plot_q4_1, 1, figsize=(18, 4 * num_vars_to_plot_q4_1), sharex=True)
    if num_vars_to_plot_q4_1 == 1: axes_q4_1 = [axes_q4_1]
    fig_q4_1.suptitle(f'Time Series at {point_name_q4_1} (y={data.y[coords_q4_1["y"]].item():.2f}, x={data.x[coords_q4_1["x"]].item():.2f})', fontsize=16, y=1.0)

    for i_var_q4_1, var_name_q4_1 in enumerate(variables_for_timeseries_q4_1):
        ax_q4_1 = axes_q4_1[i_var_q4_1]
        for ssp_str_q4_1 in ssp_scenarios_valid_q4_1:
            try:
                ssp_var_data_q4_1 = data[var_name_q4_1].sel(ssp=ssp_str_q4_1)
                if var_name_q4_1 in ['CO2', 'CH4']:
                    point_timeseries_q4_1 = ssp_var_data_q4_1.compute()
                elif var_name_q4_1 in ['SO2', 'BC']:
                    if 'latitude' in ssp_var_data_q4_1.dims and 'longitude' in ssp_var_data_q4_1.dims:
                         print(f"Warning: For {var_name_q4_1} at {point_name_q4_1} (SSP {ssp_str_q4_1}), using spatial mean due to lat/lon coords.")
                         point_timeseries_q4_1 = ssp_var_data_q4_1.mean(dim=['latitude', 'longitude']).compute()
                    elif 'y' in ssp_var_data_q4_1.dims and 'x' in ssp_var_data_q4_1.dims:
                         point_timeseries_q4_1 = ssp_var_data_q4_1.sel(y=data.y[coords_q4_1['y']], x=data.x[coords_q4_1['x']], method='nearest').compute()
                    else: 
                         point_timeseries_q4_1 = ssp_var_data_q4_1.mean(dim=[d for d in ssp_var_data_q4_1.dims if d not in ['ssp', 'time']]).compute()
                elif var_name_q4_1 == 'rsdt': 
                    point_timeseries_q4_1 = ssp_var_data_q4_1.sel(y=data.y[coords_q4_1['y']], x=data.x[coords_q4_1['x']], method='nearest').compute()
                elif var_name_q4_1 in outputs_to_plot_q4_1: 
                    point_timeseries_q4_1 = ssp_var_data_q4_1.sel(y=data.y[coords_q4_1['y']], x=data.x[coords_q4_1['x']], method='nearest').mean(dim='member_id').compute()
                else:
                    print(f"    Skipping {var_name_q4_1} for {point_name_q4_1}, SSP {ssp_str_q4_1}: unknown structure.")
                    continue
                annual_mean_ts_q4_1 = point_timeseries_q4_1.groupby('time.year').mean()
                annual_mean_ts_q4_1.plot(ax=ax_q4_1, label=f'{ssp_str_q4_1}')
            except Exception as e_plot_ts_q4_1:
                print(f"    Could not plot time series for {var_name_q4_1}, {ssp_str_q4_1} at {point_name_q4_1}: {e_plot_ts_q4_1}")
        
        ax_q4_1.set_title(f'{var_name_q4_1}')
        ax_q4_1.set_ylabel(f'{data[var_name_q4_1].attrs.get("units", "")}')
        ax_q4_1.grid(True)
        if i_var_q4_1 == 0: ax_q4_1.legend(title='SSP', loc='upper left', bbox_to_anchor=(1.02, 1))
        if i_var_q4_1 == num_vars_to_plot_q4_1 -1 : ax_q4_1.set_xlabel('Year')
        else: ax_q4_1.set_xlabel('')

    plt.tight_layout(rect=[0, 0, 0.88, 0.97]) 
    plt.show()
print("--- Finished Q4.1 plotting ---")

#### Observations for Q4.1: Time Series of Forcings and Outputs at Selected Points
*   **For each Representative Point (`Polar`, `Mid-Latitude`, `Tropical`):**
    *   **`CO2` vs. `tas`/`pr`:**
        *   (Describe the relationship. Is there a clear correlation between rising CO2 and `tas` trends at this point across SSPs? How does `pr` respond?).
    *   **`SO2` (local/regional) vs. `tas`/`pr`:**
        *   (Describe. SO2 can have a cooling effect. Is this visible, especially in SSPs with high SO2 early on, followed by reductions? How does this interact with the CO2 warming signal for `tas` and `pr` at this point?).
    *   **`BC` (local/regional) vs. `tas`/`pr`:**
        *   (Describe. BC has a warming effect. Can you see its influence, perhaps more localized or differing from the CO2 signal? How does it affect `pr` patterns at this point?).
    *   **`rsdt` (local) vs. `tas`/`pr`:**
        *   (Describe. `rsdt` is a primary driver of seasonal cycles. Here we are looking at annual means. Does the interannual variability or long-term trend of local `rsdt` (if any) correlate with `tas` or `pr` changes beyond the dominant greenhouse gas signal?).
*   **General Correlations & Lagged Effects:**
    *   (Are there any apparent lagged responses, e.g., does `tas` seem to follow CO2 changes with a delay? This might be hard to see without more formal cross-correlation analysis but visual impressions can be noted).
    *   (How do the relationships differ across the selected points? E.g., is the CO2-tas relationship stronger/clearer at some locations than others?).

### Q4.2: Can you compute correlation maps between each forcing and `tas`/`pr`? For example, a map showing the correlation of local `tas` with global `CO2` over time.


In [None]:
print("--- Executing Q4.2: Correlation maps between forcings and outputs ---")
if 'ssp_scenarios_valid_q4_1' in globals(): # Prefer more specific var if available
    ssp_scenarios_valid_q4_2 = ssp_scenarios_valid_q4_1
elif 'ssp_scenarios_valid' in globals():
    ssp_scenarios_valid_q4_2 = ssp_scenarios_valid
else:
    if 'ssp' in data.coords and hasattr(data.ssp, 'values'):
        ssp_scenarios_valid_q4_2 = [ssp_val for ssp_val in data.ssp.values if not (str(ssp_val) == 'ssp245' and 'corrupted' in data_path)]
        print(f"Redefined ssp_scenarios_valid for Q4.2: {ssp_scenarios_valid_q4_2}")
    else:
        print("Error: Cannot define ssp_scenarios_valid for Q4.2.")
        ssp_scenarios_valid_q4_2 = [] 

for output_var_q4_2 in output_variables:
    for forcing_var_q4_2 in input_forcings:
        print(f"\n-- Plotting Correlation Maps: {forcing_var_q4_2} vs {output_var_q4_2} --")
        num_ssps_q4_2 = len(ssp_scenarios_valid_q4_2)
        ncols_q4_2 = 2 
        nrows_q4_2 = (num_ssps_q4_2 + ncols_q4_2 - 1) // ncols_q4_2
        fig_q4_2, axes_q4_2 = plt.subplots(nrows_q4_2, ncols_q4_2, figsize=(6 * ncols_q4_2, 5.5 * nrows_q4_2), sharex=True, sharey=True, squeeze=False)
        fig_q4_2.suptitle(f'Correlation: Annual Mean {forcing_var_q4_2} vs Annual Mean {output_var_q4_2}', fontsize=16, y=1.0)
        axes_flat_q4_2 = axes_q4_2.flatten()
        plot_obj_ref_q4_2 = None 

        for i_ssp_q4_2, ssp_str_q4_2 in enumerate(ssp_scenarios_valid_q4_2):
            ax_q4_2 = axes_flat_q4_2[i_ssp_q4_2]
            try:
                output_data_ssp_q4_2 = data[output_var_q4_2].sel(ssp=ssp_str_q4_2).mean(dim='member_id')
                output_annual_mean_q4_2 = output_data_ssp_q4_2.groupby('time.year').mean(dim='time')
                forcing_data_ssp_q4_2 = data[forcing_var_q4_2].sel(ssp=ssp_str_q4_2)
                forcing_annual_mean_q4_2 = forcing_data_ssp_q4_2.groupby('time.year').mean(dim='time')

                if forcing_var_q4_2 in ['CO2', 'CH4']:
                    correlation_map_q4_2 = xr.corr(forcing_annual_mean_q4_2.rename({'year':'common_time'}), output_annual_mean_q4_2.rename({'year':'common_time'}), dim='common_time').compute()
                else: 
                    forcing_dims = set(forcing_annual_mean_q4_2.dims)
                    output_dims = set(output_annual_mean_q4_2.dims)
                    can_correlate_spatially = False
                    if ('y' in forcing_dims and 'x' in forcing_dims and 'y' in output_dims and 'x' in output_dims) or \
                       ('latitude' in forcing_dims and 'longitude' in forcing_dims and 'latitude' in output_dims and 'longitude' in output_dims):
                        can_correlate_spatially = True
                    
                    if can_correlate_spatially:
                         correlation_map_q4_2 = xr.corr(forcing_annual_mean_q4_2.rename({'year':'common_time'}), output_annual_mean_q4_2.rename({'year':'common_time'}), dim='common_time').compute()
                    else:
                        print(f"    Warning: For {forcing_var_q4_2} vs {output_var_q4_2} (SSP: {ssp_str_q4_2}), spatial dims mismatch. Correlating global mean of forcing.")
                        spatial_dims_forcing = [d for d in forcing_annual_mean_q4_2.dims if d != 'year']
                        global_forcing_annual_mean_q4_2 = forcing_annual_mean_q4_2.mean(dim=spatial_dims_forcing)
                        correlation_map_q4_2 = xr.corr(global_forcing_annual_mean_q4_2.rename({'year':'common_time'}), output_annual_mean_q4_2.rename({'year':'common_time'}), dim='common_time').compute()

                im = correlation_map_q4_2.plot.imshow(ax=ax_q4_2, cmap='RdBu_r', vmin=-1, vmax=1, add_colorbar=False)
                plot_obj_ref_q4_2 = im 
                ax_q4_2.set_title(ssp_str_q4_2)
                ax_q4_2.set_xlabel(''); ax_q4_2.set_ylabel('')
            except Exception as e_corr_map_q4_2:
                print(f"    Could not compute/plot correlation map for {forcing_var_q4_2} vs {output_var_q4_2}, SSP {ssp_str_q4_2}: {e_corr_map_q4_2}")
                ax_q4_2.text(0.5, 0.5, 'Error', ha='center', va='center', transform=ax_q4_2.transAxes)
            ax_q4_2.set_xticks([]); ax_q4_2.set_yticks([])
        
        for i_extra_ax in range(num_ssps_q4_2, nrows_q4_2 * ncols_q4_2):
            fig_q4_2.delaxes(axes_flat_q4_2[i_extra_ax])
            
        if plot_obj_ref_q4_2:
            fig_q4_2.colorbar(plot_obj_ref_q4_2, ax=axes_q4_2.ravel().tolist(), orientation='vertical', shrink=0.7, aspect=30, pad=0.03, label='Pearson Correlation')
        
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
print("--- Finished Q4.2 plotting ---")

#### Observations for Q4.2: Correlation Maps
*   **`tas` vs. Forcings:**
    *   **`CO2` (global) vs. `tas` (local):** (Describe spatial patterns of correlation. Expect positive correlations globally, perhaps stronger in some regions like land areas or high latitudes?).
    *   **`SO2` (local/global mean) vs. `tas` (local):** (Describe. If local-local, expect negative correlations where SO2 concentrations are high, indicating cooling effect? How localized is this? If global mean SO2, what does the pattern look like?).
    *   **`BC` (local/global mean) vs. `tas` (local):** (Describe. If local-local, expect positive correlations where BC is high, indicating local warming? Compare with SO2 patterns. If global mean BC, what does the pattern look like?).
    *   **`rsdt` (local) vs. `tas` (local):** (Describe. Expect strong positive correlation, as solar radiation is a primary driver of temperature. Are there regional variations in this strength?).
*   **`pr` vs. Forcings:**
    *   **`CO2` (global) vs. `pr` (local):** (Describe. More complex patterns expected. Some regions might show positive correlation (more rain with warming), others negative. Look for large-scale patterns).
    *   **`SO2` (local/global mean) vs. `pr` (local):** (Describe. Aerosols like SO2 can affect cloud formation and precipitation. Are there discernible patterns of correlation, positive or negative, in regions with high SO2?).
    *   **`BC` (local/global mean) vs. `pr` (local):** (Describe. BC can also influence clouds and precipitation. What do the correlation maps suggest?).
    *   **`rsdt` (local) vs. `pr` (local):** (Describe. Regions with high solar radiation might be drier or have specific precipitation regimes (e.g., convective rainfall). What do correlations show?).
*   **Differences across SSPs:**
    *   (Do the correlation patterns change significantly between SSPs? This might indicate that the relationships themselves are state-dependent or that different dominant processes are at play under different forcing scenarios?).