# This is a notebook for Fig. 1 in Albright et al. in prep.

In [1]:
import warnings

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmaps as gvcmaps
import geocat.viz.util as gvutil
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
from cartopy.util import add_cyclic_point
from matplotlib.colors import BoundaryNorm, ListedColormap, TwoSlopeNorm

warnings.filterwarnings("ignore")

## Read in files:

In [2]:
path = "/glade/u/home/malbright/nam_manuscript_figures/climatology_files_remapped"
topo_path = "/glade/derecho/scratch/malbright/FROM_CHEYENNE/remap/"
################################################
### Preindustrial:
## JJAS precip and surface wind and elevation
################################################

#### low res
u_interp = xr.open_dataset(
    f"{path}/U/b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.U.023906-028809.JJAS.interp.avg.remap.nc"
).U
v_interp = xr.open_dataset(
    f"{path}/V/b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.V.023906-028809.JJAS.interp.avg.remap.nc"
).V
prec = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.PRECT.023906-028809.JJAS.avg.remap.nc"
).PRECT
topo = xr.open_dataset(
    topo_path + "b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.ELEVATION.nc"
).elevation

# Read in grid information
lat_uv = u_interp["lat"]
lon_uv = u_interp["lon"]

#### high res
h_u_interp = xr.open_dataset(
    f"{path}/U/b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.U.007006-009909.JJAS.interp.avg.remap.nc"
).U
h_v_interp = xr.open_dataset(
    f"{path}/V/b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.V.007006-009909.JJAS.interp.avg.remap.nc"
).V
h_prec = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.PRECT.007006-009909.JJAS.avg.remap.nc"
).PRECT
h_topo = xr.open_dataset(
    topo_path + "b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.ELEVATION.nc"
).elevation

# Read in grid information
h_lat_uv = h_u_interp["lat"]
h_lon_uv = h_u_interp["lon"]

################################################
### Plio - PI diff:
## JJAS precip and 850 hpa winds diff
################################################
diff_prec = xr.open_dataset(
    f"{path}/PRECT/ne30_PRECT.JJ.plio-minus-pi.avg.remap.nc"
).PRECT
h_diff_prec = xr.open_dataset(
    f"{path}/PRECT/ne120_PRECT.JJ.plio-minus-pi.avg.remap.nc"
).PRECT

diff_u = xr.open_dataset(f"{path}/U/ne30_U.JJ.interp.plio-minus-pi.avg.remap.nc").U
diff_v = xr.open_dataset(f"{path}/V/ne30_V.JJ.interp.plio-minus-pi.avg.remap.nc").V

h_diff_u = xr.open_dataset(f"{path}/U/ne120_U.JJ.interp.plio-minus-pi.avg.remap.nc").U
h_diff_v = xr.open_dataset(f"{path}/V/ne120_V.JJ.interp.plio-minus-pi.avg.remap.nc").V

################################################
### Monsoon extent
################################################
monsoon = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne30_g16.pi.001.cam.h0.PRECT.023906-028809.monsoon_extent_pct_JJAS.avg.remap.nc"
).pct_JJAS
monsoon_plio = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne30_g16.plio.001.cam.h0.PRECT.037606-042509.monsoon_extent_pct_JJAS.avg.remap.nc"
).pct_JJAS

h_monsoon = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne120_g16.tuning.005.cam.h0.PRECT.007006-009909.monsoon_extent_pct_JJAS.avg.remap.nc"
).pct_JJAS
h_monsoon_plio = xr.open_dataset(
    f"{path}/PRECT/b.e13.B1850C5CN.ne120_g16.pliohiRes.002.cam.h0.PRECT.003006-005909.monsoon_extent_pct_JJAS.avg.remap.nc"
).pct_JJAS

################################################
### Significance
################################################
SIG_BASE = "/glade/work/malbright/final_nam_manuscript_files/significance/"

sig_lr_JJ = xr.open_dataset(f"{SIG_BASE}/prect_plio_minus_pi_LR_ttest_JJ.nc")[
    "sig_mask_JJ"
]  # 1 = sig, 0 = not sig

sig_hr_JJ = xr.open_dataset(f"{SIG_BASE}/prect_plio_minus_pi_HR_ttest_JJ.nc")[
    "sig_mask_JJ"
]

## River Channels:

In [3]:
rwid = xr.open_dataset(
    "/glade/campaign/cesm/cesmdata/cseg/inputdata/rof/mosart/MOSART_Global_8th_20191007.nc"
).rwid

In [4]:
# mask regions:
lat_mask = rwid.lat < 20
lon_mask = rwid.lon > -97
special_region_mask = lat_mask & lon_mask

