# Phase 1: Training Orchestration

This notebook orchestrates all training activities without performing local computation.

## Overview

- **Step 1**: Load Centralized Configs
- **Step 2**: Data Ingestion & Versioning (Asset Layer)
- **Step 3**: Environment Definition
- **Step 4**: The Dry Run
- **Step 5**: The Sweep (HPO)
- **Step 6**: Best Configuration Selection (Automated)
- **Step 7**: Final Training (Post-HPO, Single Run)

## Important

- This notebook **only submits and monitors Azure ML jobs**
- **No training logic** is executed locally
- All computation happens remotely on Azure ML compute
- The notebook must be **re-runnable end-to-end**

## Platform Adapter Architecture

The training and conversion scripts (`src/train.py` and `src/convert_to_onnx.py`) use a **platform adapter pattern** that automatically detects the execution environment (Azure ML vs local) and adapts accordingly:

- **Output paths**: Automatically resolves Azure ML output directories via `AZURE_ML_OUTPUT_*` environment variables
- **Logging**: Handles both MLflow and Azure ML native logging seamlessly
- **MLflow context**: Manages MLflow runs appropriately for each platform
- **Checkpoint resolution**: Handles Azure ML mounted inputs and local file paths

This architecture allows the same code to run consistently on both Azure ML and local setups. When jobs run in Azure ML, the adapters automatically detect the Azure ML environment and use Azure-specific implementations. See `docs/PLATFORM_ADAPTER_ARCHITECTURE.md` for details.


## Step P1-3.1: Load Centralized Configs

Load and validate all configuration files. Configs are immutable and will be logged with each job for reproducibility.

**Note**: The training and conversion scripts executed by these jobs use platform adapters that automatically handle Azure ML-specific concerns (output paths, logging, MLflow context) without requiring explicit configuration in this notebook.


In [None]:
# !pip install azureml-mlflow --quiet

In [None]:
import os
from pathlib import Path
from typing import Dict, Any

from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
from dotenv import load_dotenv

# Ensure we can import the orchestration package and shared utilities
import sys
ROOT_DIR = Path("..").resolve()
SRC_DIR = ROOT_DIR / "src"
sys.path.append(str(ROOT_DIR))
sys.path.append(str(SRC_DIR))

from shared.yaml_utils import load_yaml
from shared.json_cache import save_json, load_json
from orchestration import (
    STAGE_SMOKE,
    STAGE_HPO,
    STAGE_TRAINING,
    EXPERIMENT_NAME,
    MODEL_NAME,
    PROD_STAGE,
    build_aml_experiment_name,
)
from orchestration.config_loader import (
    ExperimentConfig,
    create_config_metadata,
    load_all_configs,
    load_experiment_config,
    compute_config_hashes,
    snapshot_configs,
    validate_config_immutability,
)


env_path = Path("../config.env")
if env_path.exists():
    load_dotenv(env_path)


In [None]:
CONFIG_DIR = Path("../config")

# Experiment selection (switch to try different data/model/HPO/env combos)
# The concrete experiment definition lives in config/experiment/<EXPERIMENT_NAME>.yaml

# Resolve experiment-level config into concrete file paths
experiment_config: ExperimentConfig = load_experiment_config(CONFIG_DIR, EXPERIMENT_NAME)
configs = load_all_configs(experiment_config)
config_hashes = compute_config_hashes(configs)

# Immutable snapshots for runtime mutation checks
original_configs = snapshot_configs(configs)


In [None]:
# Reuse shared immutability validator from orchestration package
validate_config_immutability(configs, original_configs)


