In [None]:
import argparse
from collections import defaultdict, OrderedDict
import logging
import numpy as np
import time
import matplotlib as mpl
import matplotlib.pyplot as plt

import dist_ir
from dist_ir.importer import import_from_onnx, parse_tensor_from_file
from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value
from dist_ir.executor import infer_types, SequentialExecutor, Simulator
from dist_ir.executor.cost_model import CostModel
from dist_ir.ir.type import Bool, Float, Int64, Tensor
from dist_ir.transforms import (
    parallel_transform_3d,
    steady_state_transform,
    PipeDreamScheduler,
)

In [None]:
DGX_BANDWIDTH_GBPS = 200

## Utils

In [None]:
def mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device):
    function = FunctionMaker(name="mlp")
    x = function.add_input_value(
        "x",
        Tensor(
            dtype=Float(), shape=(batch_size, input_dim), device=device
        ),
    )
    z = function.add_input_value(
        "z",
        Tensor(
            dtype=Float(), shape=(batch_size, output_dim), device=device
        ),
    )
    weights = []
    input_dim = input_dim
    hidden_dim = hidden_dim
    for i in range(num_hidden_layers - 1):
        w = function.add_input_value(
            f"w{chr(ord('A')+i)}",
            Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=device),
        )
        input_dim = hidden_dim
        weights.append(w)
    w = function.add_input_value(
        f"w{chr(ord('A')+i+1)}",
        Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=device),
    )
    weights.append(w)

    a = x
    for i, weight in enumerate(weights):
        y = function.add_op("MatMul", inputs=[a, weight], output_names=[f"y{i}"])
        a = function.add_op("Relu", inputs=[y], output_names=[f"a{i}"])

    l = function.add_op(
        "Loss", inputs=[a, z], attributes={"N": batch_size}, output_names=["l"]
    )
    dl = function.add_op(
        "LossGrad",
        inputs=[a, z],
        attributes={"N": batch_size},
        output_names=["dl"],
    )

    dy = dl
    for i, weight in enumerate(weights[::-1]):
        i = len(weights) - i - 1
        da = function.add_op(
            "ReluGrad",
            inputs=[function.ops[2 * i + 1].inputs[0], dy],
            output_names=[f"da{i}"],
        )
        dy, dw = function.add_op(
            "MatMulGrad",
            inputs=[function.ops[2 * i].inputs[0], weights[i], da],
            output_names=[f"dy{i}", f"dw{chr(ord('A')+i)}"],
        )
    return function.finalize()

In [None]:
def add_devices_to_topology(topology, num_devices):
    for i in range(num_devices):
        topology.add_device("gpu")
    devices = topology.devices
    for i in range(0, len(devices)):
        for j in range(i+1, len(devices)):
            topology.set_bandwidth(devices[i], devices[j], DGX_BANDWIDTH_GBPS)
    return topology

In [None]:
def get_all_degrees(n):
    all_degrees = []
    d = 1
    h = 1
    p = 1
    while d <= n:
        h = 1
        p = 1
        if d * h * p == n:
            all_degrees.append((d, h, p))
            break
        while h <= n:
            p = 1
            if d * h * p == n:
                all_degrees.append((d, h, p))
                break
            while p <= n:
                if d * h * p == n:
                    all_degrees.append((d, h, p))
                    break
                p *= 2
            h *= 2
        d *= 2
    return all_degrees

In [None]:
def measure_communication_overhead(simulation):
    total_time = 0.0
    communication_time = 0.0
    for event in simulation.trace:
        total_time += event["dur"]
        if event["name"] == "Send" or "MPI" in event["name"]:
            communication_time += event["dur"]
    return communication_time / total_time

## Grid Search

