-
Notifications
You must be signed in to change notification settings - Fork 31.2k
[core] Fix mxfp4 model loading #42070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MekkCyber
wants to merge
11
commits into
fix-bnb
Choose a base branch
from
fix-mxfp4
base: fix-bnb
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+155
−7
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
f0cf8d9
add mxfp4
MekkCyber 0364bcb
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber c4097db
fix missing keys
MekkCyber 0189b1d
fix dequantize path
MekkCyber 86082d2
first poc + tests passing
MekkCyber ee709ca
style
MekkCyber 32bec2b
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber 95933aa
add comment
MekkCyber 501ed80
fix
MekkCyber 3d7ce14
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber b8d2409
fix
MekkCyber File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| from ..core_model_loading import ConversionOps, get_loaded_parameter_class | ||
| from ..utils import is_accelerate_available, is_torch_available, logging | ||
|
|
||
|
|
||
|
|
@@ -25,6 +28,8 @@ | |
| import re | ||
| from contextlib import contextmanager | ||
|
|
||
| from ..quantizers.quantizers_utils import get_module_from_name | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
@@ -48,6 +53,61 @@ | |
| ] | ||
|
|
||
|
|
||
| class Mxfp4Quantize(ConversionOps): | ||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
| def convert( | ||
| self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys: Optional[list[str]] = None, **kwargs | ||
| ) -> dict[str, torch.Tensor]: | ||
| target_key, value = tuple(input_dict.items())[0] | ||
| value = value[0] if isinstance(value, list) else value | ||
| if not self.hf_quantizer.pre_quantized: | ||
| module, _ = get_module_from_name(model, target_key) | ||
| with torch.device(value.device): | ||
| if isinstance(module, Mxfp4GptOssExperts): | ||
| triton_weight_tensor, weight_scale = quantize_to_mxfp4(value, triton_kernels_hub) | ||
| PrecisionConfig, FlexCtx, InFlexData = ( | ||
| triton_kernels_hub.matmul_ogs.PrecisionConfig, | ||
| triton_kernels_hub.matmul_ogs.FlexCtx, | ||
| triton_kernels_hub.matmul_ogs.InFlexData, | ||
| ) | ||
| triton_weight_tensor, weight_scale = swizzle_mxfp4( | ||
| triton_weight_tensor, weight_scale, triton_kernels_hub | ||
| ) | ||
|
|
||
| proj = "gate_up_proj" if "gate_up_proj" in target_key else "down_proj" | ||
| setattr(module, proj, triton_weight_tensor) | ||
| setattr( | ||
| module, | ||
| f"{proj}_precision_config", | ||
| PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), | ||
| ) | ||
| missing_keys.discard(f"{target_key}_blocks") | ||
| missing_keys.discard(f"{target_key}_scales") | ||
| delattr(module, f"{proj}_blocks") | ||
| delattr(module, f"{proj}_scales") | ||
|
|
||
|
Comment on lines
+79
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we are setting this here, chek what I did for loadedparam. This is to make sure that we don't re-initialize the weights |
||
| else: | ||
|
|
||
| if ("blocks" in target_key or "scales" in target_key) and self.hf_quantizer.quantization_config.dequantize: | ||
| # blocks and scales have the same length that's why this works for both | ||
| module, _ = get_module_from_name(model, target_key[: -len("_blocks")]) | ||
| else: | ||
| module, _ = get_module_from_name(model, target_key) | ||
|
|
||
| if self.hf_quantizer.quantization_config.dequantize: | ||
| dequantize_convertops(module, target_key, value, value.device, missing_keys) | ||
| else: | ||
| # Eagerly set tensors on the module and perform swizzle; this function will | ||
| # set the appropriate attributes and remove *_blocks/_scales when both are loaded. | ||
| load_and_swizzle_mxfp4_convertops(module, target_key, value, value.device, missing_keys, triton_kernels_hub) | ||
|
|
||
| # We return an empty mapping since the module was updated in-place. This prevents | ||
| # the loader from trying to materialize the original meta-parameter names again. | ||
| # We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer | ||
| return {} | ||
|
|
||
|
|
||
| @contextmanager | ||
| def on_device(dev): | ||
| if is_torch_available(): | ||
|
|
@@ -88,13 +148,11 @@ def swizzle_mxfp4(w, w_scale, triton_kernels_hub): | |
| ) | ||
| layout = triton_kernels_hub.tensor_details.layout | ||
| StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout | ||
|
|
||
| value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) | ||
| w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) | ||
| w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) | ||
| return w, w_scale | ||
|
|
||
|
|
||
| # Copied from GPT_OSS repo | ||
| # TODO: Add absolute link when the repo is public | ||
| def convert_moe_packed_tensors( | ||
|
|
@@ -355,6 +413,22 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** | |
| delattr(module, blocks_attr) | ||
| delattr(module, scales_attr) | ||
|
|
||
| def dequantize_convertops(module, param_name, param_value, target_device, missing_keys): | ||
| for proj in ["gate_up_proj", "down_proj"]: | ||
| if proj in param_name: | ||
| blocks_attr = f"{proj}_blocks" | ||
| scales_attr = f"{proj}_scales" | ||
| setattr(module, param_name.rsplit(".", 1)[1], param_value) | ||
| if hasattr(module, blocks_attr) and hasattr(module, scales_attr): | ||
| dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) | ||
| if target_device == "cpu" and torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| dequantized = torch.nn.Parameter(dequantized.to(target_device)) | ||
| dequantized = get_loaded_parameter_class(dequantized.__class__)(from_existing=dequantized) | ||
| setattr(module, proj, dequantized) | ||
| missing_keys.discard(param_name.rsplit("_", 1)[0]) | ||
| delattr(module, blocks_attr) | ||
| delattr(module, scales_attr) | ||
|
|
||
| def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs): | ||
| """ | ||
|
|
@@ -423,6 +497,68 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito | |
| del blocks | ||
|
|
||
|
|
||
| def load_and_swizzle_mxfp4_convertops(module, param_name, param_value, target_device, missing_keys, triton_kernels_hub): | ||
| """ | ||
| This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`. | ||
| """ | ||
| PrecisionConfig, FlexCtx, InFlexData = ( | ||
| triton_kernels_hub.matmul_ogs.PrecisionConfig, | ||
| triton_kernels_hub.matmul_ogs.FlexCtx, | ||
| triton_kernels_hub.matmul_ogs.InFlexData, | ||
| ) | ||
|
|
||
| if "blocks" in param_name: | ||
| proj = param_name.split(".")[-1].split("_blocks")[0] | ||
| if "scales" in param_name: | ||
| proj = param_name.split(".")[-1].split("_scales")[0] | ||
|
|
||
| setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) | ||
| missing_keys.discard(param_name) | ||
|
|
||
| blocks_attr = f"{proj}_blocks" | ||
| scales_attr = f"{proj}_scales" | ||
| blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt | ||
| scales = getattr(module, scales_attr) | ||
|
|
||
| # check if blocks or scales are not on meta device | ||
| # if so, it means param_value is either a blocks or scales tensor | ||
| # and the other blocks or scales tensor is on the correct device | ||
|
|
||
| if blocks.device.type != "meta" and scales.device.type != "meta": | ||
| local_experts = blocks.size(0) | ||
| if blocks.device.type == "meta": | ||
| blocks = param_value | ||
| elif scales.device.type == "meta": | ||
| scales = param_value | ||
|
|
||
| if proj == "gate_up_proj": | ||
| blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1) | ||
| else: | ||
| blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) | ||
| if getattr(target_device, "type", target_device) == "cpu": | ||
| target_device = "cuda" | ||
|
|
||
| blocks = blocks.to(target_device).contiguous() | ||
| scales = scales.to(target_device).contiguous() | ||
| with on_device(target_device): | ||
| triton_weight_tensor, weight_scale = swizzle_mxfp4( | ||
| blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub | ||
| ) | ||
| # need to overwrite the shapes for the kernels | ||
| if proj == "gate_up_proj": | ||
| triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) | ||
| else: | ||
| triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) | ||
|
|
||
| # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor | ||
| setattr(module, proj, triton_weight_tensor) | ||
| setattr(module, f"{proj}_precision_config", PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) | ||
| delattr(module, scales_attr) | ||
| delattr(module, blocks_attr) | ||
| del blocks | ||
| del scales | ||
|
|
||
|
|
||
| def _replace_with_mxfp4_linear( | ||
| model, | ||
| modules_to_not_convert=None, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mxfp4should not even need to delete anything AFAIK because it just uses the blocks and scales !