In [13]:
from pathlib import Path

CUR_ABS_DIR = Path.cwd().resolve()
PROJ_DIR = (CUR_ABS_DIR / '../../../').resolve()
LOAD_LATEST_RESULTS = True

ISSUE_TAG = '<span style="color:red;">[issue]</span> '
INFO_TAG = '<span style="color:blue;">[info]</span> '
SUCCESS_TAG = '<span style="color:green;">[success]</span> '

def print(text=None):
    from IPython.display import display, HTML
    if text is None:
        display(HTML("<br>"))
        return
    text = str(text)
    display(HTML(text))


print(f'{INFO_TAG} Project directory: {PROJ_DIR}')
print(f'{INFO_TAG} Current directory: {CUR_ABS_DIR}')

In [14]:
import plotly.graph_objects as go
import plotly
import pandas as pd

# --- Basic settings ---
plotly.io.defaults.default_format = "pdf"

# --------------------------
# 1. Strict data logic backtracking
# --------------------------
# Since your logic is "columns -> methods", we need to iterate every "method"
# and see which "attributes" converge into it.

# Define all node names (indices 0-20)
# Based on the provided order
nodes = {
    # Col 1: Resources
    0: "High", 1: "Medium", 2: "Low",
    # Col 2: Accuracy
    3: "High Accuracy", 4: "High Stability",
    # Col 3: Speed
    5: "Speed: Fast", 6: "Speed: Medium", 7: "Speed: Slow",
    # Col 4: Window
    8: "Window: Short", 9: "Window: Long",
    # Col 5: Horizon
    10: "Nowcasting", 11: "Short Horizon", 12: "Medium Horizon",
    # Col 6: Methods
    13: "Isolated (L)", 14: "Graphlet (L)", 15: "Embedded (L)", 16: "ST-GNN (L)",
    17: "Isolated (T)", 18: "Graphlet (T)", 19: "Embedded (T)", 20: "ST-GNN (T)"
}

# Define the mappings you provided (Source -> Targets)
mappings = {
    # Stage 1: Resource -> Method
    0: [13, 14, 17, 18],
    1: [19, 20],
    2: [15, 16],

    # Stage 2: Accuracy -> Method
    3: [18, 20],
    4: [13, 14, 15, 16, 17, 19],

    # Stage 3: Speed -> Method
    5: [15],
    6: [16, 19, 20],
    7: [13, 14, 17, 18],

    # Stage 4: Window -> Method
    8: [14, 18],
    9: [13, 15, 16, 17, 19, 20],

    # Stage 5: Horizon -> Method
    10: [13, 14, 15, 16, 17, 18, 19, 20], # All methods handle Nowcasting
    11: [13, 14, 15, 16], # (L) group
    12: [17, 18, 19, 20]  # (T) group
}

# Method list (Indices 13-20)
method_indices = range(13, 21)

# Build DataFrame rows
rows = []

# For each Method, find the attributes pointing to it
for m_idx in method_indices:
    method_name = nodes[m_idx]

    # Find the corresponding attribute for this Method (reverse lookup)
    # Stage 1: Resource
    res_node = next(k for k, v in mappings.items() if k in [0,1,2] and m_idx in v)
    # Stage 2: Accuracy
    acc_node = next(k for k, v in mappings.items() if k in [3,4] and m_idx in v)
    # Stage 3: Speed
    spd_node = next(k for k, v in mappings.items() if k in [5,6,7] and m_idx in v)
    # Stage 4: Window
    win_node = next(k for k, v in mappings.items() if k in [8,9] and m_idx in v)
    # Stage 5: Horizon (special case: a Method may correspond to multiple Horizons)
    # For example, 13 (Iso L) is listed under both 10 and 11, meaning two distinct paths.
    hor_nodes = [k for k, v in mappings.items() if k in [10,11,12] and m_idx in v]

    for h_node in hor_nodes:
        rows.append([
            nodes[res_node],
            nodes[acc_node],
            nodes[spd_node],
            nodes[win_node],
            nodes[h_node],
            method_name
        ])

df = pd.DataFrame(rows, columns=["Resource", "Accuracy", "Speed", "Window", "Horizon", "Method"])
# Give each link a default weight of 1
df['Count'] = 1

# --------------------------
# 2. Visual styling (Colors & Style)
# --------------------------

