In [860]:
import xarray as xr
import numpy as np
import plotly.graph_objects as go
import matplotlib.colors as mcolors
import pandas as pd
import plotly.io as pio

In [861]:
ds20_path = r"D:\Historical KTC Datasets\IPSL-CM6A-LR_KTC_Historical_dataset"
G6_20 = xr.open_dataset(ds20_path)
ds90_path = r"D:\2085-99 KTC Data\IPSL-CM6A-LR_complete_KTC_SSP5-8.5_dataset_2085-2099.nc"
G6_90 = xr.open_dataset(ds90_path)
# replace this with however you have your area grid loaded:
area_path = r"D:\IPSL Data\areacella_fx_IPSL-CM6A-LR_G6sulfur_r1i1p1f1_gr.nc"
area_ds = xr.open_dataset(area_path)
path_land = r"D:\IPSL Data\sftlf_fx_IPSL-CM6A-LR_G6sulfur_r1i1p1f1_gr.nc"
land_frac = xr.open_dataset(path_land)
area = area_ds['areacella']


# If land_frac has a 'bnds' dimension of any length
if 'bnds' in land_frac.dims:
    land_frac = land_frac.isel(bnds=0)  #set bnds dim to 0

**don't forget to change the land fraction value!!

In [862]:
area_dataset = area.where(land_frac['sftlf'] > 30)

In [863]:
A20 = G6_20['A'].values
B20 = G6_20['B'].values
C20 = G6_20['C'].values
D20 = G6_20['D'].values
E20 = G6_20['E'].values
F20 = G6_20['F'].values

A90 = G6_90['A'].values
B90 = G6_90['B'].values
C90 = G6_90['C'].values
D90 = G6_90['D'].values
E90 = G6_90['E'].values
F90 = G6_90['F'].values
xA = G6_90['xAr'].values + G6_90['xAw'].values

In [864]:
B_20 = (B20*area_dataset).sum()/area_dataset.sum()*100
A_20 = (A20*area_dataset).sum()/area_dataset.sum()*100
C_20 = (C20*area_dataset).sum()/area_dataset.sum()*100
D_20 = (D20*area_dataset).sum()/area_dataset.sum()*100
E_20 = (E20*area_dataset).sum()/area_dataset.sum()*100
F_20 = (F20*area_dataset).sum()/area_dataset.sum()*100

B_90 = (B90*area_dataset).sum()/area_dataset.sum()*100
A_90 = (A90*area_dataset).sum()/area_dataset.sum()*100
C_90 = (C90*area_dataset).sum()/area_dataset.sum()*100
D_90 = (D90*area_dataset).sum()/area_dataset.sum()*100
E_90 = (E90*area_dataset).sum()/area_dataset.sum()*100
F_90 = (F90*area_dataset).sum()/area_dataset.sum()*100
xA = (xA*area_dataset).sum()/area_dataset.sum()*100
xAr = (G6_90['xAr']*area_dataset).sum()/area_dataset.sum()*100
xAw = (G6_90['xAw']*area_dataset).sum()/area_dataset.sum()*100

In [865]:
A = A_90 - A_20
B = B_90 - B_20
C = C_90 - C_20
D = D_90 - D_20
E = E_90 - E_20
F = F_90 - F_20

In [866]:
ds20_1 = xr.open_dataset(ds20_path)
ds90_1 = xr.open_dataset(ds90_path)
area_ds  = xr.open_dataset(area_path)
land_ds  = xr.open_dataset(path_land)

area_grid = area_ds["areacella"]         # (lat,lon) in m²
land_mask = (land_ds["sftlf"] > 30)       # percent_land > 50%
area_land = area_grid.where(land_mask)    # mask ocean

ds20 = ds20_1.interp_like(area_land)
ds90 = ds90_1.interp_like(area_land) #added because of MPI german models 

In [867]:
#-----------------------------------------------------------------

# 1) List your core Koppen classes
core_zones = ["A","B","C","D","E","F"]

# 2) Build 2020s masks from the one-hot fields (catch any >0)
zone_masks20 = {
    z: (ds20[z] > 0)
    for z in core_zones
}

