Skip to content

Commit

Permalink
update trace replay mode for post analysis (#20)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #20

- Output new traces with sequence number and latency for post analysis (one file per rank since there could be skew due to imbalanced workloads)
- Remove extra barrier in non-blocking mode
- Add option `--allow-list`/`--allow-ops` to specified desired collectives to be replayed and ignore the rest in the trace (default is 'all', replay all supported ops)
- Cleaner summary output in the end (move most messages to logging.info)
- Correctly complete GPU collectives in non-blocking mode

Reviewed By: srinivas212

Differential Revision: D26730965

fbshipit-source-id: ff07b643039e7140fbafebc270231e99c11c40a2
  • Loading branch information
kingchc authored and facebook-github-bot committed Mar 16, 2021
1 parent 7d475d0 commit 4a9e7ee
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 51 deletions.
134 changes: 84 additions & 50 deletions train/comms/pt/commsTraceReplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
logger = logging.getLogger(__name__)


def writeCommDetails(commsTracePerf, folder="./"):
def writeCommDetails(commsTracePerf, rank, folder="./"):
try:
import subprocess
subprocess.check_output(["mkdir", "-p", str(folder)], universal_newlines=True)
except Exception as err:
print("\t Error: %s while creating directory: %s " % (err, folder))
pass
comms_file = folder + "/replayedCommsPerf.json"
logger.info(f"Writing comms details to {comms_file}")
comms_file = folder + f"/replayedCommsPerf.rank{rank}.json"
logger.info(f"[Rank {rank:3}] Writing comms details to {comms_file}")
with open(comms_file, "w") as write_file:
json.dump(commsTracePerf, write_file, indent=2)

Expand All @@ -46,6 +46,8 @@ def __init__(self):
self.num_msg = 0
self.is_blocking = True
self.do_warm_up = True
self.allowList = ""
self.out_path = "/tmp/replayedTrace"

self.collInMsgSizes: Dict[str, List] = {}
self.collInUniMsgSizes: Dict[str, Set] = {}
Expand All @@ -56,6 +58,7 @@ def __init__(self):
self.collTraceStat: Dict[str, List] = {}

self.comms_blocks: Dict[str, List] = {}
self.traceWithPerf = []
self.blockStack = []
self.totalCommsLatency = 0.0

Expand Down Expand Up @@ -108,6 +111,18 @@ def readArgs(self, parser):
default=False,
help="Toggle to disable performing extra replaying for warm-up",
)
parser.add_argument(
"--allow-ops", "--allow-list",
type=str,
default="all",
help="List of desired collectives (separate by comma) to be replayed, e.g., `--allow-ops all_reduce,all_to_allv,wait`, typo or not supported collectives will be ignored.",
)
parser.add_argument(
"--output-path",
type=str,
default=self.out_path,
help="Output path to write the replayed trace for post performance analysis",
)
return parser.parse_args()

def checkArgs(self, args):
Expand Down Expand Up @@ -137,11 +152,11 @@ def reportBenchTime(self, commsParams):
lat_list = [comm["latency_us"] for comm in blockComms]
Lats = np.array(lat_list)

