From 1ca2503d8f76809b92a449a5d19130f30be610d7 Mon Sep 17 00:00:00 2001 From: Jonathan Daniel <36337649+jond01@users.noreply.github.com> Date: Sun, 19 May 2024 14:58:51 +0300 Subject: [PATCH] [Tests] Fix an uncleaned DB file (#5589) --- tests/model_monitoring/test_store.py | 29 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/model_monitoring/test_store.py b/tests/model_monitoring/test_store.py index 6f45a5f53d0..422db8f0725 100644 --- a/tests/model_monitoring/test_store.py +++ b/tests/model_monitoring/test_store.py @@ -11,32 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import string import time -import typing import unittest.mock +from collections.abc import Iterator +from pathlib import Path from random import choice, randint -from typing import Optional +from typing import Optional, cast import pytest import mlrun.common.schemas import mlrun.model_monitoring -import mlrun.model_monitoring.db.stores.sqldb.sql_store from mlrun.common.schemas.model_monitoring import ResultData, WriterEvent -from mlrun.model_monitoring.db.stores import ( # noqa: F401 - StoreBase, -) +from mlrun.model_monitoring.db.stores.sqldb.sql_store import SQLStoreBase from mlrun.model_monitoring.writer import _AppResultEvent -SQLStoreBase = typing.TypeVar("SQLStoreBase", bound="StoreBase") - class TestSQLStore: _TEST_PROJECT = "test_model_endpoints" _MODEL_ENDPOINT_ID = "some-ep-id" - _STORE_CONNECTION = "sqlite:///test.db" + + @staticmethod + @pytest.fixture + def store_connection(tmp_path: Path) -> str: + return f"sqlite:///{tmp_path / 'test.db'}" @pytest.fixture() def _mock_random_endpoint( @@ -107,15 +107,16 @@ def init_sql_tables(new_sql_store: SQLStoreBase): @classmethod @pytest.fixture - def new_sql_store(cls) -> SQLStoreBase: + def new_sql_store(cls, store_connection: str) -> Iterator[SQLStoreBase]: # Generate store object target store_type_object = mlrun.model_monitoring.db.ObjectStoreFactory(value="sql") with unittest.mock.patch( "mlrun.model_monitoring.helpers.get_connection_string", - return_value=cls._STORE_CONNECTION, + return_value=store_connection, ): - sql_store: SQLStoreBase = store_type_object.to_object_store( - project=cls._TEST_PROJECT + sql_store = cast( + SQLStoreBase, + store_type_object.to_object_store(project=cls._TEST_PROJECT), ) yield sql_store sql_store.delete_model_endpoints_resources()