# Define a more academic/paper-style palette (Muted/Pastel Palette)
# Here assign a specific color to each Method
# (several alternative palettes were commented out)

method_colors = {
    # ----- Blues (L) -----
    "Isolated (L)": "#3c6dd9",   # deep blue (stable)
    "Graphlet (L)": "#5e92f3",   # brighter blue-cyan
    "Embedded (L)": "#73b9f7",   # brightest blue
    "ST-GNN (L)":  "#b3ddfb",    # light blue for distinction

    # ----- Greens (T) -----
    "Isolated (T)": "#3d8a44",   # deep green
    "Graphlet (T)": "#5ca860",   # brighter green
    "Embedded (T)": "#8dd191",   # brightest green
    "ST-GNN (T)":  "#c4e6c6",    # very light green for contrast
}

# Map colors into the data
# 1. Create a color index
# Ensure the ordering of methods in the DataFrame matches the color table to avoid mismatch.
# We follow the order methods appear in the DataFrame (or a fixed order).
unique_methods = list(method_colors.keys())
df_methods = df["Method"].unique()
method_to_int = {m: i for i, m in enumerate(df_methods)}
df['color_val'] = df['Method'].map(method_to_int)
# Generate the color list corresponding to these numeric mappings
colorscale = [method_colors[m] for m in df_methods]

# --------------------------
# 3. Enforced sorting (Sorting)
# --------------------------
# Sort according to your logical hierarchy
order_resource = ["High", "Medium", "Low"]
order_accuracy = ["High Accuracy", "High Stability"]
order_speed    = ["Speed: Slow", "Speed: Medium", "Speed: Fast"]
order_window   = ["Window: Short", "Window: Long"]
order_horizon  = ["Nowcasting", "Short Horizon", "Medium Horizon"]
# Method ordering: group (L) and (T) together
order_method   = [
    "Isolated (L)", "Graphlet (L)", "Embedded (L)", "ST-GNN (L)",
    "Isolated (T)", "Graphlet (T)", "Embedded (T)", "ST-GNN (T)"
]
# order_method   = [
#     "Isolated (L)", "Isolated (T)",
#     "Graphlet (L)", "Graphlet (T)",
#     "Embedded (L)", "Embedded (T)",
#     "ST-GNN (L)", "ST-GNN (T)"
# ]

# --------------------------
# 4. Plotting (Parcats)
# --------------------------
fig = go.Figure(data=[go.Parcats(
    domain={'y': [0, 1]}, # Occupy full vertical space

    dimensions=[
        dict(label="Resource", values=df["Resource"], categoryarray=order_resource),
        dict(label="Accuracy vs. Stability", values=df["Accuracy"], categoryarray=order_accuracy),
        dict(label="Speed",    values=df["Speed"],    categoryarray=order_speed),
        dict(label="Window",   values=df["Window"],   categoryarray=order_window),
        dict(label="Horizon",  values=df["Horizon"],  categoryarray=order_horizon),
        dict(label="Method",   values=df["Method"],   categoryarray=order_method),
    ],

    # Link style
    line=dict(
        color=df['color_val'],
        colorscale=colorscale,
        shape='hspline', # Smooth curved lines
        # opacity=0.6      # Optional transparency to clarify overlaps
    ),

    # Interaction settings
    hoveron='color',
    hoverinfo='count+probability',
    arrangement='freeform',

    # Font settings
    labelfont=dict(size=18, family="Arial", color="black"), # column title font
    tickfont=dict(size=17, family="Arial", color="#333333"), # option font
)])

# --------------------------
# 5. Layout tweaks (compact & paper-style)
# --------------------------
width, height = 1000, 280

fig.update_layout(
    # Adjust height: reduce total figure height so bars look shorter (less elongated)
    height=height,
    width=width,
    plot_bgcolor='white',
    paper_bgcolor='white',
    # Very narrow margins
    margin=dict(l=30, r=75, t=20, b=15),
    font=dict(family="Arial")
)

fig.show()

# --------------------------
# 7. Save compact PDF
# --------------------------
# Kaleido is required: pip install kaleido
output_path = CUR_ABS_DIR / 'figures' / 'sankey_choice.pdf'
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
    # Export PDF using kaleido
    fig.write_image(output_path, format="pdf", width=width, height=height)
    print(f'{SUCCESS_TAG}Saved grid plot to {output_path}.')