In [None]:
def get_workspace_name(configs: Dict[str, Any]) -> str:
    """Resolve the Azure ML workspace name from configuration files.

    Order of precedence:
    1. ``config/infrastructure.yaml`` (``workspace.name``)
    2. ``config/env/azure.yaml`` (``workspace.name`` under ``env`` config)

    This function is pure: it only reads configuration objects and files,
    and does not perform any network or Azure ML operations.
    """
    infrastructure_config_path = Path("../config/infrastructure.yaml")
    if infrastructure_config_path.exists():
        infrastructure_config = load_yaml(infrastructure_config_path)
        return infrastructure_config["workspace"]["name"]

    env_workspace = configs["env"].get("workspace", {}).get("name")
    if env_workspace:
        return env_workspace

    raise ValueError(
        "Workspace name must be configured in either "
        "config/infrastructure.yaml (workspace.name) or config/env/azure.yaml (workspace.name)."
    )


def create_ml_client(configs: Dict[str, Any]) -> MLClient:
    """Create an MLClient instance for the configured Azure ML workspace.

    This function is responsible for reading required environment variables
    and instantiating the Azure ML client. It assumes that configuration
    loading has already completed.
    """
    subscription_id = os.getenv("AZURE_SUBSCRIPTION_ID")
    resource_group = os.getenv("AZURE_RESOURCE_GROUP")

    if not subscription_id or not resource_group:
        raise ValueError("AZURE_SUBSCRIPTION_ID and AZURE_RESOURCE_GROUP must be set")

    workspace_name = get_workspace_name(configs)
    credential = DefaultAzureCredential()
    return MLClient(
        credential=credential,
        subscription_id=subscription_id,
        resource_group_name=resource_group,
        workspace_name=workspace_name,
    )


In [None]:
# Instantiate MLClient for the configured workspace
ml_client = create_ml_client(configs)



All configs and their hashes will be attached to each Azure ML job for full reproducibility.


In [None]:
# Build config metadata for job tagging using shared helper from
# `orchestration.config_loader`.
config_metadata = create_config_metadata(configs, config_hashes)


## Step P1-3.2: Data Ingestion & Versioning (Asset Layer)

Upload dataset to Blob Storage and register as an Azure ML Data Asset for versioned, immutable data access.

**Note**: The training script accepts data asset paths and can work with both Azure ML data assets (when running in Azure ML) and local file paths (when running locally), thanks to the platform adapter architecture.


In [None]:
from orchestration.data_assets import (
    resolve_dataset_path,
    register_data_asset,
    ensure_data_asset_uploaded,
    build_data_asset_reference,
)

# Resolve local dataset path from data config (configs["data"]["local_path"])
DATASET_LOCAL_PATH = resolve_dataset_path(configs["data"])
DATA_ASSET_NAME = configs["data"]["name"]
DATA_ASSET_VERSION = configs["data"]["version"]


In [None]:
DATA_ASSET_OVERRIDE_PATH = None
blob_uri = DATA_ASSET_OVERRIDE_PATH or str(DATASET_LOCAL_PATH)


In [None]:
data_asset = register_data_asset(
    ml_client=ml_client,
    name=DATA_ASSET_NAME,
    version=DATA_ASSET_VERSION,
    uri=blob_uri,
    description=configs["data"]["description"],
)

# Best-effort upload of local content to the resolved data asset
data_asset = ensure_data_asset_uploaded(
    ml_client=ml_client,
    data_asset=data_asset,
    local_path=DATASET_LOCAL_PATH,
    description=configs["data"]["description"],
)

# Build shared references for downstream jobs
asset_paths = build_data_asset_reference(ml_client, data_asset)
asset_reference = asset_paths["asset_uri"]
datastore_path = asset_paths["datastore_path"]


### Troubleshooting

If you encounter `ScriptExecution.StreamAccess.NotFound`, verify that:
1. Compute cluster has managed identity assigned
2. Managed identity has "Storage Blob Data Reader" role on storage account
3. Storage account firewall allows Azure services


In [None]:
# Save data asset info to a JSON file
data_asset_cache_file = Path("data_asset_cache.json")

if "data_asset" in globals() and data_asset is not None:
    data_asset_info = {
        "name": data_asset.name,
        "version": data_asset.version,
        "asset_paths": asset_paths,
    }

    save_json(data_asset_cache_file, data_asset_info)
    print(
        f"Saved data asset: {data_asset_info['name']} {data_asset_info['version']} "
        f"to {data_asset_cache_file}"
    )
