Skip to content

Commit

Permalink
changed release method to __del__ & added comments to explain that ge…
Browse files Browse the repository at this point in the history
…nerate_artifacts typically will throw an error if unsuccessful
  • Loading branch information
carzh committed Jun 19, 2024
1 parent c27820e commit 44acd08
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 0 additions & 2 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,3 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
onnx.save(optim_model, optimizer_model_path)
_export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path)
logging.info("Saved optimizer model to %s", optimizer_model_path)

training_block.release()
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def infer_shapes_on_base(self):
else:
return onnx.shape_inference.infer_shapes(accessor._GLOBAL_ACCESSOR.model)

def release(self):
def __del__(self):
# since the ModelProto does not store the external data parameters themselves, just the metadata
# for where the external data can be found, we retain the external data files for the intermediate
# calls until the Block no longer needs to be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,8 @@ def test_generate_artifacts_external_data_separate_files():
artifact_directory=temp_dir,
)

# generate_artifacts should have thrown if it didn't complete successfully.
# Below is a sanity check to validate that all the expected files were created.
assert os.path.exists(os.path.join(temp_dir, "training_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "eval_model.onnx"))
assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
Expand Down

0 comments on commit 44acd08

Please sign in to comment.