diff --git a/README.md b/README.md index 00674fb442f..2e6cd40a9c8 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,7 @@ Here is an example of Accuracy Aware Quantization pipeline where model weights a ```python import nncf +import nncf.torch import torch from torchvision import datasets, models @@ -271,7 +272,7 @@ quantized_model = nncf.quantize(model, calibration_dataset) # Save quantization modules and the quantized model parameters checkpoint = { 'state_dict': model.state_dict(), - 'nncf_config': model.nncf.get_config(), + 'nncf_config': nncf.torch.get_config(model), ... # the rest of the user-defined objects to save } torch.save(checkpoint, path_to_checkpoint) diff --git a/docs/usage/training_time_compression/quantization_aware_training/Usage.md b/docs/usage/training_time_compression/quantization_aware_training/Usage.md index 79033ac09a6..aed7dfafba2 100644 --- a/docs/usage/training_time_compression/quantization_aware_training/Usage.md +++ b/docs/usage/training_time_compression/quantization_aware_training/Usage.md @@ -75,17 +75,19 @@ ov_quantized_model = ov.convert_model(stripped_model) The complete information about compression is defined by a compressed model and a NNCF config. The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules introduced by NNCF. -The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the +The NNCF config can be obtained by `nncf.torch.get_config` on saving and passed to the `nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config. The quantized model saving allows to load quantized modules to the target model in a new python process and requires only example input for the target module, corresponding NNCF config and the quantized model state dict. ```python +import nncf.torch + # save part quantized_model = nncf.quantize(model, calibration_dataset) checkpoint = { - 'state_dict':quantized_model.state_dict(), - 'nncf_config': quantized_model.nncf.get_config(), + 'state_dict': quantized_model.state_dict(), + 'nncf_config': nncf.torch.get_config(quantized_model), ... } torch.save(checkpoint, path) diff --git a/examples/quantization_aware_training/torch/resnet18/main.py b/examples/quantization_aware_training/torch/resnet18/main.py index a8a96d041c3..e1bdec21b27 100644 --- a/examples/quantization_aware_training/torch/resnet18/main.py +++ b/examples/quantization_aware_training/torch/resnet18/main.py @@ -278,11 +278,11 @@ def transform_fn(data_item): print(f"Train epoch: {epoch}") train_epoch(train_loader, quantized_model, criterion, optimizer, device=device) acc1_int8 = validate(val_loader, quantized_model, device) - print(f"Accyracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}") + print(f"Accuracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}") # Save the compression checkpoint for model with the best accuracy metric. if acc1_int8 > acc1_int8_best: state_dict = quantized_model.state_dict() - compression_config = quantized_model.nncf.get_config() + compression_config = nncf.torch.get_config(quantized_model) torch.save( { "model_state_dict": state_dict, diff --git a/nncf/torch/__init__.py b/nncf/torch/__init__.py index 4627b8d3f8f..d8cf9aee6a6 100644 --- a/nncf/torch/__init__.py +++ b/nncf/torch/__init__.py @@ -50,6 +50,7 @@ from nncf.torch.model_creation import is_wrapped_model from nncf.torch.model_creation import wrap_model from nncf.torch.model_creation import load_from_config +from nncf.torch.model_creation import get_config from nncf.torch.checkpoint_loading import load_state from nncf.torch.initialization import register_default_init_args from nncf.torch.layers import register_module diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index e64c8b48818..15bab5748b3 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -28,6 +28,8 @@ from nncf.config.extractors import has_input_info_field from nncf.config.telemetry_extractors import CompressionStartedFromConfig from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled +from nncf.experimental.torch2.function_hook.serialization import get_config as pt2_get_config +from nncf.experimental.torch2.function_hook.serialization import load_from_config as pt2_load_from_config from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -397,18 +399,43 @@ def is_wrapped_model(model: Any) -> bool: FunctionCallTelemetryExtractor("nncf.torch.load_from_config"), ], ) -def load_from_config(model: torch.nn.Module, config: Dict[str, Any], example_input: Any) -> NNCFNetwork: +def load_from_config(model: Module, config: Dict[str, Any], example_input: Optional[Any] = None) -> Module: """ - Wraps given model to a NNCFNetwork and recovers additional modules from given NNCFNetwork config. + Wraps given model and recovers additional modules from given config. Does not recover additional modules weights as they are located in a corresponded state_dict. :param model: PyTorch model. :param config: NNCNetwork config. :param example_input: An example input that will be used for model tracing. A tuple is interpreted as an example input of a set of non keyword arguments, and a dict as an example input of a set - of keywords arguments. - :return: NNCFNetwork builded from given model with additional modules recovered from given NNCFNetwork config. + of keywords arguments. Required with enabled legacy tracing mode. + :return: Wrapped model with additional modules recovered from given config. """ + if is_experimental_torch_tracing_enabled(): + return pt2_load_from_config(model, config) + + if example_input is None: + msg = "The 'example_input' parameter must be specified." + raise nncf.InternalError(msg) + nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY]) transformation_layout = deserialize_transformations(config) return PTModelTransformer(nncf_network).transform(transformation_layout) + + +@tracked_function( + NNCF_PT_CATEGORY, + [ + FunctionCallTelemetryExtractor("nncf.torch.get_config"), + ], +) +def get_config(model: Module) -> Dict[str, Any]: + """ + Returns the configuration object of the compressed model. + + :param model: The compressed model. + :return: The configuration object of the compressed model. + """ + if is_experimental_torch_tracing_enabled(): + return pt2_get_config(model) + return model.nncf.get_config() diff --git a/tests/torch2/function_hook/test_serialization.py b/tests/torch2/function_hook/test_serialization.py index 901e45dc8b0..4b72dff89e7 100644 --- a/tests/torch2/function_hook/test_serialization.py +++ b/tests/torch2/function_hook/test_serialization.py @@ -19,8 +19,8 @@ from nncf.experimental.torch2.function_hook import register_post_function_hook from nncf.experimental.torch2.function_hook import register_pre_function_hook from nncf.experimental.torch2.function_hook import wrap_model -from nncf.experimental.torch2.function_hook.serialization import get_config -from nncf.experimental.torch2.function_hook.serialization import load_from_config +from nncf.torch import get_config +from nncf.torch import load_from_config from tests.torch2.function_hook.helpers import HookWithState from tests.torch2.function_hook.helpers import SimpleModel