In [None]:
!pip install ultralytics opencv-python

In [None]:
!sudo apt-get update 
!sudo apt-get install -y libgl1
!sudo apt-get install -y poppler-utils

In [None]:
from ultralytics import YOLO
import torch
import mlflow
import mlflow.pyfunc
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

import os
import pandas as pd



In [None]:
mlflow.end_run()

In [None]:
### Replace experiment and model name ###

experiment_name = "DEFAULT REPLACE ME"
model_name = "DEFAULT REPLACE ME"

In [None]:
mlflow.set_experiment(experiment_name)

# Start an MLflow run
with mlflow.start_run():
    model_path = "forgery_detect_model.pt"
    input_schema = Schema([ColSpec(type="string", name="image")])
    output_schema = Schema([ColSpec(type="string", name="results_json")])
    signature = ModelSignature(inputs=input_schema, outputs=output_schema)

    
    # Step 1: Create the wrapper module file
    module_path = "yolo_wrapper.py"
    
    with open(module_path, "w") as f:
        f.write('''
import json
import numpy as np
import os
from PIL import Image
import pandas as pd
import mlflow

class YOLOForgeryModel(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.model = None
        
    def load_context(self, context):
        """Load the YOLO model from artifacts"""
        from ultralytics import YOLO
        
        # Get model path from artifacts
        model_path = context.artifacts["model"]
        
        # Load the YOLO model
        self.model = YOLO(model_path)
    
    def predict(self, *args, **kwargs):
        """Make predictions on input data
        
        MLflow may call this method in different ways:
        - predict(model_input)
        - predict(model_input, params=params)
        - predict(context, model_input, params=params)
        """
        # Parse arguments based on how MLflow calls the method
        if len(args) == 1:
            # Called as predict(model_input)
            model_input = args[0]
            params = kwargs.get('params', None)
        elif len(args) == 2:
            # Could be predict(context, model_input) or predict(model_input, params)
            # Check if second arg looks like params
            if isinstance(args[1], dict) and any(k in args[1] for k in ['temperature', 'max_tokens']):
                model_input = args[0]
                params = args[1]
            else:
                # Assume it's (context, model_input)
                context = args[0]  # We don't use context in this implementation
                model_input = args[1]
                params = kwargs.get('params', None)
        elif len(args) >= 3:
            # Called as predict(context, model_input, params)
            context = args[0]  # We don't use context in this implementation
            model_input = args[1]
            params = args[2] if len(args) > 2 else kwargs.get('params', None)
        else:
            # Fallback to kwargs
            model_input = kwargs.get('model_input', kwargs.get('data', None))
            params = kwargs.get('params', None)
        
        # Check if model is loaded
        if self.model is None:
            raise RuntimeError("Model not loaded. load_context must be called first.")
        
        # Handle different input formats
        if isinstance(model_input, pd.DataFrame):
            # MLflow converts dict input to DataFrame with single row
            if 'image' in model_input.columns:
                image_path = model_input["image"].iloc[0]
            else:
                # Fallback if column names are numeric
                image_path = model_input.iloc[0, 0]
        elif isinstance(model_input, dict):
            # Direct dictionary input
            image_path = model_input.get("image")
        else:
            image_path = model_input
        
        # Load the image if it's a file path
        if isinstance(image_path, str) and os.path.isfile(image_path):
            # YOLO can handle file paths directly
            image = image_path
        elif isinstance(image_path, list):
            # Convert list to numpy array
            image = np.array(image_path, dtype=np.uint8)
        elif isinstance(image_path, np.ndarray):
            # Use numpy array directly
            image = image_path
        else:
            # For any other format, try to use as-is
            image = image_path
        
        # Run YOLO prediction
        results = self.model.predict(image)
        
        # Process results
        output_list = []
        for result in results:
            # Get summary from YOLO result
            summary = result.summary()
            
            if summary and len(summary) > 0:
                # Extract prediction details
                result_dict = {
                    "prediction": summary[0].get('name', 'unknown'),
                    "confidence": float(summary[0].get('confidence', 0.0))
                }
            else:
                # No detections
                result_dict = {
                    "prediction": "no_detection",
                    "confidence": 0.0
                }
            
            output_list.append(result_dict)
        
        # Return first result as JSON string
        final_result = output_list[0] if output_list else {"prediction": "error", "confidence": 0.0}
        
        # Return in the format expected by the output schema
        return pd.DataFrame([{"results_json": json.dumps(final_result)}])

def _load_pyfunc(data_path):
    """Load function that MLflow will call
    
    Note: data_path is the path to the YOLO model file
    """
    from ultralytics import YOLO
    
    # Create model instance
    model_instance = YOLOForgeryModel()
    
    # Since we're using loader_module pattern, we need to load the model here
    # The data_path is the actual YOLO model file path
    model_instance.model = YOLO(data_path)
    
    return model_instance
    ''')
    
    # Step 2: Define model schemas using the correct approach
    input_schema = Schema([ColSpec(type="string", name="image")])
    output_schema = Schema([ColSpec(type="string", name="results_json")])
    signature = ModelSignature(inputs=input_schema, outputs=output_schema)
    
   
    if mlflow.active_run() is None:
        mlflow.start_run()

    mlflow.pyfunc.log_model(
        artifact_path="model",
        loader_module="yolo_wrapper",
        data_path=model_path,
        code_paths=[module_path],
        signature=signature,
        pip_requirements=[
            "mlflow>=2.0.0",
            "ultralytics>=8.0.0",
            "torch>=1.7.0",
            "numpy>=1.18.0",
            "pillow>=7.0.0",
            "pandas>=1.0.0"
        ]
    )
    
    # Step 5: Get the current run ID
    run_id = mlflow.active_run().info.run_id
    
    # Step 6: Register the model with the MLflow Model Registry
    model_name = model_name
    model_version = mlflow.register_model(
        model_uri=f"runs:/{run_id}/model",
        name=model_name
    )
    
    print(f"Model registered as {model_name} version {model_version.version}")
    
    # End the run
    mlflow.end_run()



In [6]:
mlflow.last_active_run()

<Run: data=<RunData: metrics={}, params={}, tags={'mlflow.log-model.history': '[{"run_id": "74320470f12a473490f9230d6e3a08d7", '
                             '"artifact_path": "model", "utc_time_created": '
                             '"2025-05-22 18:00:45.473336", "model_uuid": '
                             '"66115c8403f146fea946faa224b4446c", "flavors": '
                             '{"python_function": {"streamable": false, '
                             '"loader_module": "yolo_wrapper", '
                             '"python_version": "3.12.7", "data": '
                             '"data/forgery_detect_model.pt", "env": {"conda": '
                             '"conda.yaml", "virtualenv": "python_env.yaml"}, '
                             '"code": "code"}}}]',
 'mlflow.runName': 'marvelous-shad-913',
 'mlflow.source.name': '/opt/conda/lib/python3.12/site-packages/ipykernel_launcher.py',
 'mlflow.source.type': 'LOCAL',
 'mlflow.user': 'jovyan'}>, info=<RunInfo: artifact_uri='/