In [None]:
# Install required packages
!pip install ultralytics supervision roboflow mlflow python-dotenv matplotlib seaborn --quiet

# Import necessary libraries
import os
import time
import glob
import torch
import mlflow
from typing import Dict, Union
from ultralytics import YOLO
from IPython.display import Image as IPyImage, display
from dotenv import load_dotenv

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_curve
import numpy as np

from dataclasses import dataclass
from typing import Optional

# Set up MLflow credentials using Colab secrets
from google.colab import userdata
import os

# User defined inputs
MLFLOW_TRACKING_URI = 'https://dagshub.com/erwincarlogonzales/defect-detection-yolov11.mlflow'
EXPERIMENT_NAME = 'defect-detection-yolov11s-jetson'

# Get credentials and set environment variables
try:
    os.environ.update({
        'MLFLOW_TRACKING_URI': MLFLOW_TRACKING_URI,
        'MLFLOW_TRACKING_USERNAME': userdata.get('MLFLOW_TRACKING_USERNAME'),
        'MLFLOW_TRACKING_PASSWORD': userdata.get('MLFLOW_TRACKING_PASSWORD')
    })

    # Verify MLflow connection
    mlflow.set_experiment(EXPERIMENT_NAME)
    print(f"Successfully connected to MLflow {MLFLOW_TRACKING_URI}")
    print(f"Using experiment: {EXPERIMENT_NAME}")

except Exception as error:
    if 'userdata.get' in str(error):
        raise Exception(f"Failed to get secrets from Colab: {error}")
    raise ConnectionError(f'MLflow setup failed: {error}')

# Check GPU availability
!nvidia-smi

# Configuration

CONFIG = {
    'dataset_url': 'https://app.roboflow.com/ds/JgwZ6J3BBV?key=Q4LLoOljfi',
    'model_type': 'yolo11s.pt',
    'epochs': 40,
    'image_size': 640,
    'conf_threshold': 0.50,
    'patience': 3
}

def generate_model_name(config: Dict[str, Union[str, int, float]], separator: str = "-") -> str:

# Validate epochs value
    if not isinstance(config.get('epochs'), int) or config['epochs'] <= 0:
        raise ValueError("epochs must be a positive integer")

    # Get model type from config, default to empty string
    model_type = config.get('model_type', '')

    # Validate model type
    if not model_type or '.' not in model_type:
        raise ValueError("model_type must be a non-empty string with a file extension")

    # Extract model type without extension
    model_type = model_type.split('.')[0]

    # Get confidence threshold from config, default to 0.25
    conf_threshold = config.get('conf_threshold', 0.25)

    # Convert confidence threshold to integer percentage
    conf_value = str(int(conf_threshold * 100))  # Convert 0.25 to "25"

    # Construct and return the model name
    return f"{model_type}{separator}{config['epochs']}{separator}{conf_value}"

# Set up working directory and dataset
HOME = os.getcwd()
print(f'Working Directory: {HOME}')

# Create dataset directory
!mkdir {HOME}/datasets
%cd {HOME}/datasets

# Download and extract Roboflow dataset
%%time
!curl -L '{CONFIG['dataset_url']}' > roboflow.zip 2>/dev/null
!unzip -q roboflow.zip
!rm roboflow.zip

# Return to home directory
%cd {HOME}

# Visualization Functions
def plot_confusion_matrix_heatmap(confusion_matrix, labels=['No Defect', 'Defect']):

    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels)
    plt.title('Confusion Matrix Heatmap')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    heatmap_path = f'{HOME}/confusion_matrix_heatmap.png'
    plt.savefig(heatmap_path)
    plt.close()

    return heatmap_path

def plot_training_metrics(results_file):
    
    # Load results from YOLO's CSV file
    results = pd.read_csv(results_file)

    plt.figure(figsize=(15, 8))
    metrics_to_plot = {
        'metrics/precision(B)': 'Precision',
        'metrics/recall(B)': 'Recall',
        'metrics/mAP50(B)': 'mAP@0.5',
        'metrics/mAP50-95(B)': 'mAP@0.5:0.95'
    }

    for col, label in metrics_to_plot.items():

        if col in results.columns:
            plt.plot(results.index, results[col], label=label, marker='o')

    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Training Metrics Over Time')
    plt.grid(True)
    plt.legend()

    metrics_path = f'{HOME}/training_metrics.png'
    plt.savefig(metrics_path)
    plt.close()

    return metrics_path