else:
    print("No data asset to save")

### Logged Results

In [None]:
from orchestration.data_assets import build_data_asset_reference

# Try to reload from cache
data_asset_cache_file = Path("data_asset_cache.json")

data_asset_info = load_json(data_asset_cache_file, default=None)

if data_asset_info is None:
    print(
        f"Cache file {data_asset_cache_file} not found. "
        "Will need to register data asset."
    )
    data_asset = None
else:
    try:
        # Reload Data asset object from ML client
        data_asset = ml_client.data.get(
            name=data_asset_info["name"],
            version=data_asset_info["version"],
        )

        # Rebuild asset_paths if they were saved, otherwise regenerate them
        asset_paths = data_asset_info.get("asset_paths") or build_data_asset_reference(
            ml_client, data_asset
        )

        asset_reference = asset_paths["asset_uri"]
        datastore_path = asset_paths["datastore_path"]

        print(f"Loaded data asset: {data_asset.name} v{data_asset.version}")
        print(f"Asset URI: {asset_reference}")
        print("Skipping data asset registration - using cached asset")
    except Exception as e:
        print(
            "Warning: Could not load data asset "
            f"{data_asset_info['name']} v{data_asset_info['version']}: {e}"
        )
        print("Will need to register data asset again")
        data_asset = None

## Step P1-3.3: Environment Definition

Define a stable execution environment (Docker image + Conda dependencies) for consistent behavior across all training jobs.


In [None]:
from orchestration.environment import (
    build_environment_config,
    create_training_environment,
    prepare_environment_image,
)

# Build environment configuration from env.yaml (with sensible defaults)
env_config = build_environment_config(CONFIG_DIR, configs["env"])

# Materialize or fetch the Azure ML Environment
training_environment = create_training_environment(ml_client, env_config)

# Trigger a small warm-up job so the image is built/cached before real work
prepare_environment_image(
    ml_client=ml_client,
    environment=training_environment,
    compute_cluster=configs["env"]["compute"]["training_cluster"],
    env_config=env_config,
)


In [None]:
# Save environment info to a JSON file
env_cache_file = Path("training_environment_cache.json")

if 'training_environment' in globals() and training_environment is not None:
    env_data = {
        "name": training_environment.name,
        "version": training_environment.version,
    }

    save_json(env_cache_file, env_data)
    print(f"Saved training environment: {env_data['name']} v{env_data['version']} to {env_cache_file}")
else:
    print("No training environment to save")

### Logged Results

In [None]:
# Try to reload from cache
env_cache_file = Path("training_environment_cache.json")

env_data = load_json(env_cache_file, default=None)

if env_data is None:
    print(f"Cache file {env_cache_file} not found. Will need to create environment.")
    training_environment = None
else:
    try:
        # Reload Environment object from ML client
        training_environment = ml_client.environments.get(
            name=env_data["name"],
            version=env_data["version"]
        )
        print(f"Loaded training environment: {training_environment.name} v{training_environment.version}")
        print("Skipping environment setup - using cached environment")
    except Exception as e:
        print(f"Warning: Could not load environment {env_data['name']} v{env_data['version']}: {e}")
        print("Will need to create environment again")
        training_environment = None

## Step P1-3.4: The Dry Run

Submit a minimal sweep job using `smoke.yaml` to validate the sweep mechanism and pipeline integrity before launching the production HPO sweep.


In [None]:
from orchestration.jobs import (
    create_dry_run_sweep_job_for_backbone,
    submit_and_wait_for_job,
    validate_sweep_job,
)

TRAINING_SCRIPT_PATH = Path("../src/train.py")


In [None]:
compute_cluster_name = configs["env"]["compute"]["training_cluster"]

try:
    compute_cluster = ml_client.compute.get(compute_cluster_name)
    if compute_cluster.provisioning_state != "Succeeded":
        raise ValueError(f"Compute cluster not ready: {compute_cluster.provisioning_state}")
