Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bin format tests #27242

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from transformers.utils import (
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_flax_available,
Expand All @@ -91,6 +92,7 @@

if is_torch_available():
import torch
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from torch import nn

Expand Down Expand Up @@ -265,6 +267,46 @@ def check_save_load(out1, out2):
else:
check_save_load(first, second)

def test_save_load_bin_format(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

def check_save_load(out1, out2):
# make sure we don't have nans
out_2 = out2.cpu().numpy()
out_2[np.isnan(out_2)] = 0

out_1 = out1.cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)

for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)

# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)

model = model_class.from_pretrained(tmpdirname)
model.to(torch_device)
with torch.no_grad():
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]

if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_save_load(tensor1, tensor2)
else:
check_save_load(first, second)

def test_from_pretrained_no_checkpoint(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
Expand Down Expand Up @@ -295,6 +337,24 @@ def test_keep_in_fp32_modules(self):
else:
self.assertTrue(param.dtype == torch.float16, name)

def test_keep_in_fp32_modules_bin_format(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class._keep_in_fp32_modules is None:
return

model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)

model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)

for name, param in model.named_parameters():
if any(n in model_class._keep_in_fp32_modules for n in name.split(".")):
self.assertTrue(param.dtype == torch.float32)
else:
self.assertTrue(param.dtype == torch.float16, name)

def test_save_load_keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down Expand Up @@ -324,6 +384,38 @@ def test_save_load_keys_to_ignore_on_save(self):
)
self.assertTrue(len(load_result.unexpected_keys) == 0)

def test_save_load_keys_to_ignore_on_save_bin_format(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if _keys_to_ignore_on_save is None:
continue

# check the keys are in the original state_dict
for k in _keys_to_ignore_on_save:
self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)
output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME)
state_dict_saved = safe_load_file(output_model_file)

for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))

# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
load_result = model.load_state_dict(state_dict_saved, strict=False)
keys = set(model._keys_to_ignore_on_save)

if hasattr(model, "_tied_weights_keys"):
keys.update(set(model._tied_weights_keys))

self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys)
self.assertTrue(len(load_result.unexpected_keys) == 0)

def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down Expand Up @@ -420,6 +512,58 @@ class CopyClass(model_class):
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_from_base_bin_format(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]

if isinstance(base_class, tuple):
base_class = base_class[0]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass

model_class_copy = CopyClass

# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = _mock_init_weights
model_class_copy.init_weights = _mock_all_init_weights

model = base_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
# Before we test anything

for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
Expand Down Expand Up @@ -475,6 +619,61 @@ class CopyClass(base_class):
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base_bin_format(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]

if isinstance(base_class, tuple):
base_class = base_class[0]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass

base_class_copy = CopyClass

# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights

model = model_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.config.save_pretrained(tmpdirname, safe_serialization=False)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

model_fast_init = base_class_copy.from_pretrained(tmpdirname)
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = torch.max(
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
).item()
else:
max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down