Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
return None, grad_output


class Replaced_Layer(nn.Module, ABC):
class TensorParallel_Layer(nn.Module, ABC):
"""
A base class for model layers with tensor parallelism support.
This class is designed to be extended by specific layers that require distributed
Expand All @@ -141,7 +141,7 @@ class Replaced_Layer(nn.Module, ABC):

def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the Replaced_Layer with optional model parallelism group and layer name.
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.

Args:
mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism.
Expand Down Expand Up @@ -177,7 +177,7 @@ def gather_params(self, params_list):
pass

@abstractmethod
def partition(self, params_list: List[torch.Tensor]):
def _tp_partition(self, params_list: List[torch.Tensor]):
"""
Partitions the parameters for tensor parallelism.
It is necessary to ensure that this function only involves the logic of params partitioning.
Expand Down Expand Up @@ -205,7 +205,7 @@ def config_tp_params(self, weight):
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
setattr(weight, DS_IS_REPLACED_MODULE, True)
weight.gather_params = self.gather_params
weight.partition = self.partition
weight._tp_partition = self._tp_partition

def is_training_mode(self):
global DEEPSPEED_AUTOTP_MODE
Expand Down Expand Up @@ -294,17 +294,17 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
#TODO : Check whether there are any missing attributes.
if self.enabled:
self.params[0].partition(self.params)
self.params[0]._tp_partition(self.params)


class LinearAllreduce(Replaced_Layer):
class LinearAllreduce(TensorParallel_Layer):

def __init__(self, module, mp_group, **kwargs):
super(LinearAllreduce, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias

self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
Expand Down Expand Up @@ -335,7 +335,7 @@ def gather_params(self, params_list):
return

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -367,21 +367,22 @@ def uneven_partition(self, params_list):


#remove kwargs from partition.
class LinearLayer(Replaced_Layer):
class LinearLayer(TensorParallel_Layer):

def __init__(self, module, mp_group, skip_partition=False, **kwargs):
def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
super(LinearLayer, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias
if not skip_partition:
self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
self.config_tp_params(self.bias)

def forward(self, input):
input = ColumnParallel.apply(self.mp_group, input)
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
Expand All @@ -401,7 +402,7 @@ def gather_params(self, params_list):
params_list[idx].data = output_param.contiguous()

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -466,7 +467,7 @@ def __init__(self, module, mp_group, skip_partition=False, **kwargs):
super().__init__(module, mp_group, skip_partition, **kwargs)

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand All @@ -481,7 +482,7 @@ def partition(self, params_list):
class conv_LinearLayer(LinearLayer):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight = None
bias = None
if len(params_list) == 1:
Expand All @@ -506,7 +507,7 @@ class Yuan_LinearAllreduce(LinearAllreduce):

#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, False)
params_list[0].data = weight
Expand All @@ -517,7 +518,7 @@ def partition(self, params_list):
class Yuan_LinearLayer(LinearLayer):
#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
Expand All @@ -528,7 +529,7 @@ def partition(self, params_list):
class GateUpPack_LinearLayer(LinearLayer):
# chatGLM2, chatGLM2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
if bias is not None:
Expand All @@ -538,7 +539,7 @@ def partition(self, params_list):
class Conv_LinearALlreduce(LinearAllreduce):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/runtime/hybrid_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,13 @@ def create_inference_containers(self, module, layer_id=0):

layer_id += 1
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
if self.inference_policies[child.__class__][0] == LinearLayer:
self._other_layers.append(self.inference_policies[child.__class__][0](module=child,
mp_group=None,
skip_partition=True))
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
self._orig_modules_others.append(child)
self._orig_fwds_others.append(child.forward)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test(self, layer_type):
assert total_params == params1

for name, param in tp_layer.named_parameters(recurse=False):
param.partition([param])
param._tp_partition([param])

tp_params2 = sum(p.numel() for p in tp_layer.parameters())

Expand Down