# Apply the special condition: within the special region, set values below 80 to 0, keep other values unchanged
rwid = xr.where(special_region_mask, xr.where(rwid < 80, 0, rwid), rwid)

# Now apply the general condition: values above 40 to 1, below 40 to 0
rwid = xr.where(rwid > 40, 1, 0)

In [5]:
rwid = rwid.where(rwid.lon < -100)

## Helper Functions

In [None]:
def _monsoon_threshold(da):
    """Return a 60% threshold compatible with percent or fraction inputs."""
    try:
        mx = float(da.max().values)  # robust for xarray
    except Exception:
        mx = float(da.max())
    return 60.0 if mx > 1.5 else 0.60


def add_monsoon_contours(
    ax,
    da_plio,
    *,
    color="#D62728",  # crimson (Tableau red)
    lw=1.0,
):
    """
    Draws PI (solid) and Plio (dashed) monsoon extents (>=60%) on `ax`.
    Expects DataArray with lat/lon coords; works with PlateCarree axes.
    """
    thr = _monsoon_threshold(da_plio)

    cs_pl = da_plio.plot.contour(
        ax=ax,
        levels=[thr],
        colors=[color],
        linewidths=lw,
        linestyles="solid",  # <-- change this line
        transform=ccrs.PlateCarree(),
        add_colorbar=False,
        zorder=6,
    )

    return cs_pl


def fix_lon(ds):
    """Convert lon from 0–360 to -180–180."""
    ds = ds.copy()
    if ds.lon.max() > 180:
        ds = ds.assign_coords(lon=((ds.lon + 180) % 360) - 180)
        ds = ds.sortby("lon")
    return ds

def add_not_sig_shading(
    ax,
    sig_da,
    xlim=(-115, -90),
    ylim=(10, 35),
    shade_color="0.1",
    alpha=0.2,
    zorder=5,
):
    """
    Shade NOT-significant (sig==0) regions using contourf on a binary mask.
    """
    sig_da = fix_lon(sig_da)

    # 1 = not significant, 0 = significant
    notsig = xr.where(sig_da == 0, 1.0, 0.0)

    ax.contourf(
        notsig.lon.values,
        notsig.lat.values,
        notsig.values,
        levels=[0.5, 1.5],          # fill only the "1" class
        colors=[shade_color],
        alpha=alpha,
        transform=ccrs.PlateCarree(),
        zorder=zorder,
    )

## Plot:

In [None]:
# Specify projection for maps
proj = ccrs.PlateCarree()

# Generate figure (set its size (width, height) in inches)
fig = plt.figure(figsize=(9, 12))
grid = fig.add_gridspec(
    ncols=2, nrows=3, width_ratios=[0.85, 1.061], wspace=0.15, hspace=0.05
)

# Create axis for topo low res
ax1 = fig.add_subplot(grid[0, 0], projection=ccrs.PlateCarree())
ax2 = fig.add_subplot(grid[0, 1], projection=ccrs.PlateCarree())
ax4 = fig.add_subplot(grid[1, 0], projection=ccrs.PlateCarree())
ax5 = fig.add_subplot(grid[1, 1], projection=ccrs.PlateCarree())
ax7 = fig.add_subplot(grid[2, 0], projection=ccrs.PlateCarree())
ax8 = fig.add_subplot(grid[2, 1], projection=ccrs.PlateCarree())

for ax in [ax1, ax2, ax4, ax5, ax7, ax8]:
    ax.add_feature(cfeature.COASTLINE)
    if ax in [ax4, ax5, ax7, ax8]:
        ax.add_feature(cfeature.BORDERS)

    # Format ticks and ticklabels for the map axes
    gvutil.set_axes_limits_and_ticks(
        ax,
        xlim=(-115, -90),
        ylim=(10, 35),
        xticks=np.linspace(-130, -90, 13),
        yticks=np.linspace(15, 50, 7),
    )

ax1.xaxis.set_tick_params(labelleft=False)
ax1.yaxis.set_major_formatter(LatitudeFormatter())

ax2.xaxis.set_tick_params(labelleft=False)
ax2.yaxis.set_tick_params(labelleft=False)

ax4.xaxis.set_tick_params(labelleft=False)
ax4.yaxis.set_major_formatter(LatitudeFormatter())

ax5.xaxis.set_tick_params(labelleft=False)
ax5.yaxis.set_tick_params(labelleft=False)

ax7.xaxis.set_major_formatter(LongitudeFormatter())
ax7.yaxis.set_major_formatter(LatitudeFormatter())

ax8.xaxis.set_major_formatter(LongitudeFormatter())
ax8.yaxis.set_tick_params(labelleft=False)