In [None]:
%%time
input_dim = 8192
hidden_dim = input_dim
output_dim = input_dim
all_cluster_sizes = [16]  # [64, 128, 512, 1024]
all_num_hidden_layers = [64]  # [4, 8, 16, 32]
all_batch_sizes = [8192]  # [512, 1024, 2048, 4096, 8192]
results = defaultdict(lambda: defaultdict(lambda: {}))
for num_hidden_layers in all_num_hidden_layers:
    for batch_size in all_batch_sizes:
        topology = Topology()
        d0 = topology.add_device("gpu")
        function = mlp(
            batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0
        )
        function = infer_types(function, function.inputs)
        simulator = Simulator(CostModel(topology))
        simulation = simulator.interpret(function, (v.type for v in function.inputs))
        sequential_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        print(f"Sequential running time: {sequential_running_time * 1e3} ms")
        for i, cluster_size in enumerate(all_cluster_sizes):
            if i == 0:
                add_devices_to_topology(topology, cluster_size)
            else:
                add_devices_to_topology(
                    topology, all_cluster_sizes[i] - all_cluster_sizes[i - 1]
                )
            all_degrees = get_all_degrees(cluster_size)
            for (dp_degree, hp_degree, pp_degree) in all_degrees:
                if num_hidden_layers % pp_degree != 0:
                    continue
                dp_batch_size = batch_size // dp_degree
                if pp_degree == 1:
                    all_num_microbatches = [1]
                else:
                    all_num_microbatches = [
                        int(2 ** k)
                        for k in range(1, int(np.floor(np.log2(dp_batch_size) / 2)))
                    ]
                for num_microbatches in all_num_microbatches:
                    if pp_degree == 1:
                        num_microbatches == 1
                    transformed_function = parallel_transform_3d(
                        function,
                        dp_degree,
                        hp_degree,
                        pp_degree,
                        topology.devices,
                        num_microbatches,
                    )
                    transformed_function = infer_types(
                        transformed_function, transformed_function.inputs
                    )
                    transformed_function, typed_input_values = steady_state_transform(
                        transformed_function
                    )
                    transformed_function = infer_types(
                        transformed_function, typed_input_values
                    )
                    simulation = simulator.interpret(
                        transformed_function,
                        (v.type for v in transformed_function.inputs),
                    )
                    distributed_running_time = max(
                        [simulation.timestamps[d] for d in simulation.timestamps]
                    )
                    communication_overhead = measure_communication_overhead(simulation)
                    throughput = batch_size / distributed_running_time
                    results[dp_degree][num_microbatches][hp_degree] = throughput
                    print(dp_degree, hp_degree, pp_degree, num_microbatches, throughput)

In [None]:
plt.rcParams['font.size'] = 14
fig, axes = plt.subplots(1, 5, figsize=(16, 3), sharex=True, sharey=True)
plt.xscale('log')
plt.setp(axes, xticks=[1, 2, 4, 8, 16], xticklabels=[1, 2, 4, 8, 16])
fig.text(0.5, -.025, '# Horizontal parallel partitions (log scale)', ha='center', va='center', size=16)
fig.text(-.025, 0.5, 'Throughput\n(samples/second)', va='center', ha='center', rotation='vertical', size=16)
fig.tight_layout()
colors = {1: 'C0', 2: 'C1', 4: 'C2', 8: 'C3', 16: 'C4', 32: 'C5'}
markers = {1: 'o', 2: 'D', 4: 'v', 8: 's', 16: '<', 32: 'x'}
lines = []
labels = []
for i, dp_degree in enumerate(sorted(results.keys())):
    line_labels_ = []
    for j, k in enumerate(sorted(results[dp_degree].keys())):
        x = sorted(results[dp_degree][k].keys())
        y = [results[dp_degree][k][hp_degree] for hp_degree in x]
        if k == 1:
            label = "No pipeline parallelism"
            l = axes[i].plot(x, y, marker=markers[k], color=colors[k], label=label)[0]
        else:
            label = f"{k} microbatches"
            l = axes[i].plot(x, y, marker=markers[k], color=colors[k], label=f"{k} microbatches")[0]
        if i == 0:
            axes[i].set_title(f"No data parallelism", size=16)
        else:
            axes[i].set_title(f"{dp_degree} data parallel partitions", size=16)
        if i == 0:
            lines.append(l)
            labels.append(label)
    leg = plt.figlegend(lines, labels, loc='lower center', ncol=len(lines))
    leg.get_frame().set_linewidth(0.0)
    # Get the bounding box of the original legend.
    bb = leg.get_bbox_to_anchor().transformed(fig.gca().transAxes.inverted())
    yOffset = 1.4
    bb.y0 += yOffset
    bb.y1 += yOffset
    leg.set_bbox_to_anchor(bb, transform = fig.gca().transAxes)
plt.savefig("grid_search.pdf", dpi=600, bbox_inches="tight")

## Simulation Scaling

In [None]:
batch_size = 64
input_dim = 64
hidden_dim = 64
output_dim = 64
num_trials = 5
topology = Topology()
d0 = topology.add_device("gpu")
all_num_hidden_layers = [32, 64, 128, 256, 512, 1024, 2048, 4096]
all_simulation_times = defaultdict(lambda: [])
for num_hidden_layers in all_num_hidden_layers:
    function = mlp(
        batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0
    )
    function = infer_types(function, function.inputs)
    num_ops = len(function.ops)
    for i in range(num_trials):
        simulator = Simulator(CostModel(topology))
        start = time.time()
        simulation = simulator.interpret(function, (v.type for v in function.inputs))
        assert len(simulation.trace) == num_ops
        duration = time.time() - start
        all_simulation_times[num_ops].append(duration)