except Exception as e:
    raise RuntimeError(f"Compute cluster '{compute_cluster_name}' not accessible: {e}")

stage_name = STAGE_SMOKE
smoke_hpo_config = configs["hpo"]

# Backbones are controlled by the HPO config file (single source of truth)
backbone_values = smoke_hpo_config["search_space"]["backbone"]["values"]

dry_run_sweep_jobs = {}

for backbone in backbone_values:
    aml_experiment_name = build_aml_experiment_name(
        experiment_name=experiment_config.name,
        stage=stage_name,
        backbone=backbone,
    )
    dry_run_sweep_jobs[backbone] = create_dry_run_sweep_job_for_backbone(
        script_path=TRAINING_SCRIPT_PATH,
        data_asset=data_asset,
        environment=training_environment,
        compute_cluster=compute_cluster_name,
        backbone=backbone,
        smoke_hpo_config=smoke_hpo_config,
        configs=configs,
        config_metadata=config_metadata,
        aml_experiment_name=aml_experiment_name,
        stage=stage_name,
    )


In [None]:
for backbone, sweep_job in dry_run_sweep_jobs.items():
    completed_job = submit_and_wait_for_job(ml_client, sweep_job)
    validate_sweep_job(
        job=completed_job,
        backbone=backbone,
        job_type="Dry run sweep",
        ml_client=ml_client,
    )


## Step P1-3.5: The Sweep (HPO)

Submit a hyperparameter optimization sweep to systematically search for the best model configuration.

**Note**: Currently using `smoke.yaml` for demonstration purposes (CPU-only setup). For production with GPU, switch to `prod.yaml` in the configuration.

**Platform Adapter Note**: Each training trial in the sweep automatically uses the platform adapter to handle Azure ML-specific concerns. The adapter ensures consistent behavior across all trials regardless of the execution environment.


In [None]:
from orchestration.jobs import (
    create_hpo_sweep_job_for_backbone,
    submit_and_wait_for_job,
    validate_sweep_job,
)

TRAINING_SCRIPT_PATH = Path("../src/train.py")


In [None]:
compute_cluster_name = configs["env"]["compute"]["training_cluster"]

try:
    compute_cluster = ml_client.compute.get(compute_cluster_name)
    if compute_cluster.provisioning_state != "Succeeded":
        raise ValueError(f"Compute cluster not ready: {compute_cluster.provisioning_state}")
except Exception as e:
    raise RuntimeError(f"Compute cluster '{compute_cluster_name}' not accessible: {e}")

stage_name = STAGE_HPO
hpo_config = configs["hpo"]
backbone_values = hpo_config["search_space"]["backbone"]["values"]
hpo_sweep_jobs = {}

for backbone in backbone_values:
    aml_experiment_name = build_aml_experiment_name(
        experiment_name=experiment_config.name,
        stage=stage_name,
        backbone=backbone,
    )
    hpo_sweep_jobs[backbone] = create_hpo_sweep_job_for_backbone(
        script_path=TRAINING_SCRIPT_PATH,
        data_asset=data_asset,
        environment=training_environment,
        compute_cluster=compute_cluster_name,
        hpo_config=hpo_config,
        backbone=backbone,
        aml_experiment_name=aml_experiment_name,
        stage=stage_name,
        configs=configs,
        config_metadata=config_metadata,
    )


In [None]:
hpo_completed_jobs = {}

for backbone, sweep_job in hpo_sweep_jobs.items():
    completed_job = submit_and_wait_for_job(ml_client, sweep_job)
    validate_sweep_job(
        job=completed_job,
        backbone=backbone,
        job_type="HPO sweep",
        min_expected_trials=2,
        ml_client=ml_client,
    )
    hpo_completed_jobs[backbone] = completed_job


In [None]:
# Save HPO job references to a JSON file
hpo_jobs_cache_file = Path("hpo_completed_jobs_cache.json")

