<!--Author: May Merkle-Tan-->
<!--Author: @may-merkletan_data --> 

# YOLO11n Object Detection on Databricks Serverless GPU (Distributed)

This notebook covers a workflow for training [YOLO11n](https://docs.ultralytics.com/models/yolo11/) on [COCO128 dataset](https://www.kaggle.com/datasets/ultralytics/coco128) using [Databricks Serverless GPU](https://docs.databricks.com/aws/en/compute/serverless/dependencies) with **distributed multi-node training** across multiple A10 GPUs, [MLflow tracking](https://docs.databricks.com/en/mlflow/index.html), and [Model Serving](https://docs.databricks.com/en/machine-learning/model-serving/index.html) deployment.

---

> #### ⚠️ Dataset Size & Overfitting
>
> **COCO128 is for demonstration only.** With only 128 images (~80 train, ~24 val, ~24 test), this dataset is too small for production. The model will severely overfit, with validation loss likely increasing after initial epochs.
>
> **For production:** Use larger datasets (e.g. with 100K+ images, or 1K+ domain-specific images). This workflow is production-ready and can be applied to larger datasets by updating data paths. See [NuInsSeg](https://github.com/databricks-industry-solutions/cv-playground/tree/main/projects/NuInsSeg) for a real-world example using YOLO instance segmentation on celltypes nuclei data.

---

### Workflow Overview
1. **Environment Setup**: Install packages, import libraries, define helper functions
2. **Unity Catalog Configuration**: Create schema, volume, and configure paths
3. **Data Preparation**: Download COCO128, create train/val/test splits
4. **MLflow Configuration**: Configure experiment and infer model signature
5. **Model Training**: Train YOLO11n with **distributed multi-GPU training**, MLflow tracking, and register to Unity Catalog
6. **Model Evaluation**: Test on validation/test sets, validate serving format
7. **Model Deployment**: Deploy to serving endpoint with [AI Gateway inference tables](https://docs.databricks.com/en/ai-gateway/inference-tables/) automatically enabled for request/response logging

### Key Features
* **Distributed multi-node training** - Automatic scaling across multiple A10 GPUs using `@distributed` decorator
* **[Base64 input format](https://en.wikipedia.org/wiki/Base64)** - Universal, works across network boundaries
* **Bounding box output** - Complete detection results (class, confidence, coordinates)
* **[Custom MLflow PyFunc wrapper](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#creating-custom-pyfunc-models)** - Clean integration with MLflow
* **[Unity Catalog Volumes](https://docs.databricks.com/en/connect/unity-catalog/volumes.html)** - Organized artifacts storage
* **[WorkspaceClient SDK](https://docs.databricks.com/en/dev-tools/sdk-python.html)** - Modern SDK methods for endpoint management
* **Production-ready** - Tested locally before deployment



## Connect to serverless GPU compute and install dependencies

Connect your notebook to serverless A10 GPU:

1. Click the **Connect** dropdown at the top.
1. Select **Serverless GPU**.
1. Open the **Environment** side panel on the right side of the notebook.
1. Set **Accelerator** to A10 for this demo. You do not need to install any dependencies in the environment panel to run this notebook.
1. Under **Environment** select **`AI v4 (Beta)`** Base environment 
1. Select **Apply** and click **Confirm** to apply this environment to your notebook.

---

> **⚠️ IMPORTANT NOTE:**  
> This notebook requires **`AI v4` Environment** for Serverless GPU and uses the **`@distributed`** decorator for multi-node training, which automatically provisions multiple A10 GPU nodes as needed.

---

## Environment Setup

Install required packages and configure Python environment for YOLO training on Serverless GPU.

In [0]:
# ============================================================
# PACKAGE INSTALLATION FOR SERVERLESS GPU
# ============================================================

# %pip install databricks-serverless-gpu         # Serverless GPU for multi-node distributed training (not needed - pre-installed in runtime)
%pip install -U mlflow>=3.0                    # MLflow for experiment tracking and model registry
%pip install ultralytics==8.3.204              # YOLO11n object detection framework
%pip install nvidia-ml-py==13.580.82           # NVIDIA GPU monitoring
%pip install threadpoolctl==3.1.0              # Controls CPU thread usage
%pip install pyrsmi==0.2.0                     # AMD GPU monitoring (if using AMD GPUs)
dbutils.library.restartPython()

# ============================================================

# Set Ultralytics config directory before importing (avoids permission errors)
import os
import uuid
config_dir = f'/tmp/yolo_config_{uuid.uuid4().hex[:8]}'
os.environ['YOLO_CONFIG_DIR'] = config_dir
os.makedirs(config_dir, exist_ok=True)

print("[OK] Packages installed and Python restarted")

In [0]:
# ============================================================
# PACKAGE VERIFICATION
# ============================================================

import sys
import importlib.metadata

print("Checking required packages...\n")

missing_packages = []
installed_packages = {}

# Check each required package using importlib.metadata
packages_to_check = [
    ('mlflow', 'mlflow>=3.0'),
    ('ultralytics', 'ultralytics==8.3.204'),
    ('opencv-python', 'opencv-python (provides cv2)'),
    ('nvidia-ml-py', 'nvidia-ml-py==13.580.82'),
    ('threadpoolctl', 'threadpoolctl==3.1.0'),
    ('pyrsmi', 'pyrsmi==0.2.0')
]

for package_name, package_spec in packages_to_check:
    try:
        version = importlib.metadata.version(package_name)
        installed_packages[package_name] = version
        print(f"✓ {package_name}: {version}")
    except importlib.metadata.PackageNotFoundError:
        missing_packages.append(package_spec)
        print(f"✗ {package_name}: NOT INSTALLED")

print("\n" + "="*60)
if missing_packages:
    print("[ACTION REQUIRED] Missing packages detected!")
    print("\nInstall missing packages and restart kernel.")
    for pkg in missing_packages:
        print(f"   - {pkg}")
else:
    print("[OK] All required packages are installed!")
    print(f"   Python version: {sys.version.split()[0]}")
print("="*60)

## Helper Functions

Utility functions for the complete YOLO training and deployment workflow:

**Data Management:**
* `download_file()` - Download models and configs to UC Volume
* `download_and_extract_dataset()` - Download and extract COCO128
* `split_dataset()` - Create reproducible train/val/test splits

**MLflow Integration:**
* `infer_model_signature()` - Automatically infer model signature from predictions
* `setup_mlflow_experiment()` - Configure MLflow with system metrics
* `register_yolo_model()` - Register model to Unity Catalog with custom wrapper

**Model Evaluation:**
* `evaluate_model_on_split()` - Evaluate and visualize predictions on data splits

**Custom Wrapper:**
* `YOLOWrapper` - [MLflow PyFunc wrapper](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#creating-custom-pyfunc-models) for YOLO models
  * **Input:** [Base64-encoded](https://en.wikipedia.org/wiki/Base64) images (universal format, works across network boundaries)
  * **Output:** DataFrame with class, confidence, bounding boxes (11 columns)
  * **Purpose:** Enables deployment to [Model Serving endpoints](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html)
  * **Production-ready:** Tested locally before deployment

In [0]:
# ============================================================
# HELPER FUNCTIONS
# ============================================================

import os
import shutil
import requests
import zipfile
import io
import random
import yaml
import glob
import json
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
from ultralytics import YOLO
import mlflow
from mlflow import MlflowClient
import importlib.metadata


def download_file(url, destination, description="file"):
    """Download a file from URL to destination path."""
    if os.path.exists(destination):
        print(f"[INFO] {description} already exists at: {destination}")
        print(f"   Skipping download")
        return True
    
    print(f"Downloading {description}...")
    print(f"   From: {url}")
    print(f"   To: {destination}")
    
    try:
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            os.makedirs(os.path.dirname(destination), exist_ok=True)
            with open(destination, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            print(f"[OK] Downloaded successfully")
            if destination.endswith('.pt'):
                print(f"   Size: {os.path.getsize(destination) / (1024*1024):.2f} MB")
            return True
        else:
            print(f"[ERROR] Download failed with status code: {response.status_code}")
            return False
    except Exception as e:
        print(f"[ERROR] Download failed: {e}")
        return False


def download_and_extract_dataset(download_url, extraction_path):
    """Download and extract a zip dataset."""
    print("Downloading dataset...")
    response = requests.get(download_url)
    
    print("Extracting dataset...")
    z = zipfile.ZipFile(io.BytesIO(response.content))
    z.extractall(extraction_path)
    
    print(f"[OK] Dataset downloaded and extracted to: {extraction_path}")
    return True


def split_dataset(source_images_dir, source_labels_dir, base_images_dir, base_labels_dir,
                  train_ratio=0.625, val_ratio=0.1875, random_seed=42):
    """Split dataset into train/val/test sets with reproducible random seed."""
    print("=" * 60)
    print("DATASET SPLITTING")
    print("=" * 60)
    
    random.seed(random_seed)
    print(f"\nRandom seed: {random_seed}")
    
    test_ratio = 1.0 - train_ratio - val_ratio
    print(f"Split ratios: Train={train_ratio:.1%}, Val={val_ratio:.1%}, Test={test_ratio:.1%}\n")
    
    # Get all images
    all_images = sorted([f for f in os.listdir(source_images_dir) if f.endswith('.jpg')])
    print(f"Total images: {len(all_images)}")
    
    # Shuffle and split
    random.shuffle(all_images)
    train_size = int(len(all_images) * train_ratio)
    val_size = int(len(all_images) * val_ratio)
    
    train_images = all_images[:train_size]
    val_images = all_images[train_size:train_size + val_size]
    test_images = all_images[train_size + val_size:]
    
    print(f"Split sizes: Train={len(train_images)}, Val={len(val_images)}, Test={len(test_images)}\n")
    
    # Create directories
    for split_name in ['train', 'val', 'test']:
        os.makedirs(f"{base_images_dir}/{split_name}", exist_ok=True)
        os.makedirs(f"{base_labels_dir}/{split_name}", exist_ok=True)
    
    # Copy files
    print("Copying files to splits...")
    for split_name, image_list in [('train', train_images), ('val', val_images), ('test', test_images)]:
        print(f"  Processing {split_name} split ({len(image_list)} images)...")
        for img_name in image_list:
            # Copy image
            src_img = os.path.join(source_images_dir, img_name)
            dst_img = os.path.join(base_images_dir, split_name, img_name)
            shutil.copy2(src_img, dst_img)
            
            # Copy label if exists
            label_name = img_name.replace('.jpg', '.txt')
            src_label = os.path.join(source_labels_dir, label_name)
            dst_label = os.path.join(base_labels_dir, split_name, label_name)
            if os.path.exists(src_label):
                shutil.copy2(src_label, dst_label)
        print(f"    [OK] {split_name}: {len(image_list)} images copied")
    
    print(f"\n[OK] Dataset split complete!")
    print("=" * 60)
    return len(train_images), len(val_images), len(test_images)


def infer_model_signature(model_path, sample_image_path):
    """Infer MLflow model signature using actual model predictions."""
    import base64
    
    print("[INFO] Inferring model signature...\n")
    
    # Load YOLO model
    model = YOLO(model_path)
    
    # Read and encode image as base64
    with open(sample_image_path, 'rb') as f:
        image_bytes = f.read()
    image_base64 = base64.b64encode(image_bytes).decode('utf-8')
    
    # Create input example
    input_example = pd.DataFrame({"image_base64": [image_base64]})
    
    # Create YOLOWrapper instance and get predictions to infer output schema
    wrapper = YOLOWrapper()
    
    # Simulate load_context
    class MockContext:
        def __init__(self, model_path):
            self.artifacts = {"yolo_model": model_path}
    
    wrapper.load_context(MockContext(model_path))
    
    # Get output example by running prediction
    output_example = wrapper.predict(None, input_example)
    
    # Use MLflow's infer_signature to automatically create signature
    signature = mlflow.models.infer_signature(input_example, output_example)
    
    print(f"[OK] Model signature inferred successfully!")
    print(f"   Input: DataFrame with 'image_base64' column (base64 string)")
    print(f"   Output: DataFrame with {len(output_example.columns)} columns")
    print(f"   Columns: {', '.join(output_example.columns.tolist())}")
    
    # Optional: Show how to use manual schema (commented out)
    # from mlflow.types.schema import Schema, ColSpec
    # from mlflow.models.signature import ModelSignature
    # input_schema = Schema([ColSpec("string", "image_base64")])
    # output_schema = Schema([
    #     ColSpec("string", "class_name"),
    #     ColSpec("long", "class_num"),
    #     ColSpec("double", "confidence"),
    #     ColSpec("double", "bbox_x1"),
    #     ColSpec("double", "bbox_y1"),
    #     ColSpec("double", "bbox_x2"),
    #     ColSpec("double", "bbox_y2"),
    #     ColSpec("double", "bbox_center_x"),
    #     ColSpec("double", "bbox_center_y"),
    #     ColSpec("double", "bbox_width"),
    #     ColSpec("double", "bbox_height")
    # ])
    # signature = ModelSignature(inputs=input_schema, outputs=output_schema)
    
    return signature, input_example


def setup_mlflow_experiment(notebook_path):
    """Setup MLflow experiment with system metrics enabled."""
    notebook_dir = '/'.join(notebook_path.split('/')[:-1])
    experiment_name = f"{notebook_dir}/Experiments_YOLO_CoCo"
    
    os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"
    os.environ['MLFLOW_EXPERIMENT_NAME'] = experiment_name
    
    mlflow.set_experiment(experiment_name)
    experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
    
    if 'MLFLOW_RUN_ID' in os.environ:
        del os.environ['MLFLOW_RUN_ID']
    
    print(f"[OK] MLflow experiment initialized: {experiment_name}")
    print(f"   Experiment ID: {experiment_id}")
    print(f"   System metrics: ENABLED")
    return experiment_name, experiment_id


class YOLOWrapper(mlflow.pyfunc.PythonModel):
    """Custom MLflow wrapper for YOLO models using base64-encoded images."""
    
    def load_context(self, context):
        """Load YOLO model from artifacts."""
        from ultralytics import YOLO
        model_path = context.artifacts["yolo_model"]
        self.model = YOLO(model_path, task='detect')
    
    def _format_predictions(self, predictions):
        """Format YOLO prediction results with bounding boxes.
        
        Args:
            predictions: YOLO prediction results from model.predict()
            
        Returns:
            pd.DataFrame with class, confidence, and bounding box coordinates
        """
        import pandas as pd
        
        all_results = []
        for prediction in predictions:
            if prediction.boxes is not None:
                boxes = prediction.boxes
                for i in range(len(boxes)):
                    # Get bounding box coordinates in both formats
                    box_xyxy = boxes.xyxy[i].cpu().numpy()
                    box_xywh = boxes.xywh[i].cpu().numpy()
                    
                    all_results.append({
                        "class_name": prediction.names[int(boxes.cls[i])],
                        "class_num": int(boxes.cls[i]),
                        "confidence": float(boxes.conf[i]),
                        "bbox_x1": float(box_xyxy[0]),
                        "bbox_y1": float(box_xyxy[1]),
                        "bbox_x2": float(box_xyxy[2]),
                        "bbox_y2": float(box_xyxy[3]),
                        "bbox_center_x": float(box_xywh[0]),
                        "bbox_center_y": float(box_xywh[1]),
                        "bbox_width": float(box_xywh[2]),
                        "bbox_height": float(box_xywh[3])
                    })
        
        return pd.DataFrame(all_results)
    
    def predict(self, context, model_input):
        """Run YOLO prediction on base64-encoded images.
        
        Args:
            context: MLflow context
            model_input: DataFrame with 'image_base64' column (base64-encoded images)
            
        Returns:
            pd.DataFrame with detection results including bounding boxes
        """
        import pandas as pd
        import base64
        from PIL import Image
        import io
        import numpy as np
        
        if not isinstance(model_input, pd.DataFrame):
            raise ValueError("Input must be a DataFrame with 'image_base64' column")
        
        if 'image_base64' not in model_input.columns:
            raise ValueError("DataFrame must contain 'image_base64' column with base64-encoded images")
        
        # Process base64-encoded images
        all_predictions = []
        for image_base64 in model_input['image_base64'].tolist():
            # Decode base64 to image
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(io.BytesIO(image_bytes))
            image_array = np.array(image)
            
            # Run prediction
            predictions = self.model.predict(image_array, verbose=False)
            all_predictions.extend(predictions)
        
        return self._format_predictions(all_predictions)


def register_yolo_model(run_id, model_path, catalog_name, schema_name, model_name,
                       signature=None, input_example=None, data_yaml_path=None):
    """Register YOLO model to Unity Catalog with custom wrapper."""
    registered_model_name = f"{catalog_name}.{schema_name}.{model_name}"
    ultralytics_version = importlib.metadata.version('ultralytics')
    cloudpickle_version = importlib.metadata.version('cloudpickle')
    
    print(f"\n[INFO] Registering model to Unity Catalog...")
    print(f"   Model name: {registered_model_name}")
    print(f"   Using custom YOLO wrapper (base64 input, bbox output)")
    print(f"   Pinning CloudPickle version: {cloudpickle_version}")
    
    with mlflow.start_run(run_id=run_id):
        if data_yaml_path:
            mlflow.log_artifact(data_yaml_path, "input_data")
        
        mlflow.pyfunc.log_model(
            artifact_path="model",
            python_model=YOLOWrapper(),
            artifacts={"yolo_model": model_path},
            signature=signature,
            input_example=input_example,
            registered_model_name=registered_model_name,
            pip_requirements=[
                f"ultralytics=={ultralytics_version}",
                f"cloudpickle=={cloudpickle_version}",
                "torch", 
                "torchvision", 
                "pillow", 
                "numpy"
            ]
        )
    
    print(f"   [OK] Model registered: {registered_model_name}")
    return registered_model_name


def evaluate_model_on_split(model, image_dir, split_name, output_dir, run_id, 
                           registered_model_name, organized_run_name, num_samples=3):
    """Evaluate model on a dataset split and save results."""
    print("=" * 60)
    print(f"{split_name.upper()} SET EVALUATION")
    print("=" * 60)
    
    os.makedirs(output_dir, exist_ok=True)
    images = glob.glob(f"{image_dir}/*.jpg")
    
    if not images:
        print(f"[WARNING] No {split_name} images found")
        return
    
    print(f"\n{split_name.capitalize()} set: {len(images)} images\n")
    
    # Visualize sample predictions
    sample_images = images[:num_samples]
    fig, axes = plt.subplots(1, len(sample_images), figsize=(15, 5))
    if len(sample_images) == 1:
        axes = [axes]
    
    results = []
    for i, img_path in enumerate(sample_images):
        print(f"Sample {i+1}/{len(sample_images)}: {img_path.split('/')[-1]}")
        predictions = model.predict(img_path, verbose=False)
        
        if len(predictions) > 0:
            result = predictions[0]
            annotated_img = result.plot()
            axes[i].imshow(annotated_img)
            axes[i].axis('off')
            
            if result.boxes is not None:
                num_detections = len(result.boxes)
                axes[i].set_title(f"{img_path.split('/')[-1]}\n{num_detections} objects", fontsize=10)
                print(f"   [OK] Detections: {num_detections} objects")
                
                img_results = {
                    "image": img_path.split('/')[-1],
                    "num_detections": num_detections,
                    "detections": []
                }
                
                for j in range(min(num_detections, 3)):
                    class_name = result.names[int(result.boxes.cls[j])]
                    confidence = float(result.boxes.conf[j])
                    print(f"      - {class_name}: {confidence:.3f}")
                    img_results["detections"].append({
                        "class_name": class_name,
                        "confidence": confidence
                    })
                results.append(img_results)
        print()
    
    plt.tight_layout()
    plt.suptitle(f"{split_name.capitalize()} Set Predictions - Run {run_id[:8]}", fontsize=14, y=1.02)
    
    plot_path = os.path.join(output_dir, f"{split_name}_predictions.png")
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"[OK] Plot saved to: {plot_path}")
    plt.show()
    
    # Save results JSON
    json_path = os.path.join(output_dir, f"{split_name}_results.json")
    with open(json_path, 'w') as f:
        json.dump({
            "run_id": run_id,
            "registered_model": registered_model_name,
            "timestamp": organized_run_name.split('_run_')[0],
            "num_images": len(images),
            "sample_results": results
        }, f, indent=2)
    print(f"[OK] Results saved to: {json_path}")
    
    # Log to MLflow
    with mlflow.start_run(run_id=run_id):
        mlflow.log_artifact(plot_path, split_name)
        mlflow.log_artifact(json_path, split_name)
    
    print(f"\n[OK] {split_name.upper()} SET EVALUATION COMPLETE")
    print("=" * 60)


print("[OK] Helper functions loaded successfully")

## Setup Unity Catalog Storage

Configure catalog, schema, and volume for persistent storage.

In [0]:
dbutils.widgets.removeAll()

In [0]:
# Define widgets for catalog, schema, volume, model name, and deployment approval
dbutils.widgets.text("catalog_name", "main", "Catalog Name")
dbutils.widgets.text("schema_name", "sgc-nightly", "Schema Name")
dbutils.widgets.text("volume_name", "yolo_sgc_distributed", "Volume Name")
dbutils.widgets.text("model_name", "yolo11n_coco128_sgc_distributed", "Model Name")
dbutils.widgets.dropdown("proceed_with_deployment", "false", ["false", "true"], "Proceed with Deployment")

# Get widget values
catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
volume_name = dbutils.widgets.get("volume_name")
model_name = dbutils.widgets.get("model_name")
proceed_with_deployment_str = dbutils.widgets.get("proceed_with_deployment")

print(f"[Configuration]")
print(f"   Catalog: {catalog_name}")
print(f"   Schema: {schema_name}")
print(f"   Volume: {volume_name}")
print(f"   Model: {model_name}")
print(f"   Proceed with Deployment: {proceed_with_deployment_str}")

print(f"\nUsing catalog: {catalog_name} (already exists)")

# Create schema if it doesn't exist
spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{catalog_name}`.`{schema_name}`")
print(f"[OK] Schema: {catalog_name}.{schema_name}")

# Create volume for persistent storage
spark.sql(f"CREATE VOLUME IF NOT EXISTS `{catalog_name}`.`{schema_name}`.`{volume_name}`")
print(f"[OK] Volume: {catalog_name}.{schema_name}.{volume_name}")

In [0]:
# Get Unity Catalog parameters
catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
volume_name = dbutils.widgets.get("volume_name")

# Construct volume path from parameters
project_location = f'/Volumes/{catalog_name}/{schema_name}/{volume_name}/'
print(f"Using Unity Catalog Volume: {catalog_name}.{schema_name}.{volume_name}")
print(f"Volume path: {project_location}")

# Create subdirectories in the volume
os.makedirs(f'{project_location}runs/', exist_ok=True)       # Training runs (organized by task/model/dataset)
os.makedirs(f'{project_location}data/', exist_ok=True)       # Dataset storage
os.makedirs(f'{project_location}raw_model/', exist_ok=True)  # Pretrained models

# Ephemeral /tmp/ location for faster I/O during training
tmp_project_location = "/tmp/training_results/"
os.makedirs(tmp_project_location, exist_ok=True)

print(f"\n[OK] Project directories created:")
print(f"   Runs: {project_location}runs/")
print(f"   Data: {project_location}data/")
print(f"   Raw models: {project_location}raw_model/")
print(f"   Temp (training): {tmp_project_location}  # Ephemeral, fast I/O")

## Project Folder Structure

Unity Catalog Volume organization:

```
/Volumes/{catalog}/{schema}/{volume}/
├── data/
│   ├── coco128.yaml                    # Dataset configuration
│   └── coco128/
│       ├── images/
│       │   ├── train2017/              # Original 128 images (from zip)
│       │   ├── train/  val/  test/     # Custom splits (80/24/24)
│       └── labels/
│           ├── train2017/              # Original labels
│           └── train/  val/  test/     # Split labels
│
├── raw_model/
│   └── yolo11n.pt                      # Pretrained YOLO11n weights
│
└── runs/
    └── {task}_{model}_{dataset}_{timestamp}_run_{mlflow_run_id}/
        ├── train/                      # MLflow training outputs
        │   ├── weights/ (best.pt, last.pt)
        │   └── results.csv, confusion_matrix.png
        ├── validation_metrics/         # YOLO validation outputs
        ├── validation_samples/         # Custom evaluation samples
        └── test_samples/               # Test evaluation samples
```

**Run Naming:** `detection_yolo11n_coco128_20260120_143052_run_{mlflow_run_id}`
* Includes task, model, dataset, timestamp, and MLflow run ID for easy identification

## Download Pretrained YOLO Model

Download YOLO11n pretrained weights to Unity Catalog Volume.

In [0]:
# Download pretrained YOLO11n model to Unity Catalog Volume
model_path = f"{project_location}raw_model/yolo11n.pt"
model_url = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt"

download_file(model_url, model_path, "YOLO11n model")
print(f"\n[OK] Pretrained model ready at: {model_path}")

## Dataset Preparation

Download and configure COCO128 dataset.

In [0]:
# Download COCO128 dataset configuration to UC Volume
import yaml

# Create data directory in UC Volume
os.makedirs(f'{project_location}data/coco128', exist_ok=True)

# Download config directly to UC Volume
config_url = "https://github.com/ultralytics/ultralytics/raw/main/ultralytics/cfg/datasets/coco128.yaml"
config_path = f"{project_location}data/coco128.yaml"  # UC Volume path

download_file(config_url, config_path, "COCO128 config")

# Load and update configuration
with open(config_path, 'r') as f:
    data = yaml.safe_load(f)

print(f"\n[Dataset Configuration]")
print(f"   Dataset: {data.get('path', 'coco128')}")
print(f"   Classes: {data.get('nc', 'unknown')}")
print(f"   Download URL: {data.get('download', 'N/A')}")

# Update paths for Unity Catalog Volume
data['path'] = f"{project_location}data/coco128"

# Check if dataset already exists
dataset_images_dir = f"{project_location}data/coco128/images/train2017"
if os.path.exists(dataset_images_dir) and len(os.listdir(dataset_images_dir)) > 0:
    print(f"\n[INFO] Dataset already exists at: {dataset_images_dir}")
    print(f"   Found {len(os.listdir(dataset_images_dir))} images")
    print(f"   Skipping download")
else:
    # Download and extract dataset
    extraction_path = f"{project_location}data"
    download_and_extract_dataset(data['download'], extraction_path)

# Save updated configuration to UC Volume
data_yaml_path = f"{project_location}data/coco128.yaml"
with open(data_yaml_path, 'w') as f:
    yaml.dump(data, f, default_flow_style=False)

print(f"\n[OK] Dataset configuration saved to UC Volume: {data_yaml_path}")
print(f"   All dataset files in: {project_location}data/coco128/")

## Dataset Splits

Split COCO128 into train (62.5%), val (18.75%), and test (18.75%) sets with reproducible random seed.

In [0]:
# Split dataset into train/val/test with reproducible random seed
source_images_dir = f"{project_location}data/coco128/images/train2017"
source_labels_dir = f"{project_location}data/coco128/labels/train2017"
base_images_dir = f"{project_location}data/coco128/images"
base_labels_dir = f"{project_location}data/coco128/labels"

train_size, val_size, test_size = split_dataset(
    source_images_dir=source_images_dir,
    source_labels_dir=source_labels_dir,
    base_images_dir=base_images_dir,
    base_labels_dir=base_labels_dir,
    train_ratio=0.625,  # 62.5%
    val_ratio=0.1875,   # 18.75%
    random_seed=42
)

print(f"\nSplit summary:")
print(f"  - Train: {train_size} images (62.5%)")
print(f"  - Val: {val_size} images (18.75%)")
print(f"  - Test: {test_size} images (18.75%)")
print(f"  - Random seed: 42")

In [0]:
# Update data.yaml to use train/val/test splits
with open(data_yaml_path, 'r') as f:
    yaml_content = yaml.safe_load(f)

# Update paths
yaml_content['train'] = f"{project_location}data/coco128/images/train"
yaml_content['val'] = f"{project_location}data/coco128/images/val"
yaml_content['test'] = f"{project_location}data/coco128/images/test"

# Save updated configuration
with open(data_yaml_path, 'w') as f:
    yaml.dump(yaml_content, f, default_flow_style=False)

print(f"[OK] data.yaml updated with train/val/test splits")
print(f"   Train: {yaml_content['train']}")
print(f"   Val: {yaml_content['val']}")
print(f"   Test: {yaml_content['test']}")

## MLflow Configuration

Infer model signature and configure experiment tracking with system metrics.

In [0]:
# Infer model signature from sample prediction
# This defines the input/output schema for the serving endpoint
# Input: base64-encoded images
# Output: class, confidence, bounding boxes (11 columns)

model_path = f"{project_location}raw_model/yolo11n.pt"

# Find a sample image from training set
sample_images = glob.glob(f"{project_location}data/coco128/images/train/*.jpg")

if sample_images:
    signature, input_example = infer_model_signature(model_path, sample_images[0])
    print(f"\n[OK] Signature and input example ready for model registration")
else:
    print("[WARNING] No sample images found. Run dataset preparation first.")
    signature = None
    input_example = None

In [0]:
# signature, input_example

In [0]:
# Configure YOLO to use MLflow
from ultralytics import settings
settings.update({"mlflow": True})

# Enable MLflow autologging for system metrics
mlflow.autolog(disable=False)

print(f"\n[MLflow Configuration]")
print(f"   YOLO MLflow integration: Enabled")
print(f"   MLflow autologging: Enabled")
print(f"   System metrics: Enabled")

In [0]:
# Setup MLflow experiment with system metrics
import os
import mlflow

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()

# Extract username from notebook path to create experiment in workspace
# This avoids the error when notebook is in a Git folder
path_parts = notebook_path.split('/')
if 'Users' in path_parts:
    user_idx = path_parts.index('Users')
    username = path_parts[user_idx + 1]
    # Create experiment in user's workspace, not in Git folder
    experiment_name = f"/Users/{username}/Experiments_YOLO_CoCo"
else:
    # Fallback: use a generic workspace location
    experiment_name = "/Shared/Experiments_YOLO_CoCo"

# Enable system metrics
os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"
os.environ['MLFLOW_EXPERIMENT_NAME'] = experiment_name

# Set or create experiment
mlflow.set_experiment(experiment_name)
experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id

# Clear any existing run ID
if 'MLFLOW_RUN_ID' in os.environ:
    del os.environ['MLFLOW_RUN_ID']

print(f"[OK] MLflow experiment initialized: {experiment_name}")
print(f"   Experiment ID: {experiment_id}")
print(f"   System metrics: ENABLED")
print(f"   Note: Experiment created in workspace (not in Git folder)")

print(f"\n[Ready for Training]")
print(f"   Experiment: {experiment_name}")
print(f"   Experiment ID: {experiment_id}")

## Model Training

Train YOLO11n with MLflow tracking and register to Unity Catalog.

### Distributed Multi-Node GPU Training with Serverless GPU

This notebook is configured for **distributed training across multiple A10 GPU nodes** using Databricks Serverless GPU.

#### Key Features:

* **Serverless GPU API**: Uses `@distributed` decorator for automatic multi-node orchestration    
* **Multi-node training**: Automatically scales across multiple A10 GPU nodes        
* **Data parallelism**: Each GPU processes different batches in parallel      
* **Gradient synchronization**: Efficient all-reduce operations across nodes     
* **No cluster management**: Serverless handles provisioning and scaling     
* **Cost optimization**: Automatic scale-to-zero when not training     

#### Key Differences from Single-GPU Training:

**Traditional Approach (Single GPU):**
```python
model = YOLO('yolo11n.pt')
model.train(device=0, batch=16, epochs=100)
```

**Distributed Approach (8 GPUs):**
```python
@distributed(gpus=8, gpu_type='A10', remote=True)
def run_distributed_training():
    model = YOLO('yolo11n.pt')
    model.train(device=local_rank, batch=16, epochs=100)  # 16 per GPU = 128 total

run_distributed_training.distributed()
```

#### How It Works:

1. Wrap training function with `@distributed(gpus=8, gpu_type='A10', remote=True)`
2. Databricks provisions multiple A10 GPU nodes automatically (e.g., 8 single-GPU nodes or 2 quad-GPU nodes)
3. PyTorch Distributed Data Parallel (DDP) is set up automatically
4. Training data is distributed across all GPUs
5. Each GPU computes gradients independently
6. Gradients are synchronized using all-reduce (NCCL)
7. Model weights are updated consistently across all nodes
8. Resources are cleaned up automatically when done

#### Queuing and Resource Provisioning:

Before training begins, Serverless GPU goes through a provisioning phase:

1. **Job Submission**: When you call `.distributed()`, the job is submitted to the Serverless GPU queue
2. **Resource Allocation**: Databricks allocates 8 A10 GPUs from the serverless pool
   * May provision 8 single-GPU nodes, or 2 quad-GPU nodes, or other combinations
   * Resources are dynamically allocated based on availability
3. **Queue Wait Time**: Job waits in queue if resources are not immediately available
   * Typical wait: seconds to a few minutes depending on cluster load
   * You'll see "Launching distributed training..." status during this phase
4. **Environment Setup**: Once resources are allocated:
   * Docker containers are initialized on each node
   * Python environment and dependencies are loaded
   * Network connections between nodes are established
5. **Training Starts**: After setup completes, your training function begins executing

**Note**: The entire provisioning process is automatic and transparent. You don't need to manage clusters, instances, or networking.

#### What Happens During Training:

1. **Process Group**: PyTorch DDP process group is initialized automatically
2. **Data Distribution**: Training data is split across all 8 GPUs
3. **Forward Pass**: Each GPU processes its batch independently (16 images per GPU)
4. **Backward Pass**: Gradients are computed locally on each GPU
5. **All-Reduce**: Gradients are synchronized across all GPUs using NCCL
6. **Weight Update**: Model weights are updated consistently on all GPUs
7. **Repeat**: Process continues for all epochs

#### Performance Benefits:

* **\~8x faster training** with 8 GPUs (near-linear scaling)
* **8x larger effective batch size** (16 per GPU × 8 = 128 total)
* **Better convergence** with larger batch sizes
* **No cluster configuration** needed
* **Automatic resource provisioning** and cleanup

#### Customization:

Change the number of GPUs by modifying the decorator in the training cell:
```python
@distributed(gpus=4, gpu_type='A10', remote=True)   # 4 GPUs
@distributed(gpus=8, gpu_type='A10', remote=True)   # 8 GPUs (default)
@distributed(gpus=16, gpu_type='A10', remote=True)  # 16 GPUs
```

**No other configuration needed!** Just run the training cell and Serverless GPU handles everything.

#### Important Notes:

* **Data loading must be inside the decorated function** to avoid pickle errors
* **Only rank 0 saves artifacts** to avoid conflicts
* **MLflow tracking works automatically** from rank 0
* **YOLO detects DDP environment** via RANK, WORLD_SIZE, LOCAL_RANK env vars

In [0]:
from datetime import datetime
import uuid
import shutil
from serverless_gpu import distributed

# Close any active MLflow runs
mlflow.end_run()

# Create unique timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Model configuration
model_task = "detection"     # Task type
model_arch = "yolo11n"       # Model architecture
dataset_name = "coco128"     # Dataset name

# Distributed training configuration
num_gpus = 8                 # Number of A10 GPUs to use
gpu_type = 'A10'             # GPU type
batch_size_per_gpu = 16      # Batch size per GPU

print("=" * 60)
print("DISTRIBUTED TRAINING CONFIGURATION")
print("=" * 60)
print(f"\n[Serverless GPU Configuration]")
print(f"   GPUs requested: {num_gpus} {gpu_type} GPUs")
print(f"   Training mode: Multi-node distributed (remote=True)")
print(f"   Orchestration: Databricks Serverless GPU")
print(f"   Backend: PyTorch DDP (automatic)")

print(f"\n[Training Configuration]")
print(f"   Task: {model_task}")
print(f"   Model: {model_arch}")
print(f"   Dataset: {dataset_name}")
print(f"   Epochs: 100")
print(f"   Batch size per GPU: {batch_size_per_gpu}")
print(f"   Effective batch size: {batch_size_per_gpu * num_gpus} ({batch_size_per_gpu} × {num_gpus} GPUs)")
print()

# Get notebook context for MLflow source tracking
notebook_context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
notebook_path = notebook_context.notebookPath().get()
notebook_id = notebook_context.notebookId().get()

# Define distributed training function
@distributed(gpus=num_gpus, gpu_type=gpu_type, remote=True)
def run_distributed_training():
    """
    Distributed training function that runs across multiple A10 GPU nodes.
    
    The @distributed decorator:
    - Provisions multiple A10 GPUs across multiple nodes
    - Sets up environment variables (RANK, WORLD_SIZE, LOCAL_RANK)
    - We manually initialize PyTorch DDP process group
    - Synchronizes gradients across all GPUs
    """
    import os
    import shutil
    import torch
    import torch.distributed as dist
    import mlflow
    from ultralytics import YOLO
    from ultralytics.utils import RANK, LOCAL_RANK
    
    # Configure MLflow to use Unity Catalog
    mlflow.set_registry_uri('databricks-uc')
    
    # Get distributed training info
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    rank = int(os.environ.get('RANK', 0))
    
    # Initialize PyTorch distributed process group (required for YOLO DDP)
    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)
    
    try:
        # Set MLflow environment variables
        os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"
        os.environ['MLFLOW_EXPERIMENT_NAME'] = experiment_name
        
        # Only print from rank 0 to avoid duplicate logs
        if rank == 0:
            print(f"\n[Distributed Training Started]")
            print(f"   World size: {world_size} GPUs")
            print(f"   Effective batch size: {batch_size_per_gpu * world_size}")
            print(f"   RANK: {RANK}, LOCAL_RANK: {LOCAL_RANK}")
            print()
        
        # Create unique temp directory for this training run
        tmp_project_location_unique = f"/tmp/training_results_{uuid.uuid4().hex[:8]}/"
        os.makedirs(tmp_project_location_unique, exist_ok=True)
        
        # Model path (from Unity Catalog Volume)
        model_path = f"{project_location}raw_model/yolo11n.pt"
        
        # Create descriptive run name (without run_id since it doesn't exist yet)
        run_name = f"{model_task}_{model_arch}_{dataset_name}_{timestamp}"
        
        # Initialize model
        if rank == 0:
            print(f"Loading YOLO model on rank 0...")
            print(f"Data config: {data_yaml_path}")
            print(f"MLflow run name: {run_name}\n")
        
        model = YOLO(model_path)
        
        if rank == 0:
            print("Starting distributed training...\n")
            print("[INFO] YOLO will automatically:")
            print("   1. Detect DDP environment (RANK, WORLD_SIZE, LOCAL_RANK)")
            print("   2. Use existing PyTorch DDP process group")
            print("   3. Distribute data across all GPUs")
            print("   4. Synchronize gradients using all-reduce")
            print("   5. Update model weights consistently\n")
        
        # Train with DDP inside MLflow run context
        with mlflow.start_run(run_name=run_name) as run:
            # Log source notebook information (rank 0 only)
            if rank == 0:
                mlflow.set_tag("mlflow.source.name", notebook_path)
                mlflow.set_tag("mlflow.source.type", "NOTEBOOK")
                mlflow.set_tag("mlflow.databricks.notebookPath", notebook_path)
                mlflow.set_tag("mlflow.databricks.notebookID", notebook_id)
            
            results = model.train(
                task="detect",
                batch=batch_size_per_gpu,    # Batch size per GPU
                device=[LOCAL_RANK],         # Use LOCAL_RANK as list (required for DDP)
                data=data_yaml_path,
                epochs=100,
                project=tmp_project_location_unique,
                name=run_name,               # Folder name for training outputs
                exist_ok=True,
                fliplr=1,
                flipud=1,
                perspective=0.001,
                degrees=0.45,
                amp=True,                    # Automatic Mixed Precision
                patience=50,
                dropout=0.2,
                weight_decay=0.0005,
                save=True,
                save_period=10,
                workers=8,                   # Data loading workers per GPU
                close_mosaic=10
            )
            
            # Validate on rank 0
            if RANK in (0, -1):
                success = model.val(
                    project=tmp_project_location_unique,
                    name="validation_metrics"
                )
            
            # Get run_id while still inside the context (before it closes)
            active_run_id = run.info.run_id
        
        # Only rank 0 handles model registration and artifact copying
        if rank == 0:
            print(f"\n[OK] Distributed training complete! MLflow Run ID: {active_run_id}")
            
            # Create full organized name with run_id
            organized_run_name = f"{model_task}_{model_arch}_{dataset_name}_{timestamp}_run_{active_run_id}"
            
            # Add organized name as MLflow tag for easy reference
            with mlflow.start_run(run_id=active_run_id):
                mlflow.set_tag("organized_run_name", organized_run_name)
                mlflow.set_tag("training_mode", "distributed")
                mlflow.set_tag("num_gpus", world_size)
                mlflow.set_tag("gpu_type", gpu_type)
                mlflow.log_artifact(data_yaml_path, "input_data_yaml")
            
            print(f"\n[Distributed Training Summary]")
            print(f"   GPUs used: {world_size}")
            print(f"   Effective batch size: {batch_size_per_gpu * world_size}")
            print(f"   Training mode: DDP (Distributed Data Parallel)")
            print(f"   MLflow run name: {run_name}")
            print(f"   Organized run name: {organized_run_name}")
            
            # Copy training results to Unity Catalog Volume
            print(f"\n[INFO] Copying training results to Unity Catalog Volume...")
            training_run_dir = os.path.join(tmp_project_location_unique, run_name)
            
            # Use organized name for volume folder
            volume_run_dir = os.path.join(project_location, "runs", organized_run_name)
            volume_train_dir = os.path.join(volume_run_dir, "train")
            
            if os.path.exists(training_run_dir):
                shutil.copytree(training_run_dir, volume_train_dir, dirs_exist_ok=True)
                print(f"   [OK] Training outputs copied to: {volume_train_dir}")
            
            # Copy validation metrics
            val_metrics_dir = os.path.join(tmp_project_location_unique, "validation_metrics")
            if os.path.exists(val_metrics_dir):
                volume_val_metrics_dir = os.path.join(volume_run_dir, "validation_metrics")
                shutil.copytree(val_metrics_dir, volume_val_metrics_dir, dirs_exist_ok=True)
                print(f"   [OK] Validation metrics copied to: {volume_val_metrics_dir}")
            
            # Save best model
            print("\n[INFO] Saving best model...")
            best_model = YOLO(str(model.trainer.best))
            best_model_path = f"/tmp/best_yolo_model_{timestamp}.pt"
            best_model.save(best_model_path)
            print(f"   Saved to: {best_model_path}")
            
            # Register model to Unity Catalog using model_name widget
            registered_model_name = register_yolo_model(
                run_id=active_run_id,
                model_path=best_model_path,
                catalog_name=catalog_name,
                schema_name=schema_name,
                model_name=model_name,  # Use widget parameter
                signature=signature,
                input_example=input_example,
                data_yaml_path=data_yaml_path
            )
            
            print(f"\n" + "=" * 60)
            print("[OK] DISTRIBUTED TRAINING COMPLETE")
            print("=" * 60)
            print(f"\n[Model Details]")
            print(f"   - Name: {registered_model_name}")
            print(f"   - Run ID: {active_run_id}")
            print(f"   - Location: Unity Catalog Model Registry")
            print(f"   - Format: Custom YOLO wrapper (base64 input, bbox output)")
            print(f"   - Training: Distributed across {world_size} {gpu_type} GPUs")
            print(f"\n[Training Artifacts]")
            print(f"   - Volume location: {volume_run_dir}")
            print(f"   - Run name: {organized_run_name}")
            print(f"   - Structure: train/, validation_metrics/")
            print(f"\n[View Results]")
            print(f"   {mlflow.get_tracking_uri()}/#/experiments/{experiment_id}/runs/{active_run_id}")
            
            # Return values for notebook access
            return active_run_id, registered_model_name, organized_run_name, volume_run_dir
        
        return None, None, None, None
    
    finally:
        # Clean up process group
        if dist.is_initialized():
            dist.destroy_process_group()


In [0]:
# Launch distributed training
print("[INFO] Launching distributed training on Serverless GPU...")
print(f"   Databricks will provision {num_gpus} {gpu_type} GPUs automatically")
print("   This may take a few minutes to start...\n")

In [0]:
result = run_distributed_training.distributed()

# Extract results from rank 0 (first element of the list)
if result and result[0]:
    run_id, registered_model_name, organized_run_name, volume_run_dir = result[0]
    print(f"\n[OK] Training artifacts available in notebook scope")
    print(f"   Variables: run_id, registered_model_name, organized_run_name, volume_run_dir")
else:
    print(f"\n[INFO] Training complete. Check MLflow for results.")

In [0]:
# result

## Model Evaluations

**Split Evaluation (Native YOLO):** Assess model accuracy using file paths from UC Volume. Validates model quality on validation/test sets.

**Local Serving Test (MLflow PyFunc):** Validate production serving format using base64-encoded images. Ensures endpoint compatibility before deployment.

### Split Evaluation

Evaluate model performance on validation and test sets before deployment.

In [0]:
# Load model from MLflow and evaluate on validation set
model_uri = f"runs:/{run_id}/model"
model_path = mlflow.artifacts.download_artifacts(model_uri)

# Find the .pt file
import glob as glob_module
pt_files = glob_module.glob(f"{model_path}/**/*.pt", recursive=True)
if pt_files:
    loaded_model = YOLO(pt_files[0], task='detect')
    print(f"[OK] Model loaded from MLflow\n")
    
    # Evaluate on validation set
    val_image_dir = f"{project_location}data/coco128/images/val"
    val_output_dir = os.path.join(volume_run_dir, "validation_samples")
    
    evaluate_model_on_split(
        model=loaded_model,
        image_dir=val_image_dir,
        split_name="validation",
        output_dir=val_output_dir,
        run_id=run_id,
        registered_model_name=registered_model_name,
        organized_run_name=organized_run_name,
        num_samples=3
    )
else:
    print("[ERROR] Model file not found")

In [0]:
# Evaluate model on test set (uses loaded_model from validation cell)
test_image_dir = f"{project_location}data/coco128/images/test"
test_output_dir = os.path.join(volume_run_dir, "test_samples")

evaluate_model_on_split(
    model=loaded_model,
    image_dir=test_image_dir,
    split_name="test",
    output_dir=test_output_dir,
    run_id=run_id,
    registered_model_name=registered_model_name,
    organized_run_name=organized_run_name,
    num_samples=3
)


### (Registered Model) Local Serving Test 

In [0]:
import base64
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import io
import numpy as np

# Test registered model locally with base64 input (serving format)
print("=" * 60)
print("LOCAL MODEL TEST - BASE64 INPUT")
print("=" * 60)

client = MlflowClient()

# Ensure we have the registered model name
if 'registered_model_name' not in dir():
    registered_model_name = f"{catalog_name}.{schema_name}.yolo11n_coco128_sgc_distributed"

print(f"\nTesting model: {registered_model_name}\n")

try:
    # Get latest model version
    model_versions = client.search_model_versions(f"name='{registered_model_name}'")
    
    if model_versions:
        latest_version = model_versions[0].version
        print(f"[OK] Found model version: {latest_version}")
        print(f"   Status: {model_versions[0].status}")
        
        # Load model using pyfunc (this is what serving endpoint uses)
        model_uri = f"models:/{registered_model_name}/{latest_version}"
        serving_model = mlflow.pyfunc.load_model(model_uri)
        print(f"[OK] MLflow pyfunc model loaded successfully\n")
        
        # Get test images (skip first 3 used in test_samples evaluation)
        test_images = glob.glob(f"{project_location}data/coco128/images/test/*.jpg")
        if test_images:
            # Use images 5-7 (different from test_samples which uses 1-3)
            num_samples = min(3, len(test_images) - 3)
            sample_images = test_images[10:10+num_samples]  # Skip first 3 test images 
            # sample_images = test_images[6:6+num_samples]  # 
            
            print(f"Testing with {num_samples} sample images (different from test_samples)\n")
            
            # Create color map for different classes
            colors = plt.cm.tab20(np.linspace(0, 1, 20))  # 20 distinct colors
            
            # Create visualization
            fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
            if num_samples == 1:
                axes = [axes]
            
            for i, test_image_path in enumerate(sample_images):
                print(f"Sample {i+1}/{num_samples}: {test_image_path.split('/')[-1]}")
                
                # Encode image as base64
                with open(test_image_path, 'rb') as f:
                    image_bytes = f.read()
                image_base64 = base64.b64encode(image_bytes).decode('utf-8')
                
                # Test pyfunc wrapper with base64 input
                input_df = pd.DataFrame({"image_base64": [image_base64]})
                predictions = serving_model.predict(input_df)
                
                # Load and display image
                image = Image.open(test_image_path)
                axes[i].imshow(image)
                axes[i].axis('off')
                
                # Draw bounding boxes from pyfunc predictions
                if len(predictions) > 0:
                    num_detections = len(predictions)
                    
                    # Draw each bounding box with class-specific color
                    for idx, row in predictions.iterrows():
                        # Use xyxy coordinates
                        x1, y1, x2, y2 = row['bbox_x1'], row['bbox_y1'], row['bbox_x2'], row['bbox_y2']
                        width = x2 - x1
                        height = y2 - y1
                        
                        # Get color based on class number
                        color = colors[int(row['class_num']) % len(colors)]
                        
                        # Draw rectangle
                        rect = patches.Rectangle(
                            (x1, y1), width, height,
                            linewidth=2, edgecolor=color, facecolor='none'
                        )
                        axes[i].add_patch(rect)
                        
                        # Add label with matching color
                        label = f"{row['class_name']} {row['confidence']:.2f}"
                        axes[i].text(
                            x1, y1 - 5, label,
                            color='white', fontsize=8,
                            bbox=dict(facecolor=color, alpha=0.8, pad=2)
                        )
                    
                    axes[i].set_title(f"{test_image_path.split('/')[-1]}\n{num_detections} objects", fontsize=10)
                    print(f"   [OK] Detections: {num_detections} objects")
                    
                    for idx, row in predictions.head(3).iterrows():
                        print(f"      - {row['class_name']}: {row['confidence']:.3f}")
                else:
                    axes[i].set_title(f"{test_image_path.split('/')[-1]}\nNo objects", fontsize=10)
                    print(f"   [OK] No objects detected")
                print()
            
            plt.tight_layout()
            plt.suptitle(f"Local Serving Test - MLflow PyFunc with Base64 (v{latest_version})", fontsize=14, y=1.02)
            plt.show()
        
        print("=" * 60)
        print("[OK] MODEL READY FOR DEPLOYMENT")
        print("=" * 60)
        print(f"\n[Test Summary]")
        print(f"   - Model: {registered_model_name} (v{latest_version})")
        print(f"   - Input format: Base64-encoded images ✓")
        print(f"   - MLflow pyfunc wrapper: ✓")
        print(f"   - Bounding boxes: ✓ (color-coded by class)")
        print(f"   - Test images: Different from test_samples evaluation")
        print(f"   - Status: Validated and ready")
        print(f"\n[Key Difference from Split Tests]")
        print(f"   - Split tests: Native YOLO + file paths")
        print(f"   - This test: MLflow pyfunc wrapper + base64")
        print(f"   - This test validates actual serving endpoint format")
        print(f"\n   Next: Deploy to serving endpoint")
    else:
        print(f"[ERROR] No versions found for: {registered_model_name}")
        print(f"\nPlease register the model first.")
except Exception as e:
    print(f"[ERROR] {e}")
    import traceback
    traceback.print_exc()


## UC Registered Model Deployment CHECKPOINT

In [0]:
# ============================================================
# DEPLOYMENT CHECKPOINT
# ============================================================
# This cell acts as a safety gate before deployment cells.
# Set the 'Proceed with Deployment' widget to 'true' to continue.
# ============================================================

# Get deployment approval from widget (set in cell 10)
PROCEED_WITH_DEPLOYMENT = dbutils.widgets.get("proceed_with_deployment") == "true"

if not PROCEED_WITH_DEPLOYMENT:
    message = """
============================================================
⚠️  DEPLOYMENT PAUSED - MANUAL CONFIRMATION REQUIRED
============================================================

This checkpoint prevents accidental execution of deployment cells.

[To Proceed]
   1. Review the model validation results above
   2. Verify the model is ready for deployment
   3. Set 'Proceed with Deployment' widget to 'true' (top of notebook)
   4. Re-run this cell

[What Happens Next]
   - Cell 40: Create/update serving endpoint (AI Gateway enabled automatically)
   - Cell 42: Test deployed endpoint

[Safety Note]
   This checkpoint ensures you don't accidentally deploy
   an unvalidated model or overwrite a production endpoint.

[For 'Run All']
   Deployment cells will skip execution if not approved.
   No errors will be raised.

============================================================
⏸️  DEPLOYMENT PAUSED - AWAITING APPROVAL
============================================================
"""
    dbutils.notebook.exit(message)
else:
    message = """
============================================================
✓ DEPLOYMENT CHECKPOINT PASSED
============================================================

[Confirmation]
   User has manually approved deployment
   Execution will stop here for manual control

[Next Steps - Run Manually]
   1. Run cell 40: Create/update serving endpoint (AI Gateway enabled automatically)
   2. Wait for endpoint to be ready (10-20 minutes)
   3. Run cell 42: Test deployed endpoint

[Why Manual Execution?]
   - Endpoint provisioning takes 10-20 minutes
   - You can monitor progress in the UI
   - Each step requires verification before proceeding
   - Prevents accidental 'Run All' through deployment

============================================================
⏸️  STOPPING HERE - RUN DEPLOYMENT CELLS MANUALLY
============================================================
"""
    dbutils.notebook.exit(message)

## Model Deployment

Deploy model to Databricks Model Serving endpoint with AI Gateway and inference table logging.

In [0]:
# Check deployment approval
if 'PROCEED_WITH_DEPLOYMENT' not in dir() or not PROCEED_WITH_DEPLOYMENT:
    print("\n" + "="*60)
    print("⚠️  DEPLOYMENT SKIPPED - NOT APPROVED")
    print("="*60)
    print("\n[Reason]")
    print("   PROCEED_WITH_DEPLOYMENT flag is not set to True")
    print("\n[To Enable Deployment]")
    print("   1. Set 'Proceed with Deployment' widget to 'true' (top of notebook)")
    print("   2. Re-run deployment checkpoint cell and this cell")
    print("\n" + "="*60)
else:
    # Deployment approved - proceed with endpoint creation
    from databricks.sdk import WorkspaceClient
    from databricks.sdk.service.serving import (
        ServedEntityInput, 
        EndpointCoreConfigInput,
        AiGatewayConfig,
        AiGatewayInferenceTableConfig
    )
    from mlflow.tracking import MlflowClient
    import time

    w = WorkspaceClient()
    client = MlflowClient()

    # Get latest model version
    if 'registered_model_name' not in dir():
        registered_model_name = f"{catalog_name}.{schema_name}.{model_name}"

    model_versions = client.search_model_versions(f"name='{registered_model_name}'")
    if not model_versions:
        raise ValueError(f"No model versions found for {registered_model_name}. Run training cell to register the model.")

    model_version = model_versions[0].version

    # Derive endpoint name from model name
    model_name_only = registered_model_name.split('.')[-1]
    endpoint_name = f"{model_name_only}_endpoint"

    print("=" * 60)
    print("CREATING MODEL SERVING ENDPOINT")
    print("=" * 60)

    print(f"\nEndpoint configuration:")
    print(f"   Name: {endpoint_name}")
    print(f"   Model: {registered_model_name}")
    print(f"   Version: {model_version}")
    print(f"   Workload size: Small")
    print(f"   Scale to zero: Enabled")
    print(f"   AI Gateway: Enabled with inference tables")
    print(f"   Inference table: {catalog_name}.{schema_name}.{endpoint_name}_payload")
    print()

    try:
        # Check if endpoint already exists
        endpoint_exists = False
        needs_update = True
        needs_ai_gateway_update = False
        
        try:
            existing_endpoint = w.serving_endpoints.get(endpoint_name)
            endpoint_exists = True
            print(f"[INFO] Endpoint '{endpoint_name}' already exists")
            
            # Check if endpoint is currently being updated
            if existing_endpoint.state.config_update.value != "NOT_UPDATING":
                print(f"[INFO] Endpoint is currently being updated (status: {existing_endpoint.state.config_update.value})")
                print(f"   Checking status briefly (will timeout after 2 minutes)...\n")
                
                # Brief check for current update status
                max_wait_time = 120  # Only wait 2 minutes
                poll_interval = 10
                elapsed_time = 0
                
                while elapsed_time < max_wait_time:
                    endpoint = w.serving_endpoints.get(endpoint_name)
                    
                    if endpoint.state.config_update.value == "NOT_UPDATING":
                        print(f"\n[OK] Current update completed (took {elapsed_time}s)")
                        existing_endpoint = endpoint
                        break
                    elif endpoint.state.config_update.value == "UPDATE_FAILED":
                        print(f"\n[WARNING] Current update failed")
                        existing_endpoint = endpoint
                        break
                    else:
                        if elapsed_time % 30 == 0:
                            print(f"   Status: {existing_endpoint.state.config_update.value} ({elapsed_time}s elapsed)")
                        time.sleep(poll_interval)
                        elapsed_time += poll_interval
                
                if elapsed_time >= max_wait_time:
                    print(f"\n[INFO] Endpoint update still in progress after {max_wait_time}s")
                    print(f"   This cell will complete now to avoid blocking")
                    print(f"\n[NEXT STEP]")
                    print(f"   1. Wait a few minutes for the update to complete")
                    print(f"   2. Re-run this cell to check status")
                    print(f"   3. Once ready, proceed to testing")
                    
                    # Get final status and exit
                    endpoint = w.serving_endpoints.get(endpoint_name)
                    workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
                    endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
                    
                    print(f"\n[View Endpoint]")
                    print(f"   {endpoint_url}")
                    print(f"\n[Current Status]")
                    print(f"   - Config State: {endpoint.state.config_update.value}")
                    print(f"   - Ready State: {endpoint.state.ready.value}")
                    
                    # Exit early - don't proceed to update logic
                    raise SystemExit("Endpoint update in progress - cell completed to avoid blocking")
            
            # Check if it's already serving the same model version
            current_config = existing_endpoint.config
            if current_config and current_config.served_entities:
                current_entity = current_config.served_entities[0]
                current_model = current_entity.entity_name
                current_version = current_entity.entity_version
                
                if current_model == registered_model_name and current_version == str(model_version):
                    print(f"   Already serving {registered_model_name} version {model_version}")
                    print(f"   No model update needed")
                    needs_update = False
                    
                    # Check if AI Gateway inference tables are enabled
                    ai_gateway = existing_endpoint.ai_gateway
                    if ai_gateway and ai_gateway.inference_table_config and ai_gateway.inference_table_config.enabled:
                        print(f"   AI Gateway inference tables already enabled")
                        print(f"   No AI Gateway update needed\n")
                    else:
                        print(f"   AI Gateway inference tables not enabled")
                        print(f"   Will enable AI Gateway\n")
                        needs_ai_gateway_update = True
                else:
                    print(f"   Currently serving: {current_model} v{current_version}")
                    print(f"   Updating to: {registered_model_name} v{model_version}")
                    print(f"   Note: AI Gateway must be configured separately for updates\n")
                    needs_ai_gateway_update = True
        except Exception as e:
            if "does not exist" in str(e).lower() or "RESOURCE_DOES_NOT_EXIST" in str(e):
                print(f"[INFO] Endpoint '{endpoint_name}' does not exist")
                print(f"   Creating new endpoint with AI Gateway enabled...\n")
            else:
                raise
        
        # Create/update endpoint if needed
        if needs_update:
            if endpoint_exists:
                # Update existing endpoint using SDK method
                # Note: update_config() doesn't support ai_gateway parameter
                w.serving_endpoints.update_config(
                    name=endpoint_name,
                    served_entities=[
                        ServedEntityInput(
                            entity_name=registered_model_name,
                            entity_version=str(model_version),
                            workload_size="Small",
                            scale_to_zero_enabled=True
                        )
                    ]
                )
                print(f"[OK] Endpoint update submitted")
            else:
                # Create new endpoint with AI Gateway enabled using SDK method
                w.serving_endpoints.create(
                    name=endpoint_name,
                    config=EndpointCoreConfigInput(
                        served_entities=[
                            ServedEntityInput(
                                entity_name=registered_model_name,
                                entity_version=str(model_version),
                                workload_size="Small",
                                scale_to_zero_enabled=True
                            )
                        ]
                    ),
                    ai_gateway=AiGatewayConfig(
                        inference_table_config=AiGatewayInferenceTableConfig(
                            catalog_name=catalog_name,
                            schema_name=schema_name,
                            table_name_prefix=endpoint_name,
                            enabled=True
                        )
                    )
                )
                print(f"[OK] Endpoint creation submitted (with AI Gateway enabled)")
            
            # Brief initial wait with shorter timeout to avoid stuck state
            print(f"\n[INFO] Checking initial status (endpoint provisioning may take 10-20+ minutes)...")
            max_wait_time = 120  # Only wait 2 minutes here
            poll_interval = 10   # Check every 10 seconds
            elapsed_time = 0
            
            while elapsed_time < max_wait_time:
                endpoint = w.serving_endpoints.get(endpoint_name)
                
                if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
                    print(f"\n[OK] Endpoint is ready! (took {elapsed_time}s)")
                    break
                elif endpoint.state.config_update.value == "UPDATE_FAILED":
                    print(f"\n[ERROR] Endpoint update failed!")
                    print(f"   Check the endpoint UI for error details")
                    break
                else:
                    if elapsed_time % 30 == 0:  # Print status every 30 seconds
                        print(f"   Status: {endpoint.state.config_update.value} ({elapsed_time}s elapsed)")
                    time.sleep(poll_interval)
                    elapsed_time += poll_interval
            
            if elapsed_time >= max_wait_time:
                print(f"\n[INFO] Endpoint is still initializing (this may take several more minutes)")
                print(f"   This cell will complete now to avoid blocking")
                print(f"\n[NEXT STEP]")
                print(f"   1. Wait for endpoint to finish provisioning (check UI)")
                print(f"   2. Re-run this cell to verify status")
                print(f"   3. Once ready, proceed to testing")
        
        # Enable AI Gateway if needed (for existing endpoints that were updated)
        if needs_ai_gateway_update and endpoint_exists:
            print(f"\n[INFO] Enabling AI Gateway inference tables...")
            
            # First verify endpoint is ready
            endpoint = w.serving_endpoints.get(endpoint_name)
            if endpoint.state.ready.value != "READY":
                print(f"[WARNING] Endpoint not ready yet (status: {endpoint.state.ready.value})")
                print(f"   AI Gateway will be configured in the next cell once endpoint is ready")
            else:
                # Enable AI Gateway (table will be created automatically by AI Gateway)
                w.serving_endpoints.put_ai_gateway(
                    name=endpoint_name,
                    inference_table_config=AiGatewayInferenceTableConfig(
                        catalog_name=catalog_name,
                        schema_name=schema_name,
                        table_name_prefix=endpoint_name,
                        enabled=True
                    )
                )
                print(f"[OK] AI Gateway configuration submitted")
                
                # Brief wait for configuration
                time.sleep(5)
                max_wait = 60
                elapsed = 0
                while elapsed < max_wait:
                    endpoint = w.serving_endpoints.get(endpoint_name)
                    if endpoint.state.config_update.value == "NOT_UPDATING":
                        break
                    time.sleep(5)
                    elapsed += 5
        
        # Get final status
        endpoint = w.serving_endpoints.get(endpoint_name)
        
        # Get workspace URL for link
        workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
        endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
        
        print("\n" + "=" * 60)
        if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
            print("[OK] SERVING ENDPOINT READY")
        else:
            print("[INFO] SERVING ENDPOINT INITIALIZING")
        print("=" * 60)
        print(f"\n[Endpoint Details]")
        print(f"   - Name: {endpoint_name}")
        print(f"   - Model: {registered_model_name} (v{model_version})")
        print(f"   - Config State: {endpoint.state.config_update.value}")
        print(f"   - Ready State: {endpoint.state.ready.value}")
        
        # Check AI Gateway status
        if endpoint.ai_gateway and endpoint.ai_gateway.inference_table_config:
            ai_config = endpoint.ai_gateway.inference_table_config
            if ai_config.enabled:
                print(f"   - AI Gateway: Enabled")
                print(f"   - Inference Table: {ai_config.catalog_name}.{ai_config.schema_name}.{ai_config.table_name_prefix}_payload")
            else:
                print(f"   - AI Gateway: Disabled")
        else:
            print(f"   - AI Gateway: Not configured")
        
        print(f"\n[View Endpoint]")
        print(f"   {endpoint_url}")
        
        if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
            if endpoint.ai_gateway and endpoint.ai_gateway.inference_table_config and endpoint.ai_gateway.inference_table_config.enabled:
                print(f"\n[Next Step]")
                print(f"   Skip to 'Test Deployed Endpoint' cell")
            else:
                print(f"\n[Next Step]")
                print(f"   Run 'Enable AI Gateway Inference Tables' cell")
        else:
            print(f"\n[Next Step]")
            print(f"   Wait for endpoint to be ready, then re-run this cell")
        
    except SystemExit as e:
        # Clean exit when endpoint is still updating
        print(f"\n[INFO] Cell completed (endpoint update in progress)")
    except Exception as e:
        print(f"[ERROR] Failed to create/update endpoint: {e}")
        import traceback
        traceback.print_exc()

## AI Gateway Inference Tables - Important Notes

**AI Gateway is configured automatically** when creating new endpoints (cell 40). For existing endpoints being updated, AI Gateway is enabled separately after the model update.

**Key behaviors:**

1. **Table Creation**: AI Gateway creates the inference table automatically AFTER the first request is made to the endpoint, not when AI Gateway is configured. The table structure is created immediately, but remains empty until requests are logged.

2. **Logging Delay**: There is typically a delay (usually 2-5 minutes) between when an inference request is made and when the request/response data appears in the payload table. This is normal behavior - the data is being processed and written asynchronously.

3. **Verification**: After running the test endpoint cell below, wait a few minutes then query the table to see logged requests:
   ```sql
   SELECT * FROM `{catalog_name}`.`{schema_name}`.`{endpoint_name}_payload`
   ORDER BY timestamp_ms DESC LIMIT 10
   ```

In [0]:
# Check deployment approval 
if 'PROCEED_WITH_DEPLOYMENT' not in dir() or not PROCEED_WITH_DEPLOYMENT:
    print("\n" + "="*60)
    print("⚠️  DEPLOYMENT SKIPPED - NOT APPROVED")
    print("="*60)
    print("\n[Reason]")
    print("   PROCEED_WITH_DEPLOYMENT flag is not set to True")
    print("\n[To Enable Deployment]")
    print("   1. Set 'Proceed with Deployment' widget to 'true' (top of notebook)")
    print("   2. Re-run cell 38 (deployment checkpoint)")
    print("   3. Then run deployment cells 40 and 42")
    print("\n" + "="*60)
else:
    # Deployment approved - proceed with endpoint testing
    import json
    import base64

    print("=" * 60)
    print("TESTING DEPLOYED ENDPOINT")
    print("=" * 60)

    print(f"\nEndpoint: {endpoint_name}\n")

    try:
        test_images = glob.glob(f"{project_location}data/coco128/images/test/*.jpg")
        
        if test_images:
            test_image_path = test_images[-3]
            print(f"Test image: {test_image_path.split('/')[-1]}")
            
            # Encode image as base64
            print(f"[INFO] Encoding image as base64...")
            with open(test_image_path, 'rb') as f:
                image_bytes = f.read()
            image_base64 = base64.b64encode(image_bytes).decode('utf-8')
            
            # Test endpoint with base64 input
            print(f"[INFO] Testing endpoint with base64 input...\n")
            input_data = {"dataframe_records": [{"image_base64": image_base64}]}
            
            response = w.serving_endpoints.query(
                name=endpoint_name,
                dataframe_records=input_data["dataframe_records"]
            )
            
            print(f"[OK] Endpoint test successful!\n")
            print(f"Response preview:")
            response_dict = response.as_dict()
            print(json.dumps(response_dict, indent=2)[:500])
            
            print("\n" + "=" * 60)
            print("[OK] DEPLOYMENT COMPLETE")
            print("=" * 60)
            
            print(f"\n[Deployment Summary]")
            print(f"   - Endpoint: {endpoint_name}")
            print(f"   - Model: {registered_model_name} (v{model_version})")
            print(f"   - Status: Ready and tested")
            print(f"   - AI Gateway: Enabled")
            print(f"   - Input format: Base64-encoded images")
            print(f"   - Inference table: {catalog_name}.{schema_name}.{endpoint_name}_payload")
            
            # Get workspace URL for links
            workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
            endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
            table_url = f"https://{workspace_url}/explore/data/{catalog_name}/{schema_name}/{endpoint_name}_payload"
            
            print(f"\n[Links]")
            print(f"   - Endpoint: {endpoint_url}")
            print(f"   - Inference table: {table_url}")
            
            print(f"\n[Usage Example]")
            print(f"   import base64")
            print(f"   with open('image.jpg', 'rb') as f:")
            print(f"       img_b64 = base64.b64encode(f.read()).decode('utf-8')")
            print(f"   ")
            print(f"   w.serving_endpoints.query(")
            print(f"       name='{endpoint_name}',")
            print(f"       dataframe_records=[{{'image_base64': img_b64}}]")
            print(f"   )")
            
            print(f"\n[Monitor Inference]")
            print(f"   SELECT * FROM {catalog_name}.{schema_name}.{endpoint_name}_payload")
            print(f"   ORDER BY timestamp_ms DESC LIMIT 10")
            
        else:
            print("[WARNING] No test images found")
            
    except Exception as e:
        print(f"[ERROR] Endpoint test failed: {e}")
        import traceback
        traceback.print_exc()
        print(f"\n[INFO] Verify endpoint is ready and AI Gateway is configured")