In [1]:
# register_pickle_as_pyfunc.py
import argparse, os, time, pickle
import pandas as pd
import mlflow
from mlflow.pyfunc import PythonModel
from mlflow.models.signature import infer_signature
from mlflow.tracking import MlflowClient

class PicklePyFunc(PythonModel):
    def load_context(self, context):
        with open(context.artifacts["model_pkl"], "rb") as f:
            self._m = pickle.load(f)

    def predict(self, context, model_input):
        X = model_input if isinstance(model_input, pd.DataFrame) else pd.DataFrame(model_input)
        # choose what you want to expose; predict_proba or predict
        if hasattr(self._m, "predict_proba"):
            return self._m.predict_proba(X)
        return self._m.predict(X)

def main(pkl_path, registry_name, experiment=None, sample_csv=None):
    if experiment:
        mlflow.set_experiment(experiment)

    # optional: build a signature from a small sample to make the model usable in UIs
    signature = None
    if sample_csv and os.path.exists(sample_csv):
        X_sample = pd.read_csv(sample_csv)
        # best-effort: run predict once locally to capture output shape
        with open(pkl_path, "rb") as f:
            m = pickle.load(f)
        y_sample = m.predict_proba(X_sample) if hasattr(m, "predict_proba") else m.predict(X_sample)
        signature = infer_signature(X_sample, y_sample)

    with mlflow.start_run():
        model_info = mlflow.pyfunc.log_model(
            artifact_path="model",
            python_model=PicklePyFunc(),
            artifacts={"model_pkl": pkl_path},
            # lock exact deps; add your estimator library if not vendored into the pickle
            pip_requirements=[
                "mlflow>=2.9.0",
                "pandas>=2.0.0",
                # add the library used to train the pickle, e.g.:
                # "scikit-learn==1.4.2",
                # "xgboost==2.0.3",
                # "lightgbm==4.3.0",
            ],
            signature=signature,
        )

    # Register to the Model Registry
    mv = mlflow.register_model(model_uri=model_info.model_uri, name=registry_name)

    # Optional: wait until "READY"
    client = MlflowClient()
    while True:
        mv = client.get_model_version(registry_name, mv.version)
        if mv.status in ("READY", "FAILED_REGISTRATION"):
            break
        time.sleep(1)

    print(f"Registered: name={mv.name} version={mv.version} status={mv.status}")

ap = argparse.ArgumentParser()
ap.add_argument("--pkl", required=True, help="Path to the pickle file")
ap.add_argument("--name", required=True, help="MLflow Model Registry name")
ap.add_argument("--experiment", default=None, help="Experiment name")
ap.add_argument("--sample-csv", default=None, help="Optional CSV to infer signature")
args = ap.parse_args()
main(args.pkl, args.name, args.tracking_uri, args.experiment, args.sample_csv)

usage: ipykernel_launcher.py [-h] --pkl PKL --name NAME
                             [--experiment EXPERIMENT]
                             [--sample-csv SAMPLE_CSV]
ipykernel_launcher.py: error: the following arguments are required: --pkl, --name


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
