diff --git a/generate.py b/generate.py index 34f9625d..7f30de0a 100644 --- a/generate.py +++ b/generate.py @@ -219,7 +219,7 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) if use_tp: diff --git a/quantize.py b/quantize.py index a27a7cab..58a5ec72 100644 --- a/quantize.py +++ b/quantize.py @@ -540,7 +540,7 @@ def quantize( with torch.device('meta'): model = Transformer.from_name(checkpoint_path.parent.name) - checkpoint = torch.load(str(checkpoint_path), mmap=True) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 34475791..4aa5265c 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -64,7 +64,7 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True) + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) merged_result.update(state_dict) final_result = {} for key, value in merged_result.items():