In [None]:
x = sorted(all_simulation_times.keys())
y = [1e3 * np.median(all_simulation_times[k]) for k in x]
print(list(zip(x, y)))
# plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = "Times"
plt.rcParams["font.size"] = 12
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["figure.figsize"] = (8,3)
plt.plot([x[0]]+x[2:], [y[0]]+y[2:], marker='o')
plt.xlabel("# Ops")
plt.ylabel("Milliseconds")
plt.tight_layout()
plt.savefig("simulator_scaling.pdf", dpi=600, bbox_inches="tight")

## Isolated Parallelism Scaling

In [None]:
all_num_devices = [2, 4, 8, 16, 32, 64, 128]
all_batch_sizes = [512, 1024, 2048, 4096]
input_dim = 8192
hidden_dim = input_dim
output_dim = hidden_dim
num_hidden_layers = 64
dp_results = defaultdict(list)
for batch_size in all_batch_sizes:
    topology = Topology()
    d0 = topology.add_device("gpu")
    for i, num_devices in enumerate(all_num_devices):     
        function = mlp(
            batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0
        )
        function = infer_types(function, function.inputs)
        if i == 0:
            add_devices_to_topology(topology, num_devices)
        else:
            add_devices_to_topology(
                topology, all_num_devices[i] - all_num_devices[i - 1]
            )
        assert len(topology.devices) == all_num_devices[i] + 1
        simulator = Simulator(CostModel(topology))
        transformed_function = parallel_transform_3d(
            function,
            num_devices,
            1,
            1,
            topology.devices,
            1,
        )
        transformed_function = infer_types(
            transformed_function, transformed_function.inputs
        )
        transformed_function, typed_input_values = steady_state_transform(
            transformed_function
        )
        transformed_function = infer_types(transformed_function, typed_input_values)
        simulation = simulator.interpret(
            transformed_function,
            (v.type for v in transformed_function.inputs),
        )
        distributed_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        dp_results[batch_size].append(batch_size / distributed_running_time)

In [None]:
colors = []
markers = ["o", "D", "v", "s", "<", "x"]
styles = ["-", "--", "-.", ":"]
c = np.arange(1, len(pp_results) + 3)
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
cmap.set_array([])
plt.figure(figsize=(5, 3))
lines = []
labels = []
for i, batch_size in enumerate(dp_results):
    labels.append(f"Batch size {batch_size}")
    l = plt.plot(
        all_num_devices[:-1],
        dp_results[batch_size][:-1],
        marker=markers[i],
        linestyle=styles[i],
        label=labels[-1],
        c=cmap.to_rgba(i + 3),
    )[0]
    lines.append(l)
    plt.xticks([2, 4, 8, 16, 32, 64])
    plt.xlabel("# Data parallel partitions")
    plt.ylabel("Throughput\n(samples/second)")
leg = plt.figlegend(lines, labels, loc="upper center", ncol=2)
# Get the bounding box of the original legend.
bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted())

# Change to location of the legend.
yOffset = 0.2
bb.y0 += yOffset
bb.y1 += yOffset
leg.set_bbox_to_anchor(bb, transform=plt.gca().transAxes)
plt.tight_layout()
plt.savefig("data_parallelism.pdf", dpi=600, bbox_inches="tight")

In [None]:
all_num_devices = [2, 4, 8, 16, 32, 64]
all_input_dims = [1024, 2048, 4096, 8192]
batch_size = 8192
num_hidden_layers = 128
hp_results = defaultdict(list)
for input_dim in all_input_dims:
    hidden_dim = input_dim
    output_dim = hidden_dim
    topology = Topology()
    d0 = topology.add_device("gpu")
    for i, num_devices in enumerate(all_num_devices):     
        function = mlp(
            batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0
        )
        function = infer_types(function, function.inputs)
        if i == 0:
            add_devices_to_topology(topology, num_devices)
        else:
            add_devices_to_topology(
                topology, all_num_devices[i] - all_num_devices[i - 1]
            )
        assert len(topology.devices) == all_num_devices[i] + 1
        simulator = Simulator(CostModel(topology))
        transformed_function = parallel_transform_3d(
            function,
            1,
            num_devices,
            1,
            topology.devices,
            1,
        )
        transformed_function = infer_types(
            transformed_function, transformed_function.inputs
        )
        transformed_function, typed_input_values = steady_state_transform(
            transformed_function
        )
        transformed_function = infer_types(transformed_function, typed_input_values)
        simulation = simulator.interpret(
            transformed_function,
            (v.type for v in transformed_function.inputs),
        )
        distributed_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        hp_results[input_dim].append(batch_size / distributed_running_time)

