In [1]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
import torch.autograd.profiler as profiler
from torch.nn.parallel import DistributedDataParallel as DDP

In [2]:
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath("__file__"))))
from dist_ir.ir import FunctionMaker, Topology
from dist_ir.ir.type import Tensor, Float
from dist_ir.executor import infer_types, Simulator
from dist_ir.executor.cost_model import CostModel

In [3]:
DEVICE_THROUGHPUT = 1.38e13 #6.7e12 # FLOPS
DRAM_BANDWIDTH = 7e11 # ???
PCIE_BANDWIDTH = 128 # Gbps

In [4]:
def simulate(world_size, batch_size, hidden_dim):
    topology = Topology()
    d0 = topology.add_device(
        "gpu", throughput=DEVICE_THROUGHPUT, dram_bandwidth=DRAM_BANDWIDTH
    )
    for i in range(world_size):
        di = topology.add_device(
            "gpu", throughput=DEVICE_THROUGHPUT, dram_bandwidth=DRAM_BANDWIDTH
        )
        topology.set_bandwidth(d0, di, float("inf"))
        for j in range(1, i + 1):
            dj = topology.devices[j]
            topology.set_bandwidth(di, dj, PCIE_BANDWIDTH)
    function = FunctionMaker()
    x = function.add_input_value("x", Tensor(dtype=Float(), shape=(batch_size, hidden_dim), device=d0))
    w = function.add_input_value("w", Tensor(dtype=Float(), shape=(hidden_dim, hidden_dim), device=d0))
    if world_size == 1:
        function.add_op("MatMul", inputs=[x, w], output_names=["y"])
    else:
        x1, x2 = function.add_op("MPIScatter", inputs=[x],
                                 attributes={"dim": 0, "devices": topology.devices[1:]},
                                 output_names=["x1", "x2"])
        w1, w2 = function.add_op("MPIBroadcast", inputs=[w],
                                 attributes={"devices": topology.devices[1:]},
                                 output_names=["w1", "w2"])
        y1 = function.add_op("MatMul", inputs=[x1, w1], output_names=["y1"])
        y2 = function.add_op("MatMul", inputs=[x2, w2], output_names=["y2"])
    function = function.finalize()
    function = infer_types(function, function.inputs)
    simulator = Simulator(CostModel(topology))
    simulation = simulator.interpret(
        function,
        (v.type for v in function.inputs),
    )
    return max([simulation.timestamps[d] for d in simulation.timestamps])

In [5]:
def setup(
    local_rank, world_size, backend="nccl", master_addr="localhost", master_port="12355"
):
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port
    torch.distributed.init_process_group(
        backend, world_size=world_size, rank=local_rank
    )

In [6]:
def cleanup():
    torch.distributed.destroy_process_group()

In [7]:
def run(config):
    (local_rank, world_size, batch_size, hidden_dim,
     num_warmup_iterations, num_profiling_iterations) = config
    # NOTE: I was previously splitting batch size manually here, but this appears to be incorrect
    x = torch.randn((batch_size, hidden_dim)).cuda(local_rank)
    model = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).cuda(local_rank)
    if world_size > 1:
        setup(local_rank, world_size)
        model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    with profiler.profile(use_cuda=True) as prof:
        with profiler.record_function("matmul_benchmark"):
            for i in range(num_warmup_iterations + num_profiling_iterations):
                y = model(x)
    if world_size > 1:
        torch.distributed.barrier()
    runtimes = []
    for event in prof.function_events:
        if event.name == "aten::linear":
            assert event.cuda_time > 0
            runtimes.append(event.cuda_time)

    return np.median(runtimes[num_warmup_iterations:]) / 1e6

In [8]:
def distributed_driver(batch_size, hidden_dim, world_size):
    configs = [
        (rank, world_size, batch_size, hidden_dim, 10, 100)
        for rank in range(world_size)
    ]
    with torch.multiprocessing.Pool(world_size) as p:
        results = p.map(run, configs)
    print(
        f"world_size={world_size}, "
        f"batch_size={batch_size}, "
        f"hidden_dim={hidden_dim}, "
        f"runtime={np.mean(results)}"
    )
    return np.mean(results)