if hpo_completed_jobs:
    hpo_jobs_data = {
        backbone: {
            "job_name": job.name,
            "job_id": job.id,
        }
        for backbone, job in hpo_completed_jobs.items()
    }

    save_json(hpo_jobs_cache_file, hpo_jobs_data)
    print(f"Saved {len(hpo_jobs_data)} HPO job references to {hpo_jobs_cache_file}")
else:
    print("No HPO completed jobs to save")

### Logged Results

In [None]:
# Try to reload HPO jobs from cache
hpo_jobs_cache_file = Path("hpo_completed_jobs_cache.json")

hpo_jobs_data = load_json(hpo_jobs_cache_file, default=None)

if hpo_jobs_data is None:
    print(f"Cache file {hpo_jobs_cache_file} not found. Will need to run HPO.")
    hpo_completed_jobs = {}
else:
    hpo_completed_jobs = {}
    for backbone, job_info in hpo_jobs_data.items():
        try:
            job = ml_client.jobs.get(job_info["job_name"])
            hpo_completed_jobs[backbone] = job
            print(f"Loaded HPO job for {backbone}: {job.name} (status: {job.status})")
        except Exception as e:
            print(f"Warning: Could not load job {job_info['job_name']} for {backbone}: {e}")

    if hpo_completed_jobs:
        print(f"\nSuccessfully reloaded {len(hpo_completed_jobs)} HPO completed jobs from cache")
    else:
        print("No valid jobs found in cache, will need to run HPO again")
        hpo_completed_jobs = {}

## Step P1-3.6: Best Configuration Selection (Automated)

Programmatically select the best configuration from all HPO sweep runs across all backbone models.


In [None]:
from orchestration.jobs import select_best_configuration


In [None]:
# Select the best configuration from all HPO sweep runs
best_configuration = select_best_configuration(
    ml_client=ml_client,
    hpo_completed_jobs=hpo_completed_jobs,
    hpo_config=configs["hpo"],
    dataset_version=configs["data"]["version"],
)


In [None]:
# Save best configuration to a JSON file
best_config_cache_file = Path("best_configuration_cache.json")

if "best_configuration" in globals() and best_configuration is not None:
    # best_configuration contains trial_name, trial_id, backbone, hyperparameters, metrics, etc.
    # All of these are JSON-serializable
    save_json(best_config_cache_file, best_configuration)
    print(f"Saved best configuration to {best_config_cache_file}")
    print(f"  Backbone: {best_configuration.get('backbone')}")
    print(f"  Best metric value: {best_configuration.get('selection_criteria', {}).get('best_value')}")
else:
    print("No best configuration to save")


### Logged Results


In [None]:
# Try to reload from cache
best_config_cache_file = Path("best_configuration_cache.json")

best_configuration = load_json(best_config_cache_file, default=None)

if best_configuration is None:
    print(f"Cache file {best_config_cache_file} not found. Will need to run Step P1-3.6.")
else:
    print(f"Loaded best configuration from cache:")
    print(f"  Backbone: {best_configuration.get('backbone')}")
    print(f"  Trial: {best_configuration.get('trial_name')}")
    print(f"  Best metric value: {best_configuration.get('selection_criteria', {}).get('best_value')}")
    print(f"  Dataset version: {best_configuration.get('dataset_version')}")
    print(f"\nSkipping best configuration selection - using cached result")


## Step P1-3.7: Final Training (Post-HPO, Single Run)

Train the final production model using the best configuration from HPO with stable, controlled conditions.


In [None]:
from orchestration.jobs import (
    build_final_training_config,
    create_final_training_job,
    validate_final_training_job,
    submit_and_wait_for_job
)

TRAINING_SCRIPT_PATH = Path("../src/train.py")


In [None]:
# Build final training config from best HPO result + train.yaml defaults
final_training_config = build_final_training_config(best_configuration, configs["train"])


In [None]:
final_training_config

In [None]:
compute_cluster_name = configs["env"]["compute"]["training_cluster"]

