Skip to content

Commit

Permalink
Refactoring et_replay
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed May 17, 2024
1 parent e99ef20 commit 0d91353
Show file tree
Hide file tree
Showing 45 changed files with 4,485 additions and 5,668 deletions.
File renamed without changes.
6 changes: 6 additions & 0 deletions et_replay/lib/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .comm_op_args import CommOpArgs
from .comm_stats import CommStats
from .comm_tensor_allocator import CommTensorAllocator
from .comm_trace_reader import CommTraceReader
from .comm_trace_replay_args import CommTraceReplayArgs
from .comm_trace_replayer import CommTraceReplayer
4 changes: 4 additions & 0 deletions et_replay/lib/comm/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base_backend import BaseBackend
from .coll_args import CollArgs
from .mock import MockBackend
from .pytorch_dist import PyTorchDistBackend
327 changes: 327 additions & 0 deletions et_replay/lib/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from .coll_args import CollArgs

SupportedP2pOps = [
"send",
"recv",
"isend",
"irecv",
]


class BaseBackend(ABC):
"""
Abstract base class that defines common functionalities for distributed computation backends.
Attributes:
tcp_store (Optional[dist.TCPStore]): A TCP store for communication during distributed operations.
collective_func (Dict[str, Callable]): Mapping of collective function names to their respective callable functions.
"""

def __init__(self) -> None:
"""
Initialize common backend functionalities.
"""
self.tcp_store = None

@abstractmethod
def initialize_backend(
self, master_ip: str, master_port: str, backend: str = "gloo"
) -> None:
"""
Initialize the backend for distributed operations.
Args:
master_ip (str): IP address of the master node.
master_port (str): Port number of the master node.
backend (str): Backend to be used for initialization.
"""
pass

@abstractmethod
def say_hello(
self, global_rank: int, local_rank: int, world_size: int, master_ip: str
) -> None:
"""
Print startup information for a specific backend instance.
Args:
global_rank (int): Global rank of the process in the distributed setup.
local_rank (int): Local rank of the process on the node.
world_size (int): Total number of processes in the distributed setup.
master_ip (str): IP address of the master node.
"""
pass

# Memory management functions
@abstractmethod
def alloc_ones(
self,
size_arr: int,
cur_rank_device: str = "cuda",
dtype: torch.dtype = torch.float32,
scale_factor: float = 1.0,
) -> torch.Tensor:
"""
Allocate a tensor of ones, scaled by a given factor.
Args:
size_arr (int): Number of elements in the tensor.
cur_rank_device (str): Device to allocate tensor on.
dtype (torch.dtype): Data type of the tensor.
scale_factor (float): Factor to scale the tensor values by.
Returns:
torch.Tensor: The allocated tensor.
"""
pass

@abstractmethod
def alloc_random(
self,
size_arr: int,
cur_rank_device: str,
dtype: torch.dtype,
scale_factor: float = 1.0,
) -> torch.Tensor:
"""
Allocate a tensor with random values.
Args:
size_arr (int): Size of the tensor to allocate.
cur_rank_device (str): Device to allocate tensor on.
dtype (torch.dtype): Data type of the tensor.
scale_factor (float): Scaling factor for tensor values.
Returns:
torch.Tensor: The allocated tensor.
"""
pass

@abstractmethod
def alloc_empty(
self, size_arr: int, dtype: torch.dtype, cur_rank_device: str
) -> torch.Tensor:
"""
Allocate an uninitialized tensor.
Args:
size_arr (int): Size of the tensor.
dtype (torch.dtype): Data type of the tensor.
cur_rank_device (str): Device to allocate tensor on.
Returns:
torch.Tensor: The allocated tensor.
"""
pass

@abstractmethod
def clear_memory(self, collective_args: CollArgs) -> None:
"""
Clear memory allocated for the collective operations.
Args:
collective_args (CollArgs): Holder of collective arguments which contains tensors to be cleared.
"""
pass

# Collective communication functions
@abstractmethod
def all_reduce(self, collective_args: CollArgs, ret_flag: bool = False) -> None:
"""
Perform an all-reduce operation on the data within the collective_args.
Args:
collective_args (CollArgs): The collective arguments.
ret_flag (bool): Flag to indicate if the operation should return something.
"""
pass

@abstractmethod
def reduce(self, collective_args: CollArgs, ret_flag: bool = False) -> None:
"""
Perform a reduce operation on the data within the collective_args.
Args:
collective_args (CollArgs): The collective arguments.
ret_flag (bool): Flag to indicate if the operation should return something.
"""
pass

@abstractmethod
def all_to_all(self, collective_args: CollArgs, ret_flag: bool = False) -> None:
"""
Perform an all-to-all operation on the data within the collective_args.
Args:
collective_args (CollArgs): The collective arguments.
ret_flag (bool): Flag to indicate if the operation should return something.
"""
pass