# 3) Build the two novel 2090s masks first (catch any >0)
mask_xA = (ds90["xAr"] > 0) | (ds90["xAw"] > 0)

# 4) Now build core-class 2090s masks, *excluding* any novel cells
zone_masks90 = {}
for z in core_zones:
    zone_masks90[z] = (ds90[z] > 0) & (~mask_xA)

# 5) Finally assign the grouped novel masks
zone_masks90["xA"] = mask_xA

In [868]:
# 1) define the core zones
core_zones   = ["A","B","C","D","E","F"]

# 2) define your left-side (2020s) names
zone_names20 = core_zones.copy()
#   → ["A","B","C","D","E","F"]

# 3) define your right-side (2090s) names, novel ones first
zone_names90 = ["xA"] + core_zones
#   → ["xA","xB","A","B","C","D","E","F"]

In [869]:
# ─── initialize your transition matrix ─────────────────────────────────────────
# rows: true 2020s zones A–F
# cols: grouped 2090s zones xA, xB, A–F
transition_matrix = pd.DataFrame(
    0.0,
    index=zone_names20,      # ["A","B","C","D","E","F"]
    columns=zone_names90     # ["xA","A","B","C","D","E","F"]
)

#beginner matrix, all zeros to be filled later

In [870]:
#******************************
labels = [f"Historical {z}" for z in zone_names20] + \
         [f"Future {z}" for z in zone_names90]

sources, targets, values = [], [], []
for i, z0 in enumerate(zone_names20):
    for j, z9 in enumerate(zone_names90):
        # mask where 2020s==z0 AND 2090s==z9
        m = zone_masks20[z0] & zone_masks90[z9]
        area_m = area_land.where(m).sum(dim=["lat","lon"]).item()
                # — write the raw‐m² area into the matrix —
        transition_matrix.loc[z0, z9] = area_m
        if area_m > 0:
            sources.append(i)
            targets.append(len(zone_names20) + j)
            values.append(area_m / 1e12)   # convert to Mkm²

# 1) Grab the m² that stayed in A
v_AA_m2 = transition_matrix.loc['A','A']

# 2) Only proceed if it’s nonzero
if v_AA_m2 > 0:
    idx_A_hist = labels.index('Historical A')
    idx_A_future = labels.index('Future A')
    sources.append(idx_A_hist)
    targets.append(idx_A_future)
    values.append(v_AA_m2 / 1e12)


In [871]:
transition_matrix

Unnamed: 0,xA,A,B,C,D,E,F
A,0.0,31343340000000.0,1021632000000.0,0.0,0.0,0.0,0.0
B,0.0,1334998000000.0,29280950000000.0,126819100000.0,93125230000.0,0.0,0.0
C,0.0,6630737000000.0,3618198000000.0,7428912000000.0,0.0,0.0,0.0
D,0.0,0.0,2077644000000.0,9709759000000.0,18466830000000.0,0.0,0.0
E,0.0,0.0,155432800000.0,0.0,19862120000000.0,62356990000.0,0.0
F,0.0,0.0,70097910000.0,226014600000.0,1990565000000.0,2272342000000.0,17921340000000.0


In [872]:
import matplotlib.colors as mcolors

def to_rgba_str(hex_color: str, alpha: float = 0.3) -> str:
    """
    Convert a 6-digit hex color (e.g. "#fc8d59") into an 'rgba(r,g,b,a)' string
    with the given alpha for Plotly.
    """
    r, g, b, _ = mcolors.to_rgba(hex_color)  
    return f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"


In [873]:
# 4) BUILD labels and node‐positions (6 left, 8 right)
nL, nR = len(zone_names20), len(zone_names90)
x_nodes = [0.0]*nL + [1.0]*nR
y_nodes = list(np.linspace(0,1,nL,endpoint=False)) + list(np.linspace(0,1,nR,endpoint=False))