def extract_yolo_metrics(results_file):
    
    # Load results from the CSV file into a pandas DataFrame
    results = pd.read_csv(results_file)

    # Get the last row of the DataFrame, which represents the final epoch's results
    final_epoch = results.iloc[-1]

    # Create a dictionary to store the final metrics
    # Access specific columns from the 'final_epoch' Series using column names
    return {
        'final_precision': final_epoch['metrics/precision(B)'],
        'final_recall': final_epoch['metrics/recall(B)'],
        'final_mAP50': final_epoch['metrics/mAP50(B)'],
        'final_mAP50-95': final_epoch['metrics/mAP50-95(B)']
    }
    
    # Helper functions
def log_training_params(config):
   
    # Log key parameters to MLflow for tracking and reproducibility
    mlflow.log_params({
        'model_type': config['model_type'],         # The type of YOLO model used
        'epochs': config['epochs'],                 # Number of training epochs
        'image_size': config['image_size'],         # Image size used for training
        'conf_threshold': config['conf_threshold']  # Confidence threshold for detections
    })

def train_yolo_model(config):
  
    start_time = time.time()

    # Disable MLflow callback before training to avoid conflicts
    from ultralytics.utils.callbacks.mlflow import callbacks as mlflow_callbacks
    mlflow_callbacks.clear()

    # Start YOLO training using the command-line interface
    !yolo task=detect mode=train \
        model={config['model_type']} \
        data=/content/datasets/data.yaml \
        epochs={config['epochs']} \
        imgsz={config['image_size']} \
        patience={config['patience']} \
        plots=True

    return time.time() - start_time

def get_latest_train_dir(home_dir):
   
    # Use glob to find all directories matching the pattern 'runs/detect/train*'
    train_dirs = glob.glob(f'{home_dir}/runs/detect/train*')

    # If no training directories are found, raise an exception
    if not train_dirs:
        raise FileNotFoundError("No training directories found")

    # Return the directory with the latest modification time
    # (using os.path.getmtime as the key for the max function)
    return max(train_dirs, key=os.path.getmtime)

def process_confusion_matrix(home_dir):
   
    # Get the most recent training directory
    latest_dir = get_latest_train_dir(home_dir)

    # Construct the path to the confusion matrix file within the latest training directory
    confusion_matrix_path = f'{latest_dir}/confusion_matrix.npy'

    # Check if the confusion matrix file exists
    if not os.path.exists(confusion_matrix_path):
        print(f'WARNING: Confusion matrix not found at: {confusion_matrix_path}')
        return None  # Return None if the file is not found

    # Load the confusion matrix from the file
    confusion_matrix = np.load(confusion_matrix_path)

    # Create and log a heatmap visualization of the confusion matrix
    heatmap_path = plot_confusion_matrix_heatmap(confusion_matrix)
    mlflow.log_artifact(heatmap_path, 'visualizations')

    # Calculate and log confusion matrix metrics (TP, TN, FP, FN)
    tn, fp, fn, tp = confusion_matrix.ravel()  # Flatten the matrix
    metrics = {
        'true_positives': tp,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn
    }
    mlflow.log_metrics(metrics)

    # Return the confusion matrix
    return confusion_matrix

def process_training_metrics(home_dir):
    
    # Get the most recent training directory
    latest_dir = get_latest_train_dir(home_dir)

    # Construct the path to the results CSV file within the latest training directory
    results_file = f'{latest_dir}/results.csv'

    # Check if the results file exists
    if not os.path.exists(results_file):
        print(f'WARNING: Results file not found at: {results_file}')
        return  # Return early if the file is not found

    # Create and log a plot of the training metrics over time
    metrics_path = plot_training_metrics(results_file)
    mlflow.log_artifact(metrics_path, 'visualizations')

    # Extract and log final YOLO metrics (precision, recall, mAP)
    yolo_metrics = extract_yolo_metrics(results_file)
    mlflow.log_metrics(yolo_metrics)

def log_yolo_artifacts(home_dir):
 
    # Get the most recent training directory
    latest_dir = get_latest_train_dir(home_dir)
    print(f"Logging artifacts from: {latest_dir}")

    # Define a dictionary mapping artifact names to their file paths
    artifact_paths = {
        'confusion_matrix': f'{latest_dir}/confusion_matrix.png',
        'results': f'{latest_dir}/results.png',
        'validation_predictions': f'{latest_dir}/val_batch0_pred.jpg'
    }

    # Log each artifact to MLflow if it exists
    logged_artifacts = 0  # Initialize a counter for logged artifacts

    for name, path in artifact_paths.items():   # Iterate through the artifact paths
        if os.path.exists(path):                # Check if the artifact file exists
            mlflow.log_artifact(path, name)     # Log the artifact to MLflow
            logged_artifacts += 1               # Increment the logged artifacts counter
            print(f"Successfully logged {name}")  # Print a success message
        else:
            print(f'WARNING: Artifact not found: {path}')  # Print a warning if not found

