compat megatron.core 0.18#77
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for partial objects within the _build_mlp method of the transformer_layer.py module. A review comment identifies a potential TypeError because the implementation unconditionally passes arguments like pg_collection and is_mtp_layer to the partial object, regardless of the underlying module's signature. It also notes that additional_mlp_kwargs is redundant in the new branch and suggests a more robust way to handle different module types using build_module.
| if isinstance(mlp_spec, partial): | ||
| return mlp_spec( | ||
| config=self.config, | ||
| pg_collection=pg_collection, | ||
| is_mtp_layer=self.is_mtp_layer, | ||
| **additional_mlp_kwargs) | ||
| else: | ||
| return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) |
There was a problem hiding this comment.
The current implementation for handling partial objects is problematic for two reasons:
- Correctness: It unconditionally passes
pg_collectionandis_mtp_layerto thepartialobject. However, if thepartialwraps a standardMLPorTEFusedMLP, it will likely fail with aTypeErrorbecause these modules expecttp_groupinstead ofpg_collection, and do not acceptis_mtp_layerin their constructors. - Redundancy:
additional_mlp_kwargsis only populated ifmlp_specis aModuleSpec(lines 222-235). Since apartialis not aModuleSpec,additional_mlp_kwargsis guaranteed to be empty when line 241 is reached, making**additional_mlp_kwargsredundant in that branch.
It is better to determine the underlying module type and populate additional_mlp_kwargs accordingly, then use build_module to perform the instantiation consistently.
| if isinstance(mlp_spec, partial): | |
| return mlp_spec( | |
| config=self.config, | |
| pg_collection=pg_collection, | |
| is_mtp_layer=self.is_mtp_layer, | |
| **additional_mlp_kwargs) | |
| else: | |
| return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) | |
| if isinstance(mlp_spec, partial): | |
| module = mlp_spec.func | |
| if module in (MoELayer, TEGroupedMLP, SequentialMLP): | |
| additional_mlp_kwargs['pg_collection'] = pg_collection | |
| if module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: | |
| additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer | |
| elif module in (MLP, Glm4MLP) or (TEFusedMLP is not None and module == TEFusedMLP): | |
| additional_mlp_kwargs['tp_group'] = pg_collection.tp | |
| return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) |
No description provided.