In [31]:
import pandas as pd
import plotly.graph_objects as go
import random

def group_small_values(df, col, threshold):
    """
    Group values in 'col' whose total TVL < threshold into 'Others'.
    """
    totals = df.groupby(col)["app_token_tvl_usd"].sum()
    small_values = totals[totals < threshold].index
    df[col] = df[col].apply(lambda x: x if x not in small_values else "Others")
    return df

def build_sankey_df(df, columns_order, threshold=0.03):
    """
    Summarize TVL for a 3-level Sankey chart, e.g. [left, mid, right].
    """
    left_col, mid_col, right_col = columns_order
    df_sankey = df.groupby([left_col, mid_col, right_col], as_index=False).agg({"app_token_tvl_usd": "sum"})

    total_val = df_sankey["app_token_tvl_usd"].sum()
    cutoff = threshold * total_val

    # Consolidate small entries
    df_sankey = group_small_values(df_sankey, left_col, cutoff)
    df_sankey = group_small_values(df_sankey, mid_col, cutoff)
    df_sankey = group_small_values(df_sankey, right_col, cutoff)

    # Re-aggregate after grouping
    df_sankey = df_sankey.groupby([left_col, mid_col, right_col], as_index=False).agg({"app_token_tvl_usd": "sum"})
    return df_sankey

def build_sankey_links_and_nodes(df_sankey, columns_order):
    """
    Convert the 3-col DataFrame into Sankey link/node arrays.
    """
    left_col, mid_col, right_col = columns_order

    left_vals = df_sankey[left_col].unique().tolist()
    mid_vals = df_sankey[mid_col].unique().tolist()
    right_vals = df_sankey[right_col].unique().tolist()

    idx_left = {v: i for i, v in enumerate(left_vals)}
    idx_mid = {v: i + len(left_vals) for i, v in enumerate(mid_vals)}
    idx_right = {v: i + len(left_vals) + len(mid_vals) for i, v in enumerate(right_vals)}

    # Links left->mid
    lm_df = df_sankey.groupby([left_col, mid_col], as_index=False).agg({"app_token_tvl_usd": "sum"})
    lm_source = lm_df[left_col].map(idx_left)
    lm_target = lm_df[mid_col].map(idx_mid)
    lm_value = lm_df["app_token_tvl_usd"]

    # Links mid->right
    mr_df = df_sankey.groupby([mid_col, right_col], as_index=False).agg({"app_token_tvl_usd": "sum"})
    mr_source = mr_df[mid_col].map(idx_mid)
    mr_target = mr_df[right_col].map(idx_right)
    mr_value = mr_df["app_token_tvl_usd"]

    link_source = pd.concat([lm_source, mr_source], ignore_index=True)
    link_target = pd.concat([lm_target, mr_target], ignore_index=True)
    link_value = pd.concat([lm_value, mr_value], ignore_index=True)

    node_labels = left_vals + mid_vals + right_vals
    return node_labels, link_source, link_target, link_value

def random_color():
    r = lambda: random.randint(0,255)
    return f'#{r():02X}{r():02X}{r():02X}'

def assign_node_colors(node_labels):
    """
    Assign random (but reproducible) colors to each node label. 
    We'll make "Others" gray if present.
    """
    random.seed(42)
    color_map = {lbl: random_color() for lbl in node_labels}
    if "Others" in color_map:
        color_map["Others"] = "#999999"
    return [color_map[lbl] for lbl in node_labels]

