In [43]:
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import xarray as xr
import matplotlib.colors as mcolors
import pandas as pd
import plotly.io as pio

In [44]:
#get all datasets into lists

#future KTC datasets
future_paths = [ r"D:\2085-99 KTC Data\CESM2-WACCM_complete_KTC_SSP2-4.5_dataset_2085-2099.nc",
               r"D:\2085-99 KTC Data\CNRM-ESM2-1_complete_KTC_SSP2-4.5_dataset_2085-2099.nc",
               r"D:\2085-99 KTC Data\IPSL-CM6A-LR_complete_KTC_SSP2-4.5_dataset_2085-2099.nc",
               r"D:\2085-99 KTC Data\MPI-ESM1-2-LR_complete_KTC_SSP2-4.5_dataset_2085-2099.nc",
               r"D:\2085-99 KTC Data\MPI-ESM1-2-HR_complete_KTC_SSP2-4.5_dataset_2085-2099.nc",
               r"D:\2085-99 KTC Data\UKESM1-0-LL_complete_KTC_SSP2-4.5_dataset_2085-2099.nc" ]

#Historical KTC datasets
historical_paths = [ r"D:\Historical KTC Datasets\CESM2-WACCM_KTC_Historical_dataset",
                   r"D:\Historical KTC Datasets\CNRM-ESM2-1_KTC_Historical_dataset",
                   r"D:\Historical KTC Datasets\IPSL-CM6A-LR_KTC_Historical_dataset",
                   r"D:\Historical KTC Datasets\MPI-ESM1-2-LR_KTC_Historical_dataset",
                   r"D:\Historical KTC Datasets\MPI-ESM1-2-HR_KTC_Historical_dataset",
                   r"D:\Historical KTC Datasets\UKESM1-0-LL_KTC_Historical_dataset" ]

#area datasets
area_paths = [ r"D:\CESM2-WACCM Data\areacella_fx_CESM2-WACCM_G6sulfur_r1i1p1f2_gn.nc",
             r"D:\CNRM Data\areacella_fx_CNRM-ESM2-1_historical_r11i1p1f2_gr.nc",
             r"D:\IPSL Data\areacella_fx_IPSL-CM6A-LR_G6sulfur_r1i1p1f1_gr.nc",
             r"D:\MPI LR Data\areacella_fx_MPI-ESM1-2-LR_ssp245_r11i1p1f1_gn.nc",
             r"D:\MPI HR Data\areacella_fx_MPI-ESM1-2-HR_G6sulfur_r1i1p1f1_gn.nc",
             r"D:\UKESM Data\areacella_fx_UKESM1-0-LL_piControl_r1i1p1f2_gn.nc" ]

#land frac datasets
land_paths = [ r"D:\CESM2-WACCM Data\sftlf_fx_CESM2-WACCM_G6sulfur_r1i1p1f2_gn.nc",
             r"D:\CNRM Data\sftlf_fx_CNRM-ESM2-1_historical_r11i1p1f2_gr.nc",
             r"D:\IPSL Data\sftlf_fx_IPSL-CM6A-LR_G6sulfur_r1i1p1f1_gr.nc",
             r"D:\MPI LR Data\sftlf_fx_MPI-ESM1-2-LR_ssp245_r11i1p1f1_gn.nc",
             r"D:\MPI HR Data\sftlf_fx_MPI-ESM1-2-HR_G6sulfur_r1i1p1f1_gn.nc",
             r"D:\UKESM Data\sftlf_fx_UKESM1-0-LL_piControl_r1i1p1f2_gn.nc" ]

In [45]:
#load all datasets 

model_names = ["CESM2-WACCM", "CNRM-ESM2-1", "IPSL-CM6A-LR", "MPI-ESM1-2-LR", "MPI-ESM1-2-HR", "UKESM1-0-LL"]

datasets = []
for i in range(6):
    future_data = xr.open_dataset(future_paths[i])
    historical_data = xr.open_dataset(historical_paths[i], engine="netcdf4") #force it to open as netcdf file bc it doesn't have .nc backend
    area_data = xr.open_dataset(area_paths[i])
    land_data = xr.open_dataset(land_paths[i])
    
    datasets.append({
        "model": model_names[i],
        "future": future_data,
        "historical": historical_data,
        "area": area_data,
        "landfrac": land_data
    })


variable 'areacella' has multiple fill values {np.float32(1e+20), np.float64(1e+20)} defined, decoding all values to NaN.


variable 'sftlf' has multiple fill values {np.float32(1e+20), np.float64(1e+20)} defined, decoding all values to NaN.



In [46]:
#make sure all datasets are aligned 

