Skip to content

Commit

Permalink
adds debug options to dump onnx graphs (#1789)
Browse files Browse the repository at this point in the history
add debug options
  • Loading branch information
prathikr committed Apr 5, 2024
1 parent 35a81dc commit dac8645
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
9 changes: 8 additions & 1 deletion optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,14 @@ def _inner_training_loop(

# Wrap the model with `ORTModule`
logger.info("Wrap ORTModule for ONNX Runtime training.")
model = ORTModule(self.model)
if self.args.save_onnx:
from torch_ort import DebugOptions

model = ORTModule(
self.model, DebugOptions(save_onnx=self.args.save_onnx, onnx_prefix=self.args.onnx_prefix)
)
else:
model = ORTModule(self.model)
self.model_wrapped = model
self.model = model

Expand Down
30 changes: 30 additions & 0 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ class ORTTrainingArguments(TrainingArguments):
},
)

save_onnx: Optional[bool] = field(
default=False,
metadata={
"help": "Configure ORTModule to save onnx models. Defaults to False. \
The output directory of the onnx models by default is set to args.output_dir. \
To change the output directory, the environment variable ORTMODULE_SAVE_ONNX_PATH can be \
set to the destination directory path."
},
)

onnx_prefix: Optional[str] = field(
default=None,
metadata={"help": "Prefix for the saved ORTModule file names. Must be provided if save_onnx is True."},
)

onnx_log_level: Optional[str] = field(
default="WARNING",
metadata={
"help": "Configure ORTModule log level. Defaults to WARNING. \
onnx_log_level can also be set to one of VERBOSE, INFO, WARNING, ERROR, FATAL."
},
)

# This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 馃 Transformers.
def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
Expand Down Expand Up @@ -244,6 +267,13 @@ def __post_init__(self):
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")

if self.save_onnx:
if not self.onnx_prefix:
raise ValueError("onnx_prefix must be provided if save_onnx is True")
if not os.getenv("ORTMODULE_SAVE_ONNX_PATH", None):
os.environ["ORTMODULE_SAVE_ONNX_PATH"] = self.output_dir
os.environ["ORTMODULE_LOG_LEVEL"] = self.onnx_log_level

if (
is_torch_available()
and (self.device.type != "cuda")
Expand Down

0 comments on commit dac8645

Please sign in to comment.