except Exception as e:
    print(f"Error saving PDF. Make sure 'kaleido' is installed (pip install kaleido). Error details: {e}")

from swissrivernetwork.util.os import make_open_button
make_open_button(output_path)

Button(button_style='success', description='Open in File Manager', style=ButtonStyle())

In [15]:
# The Sankey diagram version of the above logic

import plotly.graph_objects as go
import plotly.io as pio
import plotly

# --- Key setting: compact PDF export for academic paper ---
# Ensure kaleido is installed: pip install kaleido
# Set default export format; sometimes explicit specification is needed
plotly.io.defaults.default_format = "pdf"

# --------------------------
# 1. Helper: Hex to RGBA (for semi-transparent links)
# --------------------------
def hex_to_rgba(hex_color, opacity=0.4):
    hex_color = hex_color.lstrip('#')
    if len(hex_color) == 6:
        r = int(hex_color[0:2], 16)
        g = int(hex_color[2:4], 16)
        b = int(hex_color[4:6], 16)
        return f"rgba({r},{g},{b},{opacity})"
    return hex_color

# --------------------------
# 2. Define nodes (labels)
# --------------------------
# Define all column nodes left-to-right, top-to-bottom
labels = [
    # --- Col 1: Computational Resources (Index 0-2) ---
    "Comp. Res: High", "Comp. Res: Medium", "Comp. Res: Low",
    # --- Col 2: Accuracy Requirement (Index 3-4) ---
    "High Accuracy", "High Stability",
    # --- Col 3: Training Speed Req. (Index 5-7) ---
    "Speed: Fast", "Speed: Medium", "Speed: Slow",
    # --- Col 4: Historical Window (Index 8-9) ---
    "Window: Short", "Window: Long",
    # --- Col 5: Future Horizon (Index 10-12) ---
    "Nowcasting", "Short Horizon", "Medium Horizon",
    # --- Col 6: Final Methods (Rightmost) (Index 13-20) ---
    "Isolated (L)", "Graphlet (L)", "Embedded (L)", "ST-GNN (L)",
    "Isolated (T)", "Graphlet (T)", "Embedded (T)", "ST-GNN (T)"
]

# --------------------------
# 3. Define node colors (distinct palettes per column)
# --------------------------
colors = [
    # Col 1 (Blues) - lighter variants
    "#5e92f3",
    "#73b9f7",
    "#b3ddfb",

    # Col 2 (Greens) - lighter variants
    "#5ca860",
    "#8dd191",
    # "#c4e6c6",

    # Col 3 (Oranges) - lighter variants
    "#f49844",
    "#ffbd5c",
    "#ffd9a3",

    # Col 4 (Purples) - lighter variants
    "#9854c4",
    "#c77bd1",

    # Col 5 (Reds) - lighter variants
    "#d96060",
    "#f27c79",
    "#efa3a1",

    # Col 6 (Methods - Distinct Colors) - lighter variants
    "#6d8694",
    "#846b5f",
    "#a08876",
    "#ad9b8f",

    # (L) group - lighter variants
    "#8ba3b0",
    "#ad9b8f",
    "#baa8a0",
    "#d0c4bd",
]

# --------------------------
# 4. Define flows (core logic)
# --------------------------
# This models multi-stage flows where feature combinations lead to method choices.
# Values are manually set to balance visual appearance.

sources = []
targets = []
values = []

# # --- Stage 1: Resources -> Accuracy ---
# # High resources tend to enable high accuracy
# sources.extend([0, 0, 1, 1, 1, 2, 2])
# targets.extend([3, 4, 3, 4, 5, 4, 5])
# values.extend([25, 5, 5, 20, 5, 5, 25])
#
# # --- Stage 2: Accuracy -> Speed ---
# # High accuracy usually implies slower speed; low accuracy allows faster speed
# sources.extend([3, 3, 4, 4, 4, 5, 5])
# targets.extend([7, 8, 6, 7, 8, 6, 7])
# values.extend([5, 25, 5, 20, 5, 25, 5])
#
# # --- Stage 3: Speed -> Window ---
# # Fast speed often corresponds to short windows
# sources.extend([6, 6, 7, 7, 8, 8])
# targets.extend([9, 10, 9, 10, 9, 10])
# values.extend([25, 5, 15, 15, 5, 25])
#
# # --- Stage 4: Window -> Horizon ---
# sources.extend([9, 9, 9, 10, 10, 10])
# targets.extend([11, 12, 13, 11, 12, 13])
# values.extend([20, 20, 5, 5, 10, 30])

