In [2]:
import plotly.graph_objects as go

labels = [
    "Runtime <= 115.5\n(Score = 0.657)",         
    "proportion_actress <= 0.577\n(Score = 0.604)",           
    "Year <= 2010.5\n(Score = 0.488)",                
    "incl_score <= 0.631\n(Score = 0.511)",           
    "Runtime > 115.5\n(Score = 0.732)",               
    "Runtime <= 151.5\n(Score = 0.712)",              
    "proportion_actress <= 0.173\n(Score = 0.833)",            
    "incl_score <= 0.337\n(Score = 0.803)",           
    "Bechdel_score <= 1.5\n(Score = 0.891)",          
    "Bechdel_score <= 2.5\n(Score = 0.879)",          
    "incl_score <= 0.57\n(Score = 0.852)",            
]

# Custom data for success metric values
customdata = [
    0.657, 0.604, 0.488, 0.511, 0.732, 0.712, 0.833, 0.803, 0.891, 0.879, 0.852
]

# Define links (Source → Target with Sample Count)
sources = [
    0, 0,         # Root splits
    1, 1,         # ratio_W/M splits
    2, 2,         # Year splits
    3, 3,         # incl_score splits
    4, 4,         # Runtime > 115.5 splits
    5, 5,         # Runtime <= 151.5 splits
    6, 6,         # ratio_W/M <= 0.173 splits
    7, 7,         # incl_score <= 0.337 splits
    8, 8,         # Bechdel_score <= 1.5 splits
    9, 9,         # Bechdel_score <= 2.5 splits
    10, 10        # incl_score <= 0.57 splits
]

targets = [
    1, 4,         # Root splits to ratio_W/M and Runtime > 115.5
    2, 5,         # ratio_W/M splits to Year and Runtime <= 151.5
    3, 6,         # Year splits to incl_score and ratio_W/M <= 0.173
    6, 7,         # incl_score splits to ratio_W/M <= 0.173 and incl_score <= 0.337
    7, 8,         # Runtime > 115.5 splits to Bechdel_score <= 1.5
    8, 9,         # Runtime <= 151.5 splits to Bechdel_score <= 2.5
    9, 10,        # ratio_W/M <= 0.173 splits to incl_score <= 0.57
    10, 3         # incl_score <= 0.337 splits back to incl_score
]

values = [
    804, 473,     # Root sample counts
    447, 331,     # ratio_W/M sample counts
    276, 55,      # Year sample counts
    36, 19,       # incl_score sample counts
    19, 7,        # Runtime > 115.5 sample counts
    7, 4,         # Runtime <= 151.5 sample counts
    4, 23,        # ratio_W/M sample counts
    23, 4         # incl_score sample counts
]

# Create Sankey Diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=labels,
        customdata=customdata,
        hovertemplate='Node: %{label}<br>Success Metric: %{customdata:.3f}<extra></extra>',  # Hover template
        color="lightblue"  # Node colors
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color="rgba(192, 192, 192, 0.4)"  # Link colors
    )
)])

# Update layout
fig.update_layout(
    title_text="Sankey Diagram of the Success Metric",
    font_size=12
)

# Display the diagram
fig.show()
fig.write_html("sankey_diagram.html")