In [None]:
colors = []
markers = ["o", "D", "v", "s", "<", "x"]
styles = ["-", "--", "-.", ":"]
c = np.arange(1, len(pp_results) + 3)
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
cmap.set_array([])
plt.figure(figsize=(5, 3))
lines = []
labels = []
for i, input_dim in enumerate(hp_results):
    labels.append(f"Weight dim {input_dim}")
    l = plt.plot(
        all_num_devices,
        hp_results[input_dim],
        marker=markers[i],
        linestyle=styles[i],
        label=labels[-1],
        c=cmap.to_rgba(i + 3),
    )[0]
    lines.append(l)
    plt.xticks([2, 4, 8, 16, 32, 64])
    plt.xlabel("# Horizontal parallel partitions")
    plt.ylabel("Throughput\n(samples/second)")
leg = plt.figlegend(lines, labels, loc='upper center', ncol=2)
# Get the bounding box of the original legend.
bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted())

# Change to location of the legend. 
yOffset = 0.2
bb.y0 += yOffset
bb.y1 += yOffset
leg.set_bbox_to_anchor(bb, transform = plt.gca().transAxes)
plt.tight_layout()
plt.savefig("horizontal_parallelism.pdf", dpi=600, bbox_inches="tight")

In [None]:
all_num_devices = [2, 4, 8, 16, 32, 64]
all_batch_sizes = [512, 1024, 2048, 4096]
num_microbatches = 8
input_dim = 8192
hidden_dim = input_dim
output_dim = hidden_dim
batch_size = 8192
num_hidden_layers = 64
pp_results = defaultdict(list)
for batch_size in all_batch_sizes:
    topology = Topology()
    d0 = topology.add_device("gpu")
    for i, num_devices in enumerate(all_num_devices):     
        function = mlp(
            batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0
        )
        function = infer_types(function, function.inputs)
        if i == 0:
            add_devices_to_topology(topology, num_devices)
        else:
            add_devices_to_topology(
                topology, all_num_devices[i] - all_num_devices[i - 1]
            )
        assert len(topology.devices) == all_num_devices[i] + 1
        simulator = Simulator(CostModel(topology))
        transformed_function = parallel_transform_3d(
            function,
            1,
            1,
            num_devices,
            topology.devices,
            num_microbatches,
        )
        transformed_function = infer_types(
            transformed_function, transformed_function.inputs
        )
        transformed_function, typed_input_values = steady_state_transform(
            transformed_function
        )
        transformed_function = infer_types(transformed_function, typed_input_values)
        simulation = simulator.interpret(
            transformed_function,
            (v.type for v in transformed_function.inputs),
        )
        distributed_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        pp_results[batch_size].append(batch_size / distributed_running_time)

In [None]:
colors = []
markers = ["o", "D", "v", "s", "<", "x"]
styles = ["-", "--", "-.", ":"]
c = np.arange(1, len(pp_results) + 3)
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
cmap.set_array([])
plt.figure(figsize=(5, 3))
lines = []
labels = []
for i, batch_size in enumerate(pp_results):
    labels.append(f"Batch size {batch_size}")
    l = plt.plot(
        all_num_devices,
        pp_results[batch_size],
        marker=markers[i],
        label=labels[-1],
        linestyle=styles[i],
        c=cmap.to_rgba(i + 3)
    )[0]
    lines.append(l)
    plt.xticks([2, 4, 8, 16, 32, 64])
    plt.xlabel("# Pipeline parallel partitions")
    plt.ylabel("Throughput\n(samples/second)")
leg = plt.figlegend(lines, labels, loc='upper center', ncol=2)
# Get the bounding box of the original legend.
bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted())

# Change to location of the legend. 
yOffset = 0.2
bb.y0 += yOffset
bb.y1 += yOffset
leg.set_bbox_to_anchor(bb, transform = plt.gca().transAxes)
plt.tight_layout()
plt.tight_layout()
plt.savefig("pipeline_parallelism.pdf", dpi=600, bbox_inches="tight")