# Using Plotly

Plotly is a powerful, interactive graphing library that includes built-in support for Sankey diagrams. It's widely used in data visualization and offers great flexibility for creating professional-looking diagrams with interactive features.

## Plotly basics

Let's get you up and running with some basic Sankey diagrams using Plotly.

In [None]:
# Import necessary packages
import pandas as pd
import plotly.graph_objects as go

Let's start by considering a simple dataset. 
Imagine we have some farms, which grow apples and bananas to sell to a few different customers. 
We can describe the *flow* of fruit from the farms (the *source* of the flow) to the customers (the *target* of the flow).

In [None]:
"""Import data using pandas"""
flows = pd.read_csv("example_data/simple_fruit_sales.csv")
display(flows)

In [None]:
"""We start by creating a helper function for converting flow data into Plotly Sankey data."""


def get_sankey_data(flows: pd.DataFrame) -> tuple[list, list, list, list]:
    """Helper function to convert flow data into plotly-compatible sankey data.

    Args:
        flows (pd.DataFrame): source-target-value data
    Returns:
        tuple[list, list, list, list]:
            1. List of unique node labels,
            2. list of source indices (number is the index of the unique node labels),
            3. list of target indices (number is the index of the unique node labels).
            4. list of values per list item
    """
    # Get all unique nodes

    all_nodes = list(pd.concat([flows["source"], flows["target"]]).unique())
    node_dict = {node: idx for idx, node in enumerate(all_nodes)}
    # Map flows to node indices
    flows["source_idx"] = flows["source"].map(node_dict)
    flows["target_idx"] = flows["target"].map(node_dict)
    return all_nodes, flows


# Let's take a look at the results of our helper function

all_nodes, sankey_flows = get_sankey_data(flows)
display(all_nodes)
display(sankey_flows)

In [None]:
"""Let's draw our first Sankey!"""

all_nodes, sankey_flows = get_sankey_data(flows)

# Create the Sankey diagram
fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(pad=15, thickness=20, label=all_nodes),
            link=dict(
                source=sankey_flows["source_idx"],
                target=sankey_flows["target_idx"],
                value=sankey_flows["value"],
            ),
        )
    ]
)
fig.update_layout(title_text="Simple Fruit Sales", font_size=10, height=400)
fig.show()

But you don't always want a direct correspondence between the flows in your data and the links that you see in the Sankey diagram. 

For example:

* Farms 4, 5 and 6 are all pretty small, and to make the diagram clearer we might want to group them in an "other" category.
* The flows of apples are mixed in with the flows of bananas -- we might want to group the kinds of fruit together to make them easier to compare.
* We might want to group farms or customers based on some other attributes -- to see difference between genders, locations, or organic/non-organic farms, say.

This tutorial shows how to use Plotly to achieve these groupings and create more meaningful visualizations.

### Grouping nodes

Let's start with grouping farms 4, 5 and 6 into an "other" category. 
With Plotly, we need to transform our data before creating the diagram by replacing specific farm names with the group name:

In [None]:
"""Group some farms as 'other'"""
# Create a copy of flows and replace farm4, farm5, farm6 with "other"
flows_grouped = flows.copy()
flows_grouped["source"] = flows_grouped["source"].replace(
    {"farm4": "other", "farm5": "other", "farm6": "other"}
)

# Aggregate flows to the same source-target combination
flows_grouped = (
    flows_grouped.groupby(["source", "target", "type"])["value"].sum().reset_index()
)
display(flows_grouped)

Now we can create the Sankey diagram with the grouped data:

In [None]:
"""Create Sankey with grouped farms"""
# Get all unique nodes (preserving the order: farms first, then customers)
all_nodes, sankey_flows = get_sankey_data(flows_grouped)
fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(pad=15, thickness=20, label=all_nodes),
            link=dict(
                source=sankey_flows["source_idx"],
                target=sankey_flows["target_idx"],
                value=sankey_flows["value"],
            ),
        )
    ]
)
fig.update_layout(title_text="Grouped Farms", font_size=10, height=400)
fig.show()

Great! Now we can see farms 1, 2, and 3 separately, with farms 4, 5, and 6 grouped as "other".

In [None]:
"""Group customers by gender"""
# Define gender mapping
gender_map = {"Fred": "Men", "James": "Men", "Susan": "Women", "Mary": "Women"}

flows_by_gender = flows_grouped.copy()
flows_by_gender["target"] = flows_by_gender["target"].map(gender_map)

# Aggregate by the new grouping
flows_by_gender = (
    flows_by_gender.groupby(["source", "target", "type"])["value"].sum().reset_index()
)

Now let's visualize sales grouped by gender:

In [None]:
"""Visualize with gender grouping"""
all_nodes, sankey_flows = get_sankey_data(flows_by_gender)
fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(pad=15, thickness=20, label=all_nodes),
            link=dict(
                source=sankey_flows["source_idx"],
                target=sankey_flows["target_idx"],
                value=sankey_flows["value"],
            ),
        )
    ]
)
fig.update_layout(title_text="Sales by Gender", font_size=10, height=400)
fig.show()