land_thresh = {
    'CESM2-WACCM': 50,
    'CNRM-ESM2-1': 45,
    'IPSL-CM6A-LR': 30,
    'MPI-ESM1-2-LR': 35,
    'MPI-ESM1-2-HR': 50,
    'UKESM1-0-LL': 40
}

for i, data in enumerate(datasets):
#get all the data from each type
    model = data['model']
    hist_ds = data['historical']
    future_ds = data['future']
    area_ds = data['area']
    land_ds = data['landfrac']
#specify variable, loop thru thresholds based on model and apply
    area_grid = area_ds['areacella']
    land_grid = land_ds['sftlf']
    threshold = land_thresh.get(model)
    land_mask = land_grid > threshold
    area_land = area_grid.where(land_mask)
#align all data perfectly 
    historical = hist_ds.interp_like(area_land)
    future = future_ds.interp_like(area_land)

    #USE THESE VARIABLES FROM NOW ON 
    data["hist_p"] = historical
    data["future_p"] = future
    data["area_land"] = area_land
#get total land area in each model for calculations later
    data["total_land_area"] = np.nansum(area_land)

In [47]:
#loop through and get the values of each zone in each dataset

zones_f = ['xAr', 'xAw', 'A', 'B', 'C', 'D', 'E', 'F']
zones_h = ['A', 'B', 'C', 'D', 'E', 'F']

historical_zones = []
future_zones = []

#get values (data) of all zones into their correct array (historical or future)
for i in range(6):
    hist_ds = datasets[i]['historical']
    fut_ds = datasets[i]['future']
    
    hist_zones = {z: hist_ds[z].values for z in zones_h}
    historical_zones.append(hist_zones) 

    fut_zones = {z: fut_ds[z].values for z in zones_f}
    xA_values = fut_ds['xAr'].values + fut_ds['xAw'].values
    future_zones.append(fut_zones)

In [48]:
#set up transition matrix -- using same logic as individual sankey code
transition_matrix = defaultdict(list)

for i in range(6): 
    ds_h = datasets[i]['historical']
    ds_f = datasets[i]['future']
    area = datasets[i]['area_land'].values
    
#build model masks
    #Historical zone masks
    hist_z = {z: (ds_h[z] > 0) for z in zones_h}

    #Future novel mask xA = xAr ∪ xAw 
    mask_xA = (ds_f["xAr"] > 0) | (ds_f["xAw"] > 0)

    #Future zone masks, excluding xA
    fut_z = {}
    for z in zones_h:
        fut_z[z] = (ds_f[z] > 0) & (~mask_xA)

    #Add xA to future masks
    fut_z["xA"] = mask_xA
    
    #build mapping for future and historical zones
    hist_map = np.full(area.shape, '', dtype='<U3')
    fut_map = np.full(area.shape, '', dtype='<U3')

    for z in zones_h:
        hist_map[hist_z[z]] = z

    for z in zones_h + ["xA"]:
        fut_map[fut_z[z]] = z

#fill transition matrix with the mapping
    model_total_land = datasets[i]["total_land_area"]

    for from_zone in zones_f:
        from_mask = (hist_map == from_zone).astype(int)

        for to_zone in zones_h + ["xA"]:
            to_mask = (fut_map == to_zone).astype(int)
            joint_mask = from_mask & to_mask

            land_area = np.nansum(area[joint_mask == 1])
            model_total_land = datasets[i]["total_land_area"]
            percent = 100 * (land_area / model_total_land)

            transition_matrix[(from_zone, to_zone)].append(percent)

In [49]:
#average all transitions (and percents) per climate zone per model

avg_transitions = {
    (from_zone, to_zone): np.mean(percent_list)
    for (from_zone, to_zone), percent_list in transition_matrix.items() }
#do not skip 0% percentages, each avg needs to be divided by 6

zones_h = ['A', 'B', 'C', 'D', 'E', 'F']
zones_f_raw = ['A', 'B', 'C', 'D', 'E', 'F', 'xAr', 'xAw']


zones_f_cleaned = [z for z in zones_f_raw if z not in ['xAr', 'xAw']]
if 'xAr' in zones_f_raw or 'xAw' in zones_f_raw:
    zones_f_cleaned = ['xA'] + zones_f_cleaned  # put xA first
zones_f_cleaned

['xA', 'A', 'B', 'C', 'D', 'E', 'F']

In [50]:
#merge avg transitions to get rid of xAr, xAw keys and combine them into xA 
merged_avg_trans = defaultdict(float)
for (from_zone, to_zone), value in avg_transitions.items():
    #map 'xAr' and 'xAw' to 'xA' for both from_zone and to_zone if present
    if from_zone in ['xAr', 'xAw']:
        from_zone = 'xA'
    if to_zone in ['xAr', 'xAw']:
        to_zone = 'xA'
    merged_avg_trans[(from_zone, to_zone)] += value


