Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Feb 3, 2024
1 parent 2c2cf59 commit bb56bd3
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions machine-learning/app/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import onnx
import onnxruntime as ort
from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes
from onnx.shape_inference import infer_shapes_path
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from typing_extensions import Buffer
import ann.ann
Expand Down Expand Up @@ -117,8 +117,7 @@ def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
model_path = onnx_path

if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
static_path = model_path.parent / "static_1" / "model.onnx"
static_path.parent.mkdir(parents=True, exist_ok=True)
static_path = model_path.parent / "model_static_1.onnx"
if not static_path.is_file():
self._convert_to_static(model_path, static_path)
model_path = static_path
Expand All @@ -138,29 +137,24 @@ def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
return session

def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
inferred = infer_shapes(onnx.load(source_path))
inputs = self._get_static_dims(inferred.graph.input)
outputs = self._get_static_dims(inferred.graph.output)
infer_shapes_path(source_path, strict_mode=True)
proto = onnx.load(source_path, load_external_data=False)
inputs = self._get_static_dims(proto.graph.input)
outputs = self._get_static_dims(proto.graph.output)

# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
# check_model gets called in update_inputs_outputs_dims
check_model = onnx.checker.check_model
try:

def check_model_stub(*args: Any, **kwargs: Any) -> None:
pass

onnx.checker.check_model = check_model_stub
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
updated_model = update_inputs_outputs_dims(proto, inputs, outputs)
finally:
onnx.checker.check_model = check_model

onnx.save(
updated_model,
target_path,
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=1048576,
)
onnx.save(updated_model, target_path)

def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
return {
Expand Down

0 comments on commit bb56bd3

Please sign in to comment.