In [9]:
def distributed_benchmark():
    all_batch_sizes = [256, 512, 1024]
    all_hidden_dims = [2048, 4096]
    simulated_throughputs = defaultdict(list)
    pytorch_throughputs = defaultdict(list)
    for batch_size in all_batch_sizes:
        for hidden_dim in all_hidden_dims:
            key = (batch_size, hidden_dim)
            for world_size in [1, 2, 4]:
                pytorch_runtime = distributed_driver(batch_size * world_size,
                                                     hidden_dim, world_size)
                simulated_runtime = simulate(world_size, batch_size * world_size, hidden_dim)
                pytorch_throughputs[key].append(batch_size * world_size / pytorch_runtime / 1000)
                simulated_throughputs[key].append(batch_size * world_size / simulated_runtime / 1000)
    return pytorch_throughputs, simulated_throughputs

In [10]:
(pytorch_throughputs, simulated_throughputs) = distributed_benchmark()

world_size=1, batch_size=256, hidden_dim=2048, runtime=0.000249857421875
world_size=2, batch_size=512, hidden_dim=2048, runtime=0.00024473583984375
world_size=4, batch_size=1024, hidden_dim=2048, runtime=0.00025881591796875
world_size=1, batch_size=256, hidden_dim=4096, runtime=0.000728578125
world_size=2, batch_size=512, hidden_dim=4096, runtime=0.0007841276855468749
world_size=4, batch_size=1024, hidden_dim=4096, runtime=0.0006968329467773438
world_size=1, batch_size=512, hidden_dim=2048, runtime=0.00043827294921875
world_size=2, batch_size=1024, hidden_dim=2048, runtime=0.00045209765624999997
world_size=4, batch_size=2048, hidden_dim=2048, runtime=0.00038399554443359373
world_size=1, batch_size=512, hidden_dim=4096, runtime=0.00135936328125
world_size=2, batch_size=1024, hidden_dim=4096, runtime=0.00136780908203125
world_size=4, batch_size=2048, hidden_dim=4096, runtime=0.0014635521240234375
world_size=1, batch_size=1024, hidden_dim=2048, runtime=0.0007669765625
world_size=2, batch_

In [11]:
for key in pytorch_throughputs:
    print(f"{str(key):10}:\t[{pytorch_throughputs[key][0]:.2f}, {pytorch_throughputs[key][1]:.2f}, {pytorch_throughputs[key][2]:.2f}]\t"
          f"[{simulated_throughputs[key][0]:.2f}, {simulated_throughputs[key][1]:.2f}, {simulated_throughputs[key][2]:.2f}]")

(256, 2048):	[1024.58, 2092.05, 3956.48]	[1402.14, 2804.28, 5608.56]
(256, 4096):	[351.37, 652.95, 1469.51]	[353.43, 706.87, 1413.74]
(512, 2048):	[1168.22, 2265.00, 5333.40]	[1500.64, 3001.27, 6002.54]
(512, 4096):	[376.65, 748.64, 1399.34]	[378.48, 756.96, 1513.93]
(1024, 2048):	[1335.11, 2539.68, 5761.62]	[1555.26, 3110.52, 6221.05]
(1024, 4096):	[376.93, 721.89, 1553.92]	[392.39, 784.77, 1569.55]


In [12]:
def single_device_benchmark():
    all_batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]
    all_hidden_dims = [64, 128, 256, 512, 1024, 2048, 4096]
    simulation_results = defaultdict(list)
    pytorch_results = defaultdict(list)
    for batch_size in all_batch_sizes:
        for hidden_dim in all_hidden_dims:
            simulated_runtime = simulate(batch_size, hidden_dim)
            pytorch_runtime = run(batch_size, hidden_dim)
            simulation_results[batch_size].append(simulated_runtime)
            pytorch_results[batch_size].append(pytorch_runtime)
            print(f"{batch_size},{hidden_dim},{simulated_runtime:.2f},{pytorch_runtime:.2f}")