In [10]:
# Load libraries
import pandas as pd
import numpy as np
import matplotlib

# Directly displays plotly plots in notebook
import plotly
import plotly.io as pio                           
pio.renderers.default = "notebook_connected"

# Alluvial diagram
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.colors as mcolors  

In [9]:
# Alluvial diagram (Sankey workaround)
# Needs columns: (Ambiguous Cluster Name, Disambiguous Cluster Name), (Ambiguous Question, Disambiguous Question), Weight (column of 1's)

'''
CHANGE .CSV NAME BELOW TO LOAD DATA
'''
df = pd.read_csv("20AnimalsDataset_FinalWithNames.csv")

# Step 1: Create node lists and mapping
left = df["Ambiguous Cluster Name"].unique().tolist()       
right = df["Disambiguous Cluster Name"].unique().tolist()    
labels = left + right                                              
idx = {lab: i for i, lab in enumerate(labels)}                     

# Step 2: Assign colors to left nodes
palette = px.colors.qualitative.Pastel
node_colors = {lab: palette[i % len(palette)] for i, lab in enumerate(left)}  

# Step 3: Assign link colors based on source node's color
link_colors = df["Ambiguous Cluster Name"].map(node_colors)             

### Make colors transparent
alpha = 0.5  
link_colors_transparent = [c.replace("rgb", "rgba").replace(")", f", {alpha})") for c in link_colors]

# Step 4: Convert names to node indices for links
sources = df["Ambiguous Cluster Name"].map(idx)              
targets = df["Disambiguous Cluster Name"].map(idx)              

# Step 5: Create custom data for tooltip, with Ambiguous Question (Column 1) and Disambiguous Question (Column 2)
custom = np.stack([df["Ambiguous Question"], df["Disambiguous Question"]], axis=1)

# Step 6: Create custom tooltip for each node
### Compute incoming/outgoing sums of links entering/leaving node
out_by_left  = df.groupby("Ambiguous Cluster Name")["Weight"].sum()
in_by_right  = df.groupby("Disambiguous Cluster Name")["Weight"].sum()

### Build a hovertemplate string for each node (same order as `labels`)
node_text = []
for lab in labels:
    if lab in left:
        node_text.append(f"<b>{lab}</b><br>Outgoing flow count: {float(out_by_left.get(lab, 0))}")
    else:
        node_text.append(f"<b>{lab}</b><br>Incoming flow count: {float(in_by_right.get(lab, 0))}")

# Step 7: Build alluvial diagram
fig = go.Figure(go.Sankey(
    arrangement="snap",
    node=dict(                                                      
        label=labels,
        pad=10,
        thickness=30,                                              
        color=[node_colors.get(lab, "#cccccc") for lab in labels],
        customdata=node_text,  
        hovertemplate="%{customdata}<extra></extra>"              
    ),
    link=dict(
        source=sources,                                             
        target=targets,                                            
        value=df["Weight"],                                   
        customdata=custom,
        color=link_colors_transparent,                             
        hovertemplate=(
            "<b>%{source.label}</b> → <b>%{target.label}</b><br>"
            "Ambiguous: %{customdata[0]}<br>"
            "Disambiguated: %{customdata[1]}<br>"
            "<extra></extra>"
        )
    )
))

# Add title to figure
fig.update_layout(
    title_text="Alluvial Diagram of Ambiguous → Disambiguated Cluster Topics"
)

fig.show()

# Export to .html
# fig.write_html("alluvial_diagram.html", include_plotlyjs="cdn")