def display_training_results(width: int = 800, height=None, show_titles: bool = True):
  
    # Get the most recent training directory
    latest_dir = get_latest_train_dir(HOME)
    print(f"Loading results from: {latest_dir}")

    # Define a dictionary mapping image titles to their file paths
    results_config = {
        'Confusion Matrix': f'{latest_dir}/confusion_matrix.png',
        'Training Results': f'{latest_dir}/results.png',
        'Validation Batch Predictions': f'{latest_dir}/val_batch0_pred.jpg'
    }

    # Check if each image file exists and store it in 'available_images' if it does
    available_images = {}
    for title, path in results_config.items():
        if os.path.exists(path):
            available_images[title] = path
        else:
            print(f"Warning: {title} not found at {path}")

    # If no image files were found, print a message and return
    if not available_images:
        print("No result images found to display")
        return

    # Create IPython Image objects for each available image
    images = [IPyImage(filename=path, width=width, height=height)
              for path in available_images.values()]

    # Display the images with or without titles
    if show_titles:
        for title, image in zip(available_images.keys(), images):
            print(f'\n{title}')
            display(image)
    else:
        display(*images)

def display_predictions(num_images: int = 3, image_width: int = 600):
  
    # Find all prediction folders using glob, sorting by modification time to get the latest
    prediction_folders = glob.glob('/content/runs/detect/predict*/')

    # Check if any prediction folders were found
    if not prediction_folders:
        print("No prediction folders found. Ensure the prediction step ran successfully.")
        return  # Return early if no folders are found

    # Get the most recent prediction folder based on modification time
    latest_folder = max(prediction_folders, key=os.path.getmtime)

    # Get a sorted list of prediction image paths within the latest folder
    pred_images = sorted(glob.glob(f'{latest_folder}/*.jpg'))

    # Print a heading indicating the number of predictions to be displayed
    print(f'Top {num_images} Predictions:\n')

    # Display the specified number of prediction images
    for img in pred_images[:num_images]:
        display(IPyImage(filename=img, width=image_width))  # Display the image using IPython.display.Image
        print('\n')  # Add a newline for spacing between images.
        
# Train model with MLflow tracking
%%time
with mlflow.start_run(run_name=generate_model_name(CONFIG)) as run:  # Start a new MLflow run with a generated run name
    try:
        print("Starting training process...")

        # 1. Log training parameters to MLflow
        log_training_params(CONFIG)
        print("Parameters logged successfully")

        # 2. Train the YOLO model and log training time
        print("Starting model training...")
        training_time = train_yolo_model(CONFIG)
        mlflow.log_metric('training_time', training_time)
        print("Training completed")

        # Give some time for files to be written (to avoid race conditions)
        time.sleep(2)

        # Save the best and last model files as MLflow artifacts
        print("Saving model files...")
        latest_dir = get_latest_train_dir(HOME)
        model_paths = {
            'best_model': f'{latest_dir}/weights/best.pt',
            'last_model': f'{latest_dir}/weights/last.pt'
        }
        for name, path in model_paths.items():
            if os.path.exists(path):
                mlflow.log_artifact(path, "models")
                print(f"Saved {name} successfully")
            else:
                print(f'WARNING: Model file not found: {path}')

        # Export and save ONNX model
        print("Exporting to ONNX format...")
        model = YOLO(f'{latest_dir}/weights/best.pt')
        model.export(format='onnx',
                    opset=11,
                    simplify=True,
                    dynamic=True,
                    half=False)

        # Log ONNX model to MLflow
        onnx_path = f'{latest_dir}/weights/best.onnx'
        if os.path.exists(onnx_path):
            mlflow.log_artifact(onnx_path, "models")
            print("ONNX model saved and logged successfully")
        else:
            print('WARNING: ONNX model file not found')

        # 3. Process and log confusion matrix metrics
        print("Processing confusion matrix...")
        confusion_matrix = process_confusion_matrix(HOME)

        # 4. Process and log training metrics (from results.csv)
        print("Processing training metrics...")
        process_training_metrics(HOME)

        # 5. Log other YOLO artifacts (e.g., confusion matrix image, results plot)
        print("Logging YOLO artifacts...")
        log_yolo_artifacts(HOME)

        print('Training and logging completed successfully!')

    except Exception as error:
        print(f'Error during training or logging: {str(error)}')
        raise  # Re-raise the exception to stop execution and provide the error message
    