# 5) OPTIONAL: colors by zone
base_colors = {
    "xA": "#191970",  # Midnight Blue
    "A":  "#0000ff",  # Blue
    "B":  "#ff0000",  # Red
    "C":  "#ffff00",  # Yellow
    "D":  "#20b2aa",  # Light Sea Green
    "E":  "#ff00ff",  # Magenta
    "F":  "#4b0082",  # Indigo
}
node_colors = [base_colors[z] for z in zone_names20 + zone_names90]
# suppose `base_colors` is your dict mapping zone → 6-digit hex
link_colors = [
    to_rgba_str(base_colors[ labels[src].split()[1] ], alpha=0.3)
    for src in sources
]

In [874]:
# ─── CUSTOM NODE POSITIONING ───────────────────────────────────────────────────
# Define your 6 originals and 8 grouped zones (in the exact order you want them)
orig_zones    = ['A','B','C','D','E','F']
grouped_zones = ['xA'] + orig_zones

# Explicit x‐positions: 0 for left, 1 for right
x_nodes = [0.3]*len(orig_zones) + [0.6]*len(grouped_zones)

left_zones  = ['A','B','C','D','E','F']
right_zones = ['xA', 'A','B','C','D','E','F']

# Generate evenly spaced y‐values from 0.1 to 0.9
y_left  = np.linspace(0.1, 0.9, len(left_zones))
#y_right = np.linspace(0.1, 0.9, len(right_zones)) #use for data w/ xA
y_right = [  0.1, #A --> use for data w/o xA 
    0.26,  #B
    0.42,  #C
    0.59,  #D
    0.74,  #E
    0.9,   #F
    0.0   #placeholder (don't remove)
]

# Build the mappings
y_map_left  = dict(zip(left_zones,  y_left))
#y_map_right = dict(zip(right_zones, y_right)) #use for data w/ xA
y_map_right = dict(zip(right_zones, y_right))

y_nodes = [y_map_left[z] for z in orig_zones] + [y_map_right[g] for g in grouped_zones]

In [875]:
# ─── YOUR EXISTING SANKEY CREATION ─────────────────────────────────────────────
fig = go.Figure(go.Sankey(
    arrangement="fixed",
    node = dict(
        x         = x_nodes,
        y         = y_nodes,
        color     = node_colors,
        pad       = 45,
        thickness = 30,
    ),
    link = dict(
        source = sources,
        target = targets,
        value  = values,
        color  = link_colors,
    )
))

In [876]:
# area_land = your land‐only m² grid from before

area90_raw = {
    g: area_land.where(zone_masks90[g]).sum(dim=["lat","lon"]).item()
    for g in ['xA','A','B','C','D','E','F']
}

area20_raw = {
    g: area_land.where(zone_masks20[g]).sum(dim=["lat","lon"]).item()
    for g in ['A','B','C','D','E','F']
}

total_land = sum(area20_raw.values())

# confirm sums match
sum20 = sum(area20_raw.values())
sum90 = sum(area90_raw.values())
print(f"sum20 == total_land? {sum20==total_land}")
print(f"sum90 == total_land? {sum90==total_land}")

sum20 == total_land? True
sum90 == total_land? False


In [877]:

grouped_zones = ['xA', 'A','B','C','D','E','F']

# 2) compute the raw absolute %‐of‐global‐land‐change
pct_change = {
    g: (area90_raw.get(g, 0.0) - area20_raw.get(g, 0.0))
       / total_land * 100
    for g in grouped_zones
}

# 3) compute the tiny floating‐point residual
residual = sum(pct_change.values())   # e.g. 4.4e-16

# 4) subtract it from one zone (here 'F') so the sum is exactly 0.0
pct_change['F'] -= residual
raw_sum = sum(pct_change.values())
print("Raw sum of pct_change:", raw_sum)

Raw sum of pct_change: 0.0


In [878]:
# 1) zone lists
zone_names20 = ["A","B","C","D","E","F"]
zone_names90 = ["xA"] + zone_names20
zones        = zone_names20 + zone_names90

# 2) recompute node positions, but from top→bottom
n_left, n_right = len(zone_names20), len(zone_names90)
x_left, x_right = 0.3, 0.6

# flip the linspace endpoints so index 0 → top (0.85), index last → bottom (0.15)
y_left  = np.linspace(0.85, 0.15, n_left)
y_right = np.linspace(0.85, 0.15, n_right)

