Experiment and model pruning logic
- determine retention period: (e.g. 10 days)
- detemine minimum to retrain (even if older than expiry period)
- determine if we can know a priori if a model/experiment run can be linked to a deployed model


In [None]:
from datetime import timedelta

MAX_RUN_AGE: timedelta = timedelta(days=1)

In [None]:
import os

os.environ[
    "MLFLOW_TRACKING_TOKEN"
] = "eyJhbGciOiJIUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICI1Yzk1ZGE1OC1iZTNiLTQxYmMtOWE3MS04MTVhM2Q3NWM4OWEifQ.eyJpYXQiOjE2NzA0MzYyMDksImp0aSI6IjkyMGE4YzlkLWNlYzctNDE0Zi1hMWFlLTY5M2NjZTNlMGZhNiIsImlzcyI6Imh0dHBzOi8vYWlwLmFuYWNvbmRhLmNvbS9hdXRoL3JlYWxtcy9BbmFjb25kYVBsYXRmb3JtIiwiYXVkIjoiaHR0cHM6Ly9haXAuYW5hY29uZGEuY29tL2F1dGgvcmVhbG1zL0FuYWNvbmRhUGxhdGZvcm0iLCJzdWIiOiI5YmRlMjgzMC1iOWE5LTQ3ZGUtYWI0OC04NjY2MTlhZTk3MDIiLCJ0eXAiOiJPZmZsaW5lIiwiYXpwIjoiYXBwX2NsaWVudF8xMjgwY2Q1ZmFjNjk0MjRiOWEyNDFhNmMxODNlYzUxMSIsInNlc3Npb25fc3RhdGUiOiJlZGEzMWVjZi02ZjJlLTQxODYtOTBlMS04OTQ5Y2IwN2Y0ZWMiLCJzY29wZSI6Im9wZW5pZCBBbmFjb25kYV90ZW1wbGF0ZSBvZmZsaW5lX2FjY2VzcyBwcm9maWxlIGVtYWlsIiwic2lkIjoiZWRhMzFlY2YtNmYyZS00MTg2LTkwZTEtODk0OWNiMDdmNGVjIn0.yp64pyGA4bHKddW_aRJtwNZAY7j6y0AAxZ4LFEEuphY"

In [None]:
from mlflow import MlflowClient

In [None]:
tracking_uri: str = "https://mlflow-tracking-server-jburt-poc.aip.anaconda.com/"
registry_uri: str = "https://mlflow-tracking-server-jburt-poc.aip.anaconda.com/"
client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)

In [None]:
from mlflow.store.entities import PagedList
from mlflow.entities import Experiment


def get_experiments() -> list[Experiment]:
    experiments: PagedList[Experiment] = PagedList(items=[], token=None)

    halt_paging: bool = False
    page_token: Union[str, None] = None
    while not halt_paging:
        reported_experiments: PagedList[Experiment] = client.search_experiments(page_token=page_token)
        if reported_experiments.token is not None:
            page_token = reported_experiments.token
        else:
            halt_paging = True
        experiments.append(reported_experiments)

    return list(experiments[0])

In [None]:
from typing import Union
from mlflow.entities import Run


def get_experiment_runs(experiment_id: str) -> list[Run]:
    results: PagedList[Run] = PagedList(items=[], token=None)

    halt_paging: bool = False
    page_token: Union[str, None] = None
    while not halt_paging:
        reported_runs: PagedList[Run] = client.search_runs(experiment_ids=[experiment_id], page_token=page_token)
        if reported_runs.token is not None:
            page_token = reported_runs.token
        else:
            halt_paging = True
        results.append(reported_runs)

    return list(results[0])

In [None]:
from mlflow.entities.model_registry import ModelVersion


def get_prunable_model(run_id: str) -> Union[ModelVersion, None]:
    model_list: PagedList[ModelVersion] = client.search_model_versions(f"run_id = '{run_id}'")

    # There should only be a single match
    if len(model_list) != 1:
        return None

    model_version: ModelVersion = model_list[0]

    # We only want to pull models which have no stage (meaning not staging, production, or archived).
    if model_version.current_stage != "None":
        return None

    return model_version

In [None]:
from datetime import datetime
from mlflow.entities import Run
import json


def get_prunable_runs(runs: list[Run]) -> list[dict]:
    prunables: list[dict] = []
    for run in runs:
        if run.info.end_time:
            run_id: str = run.info.run_id

            run_end_time: int = run.info.end_time / 1000
            run_end_time_dt: datetime = datetime.fromtimestamp(run_end_time)

            model_meta = json.loads(run.data.tags["mlflow.log-model.history"])[0]
            model_run_id: str = model_meta["run_id"]

            # Prunable runs
            MAX_AGE_TIME: datetime = datetime.utcnow() - MAX_RUN_AGE
            if run_end_time_dt < MAX_AGE_TIME:
                model: Union[ModelVersion, None] = get_prunable_model(run_id=model_run_id)
                if model:
                    # At this point both the run and model are prunable.
                    prunable: dict = {
                        "experiment_run_id": run_id,
                        "model": {"name": model.name, "version": model.version},
                    }
                    prunables.append(prunable)
    return prunables

In [None]:
from mlflow.entities import Experiment

experiments: list[Experiment] = get_experiments()
for experiment in experiments:
    print(f"Reviewing experiment {experiment.experiment_id}")
    runs: list[Run] = get_experiment_runs(experiment_id=experiment.experiment_id)
    prunables = get_prunable_runs(runs=runs)
    print(prunables)