# Display training results function and execution
display_training_results()

# Validate model with MLflow tracking
%%time
with mlflow.start_run(run_id=run.info.run_id):  # Start a new MLflow run using the existing run ID

    start_time = time.time()  # Record the start time for validation

    !yolo task=detect mode=val \
        model={HOME}/runs/detect/train/weights/best.pt \
        data=/content/datasets/data.yaml

    validation_time = time.time() - start_time  # Calculate the validation time

    mlflow.log_metric('validation_time', validation_time)  # Log the validation time to MLflow
    
# Run predictions with MLflow tracking
%%time
with mlflow.start_run(run_id=run.info.run_id):  # Start a new MLflow run using the existing run ID

    start_time = time.time()  # Record the start time for inference

    # Run YOLO predictions using the command-line interface
    !yolo task=detect mode=predict \
        model={HOME}/runs/detect/train/weights/best.pt \
        conf={CONFIG['conf_threshold']} \
        source=/content/datasets/test/images \
        save=True > /dev/null 2>&1

    inference_time = time.time() - start_time  # Calculate the inference time

    mlflow.log_metric('inference_time', inference_time)  # Log the inference time to MLflow

print('Predictions completed!')  # Print a message indicating that predictions are finished

# Display predictions
display_predictions()

# Register Model and Promote to Production
from mlflow.tracking import MlflowClient

def register_model_version(run_id: str, model_name: str):
    
    try:
        # Construct the model URI using the run ID
        model_uri = f'runs:/{run_id}/model'

        # Register the model with MLflow.
        model_details = mlflow.register_model(
            model_uri=model_uri,
            name=model_name
        )

        print(f'Model registered successfully with version: {model_details.version}')

        # Optionally, add a description to the registered model
        client = mlflow.MlflowClient()
        client.update_registered_model(
            name=model_name,
            description='YOLO-based defect detection model'
        )

        return model_details  # Return the registered model details

    except Exception as error:
        print(f'Error registering model: {str(error)}')

        return None  # Return None if registration fails

def promote_challenger_to_production(model_name: str, prod_name: str):
   
    try:
        client = MlflowClient()  # Create an MLflow client

        # Construct the URI for the challenger model
        current_model_uri = f"models:/{model_name}@challenger"

        # Copy the challenger model version to the production model name.
        production_model = client.copy_model_version(
            src_model_uri=current_model_uri,
            dst_name=prod_name
        )

        print(f'Successfully promoted challenger to Production as {prod_name}')

        return production_model  # Return the production model details

    except Exception as error:
        print(f'Error promoting model: {str(error)}')

        return None  # Return None if promotion fails

# Register Model Operation
def register_model_operation(run_id: str, config: dict):
   
    model_name = generate_model_name(config)  # Generate the model name using the config

    model_details = register_model_version(run_id, model_name)  # Register the model version

    if model_details:  # Check if registration was successful
        print('\nModel Registration Details:')
        print(f'Name: {model_details.name}')
        print(f'Version: {model_details.version}')
        print(f'Stage: {model_details.status}')

        return model_details  # Return the model details if registration was successful

    return None  # Return None if registration failed

# Promote Challenger Model
def promote_challenger_operation(model_name: str, prod_name: str = 'defect-detection-production'):
    
    production_model = promote_challenger_to_production(model_name, prod_name)  # Promote the model

    if production_model:  # Check if promotion was successful
        print(f'\nProduction Model Details:')
        print(f'Name: {production_model.name}')
        print(f'Version: {production_model.version}')

        return production_model  # Return the production model details if promotion was successful

    return None  # Return None if promotion failed

# Register Model
run_id = 'run id from MLflow UI'  # Get the run ID from the MLflow UI
model_details = register_model_operation(run_id, CONFIG)  # Register the model using the run ID and configuration

# Promotion (after setting challenger alias in UI)
if model_details:  # Proceed with promotion only if model registration was successful
    prod_name = 'defect-detection-v11-android-production'  # Define the desired name for the production model
    production_model = promote_challenger_operation(model_details.name, prod_name)  # Promote the challenger model to production