# --- Stage 5: Horizon -> Final Methods (most important mapping) ---
# Define which features lead to which final method

# # None Horizon -> Isolated methods
# sources.extend([11, 11])
# targets.extend([14, 18])  # To Isolated(L), Isolated(T)
# values.extend([12, 13])
#
# # Short Horizon -> Graphlet and some Embedded
# sources.extend([12, 12, 12, 12])
# targets.extend([15, 19, 16, 20])  # To Graphlet(L/T), Embedded(L/T)
# values.extend([8, 8, 7, 7])
#
# # Long Horizon -> ST-GNN and some Embedded
# sources.extend([13, 13, 13, 13])
# targets.extend([17, 21, 16, 20])  # To ST-GNN(L/T), Embedded(L/T)
# values.extend([12, 13, 5, 5])

# --- Stage 1: Resources -> Final Methods ---
sources.extend([2, 2])
targets.extend([15, 16])
values.extend([1, 1])

sources.extend([1, 1])
targets.extend([19, 20])
values.extend([1, 1])

sources.extend([0, 0, 0, 0])
targets.extend([13, 14, 17, 18])
values.extend([1, 1, 1, 1])

# --- Stage 2: Accuracy -> Final Methods ---

sources.extend([3, 3])
targets.extend([18, 20])
values.extend([1, 1])

sources.extend([4, 4, 4, 4, 4, 4])
targets.extend([13, 14, 15, 16, 17, 19])
values.extend([1, 1, 1, 1, 1, 1])

# --- Stage 3: Speed -> Final Methods ---
sources.extend([5])
targets.extend([15])
values.extend([1])

sources.extend([6, 6, 6])
targets.extend([16, 19, 20])
values.extend([1, 1, 1])

sources.extend([7, 7, 7, 7])
targets.extend([13, 14, 17, 18])
values.extend([1, 1, 1, 1])

# --- Stage 4: Window -> Final Methods ---

sources.extend([8, 8])
targets.extend([14, 18])
values.extend([1, 1])

sources.extend([9, 9, 9, 9, 9, 9])
targets.extend([13, 15, 16, 17, 19, 20])
values.extend([1, 1, 1, 1, 1, 1])

# --- Stage 5: Horizon -> Final Methods ---

sources.extend([10, 10, 10, 10, 10, 10, 10, 10])
targets.extend([13, 14, 15, 16, 17, 18, 19, 20])
values.extend([1, 1, 1, 1, 1, 1, 1, 1])

sources.extend([11, 11, 11, 11])
targets.extend([13, 14, 15, 16])
values.extend([1, 1, 1, 1])

sources.extend([12, 12, 12, 12])
targets.extend([17, 18, 19, 20])
values.extend([1, 1, 1, 1])


# Generate link colors using the source node color with transparency
link_colors = [hex_to_rgba(colors[src_idx], opacity=0.35) for src_idx in sources]

# --------------------------
# 5. Create figure object
# --------------------------
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=10,
        thickness=15,
        line=dict(color="black", width=0.5),
        label=labels,
        color=colors,
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors
    )
)])

# --------------------------
# 6. Compact layout for paper export
# --------------------------
fig.update_layout(
    # title_text="Methodology Selection Flow",
    font_family="Arial",
    font_size=10,
    height=500,
    width=1100,
    plot_bgcolor='white',
    paper_bgcolor='white',
    # Set minimal margins to remove white borders
    margin=dict(l=5, r=5, t=5, b=5),
    font=dict(
        family="Arial",
        size=10,
        color="black"
    ),
)

# Display figure (view in browser)
fig.show()

# # --------------------------
# # 7. Save compact PDF
# # --------------------------
# # Kaleido is required: pip install kaleido
# output_path = CUR_ABS_DIR / 'figures' / 'sankey_choice.pdf'
# output_path.parent.mkdir(parents=True, exist_ok=True)
# try:
#     # Export PDF using kaleido
#     fig.write_image(output_path, format="pdf")
#     print(f'{SUCCESS_TAG}Saved grid plot to {output_path}.')
# except Exception as e:
#     print(f"Error saving PDF. Make sure 'kaleido' is installed (pip install kaleido). Error details: {e}")
#
# from swissrivernetwork.util.os import make_open_button
# make_open_button(output_path)


