In [None]:
import numpy as np
import pandas as pd
import plotly.express as px

In [None]:
def get_parallelism_style(dp, hp, pp):
    if dp == 1 and hp == 1 and pp == 1:
        return "Base"
    elif dp > 1 and hp == 1 and pp == 1:
        return "D"
    elif dp == 1 and hp > 1 and pp == 1:
        return "T"
    elif dp == 1 and hp == 1 and pp > 1:
        return "P"
    elif dp > 1 and hp > 1 and pp == 1:
        return "D/T"
    elif dp == 1 and hp > 1 and pp > 1:
        return "T/P"
    elif dp > 1 and hp == 1 and pp > 1:
        return "D/P"
    elif dp > 1 and hp > 1 and pp > 1:
        return "D/T/P"
    else:
        raise ValueError(f"Invalid degree combination dp={dp}, hp={hp}, pp={pp}")

In [None]:
markers = ["o", "P", "^", "*", "X", "D", "H", "s"]
colors = [
    "#1f77b4",  # muted blue
    "#ff7f0e",  # safety orange
    "#2ca02c",  # cooked asparagus green
    "#d62728",  # brick red
    "#9467bd",  # muted purple
    "#8c564b",  # chestnut brown
    "#e377c2",  # raspberry yogurt pink
    "#7f7f7f",  # middle gray
    # "#bcbd22",  # curry yellow-green
    # "#17becf",  # blue-teal
]
parallelism_styles = [
    "D",
    "T",
    "P",
    "D/T",
    "T/P",
    "D/P",
    "D/T/P",
]

In [None]:
def plot_memory_usage_vs_throughput(
    simulation_filename,
    model_size,
    title="",
    batch_size_slider=False,
    min_batch_size=0,
):
    df = pd.read_csv(simulation_filename)
    df = df[df["model_size"] == model_size]
    df["peak_memory"] /= 1e9
#     df = df[df["peak_memory"] <= 32]
    df = df[df["batch_size"] >= (min_batch_size / 2)]
    df["dummy_column_for_size"] = 1.0
    df["parallelism_style"] = [
        get_parallelism_style(dp, hp, pp)
        for (dp, hp, pp) in df[["dp_degree", "hp_degree", "pp_degree"]].values
    ]
    df["config_name"] = [
        f"{d}/{h}/{p}/{k}" for (d, h, p, k) in df[["dp_degree", "hp_degree", "pp_degree", "num_microbatches"]].values
    ]
    df["log_peak_memory"] = np.log(df["peak_memory"])
    fig = px.scatter(
        df,
        x="throughput",
        y="log_peak_memory",
        symbol="parallelism_style",
        color="parallelism_style",
        hover_name='config_name',
        hover_data=["batch_size", "dp_degree", "hp_degree", "pp_degree", "num_microbatches"],
        labels={
            "throughput": "Simulated Throughput (samples / second)",
            "log_peak_memory": "Simulated Peak Memory (log GB)",
            "parallelism_style": "Parallelism style",
        },
        color_discrete_sequence=[
            "#1f77b4",  # muted blue
            "#ff7f0e",  # safety orange
            "#2ca02c",  # cooked asparagus green
            "#d62728",  # brick red
            "#9467bd",  # muted purple
            "#8c564b",  # chestnut brown
            "#e377c2",  # raspberry yogurt pink
            "#7f7f7f",  # middle gray
            "#bcbd22",  # curry yellow-green
            "#17becf",  # blue-teal
        ],
        category_orders={
            "parallelism_style": [
                "D",
                "T",
                "P",
                "D/T",
                "T/P",
                "D/P",
                "D/T/P",
                "Sequential",
            ],
        },
        size="dummy_column_for_size",
        size_max=10,
        animation_frame='batch_size' if batch_size_slider else None,
        animation_group='config_name' if batch_size_slider else None,
        range_x=[0, max(df["throughput"])],
        range_y=[min(df["log_peak_memory"]), max(df["log_peak_memory"])],
        title=title,
    )
    # Draw memory cutoff line
    fig.add_hline(y=np.log(32), line_width=2)
    fig.show()

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv', "mlp-small", "MLP 1B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv', "mlp-medium", "MLP 17B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv', "mlp-large", "MLP 103B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/gpt2_dgx_simulated_grid_search_results_v2.csv', "gpt3-xl", "GPT-2 1.6B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/gpt2_dgx_simulated_grid_search_results_v2.csv', "gpt3-13B", "GPT-2 13B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/gpt2_dgx_simulated_grid_search_results_v2.csv', "gpt3-175B", "GPT-2 175B")

In [None]:
plot_memory_usage_vs_throughput('~/Downloads/gpt2_dgx_simulated_grid_search_results_v2.csv', "gpt3-175B", "GPT-2 175B", batch_size_slider=True)