In [72]:
import plotly.graph_objects as go
from constants import colors

# Define top-level ethnicity categories and their colors
parent_colors = {
    "White": colors.RCPCH_LIGHT_BLUE,
    "Asian": colors.RCPCH_PINK,
    "Black": colors.RCPCH_MID_GREY,
    "Mixed": colors.RCPCH_YELLOW,
    "Other": colors.RCPCH_DARK_BLUE,
}

# Define ethnicity mapping to parents
ethnicity_parent_map = {
    "Not known": "Other",
    "Any other mixed background": "Mixed",
    "African": "Black",
    "Pakistani or British Pakistani": "Asian",
    "Caribbean": "Black",
    "British, Mixed British": "White",
    "Any other White background": "White",
    "Any other Black background": "Black",
    "Mixed (White and Black Caribbean)": "Mixed",
    "Irish": "White",
    "Any other ethnic group": "Other",
    "Chinese": "Asian",
    "Any other Asian background": "Asian",
    "Mixed (White and Asian)": "Mixed",
    "Indian or British Indian": "Asian",
    "Not Stated": "Other",
    "Mixed (White and Black African)": "Mixed",
    "Bangladeshi or British Bangladeshi": "Asian",
}

# Ethnicity data with percentage distribution
data = {
    "Any other mixed background": 13,
    "Caribbean": 13,
    "Mixed (White and Black African)": 13,
    "Any other Asian background": 12,
    "Any other Black background": 11,
    "Mixed (White and Asian)": 11,
    "Indian or British Indian": 11,
    "Any other ethnic group": 10,
    "Not Stated": 10,
    "Pakistani or British Pakistani": 9,
    "African": 8,
    "British, Mixed British": 8,
    "Bangladeshi or British Bangladeshi": 8,
    "Mixed (White and Black Caribbean)": 7,
    "Irish": 7,
    "Chinese": 6,
    "Not known": 4,
    "Any other White background": 4,
}

# Extract lists
ethnicities = list(data.keys())
percentages = list(data.values())
parents = [ethnicity_parent_map[ethnicity] for ethnicity in ethnicities]  # Assign parents

# Ensure unique parent labels in the treemap
parent_labels = list(set(parents))  # Get unique parent categories
parent_values = [
    sum(data[eth] for eth in ethnicities if ethnicity_parent_map[eth] == parent)
    for parent in parent_labels
]

# Define all labels (parents first, then children)
all_labels = parent_labels + ethnicities
all_parents = ["ALL"] * len(parent_labels) + parents

# Define values (parents first, then children)
all_values =  [100] + parent_values + percentages

# Assign the same color to subcategories as their parent
all_colors = {p: parent_colors[p] for p in parent_labels}  # Assign parent colors
all_colors.update(
    # Apply same color to children
    {eth: parent_colors[ethnicity_parent_map[eth]] for eth in ethnicities}
)

# Create Treemap
fig = go.Figure(
    go.Treemap(
        labels=all_labels,  # Labels including parents
        parents=all_parents,  # Hierarchical structure
        values=all_values,  # Sizes
        textinfo="label+percent root",  # Show labels and percentages
        marker=dict(
            # Apply parent colors to subcategories
            colors=[all_colors[label] for label in all_labels]
        ),
        hovertemplate=(
            "<b>%{label}</b><br>"
            "N=%{value} (%{percentRoot:.0%})<br><extra></extra>"
        ),
    )
)

# Customize layout
fig.update_layout(
    title="Ethnicity Distribution",
    margin=dict(t=50, l=25, r=25, b=25),
)

fig.show()