try:
    compute_cluster = ml_client.compute.get(compute_cluster_name)
    if compute_cluster.provisioning_state != "Succeeded":
        raise ValueError(f"Compute cluster not ready: {compute_cluster.provisioning_state}")
except Exception as e:
    raise RuntimeError(f"Compute cluster '{compute_cluster_name}' not accessible: {e}")
    
# Create and submit final training job
stage_name = STAGE_TRAINING
aml_experiment_name = build_aml_experiment_name(
    experiment_name=experiment_config.name,
    stage=stage_name,
    backbone=final_training_config["backbone"],
)

final_training_tags = {
    **config_metadata,
    "job_type": "final_training",
    "backbone": final_training_config["backbone"],
    "best_trial": best_configuration["trial_name"],
    "best_metric_value": str(best_configuration["selection_criteria"]["best_value"]),
    "stage": stage_name,
}

final_training_job = create_final_training_job(
    script_path=TRAINING_SCRIPT_PATH,
    data_asset=data_asset,
    environment=training_environment,
    compute_cluster=compute_cluster_name,
    final_config=final_training_config,
    aml_experiment_name=aml_experiment_name,
    tags=final_training_tags,
)


In [None]:
# Submit and validate final training job
final_training_completed_job = submit_and_wait_for_job(ml_client, final_training_job)
validate_final_training_job(final_training_completed_job)


In [None]:
final_training_cache_file = Path("final_training_job_cache.json")

if "final_training_completed_job" in globals() and final_training_completed_job is not None:
    data = {
        "job_name": final_training_completed_job.name,
        "job_id": final_training_completed_job.id,
    }
    save_json(final_training_cache_file, data)
    print(f"Saved final training job reference to {final_training_cache_file}")
else:
    print("No final training job to save")

### Logged Results

In [None]:
final_training_cache_file = Path("final_training_job_cache.json")

data = load_json(final_training_cache_file, default=None)

if data is None:
    print(f"Cache file {final_training_cache_file} not found. Will need to run Step P1-3.7: Final Training.")
    final_training_completed_job = None
else:
    try:
        final_training_completed_job = ml_client.jobs.get(data["job_name"])
        print(f"Loaded final training job: {final_training_completed_job.name} (status: {final_training_completed_job.status})")
        
        # Validate that the job has a checkpoint output
        if not hasattr(final_training_completed_job, "outputs") or "checkpoint" not in final_training_completed_job.outputs:
            print(f"\n⚠️  WARNING: Training job {final_training_completed_job.name} does not have a 'checkpoint' output.")
            print("   This job cannot be used for model conversion.")
            print("   Please re-run Step P1-3.7: Final Training to generate a new job with checkpoint output.")
            final_training_completed_job = None
        else:
            print("✓ Training job has checkpoint output")
    except Exception as e:
        print(f"Could not reload final training job {data['job_name']}: {e}")
        final_training_completed_job = None


## Step P1-4: Model Conversion & Optimization

Convert the final training checkpoint to an optimized ONNX model (int8 quantized) for production inference.

**Platform Adapter Note**: The conversion script (`src/convert_to_onnx.py`) uses the platform adapter to:
- Resolve checkpoint paths from Azure ML mounted inputs
- Handle output paths for the ONNX model (via `AZURE_ML_OUTPUT_onnx_model`)
- Manage logging and MLflow context appropriately

The adapter automatically detects the Azure ML environment and uses the appropriate implementations.


In [None]:
final_training_completed_job

In [None]:
from orchestration.jobs import (
    get_checkpoint_output_from_training_job,
    create_conversion_job,
    validate_conversion_job,
    submit_and_wait_for_job,
)

CONVERSION_SCRIPT_PATH = Path("../src/convert_to_onnx.py")


In [None]:
# Guard: ensure final_training_completed_job is set and has checkpoint output
if "final_training_completed_job" not in globals() or final_training_completed_job is None:
    raise ValueError(
        "final_training_completed_job is not set. "
        "Please run Step P1-3.7: Final Training first, or ensure the cached job has a checkpoint output."
    )

