In [1]:
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import ScalarFormatter
import pylab
import plotly.express as px

In [2]:
plt.rcParams["font.size"] = 18
plt.rcParams["font.family"] = "Times New Roman"

In [3]:
FIGURES_DIR = "figures/mlsys22"

In [4]:
def get_parallelism_style(dp, hp, pp):
    if dp == 1 and hp == 1 and pp == 1:
        return "Base"
    elif dp > 1 and hp == 1 and pp == 1:
        return "D"
    elif dp == 1 and hp > 1 and pp == 1:
        return "T"
    elif dp == 1 and hp == 1 and pp > 1:
        return "P"
    elif dp > 1 and hp > 1 and pp == 1:
        return "D/T"
    elif dp == 1 and hp > 1 and pp > 1:
        return "T/P"
    elif dp > 1 and hp == 1 and pp > 1:
        return "D/P"
    elif dp > 1 and hp > 1 and pp > 1:
        return "D/T/P"
    else:
        raise ValueError(f"Invalid degree combination dp={dp}, hp={hp}, pp={pp}")

## Memory usage vs throughput / latency

In [14]:
def plot_memory_usage_vs_metric(
    simulation_filename,
    model_size,
    min_batch_size,
    metric,
    xlabel,
    ylabel,
    output_filename,
    legend_output_filename=None,
):
#     plt.rc('xtick', labelsize=24)
#     plt.rc('ytick', labelsize=24)
#     plt.rc('axes', labelsize=24)
    df = pd.read_csv(simulation_filename)
    df = df[df["model_size"] == model_size]
    df["peak_memory"] /= 1e9
    df = df[df["peak_memory"] <= 32]
    df = df[df["batch_size"] >= (min_batch_size / 2)]
    df["dummy_column_for_size"] = 1.0
    parallelism_styles = [
        get_parallelism_style(dp, hp, pp)
        for (dp, hp, pp) in df[["dp_degree", "hp_degree", "pp_degree"]].values
    ]
    df["parallelism_style"] = parallelism_styles
    markers = ["o", "P", "^", "*", "X", "D", "H", "s"]
    colors = [
        "#1f77b4",  # muted blue
        "#ff7f0e",  # safety orange
        "#2ca02c",  # cooked asparagus green
        "#d62728",  # brick red
        "#9467bd",  # muted purple
        "#8c564b",  # chestnut brown
        "#e377c2",  # raspberry yogurt pink
        "#7f7f7f",  # middle gray
        # "#bcbd22",  # curry yellow-green
        # "#17becf",  # blue-teal
    ]
    parallelism_styles = [
        "D",
        "T",
        "P",
        "D/T",
        "T/P",
        "D/P",
        "D/T/P",
    ]
    lines = []
#     fig = pylab.figure(figsize=(8, 5))
#     ax = fig.add_subplot(111)
#     ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
#     figlegend = pylab.figure(figsize=(5, 1))
    fig = px.scatter(
        df,
        x="throughput",
        y="peak_memory",
        symbol="parallelism_style",
        color="parallelism_style",
        hover_data=["batch_size", "dp_degree", "hp_degree", "pp_degree", "num_microbatches"],
        labels={
            "throughput": "Throughput (samples / second)",
            "peak_memory": "Peak Memory (GB)",
            "parallelism_style": "Parallelism style",
        },
        color_discrete_sequence=[
            "#1f77b4",  # muted blue
            "#ff7f0e",  # safety orange
            "#2ca02c",  # cooked asparagus green
            "#d62728",  # brick red
            "#9467bd",  # muted purple
            "#8c564b",  # chestnut brown
            "#e377c2",  # raspberry yogurt pink
            "#7f7f7f",  # middle gray
            "#bcbd22",  # curry yellow-green
            "#17becf",  # blue-teal
        ],
        category_orders={
            "parallelism_style": [
                "D",
                "T",
                "P",
                "D/T",
                "T/P",
                "D/P",
                "D/T/P",
                "Sequential",
            ],
        },
        size="dummy_column_for_size",
        size_max=10,
    )
    fig.show()
#     ax.grid(zorder=0)
#     ax.set_ylim(0, 32)
#     ax.set_xlabel(xlabel)
#     ax.set_ylabel(ylabel)
#     fig.tight_layout()
#     fig.savefig(output_filename, bbox_inches="tight")
#     leg = figlegend.legend(
#         lines,
#         parallelism_styles,
#         frameon=False,
#         loc="center",
#         ncol=8,
#         columnspacing=None,
#         labelspacing=None,
#     )
#     if legend_output_filename is not None:
#         figlegend.savefig(legend_output_filename, bbox_inches="tight")

### MLP Training

In [23]:
simulation_filename = "~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "mlp-small",
    131072,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "mlp-small_memory_vs_throughput.pdf"),
    os.path.join(FIGURES_DIR, "mlp_training_memory_vs_throughput_legend.pdf"),
)

In [24]:
simulation_filename = "~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "mlp-medium",
    0,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "mlp-medium_memory_vs_throughput.pdf"),
    None,
)

In [21]:
simulation_filename = "~/Downloads/mlp_dgx_simulated_grid_search_results_v2.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "mlp-large",
    0,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "mlp-large_memory_vs_throughput.pdf"),
    None,
)

### GPT-2 Inference

In [18]:
simulation_filename = "~/Downloads/gpt2_dgx_simulated_grid_search_results.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "gpt3-xl",
    131072,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "gpt3-xl_memory_vs_throughput.pdf"),
    None,
)

In [19]:
simulation_filename = "~/Downloads/gpt2_dgx_simulated_grid_search_results.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "gpt3-13B",
    16384,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "gpt3-13B_memory_vs_throughput.pdf"),
    None,
)

In [22]:
simulation_filename = "~/Downloads/gpt2_dgx_simulated_grid_search_results_v2.csv"
plot_memory_usage_vs_metric(
    simulation_filename,
    "gpt3-175B",
    8192,
    "throughput",
    "Throughput (samples / second)",
    "Peak Memory (GB)",
    os.path.join(FIGURES_DIR, "gpt3-175B_memory_vs_throughput.pdf"),
    None,
)