In [1]:
import argparse
from collections import defaultdict, OrderedDict
import logging
import numpy as np

import dist_ir
from dist_ir.importer import import_from_onnx, parse_tensor_from_file
from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value
from dist_ir.executor import infer_types, SequentialExecutor, Simulator
from dist_ir.executor.cost_model import CostModel
from dist_ir.ir.type import Bool, Float, Int64, Tensor
from dist_ir.transforms import (
    parallel_transform_3d,
    hybrid_transform_unrolled,
    PipeDreamScheduler,
)

In [2]:
device_speeds = {"gpu": 1.0e13}

In [3]:
topology = Topology(device_speeds)
d0 = topology.add_device("gpu")
function = FunctionMaker()
x = function.add_input_value("x", Tensor(Float(), shape=(1024, 512), device=d0))
w = function.add_input_value("w", Tensor(Float(), shape=(512, 512), device=d0))
y = function.add_op("MatMul", inputs=[x, w], output_names=["y"])
function = function.finalize()
function = infer_types(function, function.inputs)
simulator = Simulator(CostModel(topology))
simulation = simulator.interpret(function, (v.type for v in function.inputs))
running_time = max(simulation.timestamps[d] for d in simulation.timestamps)
print(f"Running time: {running_time * 1e3} ms")

Running time: 0.0536870912 ms


In [4]:
topology = Topology(device_speeds)
d0 = topology.add_device("gpu")
d1 = topology.add_device("gpu")
topology.set_bandwidth(d0, d1, 200.0 * 8)
function = FunctionMaker()
y = function.add_input_value("y", Tensor(Float(), shape=(1024, 512), device=d0))
y1 = function.add_op("Send", inputs=[y], attributes={"device": d1}, output_names=["y1"])
function = function.finalize()
function = infer_types(function, function.inputs)
simulator = Simulator(CostModel(topology))
simulation = simulator.interpret(function, (v.type for v in function.inputs))
running_time = max(simulation.timestamps[d] for d in simulation.timestamps)
print(f"Running time: {running_time * 1e3} ms")

Running time: 0.00065536 ms


In [5]:
topology = Topology(device_speeds)
d0 = topology.add_device("gpu")
d1 = topology.add_device("gpu")
topology.set_bandwidth(d0, d1, 200.0 * 8)
function = FunctionMaker()
y0 = function.add_input_value("y0", Tensor(Float(), shape=(1024, 512), device=d0))
y1 = function.add_input_value("y1", Tensor(Float(), shape=(1024, 512), device=d1))
y1 = function.add_op("MPIAllreduce", inputs=[y0, y1], output_names=["y1"])
function = function.finalize()
function = infer_types(function, function.inputs)
simulator = Simulator(CostModel(topology))
simulation = simulator.interpret(function, (v.type for v in function.inputs))
running_time = max(simulation.timestamps[d] for d in simulation.timestamps)
print(f"Running time: {running_time * 1e3} ms")

> [0;32m/Users/keshavsanthanam/workspace/dist-ir/dist_ir/executor/cost_model.py[0m(145)[0;36m_mpi_allreduce_cost_fn[0;34m()[0m
[0;32m    143 [0;31m[0;34m[0m[0m
[0m[0;32m    144 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 145 [0;31m        [0minput_size[0m [0;34m=[0m [0mxs[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    146 [0;31m        [0mdevices[0m [0;34m=[0m [0;34m[[0m[0mx[0m[0;34m.[0m[0mdevice[0m [0;32mfor[0m [0mx[0m [0;32min[0m [0mxs[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    147 [0;31m        [0mnum_devices[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mdevices[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m/Users/keshavsanthanam/workspace/dist-ir/dist_ir/executor/cost_model.py[0m(146)[0;36m_mpi_allreduce_cost_fn[0;34m()[0m
[0;32m    144 [0;31m        

ipdb> n
> [0;32m/Users/keshavsanthanam/workspace/dist-ir/dist_ir/executor/cost_model.py[0m(154)[0;36m_mpi_allreduce_cost_fn[0;34m()[0m
[0;32m    152 [0;31m            [0;32mfor[0m [0mj[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mi[0m [0;34m+[0m [0;36m1[0m[0;34m,[0m [0mlen[0m[0;34m([0m[0mdevices[0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    153 [0;31m                all_bandwidths.append(
[0m[0;32m--> 154 [0;31m                    [0mself[0m[0;34m.[0m[0m_topology[0m[0;34m.[0m[0mget_bandwidth[0m[0;34m([0m[0mdevices[0m[0;34m[[0m[0mi[0m[0;34m][0m[0;34m,[0m [0mdevices[0m[0;34m[[0m[0mj[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    155 [0;31m                )
[0m[0;32m    156 [0;31m        [0maverage_bandwidth[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0mall_bandwidths[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m/Users/keshavsanthanam

ipdb> cost
1.6384000000000002e-07
ipdb> n
--Return--
{Device(device...ariable=False): 1.6384000000000002e-07, Device(device...ariable=False): 1.6384000000000002e-07}
> [0;32m/Users/keshavsanthanam/workspace/dist-ir/dist_ir/executor/cost_model.py[0m(159)[0;36m_mpi_allreduce_cost_fn[0;34m()[0m
[0;32m    157 [0;31m        [0mcost[0m [0;34m=[0m [0mper_device_data_gb[0m [0;34m/[0m [0maverage_bandwidth[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    158 [0;31m[0;34m[0m[0m
[0m[0;32m--> 159 [0;31m        [0;32mreturn[0m [0;34m{[0m[0mdevice[0m[0;34m:[0m [0mcost[0m [0;32mfor[0m [0mdevice[0m [0;32min[0m [0mdevices[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    160 [0;31m[0;34m[0m[0m
[0m[0;32m    161 [0;31m    [0;32mdef[0m [0m_mpi_broadcast_cost_fn[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mop[0m[0;34m,[0m [0mx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m/Users/keshavsanthanam/workspace/dist-ir/dist_ir/ex

ipdb> c
Running time: 0.00016384 ms
