In [2]:
import xarray as xr
import numpy as np

# --- 1. Define File Paths ---
model_data_path = r"C:\Users\benra\lithk_exp05\lithk_GIS_AWI_ISSM1_exp05.nc"
basin_mask_path = r"C:\Users\benra\output_combined_masked_basins.nc"
output_path = r'C:\Users\benra\output_combined_masked_basins.nc'


# --- 2. Load the NetCDF files ---
try:
    model_data = xr.open_dataset(model_data_path)
    basin_mask_data = xr.open_dataset(basin_mask_path)
    print("NetCDF files loaded successfully.")
except FileNotFoundError as e:
    print(f"Error loading file: {e}. Please ensure the paths are correct.")


# --- 3. Create a 2D NaN Mask from Model Data ---
# Select the variable from the model data to use for masking
model_variable = model_data['lithk']

# Create a boolean mask where True indicates a NaN value
nan_mask_raw = np.isnan(model_variable)

# Reduce the mask to 2D by checking for NaNs across any non-spatial dimensions (like time)
spatial_dims_model = [dim for dim in nan_mask_raw.dims if dim in ['x', 'y']]
non_spatial_dims_model = [dim for dim in nan_mask_raw.dims if dim not in spatial_dims_model]

if non_spatial_dims_model:
    # Aggregate across the first non-spatial dimension found
    nan_mask_2d = nan_mask_raw.any(dim=non_spatial_dims_model[0])
    print(f"NaN mask reduced to 2D by aggregating over the '{non_spatial_dims_model[0]}' dimension.")
else:
    # If no extra dimensions, assume it's already 2D
    nan_mask_2d = nan_mask_raw
    print("NaN mask is already 2D.")


# --- 4. Align the NaN Mask with the Basin Mask ---
# Select the basin ID variable
basin_mask_variable = basin_mask_data['IDs']

# Ensure the dimension names of the NaN mask match the basin mask ('x', 'y')
# This is crucial for alignment.
try:
    dim_rename_map = {old_dim: new_dim for old_dim, new_dim in zip(nan_mask_2d.dims, basin_mask_variable.dims)}
    nan_mask_2d_renamed = nan_mask_2d.rename(dim_rename_map)
except ValueError as e:
    print(f"Error renaming dimensions: {e}. Check if spatial dimension names are consistent.")

# Reindex the NaN mask to match the exact coordinates of the basin mask
# 'nearest' method snaps the grid of the mask to the grid of the basin data.
nan_mask_aligned = nan_mask_2d_renamed.reindex_like(basin_mask_variable, method='nearest')

print("NaN mask aligned with basin mask.")


# --- 5. Apply the NaN Mask to the Basin IDs ---
# Where the aligned NaN mask is True, set the basin ID to NaN.
basin_mask_nan_removed = basin_mask_variable.where(~nan_mask_aligned)
print("NaN values from the model have been masked in the basin data.")


# --- 6. Re-number the Basins According to the New Scheme ---
# Create a copy of the masked data to work on
basin_mask_renumbered = basin_mask_nan_removed.copy(deep=True)

# Define the re-numbering scheme as a dictionary for clarity
# key = new basin ID, value = list of old basin IDs
renumbering_map = {
    1: [1, 2, 3],
    2: [4, 5],
    3: [6, 7, 8],
    4: [9, 10, 11, 12, 13],
    5: [14, 15, 16, 17],
    6: [18, 19, 20],
    7: [21, 22, 23, 24, 25] 
}

# Apply the re-numbering using xarray's `isin` and `where` for a clean, vectorized operation
for new_id, old_ids in renumbering_map.items():
    basin_mask_renumbered = xr.where(
        basin_mask_nan_removed.isin(old_ids), 
        new_id,                           
        basin_mask_renumbered              
    )

print("Basin IDs have been re-numbered.")
print("New Basin ID Distribution:")
for new_id, old_ids in renumbering_map.items():
    print(f"  - Basins {old_ids} are now Basin {new_id}")


# --- 7. Update the Dataset and Save to a New NetCDF File ---
# Create a new dataset to save, ensuring metadata (like coordinates) is preserved
output_ds = basin_mask_nan_removed.to_dataset(name='original_masked_IDs') # Keep the NaN-masked original for reference
output_ds['IDs'] = basin_mask_renumbered # Add the new re-numbered variable

# Update attributes for the new variable for clarity
output_ds['IDs'].attrs['long_name'] = 'Combined and Masked Basin IDs'
output_ds['IDs'].attrs['description'] = 'Basin IDs re-numbered and masked based on lithk NaNs.'
output_ds['IDs'].attrs['renumbering_scheme'] = str(renumbering_map)


# Save the final dataset to a new NetCDF file
output_ds.to_netcdf(output_path)

print(f"\n Success! Modified basin mask saved to: {output_path}")

NetCDF files loaded successfully.
NaN mask reduced to 2D by aggregating over the 'time' dimension.
NaN mask aligned with basin mask.
NaN values from the model have been masked in the basin data.
Basin IDs have been re-numbered.
New Basin ID Distribution:
  - Basins [1, 2, 3] are now Basin 1
  - Basins [4, 5] are now Basin 2
  - Basins [6, 7, 8] are now Basin 3
  - Basins [9, 10, 11, 12, 13] are now Basin 4
  - Basins [14, 15, 16, 17] are now Basin 5
  - Basins [18, 19, 20] are now Basin 6
  - Basins [21, 22, 23, 24, 25] are now Basin 7

 Success! Modified basin mask saved to: C:\Users\benra\output_combined_masked_basins.nc