###################################################
# TOPO PLOTS
###################################################

# Import the default color map
newcmp = gvcmaps.topo_15lev

# Define contour levels
levels = np.arange(50, 2550, 250)

# Define dictionary for kwargs
kwargs = dict(
    levels=levels,
    xticks=np.arange(-115, -90, 10),  # nice x ticks
    yticks=np.arange(10, 35, 10),  # nice y ticks
    add_colorbar=False,  # allow for colorbar specification later
    transform=ccrs.PlateCarree(),  # ds projection
)

# Contouf-plot U data (for filled contours)
fillplot = topo.plot(ax=ax1, cmap=newcmp, **kwargs)

# Use geocat.viz.util convenience function to add titles to left and right of the plot axis.
gvutil.set_titles_and_labels(
    ax1,
    xlabel="",
    ylabel="",
)
ax1.set_title("Low Resolution", fontsize=12, loc="center")


# Contouf-plot U data (for filled contours)
fillplot_h = h_topo.plot(ax=ax2, cmap=newcmp, **kwargs)

# Use geocat.viz.util convenience function to add titles to left and right of the plot axis.
gvutil.set_titles_and_labels(
    ax2,
    xlabel="",
    ylabel="",
)

cbar_topo = fig.colorbar(fillplot_h, shrink=0.9)
cbar_topo.ax.set_ylabel("m")

ax2.set_title("High Resolution", fontsize=12, loc="center")

###################################################
# PRECIP PLOTS
###################################################

# Draw vector plot
Q = ax4.quiver(
    lon_uv[::2],  # longitude values for u and v
    lat_uv[::2],  # latitude values for u and v
    u_interp.sel(plev=1000.0)[::2, ::2],  # u values
    v_interp.sel(plev=1000.0)[::2, ::2],  # v values
    color="orange",
    pivot="middle",
    width=0.004,
    scale=30,
    headwidth=5,
    zorder=2,
)

# Define levels for contour map
levels_precip = np.arange(0, 14, 1)

kwargs = dict(
    levels=levels_precip,
    xticks=np.arange(-115, -90, 10),  # nice x ticks
    yticks=np.arange(10, 35, 10),  # nice y ticks
    add_colorbar=False,  # allow for colorbar specification later
    transform=ccrs.PlateCarree(),  # ds projection
)

cf = prec.plot.contourf(ax=ax4, cmap=gvcmaps.precip4_11lev, **kwargs)

# add_monsoon_contours(ax4, monsoon, monsoon_plio)
add_monsoon_contours(ax4, monsoon, color="deeppink")

gvutil.set_titles_and_labels(ax4, xlabel="", ylabel="")

# Draw vector plot
Q_h = ax5.quiver(
    h_lon_uv[::6],  # longitude values for u and v
    h_lat_uv[::6],  # latitude values for u and v
    h_u_interp.sel(plev=1000.0)[::6, ::6],  # u values
    h_v_interp.sel(plev=1000.0)[::6, ::6],  # v values
    color="orange",
    pivot="middle",
    width=0.004,
    scale=30,
    headwidth=5,
    zorder=2,
)

cf_h = h_prec.plot.contourf(ax=ax5, cmap=gvcmaps.precip4_11lev, **kwargs)

# add_monsoon_contours(ax5, h_monsoon, h_monsoon_plio)
add_monsoon_contours(ax5, h_monsoon, color="deeppink")

cbar_precip = fig.colorbar(cf_h, shrink=0.9)
cbar_precip.ax.set_ylabel("mm day$^{-1}$")

ax5.quiverkey(
    Q_h,
    # 0.9675,
    0.745,
    0.622,
    3,
    "$3$ m s$^{-1}$",
    labelpos="E",
    color="orange",
    coordinates="figure",
    fontproperties={"size": 10},
    labelsep=0.1,
)

gvutil.set_titles_and_labels(ax5, xlabel="", ylabel="")

###################################################
# JJ DIFF PLOTS
###################################################

# LOW RESOLUTION
hPa_level = 850.0
# plot velocity field
uvel, lonu = add_cyclic_point(diff_u.sel(plev=hPa_level), coord=diff_u.lon)
vvel, lonv = add_cyclic_point(diff_v.sel(plev=hPa_level), coord=diff_v.lon)

lonu = np.where(lonu >= 180.0, lonu - 360.0, lonu)

sp = ax7.streamplot(
    lonu,
    diff_u.lat,
    uvel,
    vvel,
    linewidth=0.8,
    arrowsize=0.8,
    density=1,
    color="blueviolet",
    transform=ccrs.PlateCarree(),
)

