diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index c1a0bc4..66212f2 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -2,6 +2,7 @@ import enum import inspect import torch +from functools import partial from megatron.core.extensions.transformer_engine import TEFusedMLP from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, @@ -210,6 +211,8 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): def _build_mlp(self, mlp_spec): pg_collection = self.pg_collection + if isinstance(mlp_spec, partial): + return mlp_spec(config=self.config, pg_collection=pg_collection, is_mtp_layer=self.is_mtp_layer) additional_mlp_kwargs = {} # import here to avoid circular import from mcore_bridge.model.gpts.glm4 import Glm4MLP