In [16]:
#  This is an example generated by Gemini 3 Pro (accessed on 2025-12-04). This is crazily good!

import plotly.graph_objects as go

# --------------------------
# 1. Helper: Hex to RGBA
# --------------------------
def hex_to_rgba(hex_color, opacity=0.5):
    """Convert a hexadecimal color to RGBA format for controlling opacity"""
    hex_color = hex_color.lstrip('#')
    if len(hex_color) == 6:
        r = int(hex_color[0:2], 16)
        g = int(hex_color[2:4], 16)
        b = int(hex_color[4:6], 16)
        return f"rgba({r},{g},{b},{opacity})"
    return hex_color

# --------------------------
# 2. Data definition
# --------------------------
labels = [
    # Left-side nodes (0-8)
    "Netherlands", "Canada", "Belgium", "Italy", "Mexico",
    "Russia", "Spain", "South Korea", "Germany",
    # Middle nodes (9-10)
    "China", "European Union",
    # Right/bottom nodes (11-18)
    "United Kingdom", "United States", "Japan", "Hong Kong",
    "France", "Switzerland", "Austria", "Sweden"
]

# Node colors (Hex)
colors = [
    "#7c7cba", # Netherlands (purple)
    "#b2ebf2", # Canada (light blue)
    "#f5b7b1", # Belgium (light red)
    "#ffe0b2", # Italy (light yellow)
    "#a5d6a7", # Mexico (light green)
    "#d1c4e9", # Russia (light purple)
    "#b3e5fc", # Spain (light blue)
    "#d7ccc8", # South Korea (light brown)
    "#26c6da", # Germany (teal)

    "#ef5350", # China (red - middle)
    "#fbc02d", # European Union (yellow - middle)

    "#9575cd", # UK
    "#4fc3f7", # US (deep blue)
    "#a5d6a7", # Japan
    "#80cbc4", # Hong Kong
    "#7986cb", # France
    "#ff8a65", # Switzerland
    "#ffe082", # Austria
    "#c5e1a5"  # Sweden
]

# Link relations
sources = [
    0, 0,  # Netherlands -> UK, EU
    1,     # Canada -> US
    2,     # Belgium -> EU
    3,     # Italy -> EU
    4,     # Mexico -> US
    5,     # Russia -> China
    6,     # Spain -> EU
    7,     # South Korea -> China
    8, 8,  # Germany -> China, EU

    9, 9, 9,        # China -> US, Japan, Hong Kong
    10, 10, 10, 10, 10, 10 # EU -> UK, US, France, Switz, Austria, Sweden
]

targets = [
    11, 10, # Neth -> UK, EU
    12,     # Canada -> US
    10,     # Belgium -> EU
    10,     # Italy -> EU
    12,     # Mexico -> US
    9,      # Russia -> China
    10,     # Spain -> EU
    9,      # SK -> China
    9, 10,  # Germany -> China, EU

    12, 13, 14,             # China -> US, Japan, HK
    11, 12, 15, 16, 17, 18  # EU -> UK, US, France, Switz, Austria, Sweden
]

values = [
    5, 10,   # Neth
    35,      # Canada
    8,       # Belgium
    8,       # Italy
    15,      # Mexico
    5,       # Russia
    6,       # Spain
    6,       # SK
    15, 20,  # Germany

    10, 8, 5,            # China outgoing
    10, 25, 8, 6, 4, 3   # EU outgoing
]

# --------------------------
# 3. Key step: generate link colors
# --------------------------
# Logic: for each link, take the source node's color and add transparency
link_colors = [hex_to_rgba(colors[src_idx], opacity=0.4) for src_idx in sources]

# --------------------------
# 4. Plotting
# --------------------------
fig = go.Figure(data=[go.Sankey(
    # Node settings
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=labels,
        color=colors
    ),
    # Link settings
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors  # <--- apply the generated colored link list here
    )
)])

fig.update_layout(
    title_text="Global Flow Sankey Diagram (Colored Links)",
    font_size=12,
    height=600,
    width=1000,
    plot_bgcolor='white'
)

fig.show()
