In [2]:
# Load libraries
import pandas as pd
import numpy as np
import matplotlib

# Directly displays plotly plots in notebook
import plotly
import plotly.io as pio                           
pio.renderers.default = "notebook_connected"

# Alluvial diagram
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.colors as mcolors  

In [3]:
df = pd.read_excel("Version_A_Results.xlsx", sheet_name="Alluvial Diagram Table")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 17464 entries, 0 to 17463
Data columns (total 9 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   ID                         16328 non-null  object
 1   file_num                   17464 non-null  int64 
 2   sent_num                   17464 non-null  int64 
 3   Ambiguous Question         17464 non-null  object
 4   Ambiguous Cluster Name     16329 non-null  object
 5   Disambiguous Question      16734 non-null  object
 6   Disambiguous Cluster Name  16538 non-null  object
 7   Cluster                    16538 non-null  object
 8   Weight                     17464 non-null  int64 
dtypes: int64(3), object(6)
memory usage: 1.2+ MB


In [8]:
def make_cluster_sankey(
    df,
    filter_ambig_cluster=None,
    filter_disambig_cluster=None,
    alpha=0.5,
    title="Ambiguous → Disambiguous"
):
    # Normalize inputs
    if isinstance(filter_ambig_cluster, str):
        filter_ambig_cluster = [filter_ambig_cluster]

    if isinstance(filter_disambig_cluster, str):
        filter_disambig_cluster = [filter_disambig_cluster]

    # Apply filters
    if filter_ambig_cluster is not None:
        df = df[df["Ambiguous Cluster Name"].isin(filter_ambig_cluster)]

    if filter_disambig_cluster is not None:
        df = df[df["Disambiguous Cluster Name"].isin(filter_disambig_cluster)]

    cluster_df = (
        df.groupby(["Ambiguous Cluster Name", "Disambiguous Cluster Name"], dropna=True)["Weight"]
          .sum()
          .reset_index()
    )

    # Build node lists (With prefixes to avoid self-loops)
    left = ["A: " + str(x) for x in cluster_df["Ambiguous Cluster Name"].unique()]
    right = ["D: " + str(x) for x in cluster_df["Disambiguous Cluster Name"].unique()]
    labels = left + right
    idx = {lab: i for i, lab in enumerate(labels)}

    # Map sources/targets
    sources = cluster_df["Ambiguous Cluster Name"].map(lambda x: idx["A: " + str(x)])
    targets = cluster_df["Disambiguous Cluster Name"].map(lambda x: idx["D: " + str(x)])

    # Assign colors
    palette = px.colors.qualitative.Pastel
    node_colors = {lab: palette[i % len(palette)] for i, lab in enumerate(left)}

    # Map link colors based on source cluster
    link_colors = cluster_df["Ambiguous Cluster Name"].map(
        lambda x: node_colors.get("A: " + str(x), "#cccccc")
    )

    link_colors_transparent = [
        c.replace("rgb", "rgba").replace(")", f", {alpha})") for c in link_colors
    ]

    # Compute incoming/outgoing sums
    out_by_left = cluster_df.groupby("Ambiguous Cluster Name")["Weight"].sum()
    in_by_right = cluster_df.groupby("Disambiguous Cluster Name")["Weight"].sum()

    # Node hover text (strip prefixes for readability)
    node_text = []

    for lab in labels:
        clean_lab = lab.replace("A: ", "").replace("D: ", "")
        
        if lab.startswith("A: "):
            node_text.append(
                f"<b>{clean_lab}</b><br>Outgoing flow count: {float(out_by_left.get(clean_lab, 0))}"
            )
        else:
            node_text.append(
                f"<b>{clean_lab}</b><br>Incoming flow count: {float(in_by_right.get(clean_lab, 0))}"
            )

    # Build figure
    fig = go.Figure(
        go.Sankey(
            arrangement="snap",
            node=dict(
                label=[lab.replace("A: ", "").replace("D: ", "") for lab in labels],
                pad=10,
                thickness=30,
                color=[node_colors.get(lab, "#cccccc") for lab in labels],
                customdata=node_text,
                hovertemplate="%{customdata}<extra></extra>",
            ),
            link=dict(
                source=sources,
                target=targets,
                value=cluster_df["Weight"],
                color=link_colors_transparent,
                customdata=np.stack([
                    cluster_df["Ambiguous Cluster Name"],
                    cluster_df["Disambiguous Cluster Name"]
                ], axis=1),
                hovertemplate=(
                    "<b>%{customdata[0]}</b> → <b>%{customdata[1]}</b><br>"
                    "Weight: %{value}<br>"
                    "<extra></extra>"
                ),
            ),
        )
    )

    fig.update_layout(title_text=title)
    fig.show()

In [13]:
make_cluster_sankey(df, 
                    filter_ambig_cluster=["schooling", "Russia, Russians"], 
                    filter_disambig_cluster="German occupation",)