In [6]:
import sys

sys.path.append("../../.")

import mlflow
from mlflow import MlflowClient
from src.utils.logger import get_logger

# from services.embedding_service.app.utils.mlflow_utils import log_deployment_ready_model


from sentence_transformers import SentenceTransformer
from mlflow.models.signature import infer_signature


logger = get_logger(__name__)

MLFLOW_TRACKING_URI = "http://localhost:5000/"
MODEL_NAME = "all-MiniLM-L12-v2"
EXPERIMENT_NAME = "deployment-ready-embedding-model"
ALIAS = "champion"

mlflow.set_tracking_uri(uri=MLFLOW_TRACKING_URI)
client = MlflowClient()


def check_existing_experiment(experiment_name: str) -> None:
    """Ensure the given MLflow experiment exists and is not deleted; restore if needed."""
    client = MlflowClient()
    logger.info(f"Checking past experiment with name {experiment_name}")
    exp = client.get_experiment_by_name(experiment_name)
    if exp is not None:
        if exp.lifecycle_stage == "deleted":
            logger.info(
                f"Found soft-deleted experiment with name {experiment_name}, restoring..."
            )
            client.restore_experiment(exp.experiment_id)
    # Set the experiment (this will create it if it doesn‚Äôt exist)
    # mlflow.set_experiment(experiment_name)


def log_deployment_ready_model(
    model_name: str, experiment_name: str = None
):
    """Create a production-ready semantic search model."""
    if experiment_name is not None:
        mlflow.set_experiment(experiment_name)
    with mlflow.start_run() as run:
        logger.debug(f"Started run with info: {run.info}")

        model = SentenceTransformer(model_name)

        sample_input = ["input text"]
        sample_output = model.encode(sample_input)
        signature = infer_signature(
            model_input=sample_input, model_output=sample_output
        )

        model_info = mlflow.sentence_transformers.log_model(
            model=model,
            name=model_name,
            signature=signature,
        )

        logger.debug(f"Logged model URI: {model_info.model_uri}")
        return model_info

In [5]:
check_existing_experiment(EXPERIMENT_NAME)

In [8]:
model_info = log_deployment_ready_model(MODEL_NAME, EXPERIMENT_NAME)

üèÉ View run debonair-grouse-982 at: http://localhost:5000/#/experiments/4/runs/b3f46e5b2dc04abbbb4113fb94e3bb05
üß™ View experiment at: http://localhost:5000/#/experiments/4


In [9]:
mlflow.register_model(model_uri=model_info.model_uri, name=MODEL_NAME)
latest_version = client.get_registered_model(
    name=MODEL_NAME
).latest_versions[0]
client.set_registered_model_alias(
    name=MODEL_NAME, alias=ALIAS, version=latest_version.version
)

Successfully registered model 'all-MiniLM-L12-v2'.
2025/11/06 23:43:43 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: all-MiniLM-L12-v2, version 1
Created version '1' of model 'all-MiniLM-L12-v2'.


In [2]:
# Get the model version using a model URI
model_uri = f"models:/{MODEL_NAME}@{ALIAS}"

try:
    logger.debug("Loading registered model from MLflow registry...")
    model = mlflow.sentence_transformers.load_model(model_uri)
    logger.debug("Embedding model loaded successfully.")
except Exception as e:
    logger.warning(
        f"Embeding model URI {model_uri} not found. Download, log and register..."
    )
    model_info = log_deployment_ready_model(
        experiment_name=EXPERIMENT_NAME, model_name=MODEL_NAME, alias=ALIAS
    )

    mlflow.register_model(model_uri=model_info.model_uri, name=MODEL_NAME)
    latest_version = client.get_registered_model(name=MODEL_NAME).latest_versions[0]
    client.set_registered_model_alias(
        name=MODEL_NAME, alias=ALIAS, version=latest_version.version
    )
    model = mlflow.sentence_transformers.load_model(model_uri)

Downloading artifacts: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [02:00<00:00,  8.03s/it]   
2025/11/04 02:34:54 INFO mlflow.sentence_transformers: 'models:/all-MiniLM-L12-v2@champion' resolved as 'mlflow-artifacts:/4/models/m-6398fa4bdde54276a6fa0cc1a98244d9/artifacts'
Downloading artifacts: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 22.65it/s]
