Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions src/winml/modelkit/export/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def export_pytorch(
task: str | None = None,
verbose: bool = False,
enable_reporting: bool = False,
normalize: bool = True,
**kwargs: Any,
) -> dict[str, Any]:
"""Export a PyTorch nn.Module to ONNX.
Expand All @@ -63,9 +64,15 @@ def export_pytorch(
task: Task for auto-input generation fallback.
verbose: Enable verbose logging.
enable_reporting: Generate export report file.
normalize: If True (default), run optimize_onnx on the exported model
to apply graph-level optimizations and shape inference. Set False
to keep the raw torch.onnx.export output (useful when debugging
the exporter or running custom downstream optimization).

Returns:
Export statistics dict from HTPExporter.
Export statistics dict from HTPExporter, with an extra
`model_normalization_status` entry: one of `"not_run"` (when
`normalize=False`), `"succeeded"`, or `"failed"`.
"""
from .htp.exporter import HTPExporter

Expand All @@ -86,11 +93,66 @@ def export_pytorch(
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
return exporter.export(
stats = exporter.export(
model=model,
output_path=str(output_path),
export_config=export_config,
model_name_or_path=model_name_or_path,
task=task,
**kwargs,
)

if normalize:
stats["model_normalization_status"] = (
"succeeded" if _normalize_exported_model(output_path) else "failed"
)
else:
stats["model_normalization_status"] = "not_run"

Comment thread
vortex-captain marked this conversation as resolved.
return stats


def _normalize_exported_model(output_path: Path) -> bool:
"""Normalize the exported ONNX in-place via optimize_onnx.

Writes the normalized model into a temporary directory, then replaces
the original export (and its `.data` sidecar, if any) via
copy_onnx_model. The temp directory is removed either way.

Failure modes are not symmetric:
- An optimize_onnx failure leaves the original export untouched: the
temp directory is the only write target, and it is cleaned up.
- A copy_onnx_model failure may leave the original `.onnx` and/or
`.data` sidecar partially overwritten: copy_onnx_model writes
directly to the destination (no temp-and-rename), so a process
kill or full disk mid-copy can corrupt the destination.

Returns:
True if normalization succeeded, False otherwise. On False, the
traceback is included in the warning log to aid debugging.
"""
import shutil
import tempfile
Comment thread
vortex-captain marked this conversation as resolved.

from ..onnx import copy_onnx_model
from ..optim import optimize_onnx

logger.info("Normalizing model")
Comment thread
vortex-captain marked this conversation as resolved.
# Place the temp dir next to the output so copy_onnx_model stays on the
# same volume — avoids a cross-volume data transfer for multi-GB models
# and keeps the system drive's %TEMP% free of large sidecars.
tmp_dir = Path(tempfile.mkdtemp(dir=output_path.parent))
tmp_path = tmp_dir / output_path.name

try:
optimize_onnx(model=output_path, output=tmp_path)
copy_onnx_model(tmp_path, output_path)
except Exception:
logger.warning(
"Normalization failed; keeping un-normalized export",
exc_info=True,
)
return False
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
return True
63 changes: 63 additions & 0 deletions tests/unit/export/test_pytorch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from __future__ import annotations

from unittest.mock import patch

import onnx
import pytest
import torch
Expand All @@ -23,6 +25,20 @@
)


def _all_value_info_have_shape(model: onnx.ModelProto) -> bool:
"""Every intermediate (value_info) tensor has a concrete or symbolic shape."""
if not model.graph.value_info:
return False
for vi in model.graph.value_info:
shape = vi.type.tensor_type.shape
if not shape.dim:
return False
for dim in shape.dim:
if not dim.HasField("dim_value") and not dim.HasField("dim_param"):
return False
return True


# =============================================================================
# Test Models (pure PyTorch, no HF)
# =============================================================================
Expand Down Expand Up @@ -268,6 +284,53 @@ def forward(self, ids):
# This is expected — the test verifies the flow works up to export
pass

def test_normalization_succeeds_and_shape_inferences(self, tmp_path) -> None:
Comment thread
vortex-captain marked this conversation as resolved.
"""After export, status reports succeeded and value_info is fully shaped."""
model = TwoLayerNet()
config = WinMLExportConfig(
input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))],
)
result = export_pytorch(model, tmp_path / "model.onnx", config)

