Skip to content

Commit

Permalink
Support PyTorch FX diagnosis in Neural Insights (#1190)
Browse files Browse the repository at this point in the history
Signed-off-by: bmyrcha <bartosz.myrcha@intel.com>
Co-authored-by: Cheng, Penghui <penghui.cheng@intel.com>
Co-authored-by: Agata Radys <agata.radys@intel.com>
  • Loading branch information
3 people committed Sep 12, 2023
1 parent 8d2bf27 commit 74a785e
Show file tree
Hide file tree
Showing 52 changed files with 883 additions and 262 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/codeScan/pylint/pylint.sh
Expand Up @@ -52,7 +52,7 @@ elif [ "${scan_module}" = "neural_insights" ]; then
fi

python -m pylint -f json --disable=R,C,W,E1129 --enable=line-too-long --max-line-length=120 --extension-pkg-whitelist=numpy --ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow /neural-compressor/${scan_module} \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow,torchinfo /neural-compressor/${scan_module} \
>$log_dir/pylint.json

exit_code=$?
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -968,8 +968,10 @@ def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration
for idx, input in enumerate(dataloader):
if isinstance(input, dict) or isinstance(input, UserDict):
if not self.benchmark:
assert "label" in input, "The dataloader must include label to measure the metric!"
label = input["label"].to("cpu")
assert (
"label" in input or "labels" in input
), "The dataloader must include label to measure the metric!"
label = input["label"].to("cpu") if "label" in input else input["labels"].to("cpu")
elif not self.benchmark:
assert False, "The dataloader must include label to measure the metric!"

Expand Down
30 changes: 17 additions & 13 deletions neural_compressor/benchmark.py
Expand Up @@ -518,19 +518,23 @@ def fit(model, conf, b_dataloader=None, b_func=None):
set_env_var("NC_ENV_CONF", True, overwrite_existing=True)

if conf.diagnosis and os.environ.get("NC_ENV_CONF", None) in [None, "False"]:
logger.info("Start to run Profiling")
ni_workload_id = register_neural_insights_workload(
workload_location=os.path.abspath(os.path.abspath(options.workspace)),
model=wrapped_model,
workload_mode="benchmark",
)
try:
update_neural_insights_workload(ni_workload_id, "wip")
profile(wrapped_model, conf, b_dataloader)
update_neural_insights_workload(ni_workload_id, "success")
except Exception as e:
logger.error(e)
update_neural_insights_workload(ni_workload_id, "failure")
if b_dataloader is not None:
logger.info("Start to run Profiling")
ni_workload_id = register_neural_insights_workload(
workload_location=os.path.abspath(os.path.abspath(options.workspace)),
model=wrapped_model,
workload_mode="benchmark",
workload_name=conf.ni_workload_name,
)
try:
update_neural_insights_workload(ni_workload_id, "wip")
profile(wrapped_model, conf, b_dataloader=b_dataloader)
update_neural_insights_workload(ni_workload_id, "success")
except Exception as e:
logger.error(e)
update_neural_insights_workload(ni_workload_id, "failure")
else:
logger.warning("Profiling is only supported with b_dataloader.")

logger.info("Start to run Benchmark.")
if os.environ.get("NC_ENV_CONF") == "True":
Expand Down
33 changes: 33 additions & 0 deletions neural_compressor/config.py
Expand Up @@ -302,6 +302,7 @@ def __init__(
inter_num_of_threads=None,
intra_num_of_threads=None,
diagnosis=False,
ni_workload_name="profiling",
):
"""Init a BenchmarkConfig object."""
self.inputs = inputs
Expand All @@ -316,6 +317,7 @@ def __init__(
self.inter_num_of_threads = inter_num_of_threads
self.intra_num_of_threads = intra_num_of_threads
self.diagnosis = diagnosis
self.ni_workload_name = ni_workload_name
self._framework = None

def keys(self):
Expand Down Expand Up @@ -464,6 +466,17 @@ def diagnosis(self, diagnosis):
if _check_value("diagnosis", diagnosis, bool):
self._diagnosis = diagnosis

@property
def ni_workload_name(self):
"""Get Neural Insights workload name."""
return self._ni_workload_name

@ni_workload_name.setter
def ni_workload_name(self, ni_workload_name):
"""Set Neural Insights workload name."""
if _check_value("ni_workload_name", ni_workload_name, str):
self._ni_workload_name = ni_workload_name

@property
def model_name(self):
"""Get model name."""
Expand Down Expand Up @@ -801,6 +814,7 @@ def __init__(
accuracy_criterion=accuracy_criterion,
tuning_criterion=tuning_criterion,
diagnosis=False,
ni_workload_name="quantization",
):
"""Initialize _BaseQuantizationConfig class."""
self.inputs = inputs
Expand All @@ -822,6 +836,7 @@ def __init__(
self.quant_level = quant_level
self._framework = None
self.diagnosis = diagnosis
self.ni_workload_name = ni_workload_name
self._example_inputs = example_inputs

@property
Expand Down Expand Up @@ -1259,6 +1274,8 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig):
Please refer to docstring of AccuracyCriterion class.
diagnosis(bool): This flag indicates whether to do diagnosis.
Default value is False.
ni_workload_name: Custom workload name for Neural Insights diagnosis workload.
Default value is 'quantization'.
Example::
Expand Down Expand Up @@ -1293,6 +1310,7 @@ def __init__(
accuracy_criterion=accuracy_criterion,
tuning_criterion=tuning_criterion,
diagnosis=False,
ni_workload_name="quantization",
):
"""Init a PostTrainingQuantConfig object."""
super().__init__(
Expand All @@ -1313,9 +1331,13 @@ def __init__(
accuracy_criterion=accuracy_criterion,
tuning_criterion=tuning_criterion,
diagnosis=diagnosis,
ni_workload_name=ni_workload_name,
)
self.approach = approach
self.diagnosis = diagnosis
self.ni_workload_name = ni_workload_name
if self.diagnosis:
self.tuning_criterion.max_trials = 1

@property
def approach(self):
Expand Down Expand Up @@ -1343,6 +1365,17 @@ def diagnosis(self, diagnosis):
if _check_value("diagnosis", diagnosis, bool):
self._diagnosis = diagnosis

@property
def ni_workload_name(self):
"""Get Neural Insights workload name."""
return self._ni_workload_name

@ni_workload_name.setter
def ni_workload_name(self, ni_workload_name):
"""Set Neural Insights workload name."""
if _check_value("ni_workload_name", ni_workload_name, str):
self._ni_workload_name = ni_workload_name


class QuantizationAwareTrainingConfig(_BaseQuantizationConfig):
"""Config Class for Quantization Aware Training.
Expand Down
11 changes: 11 additions & 0 deletions neural_compressor/model/torch_model.py
Expand Up @@ -48,6 +48,7 @@ def __init__(self, model, **kwargs):
torch.nn.Module.__init__(self)
self._model = model
assert isinstance(model, torch.nn.Module), "model should be pytorch nn.Module."
self._model_path = None if not isinstance(model, str) else model
self.handles = []
self.tune_cfg = None
self.q_config = None
Expand Down Expand Up @@ -99,6 +100,16 @@ def model(self, model):
"""Setter to model."""
self._model = model

@property
def model_path(self):
"""Return model path."""
return self._model_path

@model_path.setter
def model_path(self, path):
"""Set model path."""
self._model_path = path

@property
def fp32_model(self):
"""Getter to model."""
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/quantization.py
Expand Up @@ -216,6 +216,7 @@ def eval_func(model):
workload_location=os.path.abspath(options.workspace),
model=wrapped_model,
workload_mode="quantization",
workload_name=conf.ni_workload_name,
)
if ni_workload_id:
update_neural_insights_workload(ni_workload_id, "wip")
Expand Down
71 changes: 60 additions & 11 deletions neural_compressor/utils/neural_insights_utils.py
Expand Up @@ -14,23 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neural Insights utils functions."""
import os
from typing import Any, Optional

from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.utils import logger


def register_neural_insights_workload(
workload_location: str,
model: Any,
workload_mode: str,
workload_name: str,
) -> Optional[str]:
"""Register workload to Neural Insights.
Args:
workload_location: path to workload directory
model: Neural Compressor's model instance to be registered
workload_mode: workload mode
workload_name: Name of the workload
Returns:
String with Neural Insight workload UUID if registered else None
Expand All @@ -46,22 +48,15 @@ def register_neural_insights_workload(
except ValueError:
raise Exception(f"Workload mode '{workload_mode}' is not supported.")

model_path = None
if isinstance(model.model_path, str):
model_path: str = os.path.abspath(model.model_path)
elif isinstance(model, ONNXModel):
import onnx

model_path: str = os.path.join(workload_location, "input_model.onnx")
os.makedirs(workload_location, exist_ok=True)
onnx.save(model.model, model_path)
assert isinstance(model_path, str), "Model path not detected"
model_path, model_summary_file = get_model_path(model, workload_location)

neural_insights = NeuralInsights(workdir_location=WORKDIR_LOCATION)
ni_workload_uuid = neural_insights.add_workload(
workload_location=workload_location,
workload_mode=mode,
model_path=model_path,
workload_name=workload_name,
model_summary_file=model_summary_file,
)
logger.info(f"Registered {workload_mode} workload to Neural Insights.")
return ni_workload_uuid
Expand Down Expand Up @@ -123,3 +118,57 @@ def update_neural_insights_workload_accuracy_data(
logger.info("Neural Insights not found.")
except Exception as err:
logger.warning(f"Could not update workload accuracy data: {err}.")


def get_model_path(model: Any, workload_location: str) -> Any:
"""Get model path."""
from neural_insights.utils.exceptions import ClientErrorException
from neural_insights.utils.utils import check_module

model_path = None
model_summary_file = None
onnx_installed = False
pytorch_installed = False

try:
check_module("onnx")
onnx_installed = True
except ClientErrorException:
pass

try:
check_module("torch")
pytorch_installed = True
except ClientErrorException:
pass

if isinstance(model.model_path, str):
return os.path.abspath(model.model_path)
if onnx_installed:
import onnx

from neural_compressor.model.onnx_model import ONNXModel

if isinstance(model, ONNXModel):
model_path: str = os.path.join(workload_location, "input_model.onnx")
os.makedirs(workload_location, exist_ok=True)
onnx.save(model.model, model_path)
return model_path, model_summary_file
if pytorch_installed:
import torch
from torchinfo import summary

from neural_compressor.model.torch_model import PyTorchModel

if isinstance(model, PyTorchModel):
model_path: str = os.path.join(workload_location, "input_model.pt")
os.makedirs(workload_location, exist_ok=True)
torch.save(model.model.state_dict(), model_path)

model_stats = summary(model.model, depth=5, verbose=0)
summary_str = str(model_stats)
model_summary_file = os.path.join(workload_location, "model_summary.txt")
with open(model_summary_file, "w", encoding="utf-8") as summary_file:
summary_file.write(summary_str)
return model_path, model_summary_file
assert isinstance(model_path, str), "Model path not detected"
3 changes: 3 additions & 0 deletions neural_compressor/utils/utility.py
Expand Up @@ -961,6 +961,9 @@ def print_op_list(workload_location: str):
None
"""
minmax_file_path = os.path.join(workload_location, "inspect_saved", "activation_min_max.pkl")
if not os.path.exists(minmax_file_path):
logging.getLogger("neural_compressor").warning("Could not find activation min max data.")
return
input_model_tensors = get_tensors_info(
workload_location,
model_type="input",
Expand Down
41 changes: 33 additions & 8 deletions neural_insights/components/diagnosis/diagnosis.py
Expand Up @@ -24,6 +24,7 @@
from neural_insights.components.model.model import Model
from neural_insights.components.workload_manager.workload import Workload
from neural_insights.utils.exceptions import ClientErrorException, InternalException
from neural_insights.utils.logger import log
from neural_insights.utils.utils import check_module


Expand Down Expand Up @@ -69,7 +70,17 @@ def load_quantization_config(self) -> dict:
"cfg.pkl",
)
if not os.path.exists(config_path):
raise ClientErrorException("Could not find config data for specified optimization.")
log.debug("Could not find config data for specified optimization. Getting data from inspect files.")
input_model_tensors: dict = self.get_tensors_info(model_type="input")["activation"][0]
optimized_model_tensors: dict = self.get_tensors_info(model_type="optimized")["activation"][0]
common_ops = list(set(input_model_tensors.keys()) & set(optimized_model_tensors.keys()))
config_data = {
"op": {},
}
for op in common_ops:
config_data["op"].update({(op,): {}})
return config_data

with open(config_path, "rb") as config_pickle:
config_data = pickle.load(config_pickle)
return config_data
Expand All @@ -79,23 +90,37 @@ def get_op_list(self) -> List[dict]:
check_module("numpy")
import numpy as np

op_list: List[dict] = []

input_model_tensors: dict = self.get_tensors_info(model_type="input")["activation"][0]
optimized_model_tensors: dict = self.get_tensors_info(model_type="optimized")["activation"][0]

minmax_file_path = os.path.join(
self.workload_location,
"inspect_saved",
"activation_min_max.pkl",
)
with open(minmax_file_path, "rb") as min_max_file:
min_max_data: dict = pickle.load(min_max_file)

op_list: List[dict] = []
input_model_tensors: dict = self.get_tensors_info(model_type="input")["activation"][0]
optimized_model_tensors: dict = self.get_tensors_info(model_type="optimized")["activation"][0]
try:
with open(minmax_file_path, "rb") as min_max_file:
min_max_data: dict = pickle.load(min_max_file)
except FileNotFoundError:
log.debug("Could not find minmax file.")
common_ops = list(set(input_model_tensors.keys()) & set(optimized_model_tensors.keys()))
min_max_data = dict(zip(common_ops, [{"min": None, "max": None}] * len(common_ops)))

for op_name, min_max in min_max_data.items():
mse = self.calculate_mse(op_name, input_model_tensors, optimized_model_tensors)
if mse is None or np.isnan(mse):
continue
min = float(min_max.get("min", None))
max = float(min_max.get("max", None))
min = min_max.get("min", None)
max = min_max.get("max", None)

if min is not None:
min = float(min)
if max is not None:
max = float(max)

op_entry = OpEntry(op_name, mse, min, max)
op_list.append(op_entry.serialize())
return op_list
Expand Down
2 changes: 2 additions & 0 deletions neural_insights/components/diagnosis/factory.py
Expand Up @@ -15,6 +15,7 @@
"""Diagnosis class factory."""
from neural_insights.components.diagnosis.diagnosis import Diagnosis
from neural_insights.components.diagnosis.onnx_diagnosis.onnxrt_diagnosis import OnnxRtDiagnosis
from neural_insights.components.diagnosis.pytorch_diagnosis.pytorch_diagnosis import PyTorchDiagnosis
from neural_insights.components.diagnosis.tensorflow_diagnosis.tensorflow_diagnosis import TensorflowDiagnosis
from neural_insights.components.workload_manager.workload import Workload
from neural_insights.utils.consts import Frameworks
Expand All @@ -33,6 +34,7 @@ def get_diagnosis(
diagnosis_map = {
Frameworks.ONNX: OnnxRtDiagnosis,
Frameworks.TF: TensorflowDiagnosis,
Frameworks.PT: PyTorchDiagnosis,
}
diagnosis = diagnosis_map.get(workload.framework, None)
if diagnosis is None:
Expand Down

0 comments on commit 74a785e

Please sign in to comment.