future_zones_trans = [z for z in zones_f_cleaned if z not in ['xAr', 'xAw']]
if 'xA' not in future_zones_trans:
    future_zones_trans.append('xA')

In [51]:
#compute avg percent change

#compute future percentages
future_pct = defaultdict(float)

for to_zone in future_zones_trans:
    total_to = 0
    for from_zone in zones_h:
        total_to += merged_avg_trans.get((from_zone, to_zone), 0)
    future_pct[to_zone] = total_to

#compute historical percentages
historical_pct = defaultdict(float)

for from_zone in zones_h:
    total_from = 0
    for to_zone in future_zones_trans:
        total_from += merged_avg_trans.get((from_zone, to_zone), 0)
    historical_pct[from_zone] = total_from

#compute net percent change
net_change = {}
for zone in zones_h + ['xA']:
    net_change[zone] = future_pct.get(zone, 0) - historical_pct.get(zone, 0)

net_change

{'A': np.float32(2.2170315),
 'B': np.float32(1.3773594),
 'C': np.float32(-1.0166273),
 'D': np.float32(3.9199028),
 'E': np.float32(-4.1130056),
 'F': np.float32(-2.5635386),
 'xA': np.float32(0.17887549)}

In [52]:
#check total sum

total_sum = sum(merged_avg_trans.values())
print(f"Total sum of merged_avg_trans values: {total_sum:.4f}")

Total sum of merged_avg_trans values: 100.0000


In [53]:
#check total net change sum

total_change = sum(net_change.values())
print(f"Total net_change sum: {total_change:.4f}")

Total net_change sum: -0.0000


In [54]:
y_left = np.linspace(0.10, 0.90, len(zones_h))          # A–F

custom_y_right = {
    'xA': 0.08,
    'A': 0.15,
    'B': 0.27,
    'C': 0.43,
    'D': 0.58,
    'E': 0.75,
    'F': 0.90
}

y_right = [custom_y_right[zone] for zone in zones_f_cleaned]

#y_right = np.linspace(0.10, 0.90, len(zones_f_cleaned)) # xA–F

x_left = [0.3] * len(zones_h)
x_right = [0.6] * len(zones_f_cleaned)

# ✅ Final node positions
x_nodes = x_left + x_right
y_nodes = list(y_left) + list(y_right)

labels = [f"Historical {z}" for z in zones_h] + [f"Future {z}" for z in zones_f_cleaned]

zone_to_index = {label: i for i, label in enumerate(labels)}
zone_to_index

{'Historical A': 0,
 'Historical B': 1,
 'Historical C': 2,
 'Historical D': 3,
 'Historical E': 4,
 'Historical F': 5,
 'Future xA': 6,
 'Future A': 7,
 'Future B': 8,
 'Future C': 9,
 'Future D': 10,
 'Future E': 11,
 'Future F': 12}

In [55]:
#create text additions in diagram
zone_names_left = zones_h.copy()
zone_names_right = ["xA"] + zones_h.copy()

#define zone order on both sides
left_zones = ['A', 'B', 'C', 'D', 'E', 'F']
right_zones = ['xA', 'A', 'B', 'C', 'D', 'E', 'F']

In [56]:
#build sankey lists
sources = []
targets = []
values = []

for (from_zone, to_zone), val in merged_avg_trans.items():
    if from_zone in ['xAr', 'xAw']:
        from_zone = 'xA'
    if to_zone in ['xAr', 'xAw']:
        to_zone = 'xA'

    from_label = f"Historical {from_zone}"
    to_label   = f"Future {to_zone}"

    if from_label in zone_to_index and to_label in zone_to_index:
        sources.append(zone_to_index[from_label])
        targets.append(zone_to_index[to_label])
        values.append(val)

In [57]:
#create colors of diagram links/nodes
#do the color conversion from color name to hex code
def to_rgba_str(hex_color: str, alpha: float = 0.3) -> str:
    """
    Convert a 6-digit hex color (e.g. "#fc8d59") into an 'rgba(r,g,b,a)' string
    with the given alpha for Plotly.
    """
    r, g, b, _ = mcolors.to_rgba(hex_color)  
    return f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"

#define labels
labels = [f"Historical {z}" for z in zones_h] + [f"Future {z}" for z in zones_f_cleaned]

#define colors
base_colors = {
    "xA": "#191970",  # Midnight Blue
    "A":  "#0000ff",  # Blue
    "B":  "#ff0000",  # Red
    "C":  "#ffff00",  # Yellow
    "D":  "#20b2aa",  # Light Sea Green
    "E":  "#ff00ff",  # Magenta
    "F":  "#4b0082",  # Indigo
}

