Skip to content

Commit

Permalink
remove chdir jank
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Mar 15, 2024
1 parent 093c89f commit 5ccdacb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 33 deletions.
18 changes: 6 additions & 12 deletions machine-learning/app/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
Expand Down Expand Up @@ -115,17 +114,12 @@ def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
case ".armnn":
session = AnnSession(model_path)
case ".onnx":
cwd = os.getcwd()
try:
os.chdir(model_path.parent)
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
finally:
os.chdir(cwd)
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
Expand Down
21 changes: 0 additions & 21 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: Mocke

mock_ann = mocker.patch("app.models.base.AnnSession")
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
mocker.patch("app.models.base.os.chdir")

encoder = OpenCLIPEncoder("ViT-B-32__openai")
encoder._make_session(mock_armnn_path)
Expand All @@ -285,26 +284,6 @@ def test_make_session_raises_exception_if_path_does_not_exist(self, mocker: Mock
mock_ann.assert_not_called()
mock_ort.assert_not_called()

def test_make_session_changes_cwd(self, mocker: MockerFixture) -> None:
mock_model_path = mocker.Mock()
mock_model_path.is_file.return_value = True
mock_model_path.suffix = ".onnx"
mock_model_path.parent = "model_parent"
mock_model_path.with_suffix.return_value = mock_model_path
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
mock_chdir = mocker.patch("app.models.base.os.chdir")

encoder = OpenCLIPEncoder("ViT-B-32__openai")
encoder._make_session(mock_model_path)

mock_chdir.assert_has_calls(
[
mock.call(mock_model_path.parent),
mock.call(os.getcwd()),
]
)
mock_ort.assert_called_once()

def test_download(self, mocker: MockerFixture) -> None:
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")

Expand Down

0 comments on commit 5ccdacb

Please sign in to comment.