print(
logger.info(
f"+ {len(blockComms)} comms in block {curBlock}: {Lats.sum():.2f} us in total"
)

print("\n{} Message size Statistcs {}".format("=" * 20, "=" * 20))
logger.info("\n{} Message size Statistcs {}".format("=" * 20, "=" * 20))

for (name, collMsgs) in self.collInMsgSizes.items():
# input tensor
Expand Down Expand Up @@ -339,6 +354,33 @@ def commStack(self, blockStack, blockname, curComm):
self.commStack(blockStack, nextBlock, nextComm)
pass

def warmUpBench(self, commsParams):
for cnt, curComm in enumerate(self.comms_trace[:self.max_msg_cnt]):
if curComm["comms"] not in self.allowList:
continue
if self.backendFuncs.get_global_rank() == 0:
logger.debug(f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n")
print(
f"[Warm-up][{cnt} / {self.max_msg_cnt}] Replaying {curComm['comms']:>10}...", end="\r"
)

# read fields and prepare the tensors
self.prepComms(curComm)

if curComm["comms"] in self.backendFuncs.collectiveFunc.keys():
self.collectiveArgs.waitObj.append(
self.backendFuncs.collectiveFunc[curComm["comms"]](
self.collectiveArgs, retFlag=self.collectiveArgs.asyncOp
)
)
elif curComm["comms"] == "wait":
self.backendFuncs.complete_single_op(self.collectiveArgs)
else:
# not supported collective, skip
pass

self.backendFuncs.complete_accel_ops(self.collectiveArgs)

def benchTime(self, commsParams):
"""
The json format is expecting to be either
Expand Down Expand Up @@ -368,33 +410,9 @@ def benchTime(self, commsParams):
- this format is subject to be changed/defined later
- the unit of all size fields is # of elements (not bytes)
"""
# FIXME: ideally, need to know actually elemement size of datatype indicated in the trace
elem_size = 4

# warm-up
if self.do_warm_up:
for cnt, curComm in enumerate(self.comms_trace[:self.max_msg_cnt]):

if self.backendFuncs.get_global_rank() == 0:
print(
f"[Warm-up][{cnt} / {self.max_msg_cnt}] Replaying {curComm['comms']:>10}...", end="\r"
)
# read fields and prepare the tensors
self.prepComms(curComm)

if curComm["comms"] in self.backendFuncs.collectiveFunc.keys():
self.collectiveArgs.waitObj.append(
self.backendFuncs.collectiveFunc[curComm["comms"]](
self.collectiveArgs, retFlag=self.collectiveArgs.asyncOp
)
)
elif curComm["comms"] == "wait":
self.backendFuncs.complete_single_op(self.collectiveArgs)
else:
# not supported collective, skip
pass

self.backendFuncs.complete_accel_ops(self.collectiveArgs)
self.warmUpBench(commsParams)

# sync everything before starting real runs
self.collectiveArgs.waitObj.append(
Expand All @@ -404,29 +422,25 @@ def benchTime(self, commsParams):

if self.backendFuncs.get_global_rank() == 0:
print(
f"\n+ {self.max_msg_cnt} messages to be replayed..."
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {(self.allowList)}"
)
for coll, sizes in self.collInMsgSizes.items():
print(f"\t{coll}: {len(sizes)}")
logger.info(f"\t{coll}: {len(sizes)}")

# second pass to perform collectives
for cnt, curComm in enumerate(self.comms_trace[:self.max_msg_cnt]):
if curComm["comms"] not in self.allowList:
continue
collName = curComm["comms"]
curBlocks = curComm["marker_stack"]
curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else []
curBlockStack = ' '.join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"

debug_msg = (
f"[Rank {self.collectiveArgs.global_rank:3}] Replaying {collName} in block {curBlocks}, in_size {curComm['in_msg_size']}, out_size {curComm['out_msg_size']}, dtype {curComm['dtype']}"
if ("in_msg_size" in curComm)
else f"[Rank {self.collectiveArgs.global_rank:3}] Got {collName} in block {curBlocks}"
)
logger.debug(debug_msg)

if self.backendFuncs.get_global_rank() == 0:
logger.debug(f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n")
print(
f"[{cnt} / {self.max_msg_cnt}] Replaying {collName} in block [{curBlockStack}]...",
end="",
f"[{cnt} / {self.max_msg_cnt}]", end="\r"
)

# read fields and prepare the tensors
self.prepComms(curComm)

Expand All @@ -446,7 +460,11 @@ def benchTime(self, commsParams):
pass

if self.is_blocking:
self.collectiveArgs.waitObj.append(
self.backendFuncs.barrier(self.collectiveArgs, retFlag=self.collectiveArgs.asyncOp)
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)

end = time.monotonic()
latency = (end - begin) * 1e6 # make it microsecond

Expand All @@ -457,14 +475,19 @@ def benchTime(self, commsParams):
else (0, 0, latency)
)

curComm["latency(us)"] = latency
curComm["seqnum"] = cnt
curComm["latency_us"] = latency
self.totalCommsLatency += latency
# Keep a copy of trace with performance (latency) and seqnum
self.traceWithPerf.append(curComm)

# categorized by the marker
for curBlock in curComm["marker_stack"]:
elem_size = self.collectiveArgs.ipTensor.element_size()
self.comms_blocks[curBlock].append(
{
"comms": collName,
"seqnum": cnt,
"blocked": "Y" if (self.is_blocking) else "N",
"in_msg_size_bytes": curComm["in_msg_size"] * elem_size if "in_msg_size" in curComm else 0,
"out_msg_size_bytes": curComm["out_msg_size"] * elem_size if "out_msg_size" in curComm else 0,
Expand All @@ -473,12 +496,14 @@ def benchTime(self, commsParams):
)

if self.backendFuncs.get_global_rank() == 0:
print(f"{latency:.2f} us")

if not self.is_blocking:
self.backendFuncs.barrier(self.collectiveArgs)
logger.info(
f"[{cnt} / {self.max_msg_cnt}] Replayed {collName} in block [{curBlockStack}]... {latency:.2f} us"
)

# make sure all ops are completed
self.collectiveArgs.waitObj.append(
self.backendFuncs.barrier(self.collectiveArgs, retFlag=True)
)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
self.backendFuncs.clear_memory()

Expand All @@ -502,11 +527,12 @@ def runBench(self, comms_world_info, commsParams):
# rank 0 reports statistics
if comms_world_info.global_rank == 0:
self.reportBenchTime(commsParams)
# dump trace sorted with block and with latency if not dry run
writeCommDetails(self.comms_blocks)
# TODO: collect perf. from all ranks to rank 0 and detect any imbalanced perf?
writeCommDetails(self.comms_blocks, rank=comms_world_info.global_rank)

if not self.is_dry_run:
# dump trace sorted with block and with latency if not dry run
writeCommDetails(self.traceWithPerf, folder=self.out_path, rank=comms_world_info.global_rank)
# TODO: collect perf. from all ranks to rank 0 and detect any imbalanced perf?
self.backendFuncs.barrier(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)

Expand Down Expand Up @@ -558,12 +584,20 @@ def setBench(self, comms_world_info, commsParams):
self.collectiveArgs.opTensor = None
self.collectiveArgs.quant_threshold = commsParams.quant_threshold

# set of collectives to be replayed
if (self.allowList in ("all", "default", "*")):
self.allowList = self.backendFuncs.collectiveFunc.keys()
else:
self.allowList = self.allowList.split(',')

def initBench(self, comms_world_info, commsParams, args):
self.is_dry_run = args.dry_run
self.shrink = args.auto_shrink
self.max_msg_cnt = args.max_msg_cnt
self.is_blocking = args.z
self.do_warm_up = not args.no_warm_up
self.allowList = args.allow_ops
self.out_path = args.output_path

if commsParams.bitwidth < 32:
logger.info(f"communication bitwidth set to {commsParams.bitwidth}")
Expand Down
14 changes: 13 additions & 1 deletion train/comms/pt/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,23 @@ def complete_accel_ops(self, collectiveArgs, initOp=False):
if dev_str == "cuda":
torch.cuda.synchronize(collectiveArgs.device)

def complete_single_op(self, collectiveArgs):
# retFlag not used
def complete_single_op(self, collectiveArgs, retFlag=False):
""" only wait the first op in the queue """
if len(collectiveArgs.waitObj) > 0:
waitReq = collectiveArgs.waitObj.pop(0)
if waitReq is not None:
waitReq.wait()

# to ensure GPU collective is completed
dev_str = (
self.commsParams["device"]
if isinstance(self.commsParams, dict)
else self.commsParams.device
)
if dev_str == "cuda":
torch.cuda.synchronize(collectiveArgs.device)


def barrier(self, collectiveArgs, name="dummy", retFlag=False):
retObj = dist.barrier(collectiveArgs.group, async_op=collectiveArgs.asyncOp)
Expand Down Expand Up @@ -327,6 +337,8 @@ def __init__(self, comms_world_info, commsParams):
super().__init__()
self.comms_world_info = comms_world_info
self.commsParams = commsParams
# Add single wait op (Note this is not supported in pytorch_tpu_backend.py now)
self.collectiveFunc["wait"] = self.complete_single_op

# Import ucc plugin
if commsParams.backend == "ucc":
Expand Down

0 comments on commit 4a9e7ee

Please sign in to comment.