In [None]:
import xarray as xr
import numpy as np

# --- 1. Load the NetCDF files ---
try:
    model_data = xr.open_dataset(r"C:\Users\benra\lithk_exp05\lithk_GIS_AWI_ISSM1_exp05.nc")
    basin_mask_data = xr.open_dataset(r"C:\Users\benra\ISMIP6_Extensions_05000m.nc")
except FileNotFoundError as e:
    print(f"Error loading file: {e}. Please ensure the paths are correct.")
    exit()

# --- 2. Identify NaNs from the model data and create a 2D mask ---
model_variable = model_data['lithk']

# Calculate initial nan_mask.
nan_mask_raw = np.isnan(model_variable)

# Reduce the nan_mask to 2D by checking if any time slice is NaN.
spatial_dims_model = [dim for dim in nan_mask_raw.dims if dim in ['x', 'y', 'lon', 'lat']]
non_spatial_dims_model = [dim for dim in nan_mask_raw.dims if dim not in spatial_dims_model]

if non_spatial_dims_model:
    # Aggregate across the first identified non-spatial dimension
    nan_mask_2d = nan_mask_raw.any(dim=non_spatial_dims_model[0])
else:
    # If no extra dimension was found, assume it's already 2D
    nan_mask_2d = nan_mask_raw

# --- 3. Prepare basin mask variable and align dimensions ---
basin_mask_variable = basin_mask_data['IDs']

# Ensure nan_mask_2d has the exact same dimension names as basin_mask_variable
# Get the dimension names of your basin mask variable
target_dims = basin_mask_variable.dims

# Create a mapping from nan_mask_2d's current dimension names to the target names
if len(nan_mask_2d.dims) == len(target_dims):
    dim_rename_map = {old_dim: new_dim for old_dim, new_dim in zip(nan_mask_2d.dims, target_dims)}
    nan_mask_2d = nan_mask_2d.rename(dim_rename_map)
else:
    print("Error: Dimension count mismatch after reducing NaN mask to 2D. Cannot align.")
    exit()

# Reindex nan_mask_2d to match the coordinates of basin_mask_variable.
nan_mask_aligned = nan_mask_2d.reindex_like(basin_mask_variable, method='nearest', tolerance=1e-6)

# --- 4. Apply the NaN mask ---
# Where nan_mask_aligned is True set the corresponding basin_mask_variable values to NaN.
basin_mask_variable_masked = basin_mask_variable.where(~nan_mask_aligned)

# --- 5. Update the original Dataset with the masked variable ---
basin_mask_data[basin_mask_variable.name] = basin_mask_variable_masked

# --- 6. Save the modified basin mask to a new NetCDF file ---
output_path = r'C:\Users\benra\output_masked_basin_mask.nc'
basin_mask_data.to_netcdf(output_path)
print(f"Masked basin mask saved to: {output_path}")