In [1]:
import argparse
from collections import defaultdict, OrderedDict
import logging
import numpy as np

import dist_ir
from dist_ir.importer import import_from_onnx, parse_tensor_from_file
from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value
from dist_ir.executor import infer_types, SequentialExecutor, Simulator
from dist_ir.executor.cost_model import CostModel
from dist_ir.ir.type import Bool, Float, Int64, Tensor
from dist_ir.transforms import (
    parallel_transform_3d,
    hybrid_transform_unrolled,
    PipeDreamScheduler,
)

In [2]:
DGX_BANDWIDTH_GBPS = 200 * 8.0
device_speeds = {"gpu": 1.0e13}

In [3]:
def mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device):
    function = FunctionMaker(name="mlp")
    x = function.add_input_value(
        "x",
        Tensor(
            dtype=Float(), shape=(batch_size, input_dim), device=device
        ),
    )
    z = function.add_input_value(
        "z",
        Tensor(
            dtype=Float(), shape=(batch_size, output_dim), device=device
        ),
    )
    weights = []
    input_dim = input_dim
    hidden_dim = hidden_dim
    for i in range(num_hidden_layers - 1):
        w = function.add_input_value(
            f"w{chr(ord('A')+i)}",
            Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=device),
        )
        input_dim = hidden_dim
        weights.append(w)
    w = function.add_input_value(
        f"w{chr(ord('A')+i+1)}",
        Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=device),
    )
    weights.append(w)

    a = x
    for i, weight in enumerate(weights):
        y = function.add_op("MatMul", inputs=[a, weight], output_names=[f"y{i}"])
        a = function.add_op("Relu", inputs=[y], output_names=[f"a{i}"])

    l = function.add_op(
        "Loss", inputs=[a, z], attributes={"N": batch_size}, output_names=["l"]
    )
    dl = function.add_op(
        "LossGrad",
        inputs=[a, z],
        attributes={"N": batch_size},
        output_names=["dl"],
    )

    dy = dl
    for i, weight in enumerate(weights[::-1]):
        i = len(weights) - i - 1
        da = function.add_op(
            "ReluGrad",
            inputs=[function.ops[2 * i + 1].inputs[0], dy],
            output_names=[f"da{i}"],
        )
        dy, dw = function.add_op(
            "MatMulGrad",
            inputs=[function.ops[2 * i].inputs[0], weights[i], da],
            output_names=[f"dy{i}", f"dw{chr(ord('A')+i)}"],
        )
    return function.finalize()

In [4]:
def add_devices_to_topology(topology, num_devices):
    for i in range(num_devices):
        topology.add_device("gpu")
    devices = topology.devices
    for i in range(1, len(devices)):
        topology.set_bandwidth(devices[0], devices[i], float("inf"))
        for j in range(i, len(devices)):
            topology.set_bandwidth(devices[i], devices[j], DGX_BANDWIDTH_GBPS)
    return topology

In [5]:
def get_all_degrees(n):
    all_degrees = []
    d = 1
    h = 1
    p = 1
    while d <= n:
        h = 1
        p = 1
        if d * h * p == n:
            all_degrees.append((d, h, p))
            break
        while h <= n:
            p = 1
            if d * h * p == n:
                all_degrees.append((d, h, p))
                break
            while p <= n:
                if d * h * p == n:
                    all_degrees.append((d, h, p))
                    break
                p *= 2
            h *= 2
        d *= 2
    return all_degrees

In [None]:
input_dim = 512
hidden_dim = 512
output_dim = 1
all_cluster_sizes = [64]#[64, 128, 512, 1024]
all_num_hidden_layers = [8]#[4, 8, 16, 32]
all_batch_sizes = [1024]#[512, 1024, 2048, 4096, 8192]
microbatch_size = 256
results = []
for num_hidden_layers in all_num_hidden_layers:
    for batch_size in all_batch_sizes:
        num_microbatches = batch_size // microbatch_size
        topology = Topology(device_speeds)
        d0 = topology.add_device("gpu")
        function = mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, d0)
        function = infer_types(function, function.inputs)
        simulator = Simulator(CostModel(topology))
        simulation = simulator.interpret(function, (v.type for v in function.inputs))
        sequential_running_time = max(
            [simulation.timestamps[d] for d in simulation.timestamps]
        )
        print(f"Sequential running time: {sequential_running_time * 1e3} ms")
        for i, cluster_size in enumerate(all_cluster_sizes):
            if i == 0:
                add_devices_to_topology(topology, cluster_size)
            else:
                add_devices_to_topology(
                    topology, all_cluster_sizes[i] - all_cluster_sizes[i - 1]
                )
            all_degrees = get_all_degrees(cluster_size)
            for (dp_degree, hp_degree, pp_degree) in all_degrees:
                if num_hidden_layers % pp_degree != 0:
                    continue
                transformed_function = parallel_transform_3d(
                    function,
                    dp_degree,
                    hp_degree,
                    pp_degree,
                    topology.devices,
                    num_microbatches,
                )
                transformed_function = infer_types(
                    transformed_function, transformed_function.inputs
                )
                simulation = simulator.interpret(
                    transformed_function, (v.type for v in transformed_function.inputs)
                )
                distributed_running_time = max(
                    [simulation.timestamps[d] for d in simulation.timestamps]
                )
                results.append(
                    (
                        num_hidden_layers,
                        batch_size,
                        cluster_size,
                        dp_degree,
                        hp_degree,
                        pp_degree,
                        distributed_running_time * 1e3
                    ))

Sequential running time: 1.127743488 ms


In [None]:
results