assert result["model_normalization_status"] == "succeeded"

onnx_model = onnx.load(str(tmp_path / "model.onnx"))
assert _all_value_info_have_shape(onnx_model)

def test_failed_normalization_skips_shape_inference(self, tmp_path) -> None:
"""When normalization is mocked to return False, status is failed."""
model = TwoLayerNet()
config = WinMLExportConfig(
input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))],
)
with patch(
"winml.modelkit.export.pytorch._normalize_exported_model",
return_value=False,
):
result = export_pytorch(model, tmp_path / "model.onnx", config)

assert result["model_normalization_status"] == "failed"

onnx_model = onnx.load(str(tmp_path / "model.onnx"))
assert not _all_value_info_have_shape(onnx_model)

def test_normalize_false_skips_normalization(self, tmp_path) -> None:
"""When normalize=False, the helper isn't called and status is not_run."""
model = TwoLayerNet()
config = WinMLExportConfig(
input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))],
)
with patch(
"winml.modelkit.export.pytorch._normalize_exported_model",
) as mock_normalize:
result = export_pytorch(model, tmp_path / "model.onnx", config, normalize=False)

mock_normalize.assert_not_called()
assert result["model_normalization_status"] == "not_run"

onnx_model = onnx.load(str(tmp_path / "model.onnx"))
Comment thread
vortex-captain marked this conversation as resolved.
assert not _all_value_info_have_shape(onnx_model)

def test_mismatched_input_order_exports_successfully(self, tmp_path) -> None:
"""Export succeeds when InputTensorSpec order differs from forward() param order.

Expand Down
127 changes: 120 additions & 7 deletions tests/unit/onnx/test_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper
from onnx import TensorProto, external_data_helper, helper, numpy_helper

from winml.modelkit.onnx.external_data import (
copy_onnx_model,
Expand All @@ -27,14 +27,39 @@ def _make_small_model() -> onnx.ModelProto:
"""Create a minimal ONNX model (no external data)."""
x_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4])
y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2])
weight = numpy_helper.from_array(
np.random.randn(4, 2).astype(np.float32), name="W"
)
weight = numpy_helper.from_array(np.random.randn(4, 2).astype(np.float32), name="W")
node = helper.make_node("MatMul", ["X", "W"], ["Y"])
graph = helper.make_graph([node], "test", [x_info], [y_info], [weight])
return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])


def _make_filled_model(value: float, shape: tuple[int, ...]) -> onnx.ModelProto:
"""Create a deterministic ONNX model with a constant-filled initializer.

