Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
"""
Imagine if you had 4 tokens, top_k = 4, and 128experts.
With EP = 8.
With EP = 8. The num_local_expert should be 128/8 = 16
Imagine router_indices being:
[ 52, 42, 119, 67],
[102, 89, 61, 40],
Expand All @@ -860,12 +860,12 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
[5, 6, 0, 2],
[5, 1, 6, 0],

Thus for say rank 0, you fill with 0 the index tensor
Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor

[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 4, 0],
[ 0, 0, 0, 11],
[ 16, 16, 16, 16],
[ 16, 16, 16, 16],
[ 16, 16, 4, 16],
[ 16, 16, 16, 11],

This works well. For another rank you need to make sure you round to num_local_expert
because the next operation will one hot encode the router index vector.
Expand All @@ -876,13 +876,25 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me

The kinda naive training loop that we use for device_map "auto" uses a similar logic.
Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
"""
ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
if mod.num_experts % ep_size != 0:
raise ValueError(
f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
)
num_local_experts = mod.num_experts // ep_size
router_scores, router_indices = outputs
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0)
router_indices = router_indices % num_local_experts
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
# As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
if num_local_experts > 1:
router_indices = torch.fmod(router_indices, num_local_experts)
else:
router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
router_indices = router_indices.masked_fill(
router_indices == -1, num_local_experts
) # masking class for one hot
return router_scores, router_indices

def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
if hidden_states.device.type == "cpu" or self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=num_experts + 1
) # masking is also a class
expert_mask = expert_mask.permute(2, 1, 0)
# we sum on the top_k and on the sequence length to get which experts
# are hit this time around
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
# expert_idx only have 1 element, so we can use scale for fast indexing
expert_idx = expert_idx[0]
# skip masking index
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gpt_oss/modular_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
if hidden_states.device.type == "cpu" or self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=num_experts + 1
) # masking is also a class
expert_mask = expert_mask.permute(2, 1, 0)
# we sum on the top_k and on the sequence length to get which experts
# are hit this time around
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
# expert_idx only have 1 element, so we can use scale for fast indexing
expert_idx = expert_idx[0]
# skip masking index
if expert_idx == num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def get_hf_quantizer(config, quantization_config, dtype, from_tf, from_flax, dev
dtype = hf_quantizer.update_dtype(dtype)
device_map = hf_quantizer.update_device_map(device_map)
config = hf_quantizer.update_tp_plan(config)
config = hf_quantizer.update_ep_plan(config)

# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if not getattr(hf_quantizer.quantization_config, "dequantize", False):
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def update_tp_plan(self, config):
"updates the tp plan for the scales"
return config

def update_ep_plan(self, config):
"updates the tp plan for the scales"
return config

def preprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,19 @@ def update_tp_plan(self, config):
)
return config

def update_ep_plan(self, config):
if "GptOssConfig" in config.__class__.__name__:
if getattr(config, "base_model_ep_plan", None) is not None:
config.base_model_ep_plan.update(
{
"layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
"layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
"layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
"layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
}
)
return config

def update_param_name(self, param_name: str) -> str:
if self.quantization_config.dequantize:
if "_blocks" in param_name:
Expand Down