# Guard: ensure ml_client is defined (required for fetching checkpoint data asset)
if "ml_client" not in globals() or ml_client is None:
    raise ValueError(
        "ml_client is not defined. "
        "Please run the cells that set up ml_client (Step P1-3.1) before running this cell."
    )

checkpoint_output = get_checkpoint_output_from_training_job(final_training_completed_job, ml_client=ml_client)
print(f"✓ Retrieved checkpoint output: {checkpoint_output}")


In [None]:
conversion_cluster_name = configs["env"]["compute"]["conversion_cluster"]
conversion_experiment_name = configs["env"]["logging"]["experiment_name"]

conversion_tags = {
    **config_metadata,
    "job_type": "model_conversion",
    "backbone": best_configuration["backbone"],
    "source_training_job": final_training_completed_job.name,
    "quantization": "int8",
}

conversion_job = create_conversion_job(
    script_path=CONVERSION_SCRIPT_PATH,
    checkpoint_uri=str(checkpoint_output),
    environment=training_environment,
    compute_cluster=conversion_cluster_name,
    backbone=best_configuration["backbone"],
    experiment_name=conversion_experiment_name,
    tags=conversion_tags,
)


In [None]:
conversion_completed_job = submit_and_wait_for_job(ml_client, conversion_job)
validate_conversion_job(conversion_completed_job, ml_client=ml_client)


In [None]:
conversion_cache_file = Path("conversion_job_cache.json")

if "conversion_completed_job" in globals() and conversion_completed_job is not None:
    data = {
        "job_name": conversion_completed_job.name,
        "job_id": conversion_completed_job.id,
    }
    save_json(conversion_cache_file, data)
    print(f"Saved conversion job reference to {conversion_cache_file}")
else:
    print("No conversion job to save")


### Logged Results


In [None]:
conversion_cache_file = Path("conversion_job_cache.json")

data = load_json(conversion_cache_file, default=None)

if data is None:
    print(f"Cache file {conversion_cache_file} not found. Will need to run Step P1-4: Model Conversion.")
    conversion_completed_job = None
else:
    try:
        conversion_completed_job = ml_client.jobs.get(data["job_name"])
        print(f"Loaded conversion job: {conversion_completed_job.name} (status: {conversion_completed_job.status})")
        
        # Validate that the job has an onnx_model output
        if not hasattr(conversion_completed_job, "outputs") or "onnx_model" not in conversion_completed_job.outputs:
            print(f"\n⚠️  WARNING: Conversion job {conversion_completed_job.name} does not have an 'onnx_model' output.")
            print("   This job cannot be used for model registration.")
            print("   Please re-run Step P1-4: Model Conversion to generate a new job with ONNX model output.")
            conversion_completed_job = None
        else:
            print("✓ Conversion job has ONNX model output")
    except Exception as e:
        print(f"Could not reload conversion job {data['job_name']}: {e}")
        conversion_completed_job = None



## Step P1-5: Model Registration (The Handover)

Register the optimized ONNX model in Azure ML Model Registry with full metadata for production deployment.

**Platform Adapter Note**: The conversion job's ONNX model output is automatically handled by the platform adapter. The model path is resolved from the Azure ML job output and registered in the model registry with full traceability back to the training and conversion jobs.


In [None]:
from azure.ai.ml.entities import Model
from azure.core.exceptions import ResourceNotFoundError


In [None]:
def get_onnx_model_path(conversion_job: Job) -> str:
    """
    Get ONNX model path from completed conversion job.
    
    Args:
        conversion_job: Completed conversion job
        
    Returns:
        str: ONNX model path (Azure ML datastore URI)
        
    Raises:
        ValueError: If ONNX model not found in job outputs
    """
    if not hasattr(conversion_job, "outputs") or not conversion_job.outputs:
        raise ValueError("Conversion job produced no outputs")
    
    if "onnx_model" not in conversion_job.outputs:
        raise ValueError("Conversion job missing 'onnx_model' output")
    
    onnx_output = conversion_job.outputs["onnx_model"]
    
    if hasattr(onnx_output, "path"):
        return onnx_output.path
    elif isinstance(onnx_output, str):
        return onnx_output
    else:
        raise ValueError(f"Unexpected ONNX output type: {type(onnx_output)}")


