In [1]:
!pip install plotly kaleido nbformat>=4.2.0

[0m

In [19]:
def plot_condense_time_series(
    date_data: dict,
    dashed_lines: list[dict],
    title: str,
    shortname_description: dict,
):
    import plotly.graph_objects as go
    from datetime import datetime, timedelta

    # Define size mapping function
    def get_marker_size(tokens):
        if tokens < 1024:
            return 8  # small
        elif tokens <= 4096:
            return 15  # medium
        else:
            return 25  # large

    fig = go.Figure()

    # Get all dates
    dates = list(date_data.keys())
    date_objects = [datetime.strptime(d, "%Y-%m-%d") for d in dates]
    
    # Calculate x-axis range with padding
    if len(dates) == 1:
        x_min = (date_objects[0] - timedelta(days=5)).strftime("%Y-%m-%d")
        x_max = (date_objects[0] + timedelta(days=5)).strftime("%Y-%m-%d")
    else:
        x_min = min(date_objects).strftime("%Y-%m-%d")
        x_max = max(date_objects).strftime("%Y-%m-%d")

    # Add dashed lines dynamically
    for line in dashed_lines:
        value = line.get("value")
        label = line.get("label", "") + f" - {value:.1%}"
        color = line.get("color", "#EF553B")

        fig.add_trace(
            go.Scatter(
                x=[x_min, x_max],
                y=[value, value],
                mode="lines+text",
                name=label,
                line=dict(color=color, width=2, dash="dash"),
                text=[label],  # Add label at the end of line
                textposition="top right",
                textfont=dict(color=color),
            )
        )

    # Add scatter plot for each sampling method
    for sampling_method in date_data[dates[0]].keys():
        dates_list = []
        accuracies = []
        token_means = []
        token_stds = []
        descriptions = []
        
        for date in dates:
            data = date_data[date][sampling_method]
            dates_list.append(date)
            accuracies.append(data["accuracy"])
            token_means.append(data["it"]["mean"])
            token_stds.append(data["it"]["std"])
            descriptions.append(f"{sampling_method} - IT-{data['it']['mean']:,}+/-{data['it']['std']} - {data['accuracy']:.1%}")

        marker_sizes = [get_marker_size(tokens) for tokens in token_means]
        
        # Replace shortnames with full descriptions
        display_name = sampling_method
        for short, full in shortname_description.items():
            display_name = display_name.replace(short, full)
            
        fig.add_trace(
            go.Scatter(
                x=dates_list,
                y=accuracies,
                mode="lines+markers+text",
                name=display_name,
                line=dict(color="#636EFA", width=2),
                marker=dict(
                    size=marker_sizes,
                    color="#636EFA",
                    symbol="circle",
                    line=dict(color="white", width=2),
                ),
                text=descriptions,
                textposition="top center",
                textfont=dict(color="#636EFA"),
            )
        )

    # Add shortname descriptions as annotations
    y_pos = 0.02  # Starting y position for annotations
    for short, full in shortname_description.items():
        fig.add_annotation(
            xref="paper",
            yref="paper",
            x=1.02,  # Position to the right of the plot
            y=y_pos,
            text=f"{short}: {full}",
            showarrow=False,
            font=dict(size=12),
            align="left"
        )
        y_pos += 0.05  # Increment y position for next annotation

    # Update layout
    fig.update_layout(
        title={
            "text": title,
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
            "font": dict(size=20),
        },
        template="plotly_white",
        hovermode=False,  # Disable hover interactions
        height=600,
        width=1200,
        showlegend=True,
        legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99, tracegroupgap=5),
        yaxis=dict(
            title="Accuracy",
            tickformat=".0%",
            range=[0, 1],  # Fixed range from 0 to 100%
        ),
        xaxis=dict(title="Date"),
        margin=dict(r=150)  # Add right margin for annotations
    )

    # Save the plot
    fig.write_html(f"results/{title}.html")
    fig.write_image(f"results/{title}.png")
    return fig


In [20]:
original_token_counts = 3361
date_data = {
    "2024-12-04": {
        "top-1": {
            "accuracy": 0.73,
            "it": {"mean": 963, "std": 156}, # show on point: top-1 - IT-963+/-156
        },
        "top-5": {
            "accuracy": 0.59,
            "it": {"mean": 958, "std": 148}, # show on point: top-5 - IT-958+/-148
        },
    },
}

dashed_lines = [
    {"value": 0.93, "label": f"Original - IT-{original_token_counts:,}", "color": "#EF553B"},
    {"value": 0.29, "label": "Knorm (kvpress) - CR-0.72", "color": "#00CC96"},
    {"value": 0.70, "label": "ExpectedAttentionPress (kvpress) - CR-0.72", "color": "#FF97D9"},
]

shortname_description = {
    "IT": "Input Token",
    "CR": "Compression Rate",
    "top-x": "Top-x Elo Sampling",
}

title = "Condense Subnet Benchmark: RULER-4K-qa"


plot_condense_time_series(
    date_data,
    dashed_lines,
    title,
    shortname_description, # Show beside the plot
)