newcmp = gvcmaps.precip_diff_12lev
# newcmp = mcolors.LinearSegmentedColormap.from_list(
#     "precip_diff_custom", cmap(np.linspace(0, 1, 12)), N=20
# )

# Define levels for contour map
levels_diff = np.arange(-2.4, 2.5, 0.2)  # shared for BOTH panels
diff_kwargs = dict(
    levels=levels_diff,
    xticks=np.arange(-115, -90, 10),
    yticks=np.arange(10, 35, 10),
    cmap=newcmp,
    extend="both",
    norm=TwoSlopeNorm(vmin=-2.4, vcenter=0, vmax=2.4),
    add_colorbar=False,
    transform=ccrs.PlateCarree(),
)

cf = diff_prec.plot.contourf(
    ax=ax7,
    # cmap=gvcmaps.precip_diff_12lev,
    **diff_kwargs
)

add_not_sig_shading(ax7, sig_lr_JJ, xlim=(-115, -90), ylim=(10, 35), alpha=0.3, zorder=20)

# RIVER CHANNELS:

# Mask where the river width is greater than 0
river_mask = rwid.where(rwid > 0)

# Define a custom colormap: 0 is transparent, 1 is blue
rwid_cmap = ListedColormap(["none", "blue"])

# Define normalization to map binary data (0 and 1) to the custom colormap
norm = BoundaryNorm([0, 0.5, 1], rwid_cmap.N)

river_plot_low = ax7.pcolormesh(
    rwid.lon,
    rwid.lat,
    river_mask,
    cmap=rwid_cmap,
    norm=norm,
    transform=ccrs.PlateCarree(),
    alpha=1.0,
)

ax7.plot(
    -109.05,
    23.05,
    marker="D",
    markerfacecolor="red",
    markeredgecolor="black",
    markersize=8,
    transform=ccrs.PlateCarree(),
)

gvutil.set_titles_and_labels(ax7, xlabel="", ylabel="")

# HIGH RESOLUTION

# plot velocity field
h_uvel, h_lonu = add_cyclic_point(h_diff_u.sel(plev=hPa_level), coord=h_diff_u.lon)
h_vvel, h_lonv = add_cyclic_point(h_diff_v.sel(plev=hPa_level), coord=h_diff_v.lon)

h_lonu = np.where(h_lonu >= 180.0, h_lonu - 360.0, h_lonu)

sp_h = ax8.streamplot(
    h_lonu,
    h_diff_u.lat,
    h_uvel,
    h_vvel,
    linewidth=0.8,
    arrowsize=0.8,
    density=1,
    color="blueviolet",
    transform=ccrs.PlateCarree(),
)

cf_h = h_diff_prec.plot.contourf(
    ax=ax8,
    # cmap=gvcmaps.precip_diff_12lev,
    # extend='both',
    **diff_kwargs
)

add_not_sig_shading(ax8, sig_hr_JJ, xlim=(-115, -90), ylim=(10, 35), alpha=0.3, zorder=20)

# RIVER CHANNELS:
river_plot_high = ax8.pcolormesh(
    rwid.lon,
    rwid.lat,
    river_mask,
    cmap=rwid_cmap,
    norm=norm,
    transform=ccrs.PlateCarree(),
    alpha=1.0,
)

ax8.plot(
    -109.05,
    23.05,
    marker="D",
    markerfacecolor="red",
    markeredgecolor="black",
    markersize=8,
    transform=ccrs.PlateCarree(),
)

# add color bar
cbar_precip_diff = fig.colorbar(cf_h, shrink=0.9)
cbar_precip_diff.ax.set_ylabel("mm day$^{-1}$")

gvutil.set_titles_and_labels(ax8, xlabel="", ylabel="")

text_kwargs = dict(ha="center", va="center", fontsize=14, fontweight="bold")
plt.text(-147, 90, "a", **text_kwargs)
plt.text(-116.5, 90, "b", **text_kwargs)
plt.text(-147, 62.5, "c", **text_kwargs)
plt.text(-116.5, 62.5, "d", **text_kwargs)
plt.text(-147, 35, "e", **text_kwargs)
plt.text(-116.5, 35, "f", **text_kwargs)

topo_text_kwargs = dict(
    ha="center",
    va="center",
    fontsize=5.5,
    color="black",
    rotation=-48,
    fontweight="bold",
)

# Show the plot
# plt.show()
plt.savefig("figures/final_new_Figure1_rivers_no_labels_monsoon_significance.pdf", dpi=300, bbox_inches="tight")
# plt.savefig("figures/final_new_Figure1_rivers_no_labels_monsoon_significance.png", dpi=150, bbox_inches="tight")