def plot_sankey(df_sankey, column_labels_map, columns_order, snapshot_date, threshold=0.03, note=None):
    """
    Build and show a Sankey chart.
    """
    if df_sankey.empty:
        raise ValueError("No data available after filters. Sankey would be empty.")

    node_labels, link_source, link_target, link_value = build_sankey_links_and_nodes(df_sankey, columns_order)
    node_colors = assign_node_colors(node_labels)

    total_tvl = df_sankey["app_token_tvl_usd"].sum()

    left_col, mid_col, right_col = columns_order
    left_lbl = column_labels_map.get(left_col, left_col)
    mid_lbl = column_labels_map.get(mid_col, mid_col)
    right_lbl = column_labels_map.get(right_col, right_col)

    fig = go.Figure(
        data=[go.Sankey(
            arrangement='snap',
            node=dict(
                label=node_labels,
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                color=node_colors
            ),
            link=dict(
                source=link_source,
                target=link_target,
                value=link_value,
                color="#cccccc",  # all links gray, or customize if you'd like
                customdata=(link_value / 1e6), 
                hovertemplate='TVL: %{customdata:.2f}M USD<extra></extra>'
            )
        )]
    )

    fig.update_layout(
        title={
            'text': "Protocol Token Lineage - " + snapshot_date,
            'font': {'size': 15}
        },
        annotations=[
            dict(
                x=0.0,
                y=-0.35,
                xref='paper',
                yref='paper',
                text=note,
                showarrow=False,
                font=dict(size=10),
                align='left'
            ),
            dict(
                x=0,
                y=1.1,
                xref='paper',
                yref='paper',
                text=f"<b>{left_lbl}</b>",
                showarrow=False,
                font=dict(size=12),
                align='center'
            ),
            dict(
                x=0.5,
                y=1.1,
                xref='paper',
                yref='paper',
                text=f"<b>{mid_lbl}</b>",
                showarrow=False,
                font=dict(size=12),
                align='center'
            ),
            dict(
                x=1,
                y=1.1,
                xref='paper',
                yref='paper',
                text=f"<b>{right_lbl}</b>",
                showarrow=False,
                font=dict(size=12),
                align='center'
            ),
            dict(
                x=0, y=-0.10, xref='paper', yref='paper',
                text=f"<b>Total TVL Shown: ${total_tvl/1e9:,.2f}B</b>",
                showarrow=False, font=dict(size=12)
            )
        ],
        font_size=10,
        margin=dict(l=50, r=50, b=150, t=100),
        autosize=True,
        width=950, height=600
    )
    return fig

In [5]:
df_all = pd.read_csv("protocol_data_prepped.csv")
print(f"Loaded {df_all.shape[0]} rows from 'protocol_data_prepped.csv'.")

Loaded 5885377 rows from 'protocol_data_prepped.csv'.


In [37]:
# Column label map for friendlier Sankey headings
COLUMN_LABELS_MAP = {
    "source_protocol": "Token Issuer",
    "parent_protocol": "Protocol Destination",
    "token": "Token",
    # "token_category": "Token Category" (if you want a 3rd or 4th dimension)
}

# %%
# 2. Pick a snapshot date & apply filters
SNAPSHOT_DATE = "2024-12-17"

# Example filters
selected_chains = df_all.chain.unique().tolist()
selected_protocol_categories = df_all.protocol_category.unique().tolist()
selected_protocols = df_all.protocol_slug.unique().tolist()
selected_token_categories = df_all.token_category.unique().tolist()

# Filtering Data
df_focus = df_all[
    (df_all["dt"] == SNAPSHOT_DATE)
    & df_all.chain.isin(selected_chains)
    & df_all.protocol_category.isin(selected_protocol_categories)
    & df_all.token_category.isin(selected_token_categories)
    & df_all.protocol_slug.isin(selected_protocols)
].copy()

print(f"Rows matching {SNAPSHOT_DATE}: {df_focus.shape[0]}")

# 3. Build the Sankey DF (choose columns for left->middle->right)
COLUMNS_ORDER = ["source_protocol", "token", "parent_protocol"]
THRESHOLD = 0.03  # Group anything under 3% total TVL into 'Others'

df_sankey = build_sankey_df(df_focus, COLUMNS_ORDER, threshold=THRESHOLD)
print(f"df_sankey shape after grouping: {df_sankey.shape}")

def truncate_list(lst, n=10):
    return ', '.join(lst[:n]) + (', ...' if len(lst) > n else '')

note =(
    f"<b>Filters:</b><br>"
    f"• <b>Chains:</b> {truncate_list(selected_chains)}<br>"
    f"• <b>Protocol Categories:</b> {truncate_list(selected_protocol_categories)}<br>"
    f"• <b>Asset Type:</b> {truncate_list(selected_token_categories)}<br><br>"
    f"Note: Smaller protocols than {THRESHOLD*100}% of total TVL are grouped under 'Others'."
)

# %%
# 4. Plot
fig = plot_sankey(
    df_sankey=df_sankey,
    column_labels_map=COLUMN_LABELS_MAP,
    columns_order=COLUMNS_ORDER,
    snapshot_date=SNAPSHOT_DATE,
    threshold=THRESHOLD,
    note = note
)
fig.show()

Rows matching 2024-12-17: 96425
df_sankey shape after grouping: (39, 4)
