In [14]:
from google.cloud import bigquery
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import kaleido

In [13]:
# Clients
client = bigquery.Client(project="starbucks-uk")

In [15]:
# BQ Read

query = """
    SELECT market, starLoyaltyLevel, segment, user_type, source_step, target_step, transition_count, sequence_number
    FROM `starbucks-uk.monks_homepage_analysis.conversion_pathways_long_event_aggregated_android`
"""
data = client.query(query).to_dataframe()
if data.empty:
    print("No data returned from BigQuery. Check your table and query.")
    exit()

# Data Transformation
data["target_step"] = data["target_step"].replace({None: "end", "null": "end"})

# 1. Remove transitions where source_step == target_step (self-loops)
data = data[data["source_step"] != data["target_step"]]

# 2. Filter out cases where the first node doesn't start with sequence_number == 1
data = data[
    (data["source_step"] == "start") & (data["sequence_number"] == 1)
    | (data["source_step"] != "start")
]

# Aggregation
aggregated_data = (
    data.groupby(
        [
            "market",
            "starLoyaltyLevel",
            "segment",
            "user_type",
            "source_step",
            "target_step",
        ]
    )["transition_count"]
    .sum()
    .reset_index()
)

# Print distinct combinations
market_star_segment_user_combinations = aggregated_data[
    ["market", "starLoyaltyLevel", "segment", "user_type"]
].drop_duplicates()
print("Available market, starLoyaltyLevel, segment, and user_type combinations:")
print(market_star_segment_user_combinations)

Available market, starLoyaltyLevel, segment, and user_type combinations:
      market starLoyaltyLevel                 segment  user_type
0         AE             GOLD            bounced user        new
4         AE             GOLD            bounced user  returning
356       AE             GOLD  frequent non-purchaser        new
473       AE             GOLD  frequent non-purchaser  returning
1645      AE             GOLD      frequent purchaser        new
1670      AE             GOLD      frequent purchaser  returning
2667      AE            GREEN            bounced user        new
2699      AE            GREEN            bounced user  returning
3055      AE            GREEN  frequent non-purchaser        new
3368      AE            GREEN  frequent non-purchaser  returning
4542      AE            GREEN      frequent purchaser        new
4547      AE            GREEN      frequent purchaser  returning
5199      AE          WELCOME  frequent non-purchaser        new
5205      AE     