@abstractmethod
def all_to_allv(self, collective_args: CollArgs, ret_flag: bool = False) -> None:
"""
Perform an all-to-all variable operation on the data within the collective_args.
Args:
collective_args (CollArgs): The collective arguments.
ret_flag (bool): Flag to indicate if the operation should return something.
"""
pass

@abstractmethod
def barrier(self, collective_args: CollArgs, name: str = "dummy") -> None:
"""
Synchronize all processes in the distributed environment.
Args:
collective_args (CollArgs): The collective arguments.
name (str): Name of the barrier for debugging.
"""
pass

# Placeholder and utility functions
def noop(
self,
collective_args: CollArgs = None,
ret_flag: bool = False,
pair: bool = False,
) -> None:
"""
A no-operation function used as a placeholder.
Args:
collective_args (CollArgs, optional): The collective arguments.
ret_flag (bool, optional): Flag to indicate if the operation should return something.
pair (bool, optional): Flag to indicate if the operation involves pairs.
"""
pass

def sync_barrier(self, collective_args: CollArgs, desc: str = "world") -> None:
"""
Synchronize all processes in the distributed environment, ensuring all previous operations are completed.
Args:
collective_args (CollArgs): The collective arguments.
desc (str): Description of the sync point for debugging.
"""
self.barrier(collective_args, name=desc)

@abstractmethod
def get_reduce_op(self, op_name: str) -> dist.ReduceOp:
"""
Get the corresponding reduce operation.
Args:
op_name (str): Name of the operation.
Returns:
dist.ReduceOp: The reduce operation.
"""
pass

@abstractmethod
def get_world_size(self) -> int:
"""
Get the total number of processes in the distributed environment.
Returns:
int: The world size.
"""
pass

@abstractmethod
def get_local_size(self) -> int:
"""
Get the number of processes on the local node.
Returns:
int: The local size.
"""
pass

@abstractmethod
def get_global_rank(self) -> int:
"""
Get the global rank of the process.
Returns:
int: The global rank.
"""
pass

@abstractmethod
def get_local_rank(self) -> int:
"""
Get the local rank of the process.
Returns:
int: The local rank.
"""
pass

@abstractmethod
def get_device(self) -> str:
"""
Get the current device the process is using.
Returns:
str: The device identifier.
"""
pass

@abstractmethod
def get_hw_device(self) -> str:
"""
Get the hardware device information.
Returns:
str: The hardware device identifier.
"""
pass

@abstractmethod
def get_default_group(self) -> ProcessGroup:
"""
Get the default process group for collective operations.
Returns:
ProcessGroup: The default process group.
"""
pass

@abstractmethod
def get_groups(self) -> List[ProcessGroup]:
"""
Get all the process groups available.
Returns:
list: List of process groups.
"""
pass

@abstractmethod
def get_num_pgs(self) -> int:
"""
Get the number of process groups.
Returns:
int: Number of process groups.
"""
pass

@abstractmethod
def complete_accel_ops(self, collectiveArgs: CollArgs):
pass
54 changes: 54 additions & 0 deletions et_replay/lib/comm/backend/coll_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
class CollArgs:
"""Class holding object for all the parameters related to a collective operation/experiment."""

def __init__(self) -> None:
self.group = None
self.groups = {} # {pg_id, pg}
self.num_pgs = 0
self.device = {}
self.world_size = 0
self.data_type = ""

self.numIters = 0
self.numWarmupIters = 0
self.global_rank = -1
self.backendFuncs = {}
self.collective = ""
self.collectiveId = 0
self.pt2pt = ""
self.src_rank = -1
self.dst_rank = -1
self.p2pOps = []

self.reuseTensors = False

self.batch_size = 0

self.input_tensor_split = []
self.output_tensor_split = []

self.input_tensor = []
self.output_tensor = []
self.srcOrDst = -1
self.asyncOp = -1
self.dataSize = 0
self.numElements = 0
self.waitObj = []
self.waitObjIds = {} # mapping of reqID to future of async collectives

self.input_tensor_split_pair = []
self.output_tensor_split_pair = []

self.input_tensor_pair = None
self.output_tensor_pair = None
self.dataSize_pair = 0
self.numElements_pair = 0

self.all2all_qcomm = None
self.reducescatter_allgather_qcomm = None
self.allreduce_qcomm = 32 # TODO: set it as the bitwidth for now until the quantization kernels be supported
self.reduce_qcomm = 32
self.quant_threshold = 0
self.enable_profiler = False

self.use_ext_dist = False
Loading

0 comments on commit 0d91353

Please sign in to comment.