Perfect! Now we can see the distribution of fruit sales to men and women.

### Adding colors to distinguish flow types

These diagrams have lost sight of the kind of fruit that is actually being sold. Let's add colors to distinguish apples from bananas:

In [None]:
"""Color flows by fruit type"""
# Define color mapping for fruit types
color_map = {"apples": "yellowgreen", "bananas": "gold"}

# Assign colors to each link based on fruit type
link_colors = [color_map[fruit_type] for fruit_type in flows_by_gender["type"]]

fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(pad=15, thickness=20, label=all_nodes),
            link=dict(
                source=sankey_flows["source_idx"],
                target=sankey_flows["target_idx"],
                value=sankey_flows["value"],
                color=link_colors,
            ),
        )
    ]
)
fig.update_layout(
    title_text="Sales by Gender (Colored by Fruit Type)", font_size=10, height=400
)
fig.show()

Now we can see which fruit types are being sold, indicated by the colors!

### Adding intermediate nodes for flow types

To make the fruit types even more explicit, we can add intermediate nodes that represent the different types of fruit. This creates a three-layer diagram: farms → fruit types → customers.

In [None]:
"""First, we modify our helper function to allow for intermediate flows."""


def get_sankey_data_with_intermediates(
    flows: pd.DataFrame, intermediate_flow: list
) -> tuple[list, list, list, list]:
    """Helper function to convert flow data into plotly-compatible sankey data.

    Args:
        flows (pd.DataFrame): source-target-value data
        intermediate_flow (list):
            List of intermediate nodes to merge flows.

    Returns:
        tuple[list, list, list, list]:
            1. List of unique node labels,
            2. list of source indices (number is the index of the unique node labels),
            3. list of target indices (number is the index of the unique node labels).
            4. values for each flow
    """
    # Get all unique nodes
    all_nodes = list(
        pd.concat(
            [flows["source"], flows["target"], *[flows[i] for i in intermediate_flow]]
        ).unique()
    )
    node_dict = {node: idx for idx, node in enumerate(all_nodes)}
    # Map flows to node indices
    start = "source"
    sankey_flows = []
    for inter in intermediate_flow:
        inter_flows = flows.groupby([start, inter])["value"].sum().reset_index()
        inter_flows["source_idx"] = inter_flows[start].map(node_dict)
        inter_flows["target_idx"] = inter_flows[inter].map(node_dict)
        sankey_flows.append(inter_flows)
        start = inter

    # Final flow to target
    final_flows = flows.groupby([start, "target"])["value"].sum().reset_index()
    final_flows["source_idx"] = final_flows[start].map(node_dict)
    final_flows["target_idx"] = final_flows["target"].map(node_dict)

    all_flows = pd.concat(sankey_flows + [final_flows], ignore_index=True)
    return all_nodes, all_flows


all_nodes, sankey_flows = get_sankey_data_with_intermediates(flows_by_gender, ["type"])

In [None]:
colour_mapping = {"apples": "yellowgreen", "bananas": "gold"}
colours_in = sankey_flows["target"].fillna(sankey_flows["type"]).map(colour_mapping)
colours_out = sankey_flows["source"].fillna(sankey_flows["type"]).map(colour_mapping)
colours_all = pd.Series(colours_in).fillna(colours_out).tolist()
display(colours_all)

In [None]:
"""Create three-layer diagram with fruit types in the middle"""
# We need two sets of links: farms -> fruit types, and fruit types -> customers
all_nodes, sankey_flows = get_sankey_data_with_intermediates(flows_by_gender, ["type"])

fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(pad=15, thickness=20, label=all_nodes),
            link=dict(
                source=sankey_flows["source_idx"],
                target=sankey_flows["target_idx"],
                value=sankey_flows["value"],
                color=colours_all,
            ),
        )
    ]
)
fig.update_layout(
    title_text="Three-Layer Diagram with Fruit Types", font_size=10, height=400
)
fig.show()

Perfect! Now we can clearly see:
- Which farms produce which types of fruit
- How much of each fruit type goes to men vs. women

This three-layer approach makes the flow of different product types much clearer.

#### Summary

This has demonstrated the basic usage of Plotly for Sankey diagrams:
- Creating node lists and mapping flows to node indices
- Grouping and aggregating data to simplify diagrams
- Adding colors to distinguish different flow types
- Creating multi-layer diagrams with intermediate nodes

If you like, why not go back and try out some different ways to present the data? 

Here are some suggestions:

1. Farms 1, 3 and 5 are organic. 
   Can you modify the grouping to show two farm groups: organic and non-organic?
2. Try creating a diagram that shows individual customers (not grouped by gender) but with fruit types in the middle layer

## Advanced Plotly tutorials

You've now learnt the basics of Plotly Sankey diagrams. 

Plotly offers many additional features including:
- Interactive tooltips and hover information
- Custom node colors and positioning
- Export to static images (PNG, SVG, PDF)

### Additional resources
For more information on Plotly Sankey diagrams, check out the [official Plotly documentation](https://plotly.com/python/sankey-diagram/).