In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import mapping
import numpy as np
import rasterio.features
from affine import Affine
import matplotlib.patheffects as PathEffects
from scipy.interpolate import griddata

# -------------------------------
# Load Data
# -------------------------------
ds = xr.open_dataset("slr_delta.nc")
slr = ds.sel(ssp='ssp585', year_bins='2071-2100')['slr_p50']

ns_land = gpd.read_file("C:/Users/jclabarcena/Documents/EN2902/NovaScotia_province_boundary/Nova_Scotia_NAD83.shp").to_crs("EPSG:4326")
divs = gpd.read_file("C:/Users/jclabarcena/Documents/EN2902/NovaScotia_province_boundary/cd_boundaries_2021.shp").to_crs("EPSG:4326")

# -------------------------------
# Interpolate to a Finer Grid
# -------------------------------
lat_vals = slr['lat'].values
lon_vals = slr['lon'].values
lon_grid, lat_grid = np.meshgrid(lon_vals, lat_vals)

# Extract valid ocean-only values (after initial land mask)
nlat, nlon = slr.shape
res_lat = (lat_vals[-1] - lat_vals[0]) / (nlat - 1)
res_lon = (lon_vals[-1] - lon_vals[0]) / (nlon - 1)
transform = Affine.translation(lon_vals[0] - res_lon / 2, lat_vals[0] - res_lat / 2) * Affine.scale(res_lon, res_lat)

# Initial land mask (on coarse grid)
initial_mask = rasterio.features.geometry_mask(
    [mapping(geom) for geom in ns_land.geometry],
    out_shape=(nlat, nlon),
    transform=transform,
    invert=False
)
slr_masked = slr.where(~xr.DataArray(initial_mask, dims=("lat", "lon")))

# Prepare finer grid (1000x1000 approx)
fine_lon = np.linspace(lon_vals.min(), lon_vals.max(), 1000)
fine_lat = np.linspace(lat_vals.min(), lat_vals.max(), 1000)
fine_lon2d, fine_lat2d = np.meshgrid(fine_lon, fine_lat)

# Flatten and interpolate using only valid (ocean) data
valid = ~np.isnan(slr_masked.values)
points = np.column_stack((lon_grid[valid], lat_grid[valid]))
values = slr_masked.values[valid]

# Linear interpolation + nearest fill for gaps
fine_interp = griddata(points, values, (fine_lon2d, fine_lat2d), method='linear')
fill_vals = griddata(points, values, (fine_lon2d, fine_lat2d), method='nearest')
fine_interp[np.isnan(fine_interp)] = fill_vals[np.isnan(fine_interp)]

# -------------------------------
# Re-mask land on finer grid
# -------------------------------
fine_transform = Affine.translation(fine_lon[0], fine_lat[0]) * Affine.scale(
    (fine_lon[1] - fine_lon[0]), (fine_lat[1] - fine_lat[0])
)

fine_mask = rasterio.features.geometry_mask(
    [mapping(geom) for geom in ns_land.buffer(0.0005).geometry],
    out_shape=(len(fine_lat), len(fine_lon)),
    transform=fine_transform,
    invert=True
)

# Apply final mask: set land pixels to NaN
final_vals = np.where(fine_mask, np.nan, fine_interp)

# Convert to xarray
slr_interp = xr.DataArray(
    final_vals,
    coords={"lat": fine_lat, "lon": fine_lon},
    dims=["lat", "lon"]
)

# -------------------------------
# Plotting
# -------------------------------
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={'projection': ccrs.PlateCarree()})

# Global land background
ax.add_feature(cfeature.LAND, facecolor='lightgrey', zorder=0)

# Fill Nova Scotia land explicitly (in case of coastline mismatch)
ns_land.plot(ax=ax, facecolor='lightgrey', edgecolor='none', zorder=1)

# Plot SLR (ocean only, masked)
slr_interp.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    cmap='RdYlGn_r',
    vmin=88, vmax=112,
    cbar_kwargs={'label': 'Sea Level Rise (cm)', 'shrink': 0.5},
    zorder=2
)

# NS division boundaries
divs.boundary.plot(ax=ax, edgecolor='dimgray', linewidth=0.6, zorder=3, facecolor='none')

# Division labels with white halo
for _, row in divs.iterrows():
    if row.geometry.is_empty:
        continue
    centroid = row.geometry.centroid
    ax.text(
        centroid.x, centroid.y, row["CDNAME"],
        fontsize=8, ha='center', va='center', color='black', zorder=4,
        path_effects=[PathEffects.withStroke(linewidth=2.5, foreground="white")]
    )

# Coastlines
ax.coastlines(resolution='10m', color='black', linewidth=0.6, zorder=5)

# Final layout
ax.set_title("Projected Delta Sea Level Rise (cm) by 2100 - SSP585", fontsize=14)
ax.set_extent([-67.5, -58.5, 42.5, 47.8])
ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.4, linestyle='--')

plt.tight_layout()
plt.savefig("NS_SLR_SSP585_Final_Clean_Interpolated.png", dpi=300)
plt.show()
