From f5056f71aab113b23cade07bbc9a2d0fc6cfa022 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 30 Nov 2023 15:03:55 -0800 Subject: [PATCH] Use weights_only for load --- generate.py | 2 +- quantize.py | 2 +- scripts/convert_hf_checkpoint.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/generate.py b/generate.py index 34f9625d..7f30de0a 100644 --- a/generate.py +++ b/generate.py @@ -219,7 +219,7 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) if use_tp: diff --git a/quantize.py b/quantize.py index a27a7cab..58a5ec72 100644 --- a/quantize.py +++ b/quantize.py @@ -540,7 +540,7 @@ def quantize( with torch.device('meta'): model = Transformer.from_name(checkpoint_path.parent.name) - checkpoint = torch.load(str(checkpoint_path), mmap=True) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 34475791..4aa5265c 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -64,7 +64,7 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True) + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) merged_result.update(state_dict) final_result = {} for key, value in merged_result.items():