onnx_model_path = get_onnx_model_path(conversion_completed_job)


In [None]:
def compute_model_version(best_config: Dict[str, Any], config_hashes: Dict[str, str]) -> str:
    """
    Compute deterministic model version from configuration hashes.
    
    Args:
        best_config: Best configuration from HPO selection
        config_hashes: Configuration hashes dictionary
        
    Returns:
        str: Model version string
    """
    version_components = [
        config_hashes["data"],
        config_hashes["model"],
        config_hashes["train"],
        best_config["backbone"],
    ]
    version_str = "_".join(version_components)
    version_hash = hashlib.sha256(version_str.encode()).hexdigest()[:CONFIG_HASH_LENGTH]
    return f"v{version_hash}"


model_version = compute_model_version(best_configuration, config_hashes)


In [None]:
def register_production_model(
    ml_client: MLClient,
    model_name: str,
    model_version: str,
    model_path: str,
    best_config: Dict[str, Any],
    configs: Dict[str, Any],
    config_metadata: Dict[str, str],
) -> Model:
    """
    Register optimized ONNX model in Azure ML Model Registry.
    
    Args:
        ml_client: MLClient instance
        model_name: Model name in registry
        model_version: Model version
        model_path: Path to ONNX model (Azure ML datastore URI)
        best_config: Best configuration from HPO selection
        configs: Configuration dictionaries
        config_metadata: Configuration metadata for tagging
        
    Returns:
        Model: Registered model instance
        
    Raises:
        ValueError: If model path is invalid
    """
    if not model_path or not model_path.endswith(".onnx"):
        raise ValueError(f"Invalid ONNX model path: {model_path}")
    
    selection_criteria = best_config["selection_criteria"]
    
    model_description = (
        f"Production ONNX model for Resume NER. "
        f"Backbone: {selection_criteria['backbone']}, "
        f"Metric: {selection_criteria['metric']}={selection_criteria['best_value']:.4f}"
    )
    
    model_tags = {
        **config_metadata,
        "stage": PROD_STAGE,
        "backbone": selection_criteria["backbone"],
        "metric": selection_criteria["metric"],
        "metric_value": str(selection_criteria["best_value"]),
        "dataset_version": best_config["dataset_version"],
        "model_format": "onnx",
        "quantization": "int8",
        "source_training_job": final_training_completed_job.name,
        "source_conversion_job": conversion_completed_job.name,
    }
    
    model = Model(
        name=model_name,
        version=model_version,
        description=model_description,
        path=model_path,
        tags=model_tags,
    )
    
    try:
        existing_model = ml_client.models.get(name=model_name, version=model_version)
        return existing_model
    except ResourceNotFoundError:
        return ml_client.models.create_or_update(model)


registered_model = register_production_model(
    ml_client=ml_client,
    model_name=MODEL_NAME,
    model_version=model_version,
    model_path=onnx_model_path,
    best_config=best_configuration,
    configs=configs,
    config_metadata=config_metadata,
)


In [None]:
def validate_registered_model(model: Model) -> None:
    """
    Validate registered model has required metadata and tags.
    
    Args:
        model: Registered model instance
        
    Raises:
        ValueError: If validation fails
    """
    required_tags = ["stage", "backbone", "metric", "dataset_version"]
    for tag in required_tags:
        if tag not in model.tags:
            raise ValueError(f"Registered model missing required tag: {tag}")
    
    if model.tags.get("stage") != PROD_STAGE:
        raise ValueError(f"Model stage must be '{PROD_STAGE}', got: {model.tags.get('stage')}")
    
    if not model.path or not model.path.endswith(".onnx"):
        raise ValueError(f"Invalid model path: {model.path}")


validate_registered_model(registered_model)