#set node and link colors
node_colors = [base_colors[z] for z in zone_names_left + zone_names_right]
link_colors = [
    to_rgba_str(base_colors[ labels[src].split()[1] ], alpha=0.3) #add transparency
    for src in sources
]

In [58]:
#define sankey fig

fig = go.Figure(go.Sankey(
    arrangement="fixed",
    node = dict(
        x         = x_nodes,
        y         = y_nodes,
        color     = node_colors,
        pad       = 45,
        thickness = 30,
    ),
    link = dict(
        source = sources,
        target = targets,
        value  = values,
        color  = link_colors,
    )
))

In [62]:
#actually create sankey and format its annotations/labels, etc

pct_pos = { 
    'xA': (0.65, 0.93),
    'A' : (0.65, 0.85),
    'B' : (0.65, 0.73),
    'C' : (0.65, 0.57),
    'D' : (0.65, 0.42),
    'E' : (0.65, 0.25),
    'F' : (0.65, 0.10),
}

for zone, (xpos, ypos) in pct_pos.items():
    val = net_change.get(zone, 0.0)
    fig.add_annotation(
        x         = xpos,
        y         = ypos,
        xref      = 'paper',
        yref      = 'paper',
        text      = f"{val:+.2f}%",
        showarrow = False,
        xanchor   = 'left',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14) )

name_pos = { #zone names on left
  "A": (0.28, 0.90),
  "B": (0.28, 0.74),
  "C": (0.28, 0.58),
  "D": (0.28, 0.42),
  "E": (0.28, 0.26),
  "F": (0.28, 0.10),
}

for zone,(xpos,ypos) in name_pos.items():
    display = f"<b>{zone}</b>"  if zone in ['A','B','C','D','E','F'] \
              else zone
    fig.add_annotation(
        x         = xpos,
        y         = ypos,
        xref      = 'paper',  
        yref      = 'paper',
        text      = display,
        showarrow = False,
        xanchor   = 'right',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14, color='black') )

name_pos_right = {
  'xA': (0.62, 0.93),
  'A' : (0.62, 0.85),
  'B' : (0.62, 0.73),
  'C' : (0.62, 0.57),
  'D' : (0.62, 0.42),
  'E' : (0.62, 0.25),
  'F' : (0.62, 0.10),
}

for zone,(xpos,ypos) in name_pos_right.items():
    display = f"<b>{zone}</b>"  if zone in ['xA', 'A','B','C','D','E','F'] \
              else zone
    fig.add_annotation(
        x         = xpos,  
        y         = ypos,
        xref      = 'paper',  
        yref      = 'paper',
        text      = display,
        showarrow = False,
        xanchor   = 'left',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14, color='black') )

fig.update_layout(
  title={
    "text":      "Model Average Climate Zone Transitions: SSP2-4.5",
    "x":         0.47,         # 0 = left, 0.5 = center, 1 = right
    "y":         0.92,        # 0 = bottom, 1 = top
    "xanchor":   "center",    # align the title’s x according to x
    "yanchor":   "top",       # align the title’s y according to y
    "font":      {"family":"Times New Roman","size":20,"color":"black"},
  },
  margin={"t":60} )

fig.add_annotation(
    x=0.30, y=0.001,          #paper‐coords: 30% across, 0.001 up
    xref='paper', yref='paper',
    text='1985-2014',
    showarrow=False,
    font=dict(family='Times New Roman', size=20, color='black'),
    xanchor='center',
    yanchor='bottom' )

fig.add_annotation(
    x=0.60, y=0.001,          #60% across, 0.001 up
    xref='paper', yref='paper',
    text='2085-2099',
    showarrow=False,
    font=dict(family='Times New Roman', size=20, color='black'),
    xanchor='center',
    yanchor='bottom' )

fig.update_layout(width=1200, height=600)

pio.renderers.default = 'notebook_connected' 
fig.show()

In [63]:
#fig.write_html('/Users/jaybr/OneDrive/Desktop/RESEARCH/Model_Avg_sankey_SSP2-4.5.html')

In [61]:
for i, data in enumerate(datasets):
    print(f"{data['model']}: {data['total_land_area']:.2e} m²")


CESM2-WACCM: 1.49e+14 m²
CNRM-ESM2-1: 1.49e+14 m²
IPSL-CM6A-LR: 1.54e+14 m²
MPI-ESM1-2-LR: 1.47e+14 m²
MPI-ESM1-2-HR: 1.47e+14 m²
UKESM1-0-LL: 1.50e+14 m²
