Skip to content
7 changes: 7 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,13 @@ def tobytes(self) -> bytes:
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def release(self) -> None:
"""Delete all references to the memory buffer and close the memory-mapped file."""
self._array = None
if self.raw is not None:
self.raw.close()
self.raw = None

@property
def metadata_props(self) -> dict[str, str]:
if self._metadata_props is None:
Expand Down
20 changes: 20 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,26 @@ def test_initialize(self):
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_release_does_not_invalidate_tensor(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
tensor = _core.ExternalTensor(
external_info.location,
offset=external_info.offset,
length=external_info.length,
dtype=ir.DataType.FLOAT,
base_dir=self.base_path,
name="input",
shape=_core.Shape(external_tensor.dims),
)
self.assertEqual(tensor.dtype, ir.DataType.FLOAT)
self.assertEqual(tensor.tobytes(), self.data.tobytes())
# Release tensor
tensor.release()
self.assertEqual(tensor.raw, None)
# Tensor can be re-loaded after release
self.assertEqual(tensor.tobytes(), self.data.tobytes())

def test_initialize_with_relative_path(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
Expand Down
11 changes: 11 additions & 0 deletions onnxscript/ir/_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _load_external_data_file(
if os.path.samefile(tensor.path, os.path.join(base_path, relative_path)):
# Copy the data as the .numpy() call references data from a file whose data is eventually modified
tensor_data = external_tensor.numpy().copy()
external_tensor.release()
tensor = _core.Tensor(
tensor_data, name=external_tensor.name, dtype=external_tensor.dtype
)
Expand Down Expand Up @@ -165,6 +166,8 @@ def _save_external_data(
current_offset = tensor_info.offset
assert tensor is not None
raw_data = tensor.tobytes()
if isinstance(tensor, _core.ExternalTensor):
tensor.release()
# Pad file to required offset if needed
file_size = data_file.tell()
if current_offset > file_size:
Expand Down Expand Up @@ -223,6 +226,7 @@ def convert_tensors_to_external(
path = os.path.join(base_path, relative_path)
# Check if file path is valid, and create subsequent subdirectories within the path if they don't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_file_created = False
# Check if file exists. Load pre-existing external data if it does.
if os.path.exists(path):
# Check if any tensor in the model is using the destination file
Expand All @@ -241,6 +245,7 @@ def convert_tensors_to_external(
os.makedirs(tmp_path, exist_ok=True)
# If exisiting external tensors are not loaded to memory, copy the external data to a temporary location
os.rename(path, os.path.join(tmp_path, relative_path))
tmp_file_created = True
for tensor in tensors:
if (
isinstance(tensor, _core.ExternalTensor)
Expand Down Expand Up @@ -270,6 +275,12 @@ def convert_tensors_to_external(
external_tensors[i]
for i in sorted(range(len(external_tensors)), key=lambda i: sorted_indices[i])
]

# Clean-up temporary file if it is created
tmp_path = os.path.join(base_path, "tmp", relative_path)
if os.path.exists(tmp_path) and tmp_file_created:
os.remove(tmp_path)

return external_tensors


Expand Down
16 changes: 14 additions & 2 deletions onnxscript/ir/_external_data_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import tempfile
import typing
import unittest
Expand Down Expand Up @@ -115,7 +116,10 @@ class OffloadExternalTensorTest(unittest.TestCase):

def setUp(self):
# File paths
self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
if sys.version_info[:2] >= (3, 10):
self.temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) # pylint: disable=consider-using-with
else:
self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.external_data_name = "external_tensors.bin"
self.base_path = self.temp_dir.name
self.ext_data_1 = "external_data_1.bin"
Expand All @@ -136,7 +140,15 @@ def setUp(self):
self.model_with_mixed_external_data = self._model_with_mixed_external_data()

def tearDown(self) -> None:
self.temp_dir.cleanup()
# Handle exceptions for windows and python versions < 3.10
try:
self.temp_dir.cleanup()
except PermissionError as e:
print(f"PermissionError: {e}")
except FileNotFoundError as e:
print(f"FileNotFoundError: {e}")
except Exception as e: # pylint: disable=broad-exception-caught
print(f"An unexpected error occurred: {e}")

def _simple_model(self) -> ir.Model:
tensor1 = ir.Tensor(
Expand Down
Loading