Skip to content

Commit

Permalink
consolidate communication for tensor metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed May 1, 2024
1 parent 822aeee commit 9f96ad4
Showing 1 changed file with 40 additions and 66 deletions.
106 changes: 40 additions & 66 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'

TENSOR_META_SIZE = 256


def is_even(number):
return number % 2 == 0
Expand Down Expand Up @@ -930,17 +932,17 @@ def _send_tensor_meta(self, buffer, recv_stage):
* ndims
* shape
"""
send_bytes = 0
meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
if isinstance(buffer, torch.Tensor):
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[buffer.dtype]]).to(self.device)
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.send(type_tensor, recv_stage)
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(buffer)
meta_buf_list = [
0, # type of data (0: tensor, 1: list, 2: tuple)
self.DTYPE_TO_ID[buffer.dtype], # dtype
len(buffer.size()) # ndims
]
meta_buf_list.extend(buffer.size())
meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
p2p.send(meta_buffer, recv_stage)

elif isinstance(buffer, list):
assert (False)
type_tensor = torch.LongTensor(data=[1]).to(self.device)
Expand All @@ -953,30 +955,21 @@ def _send_tensor_meta(self, buffer, recv_stage):
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(tensor)
elif isinstance(buffer, tuple):
type_tensor = torch.LongTensor(data=[2]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for idx, tensor in enumerate(buffer):
meta_buf_list = [
2, # type of data (0: tensor, 1: list, 2: tuple)
len(buffer) # num_tensors
]

for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
# Useful for performance debugging.
'''
new_bytes = _tensor_bytes(tensor)
send_bytes += _tensor_bytes(tensor)
# Useful for performance debugging.
if self.grid.data_parallel_id == 0:
print(
f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
)
'''
meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype])
meta_buf_list.append(len(tensor.size()))
meta_buf_list.extend(tensor.size())

meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
p2p.send(meta_buffer, recv_stage)

else:
raise NotImplementedError(f'Could not send meta type {type(buffer)}')

Expand All @@ -989,53 +982,34 @@ def _send_tensor_meta(self, buffer, recv_stage):
def _recv_tensor_meta(self, send_stage):
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
Returns:
Allocated buffer for receiving from send_stage.
"""
buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
p2p.recv(buffer, send_stage)

type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(type_tensor, send_stage)
recv_type = type_tensor.item()
recv_type = buffer[0].item()

# A single tensor will be sent.
if recv_type == 0:
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shape = recv_shape.tolist()
recv_dtype = self.ID_TO_DTYPE[buffer[1].item()]
recv_ndims = buffer[2].item()
recv_shape = buffer[3:3 + recv_ndims].tolist()
return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)

# List or tuple of tensors
elif recv_type == 1 or recv_type == 2:
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes_and_dtypes = []
num_tensors = buffer[1].item()

buffers = []
offset = 2
for idx in range(num_tensors):
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))

buffers.append(self._allocate_or_extend_buffers(idx, recv_shape.tolist(), recv_dtype))
recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()]
recv_ndims = buffer[offset + 1].item()
recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist()
offset += 2 + recv_ndims

buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype))

# Convert to tuples if requested.
if recv_type == 2:
Expand Down

0 comments on commit 9f96ad4

Please sign in to comment.