In [1]:
import argparse
from collections import defaultdict, OrderedDict
import logging
import numpy as np
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
import csv
import pandas as pd

import dist_ir
from dist_ir.importer import import_from_onnx, parse_tensor_from_file
from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value
from dist_ir.executor import infer_types, SequentialExecutor, Simulator
from dist_ir.executor.cost_model import CostModel
from dist_ir.ir.type import Bool, Float, Int64, Tensor
from dist_ir.transforms import (
    mlp_dhp_transform,
    filter_transform,
    PipeDreamScheduler,
)
from examples.mlp import mlp

## Isolated parallelism simulation results

### Constants

In [2]:
DGX_BANDWIDTH_GBPS = 200

### Utils

In [3]:
def add_devices_to_topology(topology, num_devices):
    for i in range(num_devices):
        topology.add_device("gpu")
    devices = topology.devices
    for i in range(0, len(devices)):
        for j in range(i+1, len(devices)):
            topology.set_bandwidth(devices[i], devices[j], DGX_BANDWIDTH_GBPS)
    return topology

### Data parallelism

In [4]:
all_num_devices = [2, 4, 8, 16, 32]
all_batch_sizes = [512, 1024, 2048, 4096]
input_dim = 4096
hidden_dim = input_dim
output_dim = hidden_dim
num_hidden_layers = 64
dp_results = defaultdict(list)
for batch_size in all_batch_sizes:
    topology = Topology()
    d0 = topology.add_device("gpu")
    function = mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0)
    function = infer_types(function, function.inputs)
    simulator = Simulator(CostModel(topology))
    simulation = simulator.interpret(
        function,
        (v.type for v in function.inputs),
    )
    sequential_running_time = max(
        [simulation.timestamps[d] for d in simulation.timestamps]
    )
    for i, num_devices in enumerate(all_num_devices):
        if i == 0:
            add_devices_to_topology(topology, num_devices)
        else:
            add_devices_to_topology(
                topology, all_num_devices[i] - all_num_devices[i - 1]
            )
        assert len(topology.devices) == all_num_devices[i] + 1
        simulator = Simulator(CostModel(topology))
        transformed_function = mlp_dhp_transform(
            function,
            num_devices,
            1,
            1,
            topology.devices,
            1,
        )
        transformed_function = infer_types(
            transformed_function, transformed_function.inputs
        )
        transformed_function, typed_input_values = filter_transform(
            transformed_function, filter_set=set(["MPIBroadcast", "Send"])
        )
        transformed_function = infer_types(transformed_function, typed_input_values)
        simulation = simulator.interpret(
            transformed_function,
            (v.type for v in transformed_function.inputs),
        )
        distributed_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        speedup = sequential_running_time / distributed_running_time
        dp_results[batch_size].append(speedup)

### Pipeline parallelism

In [5]:
all_num_devices = [2, 4, 8, 16, 32]
all_num_microbatches = [4, 8, 16, 32]
input_dim = 4096
hidden_dim = input_dim
output_dim = hidden_dim
num_hidden_layers = 64
batch_size = 4096
pp_results = defaultdict(list)
topology = Topology()
d0 = topology.add_device("gpu")
function = mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0)
function = infer_types(function, function.inputs)
simulator = Simulator(CostModel(topology))
simulation = simulator.interpret(
    function,
    (v.type for v in function.inputs),
)
sequential_running_time = max([simulation.timestamps[d] for d in simulation.timestamps])
for i, num_microbatches in enumerate(all_num_microbatches):
    topology = Topology()
    d0 = topology.add_device("gpu")
    for j, num_devices in enumerate(all_num_devices):
        if j == 0:
            add_devices_to_topology(topology, num_devices)
        else:
            add_devices_to_topology(
                topology, all_num_devices[j] - all_num_devices[j - 1]
            )
        assert len(topology.devices) == all_num_devices[j] + 1
        simulator = Simulator(CostModel(topology))
        transformed_function = mlp_dhp_transform(
            function,
            1,
            1,
            num_devices,
            topology.devices,
            num_microbatches,
        )
        transformed_function = infer_types(
            transformed_function, transformed_function.inputs
        )
        transformed_function, typed_input_values = filter_transform(
            transformed_function,
            filter_set=set(["Send"]),
            exception_set=set(transformed_function.inputs[:2]),
        )
        transformed_function = infer_types(transformed_function, typed_input_values)
        simulation = simulator.interpret(
            transformed_function,
            (v.type for v in transformed_function.inputs),
        )
        distributed_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        speedup = sequential_running_time / distributed_running_time
        pp_results[num_microbatches].append(speedup)

### Figure 3

In [None]:
plt.rcParams["font.size"] = 12
fig, axes = plt.subplots(2, 1, figsize=(5, 4), sharex=True, sharey=True)
markers = ["o", "D", "v", "s", "<", "x"]
styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1, 1, 1))]
c = np.arange(1, len(pp_results) + 3)
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
dp_cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
dp_cmap.set_array([])
pp_cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens)
pp_cmap.set_array([])
dp_lines = []
dp_labels = []
for i, batch_size in enumerate(dp_results):
    dp_labels.append(f"Batch size {batch_size}")
    l = axes[0].plot(
        all_num_devices,
        dp_results[batch_size],
        marker=markers[i],
        linestyle=styles[i],
        label=dp_labels[-1],
        c=dp_cmap.to_rgba(i + 3),
    )[0]
    dp_lines.append(l)
    plt.xticks([2, 4, 8, 16, 32])
    plt.yticks([5, 10, 15, 20])