node_x = [x_left]*n_left  + [x_right]*n_right
node_y = list(y_left)     + list(y_right)

# 3) build name→position mapping
offset = 0.03
default_positions = [
    (node_x[i] - offset, node_y[i]) if i < n_left
    else (node_x[i] + offset, node_y[i])
    for i in range(len(zones))
]
custom_pos = { zones[i]: default_positions[i] for i in range(len(zones)) }

In [879]:
# 1) Define which raw 90s fields belong to xA and xB
novel_subs = {
    'xA': ['xAr','xAw']
}

flows = {}
# 2) Loop over the two origins A and B
for orig in ['A','B']:
    # build the 2020s mask for that origin
    mask20 = (ds20[orig] == 1)                      # or (G6_20[orig] > 0)
    area20 = area_land.where(mask20).sum(...).item()
    
    for grp, subs in novel_subs.items():
        # 3) Build a mask = (orig in 20s AND sub in 90s)
        mask = False
        for sub in subs:
            mask |= (mask20 & (ds90[sub] == 1))     # or (G6_90[sub] > 0)
        
        # 4) Sum that to get the raw flow in m²
        m2 = area_land.where(mask).sum(dim=['lat','lon']).item()
        flows[f"{orig}_to_{grp}"] = m2

# 5) Now compute percentages
total_land = area_land.sum(dim=['lat','lon']).item()

pct_of_origin = {}
pct_of_global = {}
for k,m2 in flows.items():
    orig = k.split('_')[0]        # 'A' or 'B'
    pct_of_origin[k] = m2 / area20_raw[orig] * 100
    pct_of_global[k] = m2 / total_land   * 100

# 6) Print them out
print("Flows (in m²):", flows)
print("\n% of origin (A or B):")
for k,v in pct_of_origin.items():
    print(f"  {k}: {v:+.2f}%")
print("\n% of global land:")
for k,v in pct_of_global.items():
    print(f"  {k}: {v:+.2f}%")


Flows (in m²): {'A_to_xA': 0.0, 'B_to_xA': 0.0}

% of origin (A or B):
  A_to_xA: +0.00%
  B_to_xA: +0.00%

% of global land:
  A_to_xA: +0.00%
  B_to_xA: +0.00%


In [880]:
# ─── 1) Compute your “true” %‐of‐global‐land change for C–F ──────────────
#    (i.e. exactly the C_90–C_20 snippet you already ran)
C_pct = C_90 - C_20   # e.g. ≃ 0.12
D_pct = D_90 - D_20   # whatever you got
E_pct = E_90 - E_20
F_pct = F_90 - F_20

# ─── 2) Build your uniform percent change for all eight ──────────────────
pct_uniform = {
    g: (area90_raw[g] - area20_raw.get(g,0.0)) / total_land * 100
    for g in ['xA', 'A','B','C','D','E','F']
}

# ─── 3) Override C–F with the “true” values ───────────────────────────────
pct_uniform.update({
    'C': C_pct,
    'D': D_pct,
    'E': E_pct,
    'F': F_pct,
})

pct_pos = {
  'xA': (0.68, 0.79),
  'A' : (0.68, 0.67),
  'B' : (0.68, 0.55),
  'C' : (0.68, 0.44),
  'D' : (0.68, 0.33),
  'E' : (0.68, 0.22),
  'F' : (0.68, 0.10)}

# ─── 4) (Optional) re‐zero the tiny drift if you need a perfect sum → 0 ───
residual = sum(pct_uniform.values())
pct_uniform['F']  -= residual

# ─── 5) Annotate using this mixed dict ───────────────────────────────────
# ─── 8) Annotate your Sankey with these values at your custom coords ───

# ─── assume you’ve already drawn your Sankey and cleared old annotations ────
fig.layout.annotations = []

# ─── 1) percent labels on the right ────────────────────────────────────────
pct_pos = { #use for data w/ xA
  'xA': (0.65, 0.90),
'A' : (0.65, 0.76),
'B' : (0.65, 0.63),
 'C' : (0.65, 0.50),
'D' : (0.65, 0.36),
'E' : (0.65, 0.23),
'F' : (0.65, 0.10),
}

