Skip to content

Commit

Permalink
[DB] Fix backup and migrations not to happen before unwanted-ly [Back…
Browse files Browse the repository at this point in the history
…port 0.10.x] (#1832)
  • Loading branch information
Hedingber committed Mar 22, 2022
1 parent 8a140a1 commit 089a21e
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 57 deletions.
1 change: 1 addition & 0 deletions mlrun/api/api/deps.py
Expand Up @@ -40,6 +40,7 @@ def verify_api_state(request: Request):
enabled_endpoints = [
"healthz",
"background-tasks",
"client-spec",
"migrations",
]
if not any(enabled_endpoint in path for enabled_endpoint in enabled_endpoints):
Expand Down
2 changes: 1 addition & 1 deletion mlrun/api/api/endpoints/operations.py
Expand Up @@ -22,7 +22,7 @@
http.HTTPStatus.ACCEPTED.value: {"model": mlrun.api.schemas.BackgroundTask},
},
)
def start_migration(
def trigger_migrations(
background_tasks: fastapi.BackgroundTasks, response: fastapi.Response,
):
# we didn't yet decide who should have permissions to such actions, therefore no authorization at the moment
Expand Down
50 changes: 30 additions & 20 deletions mlrun/api/initial_data.py
Expand Up @@ -13,38 +13,35 @@
import mlrun.api.db.sqldb.helpers
import mlrun.api.db.sqldb.models
import mlrun.api.schemas
import mlrun.api.utils.db.alembic
import mlrun.api.utils.db.backup
import mlrun.api.utils.db.mysql
import mlrun.api.utils.db.sqlite_migration
import mlrun.artifacts
from mlrun.api.db.init_db import init_db
from mlrun.api.db.session import close_session, create_session
from mlrun.config import config
from mlrun.utils import logger

from .utils.db.alembic import AlembicUtil
from .utils.db.backup import DBBackupUtil
from .utils.db.mysql import MySQLUtil
from .utils.db.sqlite_migration import SQLiteMigrationUtil


def init_data(
from_scratch: bool = False, perform_migrations_if_needed: bool = False
) -> None:
MySQLUtil.wait_for_db_liveness(logger)
logger.info("Initializing DB data")
mlrun.api.utils.db.mysql.MySQLUtil.wait_for_db_liveness(logger)

sqlite_migration_util = None
if not from_scratch and config.httpdb.db.database_migration_mode == "enabled":
sqlite_migration_util = SQLiteMigrationUtil()
sqlite_migration_util = (
mlrun.api.utils.db.sqlite_migration.SQLiteMigrationUtil()
)
alembic_util = _create_alembic_util()
(
is_migration_needed,
is_migration_from_scratch,
is_backup_needed,
) = _resolve_needed_operations(alembic_util, sqlite_migration_util, from_scratch)

if is_backup_needed:
logger.info("DB Backup is needed, backing up...")
db_backup = DBBackupUtil()
db_backup.backup_database()

if (
not is_migration_from_scratch
and not perform_migrations_if_needed
Expand All @@ -55,6 +52,11 @@ def init_data(
config.httpdb.state = state
return

if is_backup_needed:
logger.info("DB Backup is needed, backing up...")
db_backup = mlrun.api.utils.db.backup.DBBackupUtil()
db_backup.backup_database()

logger.info("Creating initial data")
config.httpdb.state = mlrun.api.schemas.APIStates.migrations_in_progress

Expand Down Expand Up @@ -94,18 +96,22 @@ def init_data(


def _resolve_needed_operations(
alembic_util: AlembicUtil,
sqlite_migration_util: typing.Optional[SQLiteMigrationUtil],
alembic_util: mlrun.api.utils.db.alembic.AlembicUtil,
sqlite_migration_util: typing.Optional[
mlrun.api.utils.db.sqlite_migration.SQLiteMigrationUtil
],
force_from_scratch: bool = False,
) -> typing.Tuple[bool, bool, bool]:
is_database_migration_needed = False
if sqlite_migration_util is not None:
is_database_migration_needed = (
sqlite_migration_util.is_database_migration_needed()
)
# the util checks whether the target DB has data, when database migration needed, it obviously does not have data
# but in that case it's not really a migration from scratch
is_migration_from_scratch = (
force_from_scratch or alembic_util.is_migration_from_scratch()
)
) and not is_database_migration_needed
is_schema_migration_needed = alembic_util.is_schema_migration_needed()
is_data_migration_needed = (
not _is_latest_data_version()
Expand Down Expand Up @@ -134,20 +140,22 @@ def _resolve_needed_operations(
return is_migration_needed, is_migration_from_scratch, is_backup_needed


def _create_alembic_util() -> AlembicUtil:
def _create_alembic_util() -> mlrun.api.utils.db.alembic.AlembicUtil:
alembic_config_file_name = "alembic.ini"
if MySQLUtil.get_mysql_dsn_data():
if mlrun.api.utils.db.mysql.MySQLUtil.get_mysql_dsn_data():
alembic_config_file_name = "alembic_mysql.ini"

# run schema migrations on existing DB or create it with alembic
dir_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
alembic_config_path = dir_path / alembic_config_file_name

alembic_util = AlembicUtil(alembic_config_path, _is_latest_data_version())
alembic_util = mlrun.api.utils.db.alembic.AlembicUtil(
alembic_config_path, _is_latest_data_version()
)
return alembic_util


def _perform_schema_migrations(alembic_util: AlembicUtil):
def _perform_schema_migrations(alembic_util: mlrun.api.utils.db.alembic.AlembicUtil):
logger.info("Performing schema migration")
alembic_util.init_alembic()

Expand All @@ -165,7 +173,9 @@ def _is_latest_data_version():


def _perform_database_migration(
sqlite_migration_util: typing.Optional[SQLiteMigrationUtil],
sqlite_migration_util: typing.Optional[
mlrun.api.utils.db.sqlite_migration.SQLiteMigrationUtil
],
):
if sqlite_migration_util:
logger.info("Performing database migration")
Expand Down
3 changes: 3 additions & 0 deletions mlrun/api/main.py
Expand Up @@ -238,6 +238,9 @@ def _cleanup_runtimes():

def main():
init_data()
logger.info(
"Starting API server", port=config.httpdb.port, debug=config.httpdb.debug,
)
uvicorn.run(
"mlrun.api.main:app",
host="0.0.0.0",
Expand Down
10 changes: 7 additions & 3 deletions mlrun/api/utils/db/mysql.py
Expand Up @@ -28,9 +28,13 @@ def wait_for_db_liveness(logger, retry_interval=3, timeout=2 * 60):
logger.debug("Waiting for database liveness")
mysql_dsn_data = MySQLUtil.get_mysql_dsn_data()
if not mysql_dsn_data:
logger.warn(
f"Invalid mysql dsn: {MySQLUtil.get_dsn()}, assuming sqlite and skipping liveness verification"
)
dsn = MySQLUtil.get_dsn()
if "sqlite" in dsn:
logger.debug("SQLite DB is used, liveness check not needed")
else:
logger.warn(
f"Invalid mysql dsn: {MySQLUtil.get_dsn()}, assuming live and skipping liveness verification"
)
return

tmp_connection = mlrun.utils.retry_until_successful(
Expand Down
48 changes: 46 additions & 2 deletions mlrun/db/httpdb.py
Expand Up @@ -113,6 +113,7 @@ def __init__(self, base_url, user="", password="", token=""):
self.server_version = ""
self.session = None
self._wait_for_project_terminal_state_retry_interval = 3
self._wait_for_background_task_terminal_state_retry_interval = 3
self._wait_for_project_deletion_interval = 3
self.client_version = version.Version().get()["version"]

Expand Down Expand Up @@ -1226,16 +1227,24 @@ def remote_start(self, func_url) -> schemas.BackgroundTask:
def get_project_background_task(
self, project: str, name: str,
) -> schemas.BackgroundTask:
""" Retrieve updated information on a background task being executed."""
"""Retrieve updated information on a project background task being executed."""

project = project or config.default_project
path = f"projects/{project}/background-tasks/{name}"
error_message = (
f"Failed getting background task. project={project}, name={name}"
f"Failed getting project background task. project={project}, name={name}"
)
response = self.api_call("GET", path, error_message)
return schemas.BackgroundTask(**response.json())

def get_background_task(self, name: str) -> schemas.BackgroundTask:
"""Retrieve updated information on a background task being executed."""

path = f"background-tasks/{name}"
error_message = f"Failed getting background task. name={name}"
response = self.api_call("GET", path, error_message)
return schemas.BackgroundTask(**response.json())

def remote_status(self, project, name, kind, selector):
""" Retrieve status of a function being executed remotely (relevant to ``dask`` functions).
Expand Down Expand Up @@ -2112,6 +2121,26 @@ def _verify_project_in_terminal_state():
_verify_project_in_terminal_state,
)

def _wait_for_background_task_to_reach_terminal_state(
self, name: str
) -> schemas.BackgroundTask:
def _verify_background_task_in_terminal_state():
background_task = self.get_background_task(name)
state = background_task.status.state
if state not in mlrun.api.schemas.BackgroundTaskState.terminal_states():
raise Exception(
f"Background task not in terminal state. name={name}, state={state}"
)
return background_task

return mlrun.utils.helpers.retry_until_successful(
self._wait_for_background_task_terminal_state_retry_interval,
60 * 60,
logger,
False,
_verify_background_task_in_terminal_state,
)

def _wait_for_project_to_be_deleted(self, project_name: str):
def _verify_project_deleted():
projects = self.list_projects(
Expand Down Expand Up @@ -2693,6 +2722,21 @@ def verify_authorization(
body=dict_to_json(authorization_verification_input.dict()),
)

def trigger_migrations(self) -> Optional[schemas.BackgroundTask]:
"""Trigger migrations (will do nothing if no migrations are needed) and wait for them to finish if actually
triggered
:returns: :py:class:`~mlrun.api.schemas.BackgroundTask`.
"""
response = self.api_call(
"POST", "operations/migrations", "Failed triggering migrations",
)
if response.status_code == http.HTTPStatus.ACCEPTED:
background_task = schemas.BackgroundTask(**response.json())
return self._wait_for_background_task_to_reach_terminal_state(
background_task.metadata.name
)
return None


def _as_json(obj):
fn = getattr(obj, "to_json", None)
Expand Down
5 changes: 1 addition & 4 deletions mlrun/runtimes/base.py
Expand Up @@ -708,10 +708,7 @@ def _store_run_dict(self, rundict: dict):
self._get_db().store_run(rundict, uid, project, iter=iter)

def _update_run_state(
self,
resp: dict = None,
task: RunObject = None,
err=None,
self, resp: dict = None, task: RunObject = None, err=None,
) -> dict:
"""update the task state in the DB"""
was_none = False
Expand Down
20 changes: 5 additions & 15 deletions tests/api/api/test_functions.py
Expand Up @@ -59,22 +59,15 @@ async def test_multiple_store_function_race_condition(
"""
This is testing the case that the retry_on_conflict decorator is coming to solve, see its docstring for more details
"""
project = {
"metadata": {
"name": "project-name",
}
}
response = await async_client.post(
"projects",
json=project,
)
project = {"metadata": {"name": "project-name"}}
response = await async_client.post("projects", json=project,)
assert response.status_code == HTTPStatus.CREATED.value
# Make the get function method to return None on the first two calls, and then use the original function
get_function_mock = tests.conftest.MockSpecificCalls(
mlrun.api.utils.singletons.db.get_db()._get_class_instance_by_uid, [1, 2], None
).mock_function
mlrun.api.utils.singletons.db.get_db()._get_class_instance_by_uid = (
unittest.mock.Mock(side_effect=get_function_mock)
mlrun.api.utils.singletons.db.get_db()._get_class_instance_by_uid = unittest.mock.Mock(
side_effect=get_function_mock
)
function = {
"kind": "job",
Expand All @@ -97,10 +90,7 @@ async def test_multiple_store_function_race_condition(
json=function,
)
)
response1, response2 = await asyncio.gather(
request1_task,
request2_task,
)
response1, response2 = await asyncio.gather(request1_task, request2_task,)

assert response1.status_code == HTTPStatus.OK.value
assert response2.status_code == HTTPStatus.OK.value
Expand Down

0 comments on commit 089a21e

Please sign in to comment.