pp_lines = []
pp_labels = []
for i, num_microbatches in enumerate(pp_results):
    pp_labels.append(f"{num_microbatches} microbatches")
    l = axes[1].plot(
        all_num_devices,
        pp_results[num_microbatches],
        marker=markers[i],
        label=pp_labels[-1],
        linestyle=styles[i],
        c=pp_cmap.to_rgba(i + 3)
    )[0]
    pp_lines.append(l)
axes[0].set_title("Data parallelism")
axes[1].set_title("Pipeline parallelism")
fig.text(0.5, -.025, '# Devices', ha='center', va='center')
fig.text(-.01, 0.5, 'Speedup', va='center', ha='center', rotation='vertical')
fig.tight_layout()
dp_leg = axes[0].legend(dp_lines, dp_labels, loc='upper center', ncol=1)
dp_leg.get_frame().set_linewidth(0.0)
# Get the bounding box of the original legend.
bb = dp_leg.get_bbox_to_anchor().transformed(axes[0].transAxes.inverted())

# Change to location of the legend. 
xOffset = 0.75
bb.x0 += xOffset
bb.x1 += xOffset
dp_leg.set_bbox_to_anchor(bb, transform = axes[0].transAxes)

pp_leg = axes[1].legend(pp_lines, pp_labels, loc='upper center', ncol=1)
pp_leg.get_frame().set_linewidth(0.0)
# Get the bounding box of the original legend.
bb = pp_leg.get_bbox_to_anchor().transformed(axes[1].transAxes.inverted())

# Change to location of the legend. 
bb.x0 += xOffset
bb.x1 += xOffset
pp_leg.set_bbox_to_anchor(bb, transform = axes[1].transAxes)
plt.savefig("dp_and_pp.pdf", dpi=600, bbox_inches="tight")

## Grid search simulation results

In [None]:
# Columns are: dp_degree / hp_degree / pp_degree / num_microbatches / throughput
data = pd.read_csv(
    "euromlsys21_grid_search_results.csv",
    names=["D", "H", "P", "num_microbatches", "throughput"],
    header=0,
)

### Figure 4

In [None]:
plt.rcParams["font.size"] = 12
data["speedup"] = data["throughput"] / 3725.917956
only_D = [r["speedup"] for i, r in data.iterrows() if r["H"] == 1 and r["P"] == 1]
only_H = [r["speedup"] for i, r in data.iterrows() if r["D"] == 1 and r["P"] == 1]
half_DH = [r["speedup"] for i, r in data.iterrows() if r["D"] == r["H"] and r["P"] == 1]
half_DH_devs = [
    r["D"] * r["H"] * r["P"]
    for i, r in data.iterrows()
    if r["D"] == r["H"] and r["P"] == 1
]
# For each P degree, find the num_microbatches with best throughput (note it's not always 32!)
pp_data = data[(data["D"] == 1) & (data["H"] == 1)]
best_P = pp_data.loc[pp_data.groupby("P")["throughput"].idxmax()]

devices = [2 ** i for i in range(5)]
colors = []
markers = ["o", "D", "v", "s", "<", "x"]
styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1, 1, 1))]
c = np.arange(1, len(only_D) + 3)
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Reds)
cmap.set_array([])
plt.figure(figsize=(4, 2))
lines = []
labels = [
    "Data parallelism",
    "Horizontal parallelism",
    "Pipeline parallelism",
    "D+H parallelism",
]
lines.append(
    plt.plot(
        devices,
        only_D,
        marker=markers[0],
        linestyle=styles[0],
        c=cmap.to_rgba(0 + 3),
        label=labels[0],
    )[0]
)
lines.append(
    plt.plot(
        devices,
        only_H,
        marker=markers[1],
        linestyle=styles[1],
        c=cmap.to_rgba(1 + 3),
        label=labels[1],
    )[0]
)
lines.append(
    plt.plot(
        best_P["P"],
        best_P["speedup"],
        marker=markers[2],
        linestyle=styles[2],
        c=cmap.to_rgba(2 + 3),
        label=labels[2],
    )[0]
)
lines.append(
    plt.plot(
        half_DH_devs,
        half_DH,
        marker=markers[3],
        linestyle=styles[3],
        c=cmap.to_rgba(3 + 3),
        label=labels[3],
    )[0]
)
plt.scatter(
    data["D"] * data["H"] * data["P"],
    data["speedup"],
    marker="x",
    color="lightgray",
    alpha=0.5,
)
# plt.grid(True)
plt.yticks([5, 10, 15, 20])
plt.xticks([2, 4, 8, 16])
leg = plt.figlegend(lines, labels, loc="upper center", ncol=2)
leg.get_frame().set_linewidth(0.0)
# Get the bounding box of the original legend.
bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted())

# Change to location of the legend.
yOffset = 0.45
bb.y0 += yOffset
bb.y1 += yOffset
leg.set_bbox_to_anchor(bb, transform=plt.gca().transAxes)
plt.tight_layout()
plt.tight_layout()
plt.xlabel("# Devices")
plt.ylabel("Speedup")
plt.savefig("grid_search.pdf", dpi=600, bbox_inches="tight")