Used by overwrite tests where two distinguishable models are needed.
"""
weight = numpy_helper.from_array(np.full(shape, value, dtype=np.float32), name="W")
inp = helper.make_tensor_value_info("X", TensorProto.FLOAT, list(shape))
out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(shape))
node = helper.make_node("Add", ["X", "W"], ["Y"])
graph = helper.make_graph([node], "g", [inp], [out], [weight])
return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])


def _serialize_without_external_location(model: onnx.ModelProto) -> bytes:
"""Serialize the model with the `location` entry stripped from every
external_data tensor — for comparing two models that point to different
sidecar filenames but are otherwise identical."""
clone = onnx.ModelProto()
clone.CopyFrom(model)
for tensor in external_data_helper._get_all_tensors(clone):
if tensor.data_location == TensorProto.EXTERNAL:
for entry in list(tensor.external_data):
if entry.key == "location":
tensor.external_data.remove(entry)
return clone.SerializeToString(deterministic=True)


class TestGetExternalDataFiles:
"""Tests for get_external_data_files()."""

Expand All @@ -51,7 +76,8 @@ def test_with_external_data(self, tmp_path: Path) -> None:
model = _make_small_model()
path = tmp_path / "ext.onnx"
onnx.save_model(
model, str(path),
model,
str(path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="ext.onnx.data",
Expand All @@ -74,7 +100,8 @@ def test_with_external(self, tmp_path: Path) -> None:
model = _make_small_model()
path = tmp_path / "ext.onnx"
onnx.save_model(
model, str(path),
model,
str(path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="ext.onnx.data",
Expand Down Expand Up @@ -106,7 +133,8 @@ def test_copy_with_external_data(self, tmp_path: Path) -> None:
src = tmp_path / "src" / "model.onnx"
src.parent.mkdir()
onnx.save_model(
model, str(src),
model,
str(src),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="model.onnx.data",
Expand Down Expand Up @@ -145,3 +173,88 @@ def test_copy_invalid_file_falls_back(self, tmp_path: Path) -> None:

assert dst.exists()
assert dst.read_text() == "not a real onnx file"

def test_copy_overwrites_existing_dst_no_external_data(self, tmp_path: Path) -> None:
"""Pre-existing dst (no external data) is overwritten byte-for-byte by src."""
src = tmp_path / "src.onnx"
dst = tmp_path / "dst.onnx"

onnx.save(_make_filled_model(1.0, (4, 4)), str(src))
onnx.save(_make_filled_model(99.0, (8, 8)), str(dst)) # pre-existing, different

pre_dst_bytes = dst.read_bytes()
src_bytes = src.read_bytes()
assert pre_dst_bytes != src_bytes

copy_onnx_model(src, dst)

post_dst_bytes = dst.read_bytes()
assert post_dst_bytes == src_bytes
assert post_dst_bytes != pre_dst_bytes
assert not (tmp_path / "dst.onnx.data").exists()

def test_copy_overwrites_existing_dst_with_external_data(self, tmp_path: Path) -> None:
"""Pre-existing dst + sidecar (external data) are both overwritten.

Verifies:
- dst.onnx.data is byte-identical to src.onnx.data
- dst.onnx matches src.onnx except for the external_data.location field
- dst.onnx's location field points at dst.onnx.data
- Loaded initializer arrays are equal
"""
src = tmp_path / "src.onnx"
dst = tmp_path / "dst.onnx"
src_data = tmp_path / "src.onnx.data"
dst_data = tmp_path / "dst.onnx.data"

onnx.save_model(
_make_filled_model(2.0, (64, 64)),
str(src),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="src.onnx.data",
size_threshold=0,
)
onnx.save_model(
_make_filled_model(999.0, (32, 32)),
str(dst),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="dst.onnx.data",
size_threshold=0,
)

src_data_bytes = src_data.read_bytes()
pre_dst_data_bytes = dst_data.read_bytes()
pre_dst_onnx_bytes = dst.read_bytes()
assert src_data_bytes != pre_dst_data_bytes

copy_onnx_model(src, dst)

# .data file byte-identical to src's sidecar
post_dst_data_bytes = dst_data.read_bytes()
assert post_dst_data_bytes == src_data_bytes
assert post_dst_data_bytes != pre_dst_data_bytes

# .onnx file no longer matches old dst
assert dst.read_bytes() != pre_dst_onnx_bytes

# .onnx matches src modulo external_data.location field
src_model = onnx.load(str(src), load_external_data=False)
dst_model = onnx.load(str(dst), load_external_data=False)
assert _serialize_without_external_location(
src_model
) == _serialize_without_external_location(dst_model)

# dst.onnx's location must point at dst.onnx.data
for tensor in external_data_helper._get_all_tensors(dst_model):
if tensor.data_location == TensorProto.EXTERNAL:
info = external_data_helper.ExternalDataInfo(tensor)
assert info.location == "dst.onnx.data"

# Semantic check: loaded initializer arrays are equal
src_full = onnx.load(str(src), load_external_data=True)
dst_full = onnx.load(str(dst), load_external_data=True)
src_arr = numpy_helper.to_array(src_full.graph.initializer[0])
dst_arr = numpy_helper.to_array(dst_full.graph.initializer[0])
assert np.array_equal(src_arr, dst_arr)
Loading