Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,9 +1195,6 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N
# Load dataset
from auto_round.calib_dataset import get_dataloader

if _is_fp8_model(self.model):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)

if isinstance(self.dataset, str):
if self.tokenizer is None:
raise ValueError("A tokenizer must be set for the model when using a dataset string.")
Expand Down Expand Up @@ -1244,6 +1241,8 @@ def get_imatrix_hook(module, input, output):
dispatch_model(self.model, self.model.hf_device_map)
else:
model = model.to(self.device)
if _is_fp8_model(self.model):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
cnt = 0

# Run forward pass to accumulate imatrix
Expand Down Expand Up @@ -1422,7 +1421,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
"""
m = get_module(self.model, name)

# if m.__class__.__name__ == "FP8Linear":
if _is_fp8_linear(m):
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
set_module(self.model, name, m)
Expand Down
5 changes: 4 additions & 1 deletion auto_round/export/export_to_gguf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ def prepare_tensors(cls):
clean_weight_list = []

modify_name = _special_name_handle(cls, name)
orig_device = data_torch.device
data_torch = data_torch.to("cpu")
for new_name, data_torch in cls.modify_tensors(data_torch, modify_name, bid):
data_torch.to(orig_device)
skip = False
for tensor_info in cls.gguf_writer.tensors:
if new_name in tensor_info:
Expand Down Expand Up @@ -545,7 +548,7 @@ def prepare_tensors(cls):
attr_tensor = getattr(module, attr)
else:
attr_tensor = getattr(module, "w_" + attr)
if attr_tensor is None:
if attr_tensor is None or not isinstance(attr_tensor, torch.Tensor):
continue
kv_b = attr_tensor.view(n_head_kv, v_head_dim + qk_nope_head_dim, -1)
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
Expand Down
2 changes: 2 additions & 0 deletions auto_round/export/export_to_gguf/convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@

if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))

from auto_round.utils import LazyImport

gguf = LazyImport("gguf")

MistralTokenizerType = LazyImport("gguf.vocab.MistralTokenizerType")
MistralVocab = LazyImport("gguf.vocab.MistralVocab")
DATASET_MEAN = LazyImport("mistral_common.tokens.tokenizers.multimodal.DATASET_MEAN")
Expand Down
8 changes: 5 additions & 3 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def __getitem__(self, key):

# Changed to str as it relies on triton or others lib to load this
INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear",)
# INNER_SUPPORTED_LAYER_TYPES = (transformers.integrations.finegrained_fp8.FP8Linear,)

# transformers.integrations.finegrained_fp8.FP8Linear
if deepspeed_exists:
from deepspeed.module_inject import LinearAllreduce, LinearLayer

Expand Down Expand Up @@ -2298,10 +2297,14 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16):
Convert a model with FP8 quantized layers to a model with 16-bit linear layers.
This is useful for compatibility with other frameworks or for further processing.
"""
cnt = 0
for n, m in model.named_modules():
if m.__class__.__name__ == "FP8Linear":
new_module = convert_fp8_layer_to_linear(m, dtype=dtype)
set_module(model, n, new_module)
cnt += 1
if cnt % 10 == 0: # Tricky setting
clear_memory()
return model


Expand Down Expand Up @@ -2344,7 +2347,6 @@ def download_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
"""Download hugging face model from hf hub."""
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
from huggingface_hub.utils import EntryNotFoundError

if cache_dir is None:
cache_dir = HUGGINGFACE_HUB_CACHE
Expand Down