In [16]:
def generate_sankey(
    market_input=None,
    starLoyaltyLevel_input=None,
    segment_input=None,
    user_type_input=None,
    top_n=250,
    file_prefix="sankey_diagram", 
):
    # Filtered input
    filtered_data = data.copy()

    if market_input:
        filtered_data = filtered_data[filtered_data["market"] == market_input]
    if starLoyaltyLevel_input:
        filtered_data = filtered_data[
            filtered_data["starLoyaltyLevel"] == starLoyaltyLevel_input
        ]
    if segment_input:
        filtered_data = filtered_data[filtered_data["segment"] == segment_input]
    if user_type_input:
        filtered_data = filtered_data[filtered_data["user_type"] == user_type_input]

    if filtered_data.empty:
        print("No data available for the specified filters.")
        return

    # Level 1 Set up
    level_1_data = filtered_data[
        (filtered_data["source_step"] == "start")
        & (filtered_data["sequence_number"] == 1)
    ]

    if level_1_data.empty:
        print("No valid level 1 data found.")
        return

    # Collect targets of level 1
    valid_targets_level_1 = level_1_data["target_step"].unique()

    # Filter second-level transitions starting from valid_targets_level_1
    level_2_data = filtered_data[
        filtered_data["source_step"].isin(valid_targets_level_1)
    ]

    valid_targets_level_2 = level_2_data["target_step"].unique()

    # Filter third-level transitions starting from valid_targets_level_2
    level_3_data = filtered_data[
        filtered_data["source_step"].isin(valid_targets_level_2)
    ]

    # Combine all three levels
    combined_data = pd.concat([level_1_data, level_2_data, level_3_data])

    if combined_data.empty:
        print("No valid data found after combining levels.")
        return

    # Ensure we only take the top N transitions based on transition_count (Otherweise the chart can get too crowded)
    combined_data = combined_data.nlargest(top_n, "transition_count")

    # Group nodes under 'Other' if they are below x% of the total transitions. You can specify that in the threshold variable
    total_transitions = combined_data["transition_count"].sum()
    threshold = total_transitions * 0.02

    combined_data["source_step"] = combined_data["source_step"].apply(
        lambda x: (
            "Other"
            if combined_data[combined_data["source_step"] == x][
                "transition_count"
            ].sum()
            < threshold
            else x
        )
    )
    combined_data["target_step"] = combined_data["target_step"].apply(
        lambda x: (
            "Other"
            if combined_data[combined_data["target_step"] == x][
                "transition_count"
            ].sum()
            < threshold
            else x
        )
    )

    # Aggregation
    combined_data = (
        combined_data.groupby(["source_step", "target_step"])["transition_count"]
        .sum()
        .reset_index()
    )

    if combined_data.empty:
        print("No data left after applying 'Other' grouping.")
        return

    # Source - Target Mapping
    nodes = list(
        pd.concat([combined_data["source_step"], combined_data["target_step"]]).unique()
    )
    source_indices = [nodes.index(src) for src in combined_data["source_step"]]
    target_indices = [nodes.index(tgt) for tgt in combined_data["target_step"]]
    values = combined_data["transition_count"].tolist()

    # Debugging Outputs
    print("Nodes:", nodes)
    print("Source Indices:", source_indices)
    print("Target Indices:", target_indices)
    print("Values:", values)

    # Count occurrences of each node
    node_counts = {node: 0 for node in nodes}
    for _, row in combined_data.iterrows():
        node_counts[row["source_step"]] += row["transition_count"]
        node_counts[row["target_step"]] += row["transition_count"]

    # Calculate percentages for nodes
    node_percentages = {
        node: (node_counts[node] / total_transitions) * 100 for node in nodes
    }

    # Compact number format
    def compact_number(num):
        if num >= 1_000_000:
            return f"{num / 1_000_000:.1f}M"
        elif num >= 1_000:
            return f"{num / 1_000:.1f}K"
        else:
            return str(num)

    # Sankey
    fig = go.Figure(
        go.Sankey(
            node=dict(
                pad=20,
                thickness=20,
                label=[
                    (
                        f"{node} ({compact_number(node_counts[node])} - {node_percentages[node]:.1f}%)"
                        if node not in ["start"]
                        else f"{node}"
                    )
                    for node in nodes
                ],
            ),
            link=dict(source=source_indices, target=target_indices, value=values),
        )
    )

    market_display = market_input if market_input else "All Markets"
    starLoyaltyLevel_display = (
        starLoyaltyLevel_input if starLoyaltyLevel_input else "All Star Loyalty Levels"
    )
    segment_display = segment_input if segment_input else "All Segments"
    user_type_display = user_type_input if user_type_input else "All User Types"
    fig.update_layout(
        title={
            "text": f"User Journey Analysis<br><sub>Market - {market_display} | Star Loyalty Level - {starLoyaltyLevel_display} | Segment - {segment_display} | User Type - {user_type_display}</sub>",
            "font": {"size": 18},
            "x": 0.5,
            "xanchor": "center",
        },
        font=dict(size=12),
        height=800,
        width=1600,
    )

    # Save as PNG
    file_name = f"{file_prefix}_{market_input or 'AllMarkets'}_{starLoyaltyLevel_input or 'AllLevels'}_{segment_input or 'AllSegments'}_{user_type_input or 'AllUserTypes'}.png"
    fig.write_image(file_name, format="png", scale=2)
    print(f"Sankey diagram saved to '{file_name}'.")
    fig.show()

In [20]:
# User input for market, starLoyaltyLevel, segment, and user_type
market_input = input(
    "Enter the market you want to analyze (or press Enter for all markets): "
)
starLoyaltyLevel_input = input(
    "Enter the starLoyaltyLevel you want to analyze (or press Enter for all levels): "
)
segment_input = input(
    "Enter the segment you want to analyze (or press Enter for all segments): "
)
user_type_input = input(
    "Enter the user type you want to analyze (or press Enter for all user types): "
)
file_prefix = (
    input("Enter the prefix for the output file name (default: sankey_diagram): ")
    or "sankey_diagram"
)

# Generate the Sankey diagram for the specified inputs
generate_sankey(
    market_input=market_input,
    starLoyaltyLevel_input=starLoyaltyLevel_input,
    segment_input=segment_input,
    user_type_input=user_type_input,  # Pass the user_type_input to the function
    top_n=50,  # You can specify another value here if needed
    file_prefix=file_prefix,  # Pass the user-specified file prefix
)

Nodes: ['start', 'tap_bottom_navigation_order_tab', 'tap_change_store', 'tap_home_link', 'tap_menu_menu_navigation_link', 'tap_scan_link', 'Other', 'end']
Source Indices: [0, 0, 0, 0, 0, 0, 1, 1, 2, 3, 3, 4, 5, 5]
Target Indices: [6, 7, 1, 3, 4, 5, 6, 4, 4, 7, 5, 1, 7, 3]
Values: [48165, 102331, 64103, 88834, 73959, 883981, 57224, 187064, 167330, 221662, 60792, 146844, 976040, 350558]
Sankey diagram saved to 'android_AllMarkets_GOLD_AllSegments_AllUserTypes.png'.