#pct_pos = { #use for data w/o xA
  #'A' : (0.65, 0.90),
  #'B' : (0.65, 0.74),
  #'C' : (0.65, 0.58),
  #'D' : (0.65, 0.42),
  #'E' : (0.65, 0.26),
  #'F' : (0.65, 0.10),
#}

for zone, (xpos, ypos) in pct_pos.items():
    val = pct_uniform.get(zone, 0.0)
    fig.add_annotation(
        x         = xpos,
        y         = ypos,
        xref      = 'paper',
        yref      = 'paper',
        text      = f"{val:+.2f}%",
        showarrow = False,
        xanchor   = 'left',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14)
    )

# ─── 2) zone‐name labels on the left ────────────────────────────────────────
name_pos = {
  "A": (0.28, 0.90),
  "B": (0.28, 0.74),
  "C": (0.28, 0.58),
  "D": (0.28, 0.42),
  "E": (0.28, 0.26),
  "F": (0.28, 0.10),
}

for zone,(xpos,ypos) in name_pos.items():
    display = f"<b>{zone}</b>"  if zone in ['A','B','C','D','E','F'] \
              else zone
    fig.add_annotation(
        x         = xpos,  y         = ypos,
        xref      = 'paper',  yref      = 'paper',
        text      = display,
        showarrow = False,
        xanchor   = 'right',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14, color='black')
    )

name_pos_right = { #use for data w/ xA
  'xA': (0.62, 0.90),
  'A' : (0.62, 0.76),
  'B' : (0.62, 0.63),
  'C' : (0.62, 0.50),
  'D' : (0.62, 0.36),
  'E' : (0.62, 0.23),
  'F' : (0.62, 0.10),
}

#name_pos_right = { #use for data w/o xA
 # "A": (0.62, 0.90),
 # "B": (0.62, 0.74),
  #"C": (0.62, 0.58),
  #"D": (0.62, 0.42),
  #"E": (0.62, 0.26),
  #"F": (0.62, 0.10),
#}

for zone,(xpos,ypos) in name_pos_right.items():
    display = f"<b>{zone}</b>"  if zone in ['xA', 'A','B','C','D','E','F'] \
              else zone
    fig.add_annotation(
        x         = xpos,  y         = ypos,
        xref      = 'paper',  yref      = 'paper',
        text      = display,
        showarrow = False,
        xanchor   = 'left',
        yanchor   = 'middle',
        font      = dict(family='Times New Roman', size=14, color='black')
    )

fig.update_layout(
  title={
    "text":      "IPSL-CM6A-LR SSP5-8.5 Climate Zone Transitions",
    "x":         0.47,         # 0 = left, 0.5 = center, 1 = right
    "y":         0.94,        # 0 = bottom, 1 = top
    "xanchor":   "center",    # align the title’s x according to x
    "yanchor":   "top",       # align the title’s y according to y
    "font":      {"family":"Times New Roman","size":20,"color":"black"},
  },
  margin={"t":60}             # you may need a bigger top margin
)

fig.add_annotation(
    x=0.30, y=0.001,          # paper‐coords: 30% across, 2% up from bottom
    xref='paper', yref='paper',
    text='1985-2014',
    showarrow=False,
    font=dict(family='Times New Roman', size=20, color='black'),
    xanchor='center',
    yanchor='bottom'
)

fig.add_annotation(
    x=0.60, y=0.001,          # 60% across, 2% up
    xref='paper', yref='paper',
    text='2085-2099',
    showarrow=False,
    font=dict(family='Times New Roman', size=20, color='black'),
    xanchor='center',
    yanchor='bottom',
)
fig.update_layout(
    width=1200,
    height=600
)

pio.renderers.default = 'notebook_connected' 

fig.show()

** this is counting same zone transitions, everything should be correct now, just work on fixing formatting **
**all areas add up to 100 (both 20s and 90s)

In [881]:
A_90 + B_90 + C_90 + D_90 + E_90 + F_90

In [883]:
#fig.write_html('/Users/jaybr/OneDrive/Desktop/RESEARCH/IPSL-CM6A-LR_sankey_SSP5-8.5.html')