# Production-Grade SageMaker Training & HPO Launcher

This notebook provides a sophisticated and robust interface for launching and managing VideoMAE pre-training jobs and Hyperparameter Optimization (HPO) tasks on Amazon SageMaker. It is designed for production-level workflows, incorporating best practices for configuration, data handling, and job management.

**Key Features:**

1.  **Centralized Configuration:** All parameters are defined in a single block for easy management.
2.  **SageMaker Data Channels:** Uses SageMaker's native, efficient data channels instead of manual data syncing.
3.  **Cost-Effective Spot Training:** Integrated support for SageMaker Managed Spot Instances.
4.  **Live Log Streaming:** Training job logs are streamed directly into the notebook for real-time monitoring.
5.  **Model Registry Integration:** A full workflow to register the best model from an HPO job into the SageMaker Model Registry.

## 1. Global Configuration

All user-configurable parameters are defined in this cell. Adjust these values to match your environment and training requirements.

In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Production-Grade SageMaker Training & HPO Launcher\n",
    "\n",
    "This notebook provides a sophisticated and robust interface for launching and managing VideoMAE pre-training jobs and Hyperparameter Optimization (HPO) tasks on Amazon SageMaker. It is designed for production-level workflows, incorporating best practices for configuration, data handling, and job management.\n",
    "\n",
    "**Key Features:**\n",
    "1.  **Centralized Configuration:** All parameters are defined in a single block for easy management.\n",
    "2.  **SageMaker Data Channels:** Uses SageMaker's native, efficient data channels instead of manual data syncing.\n",
    "3.  **Cost-Effective Spot Training:** Integrated support for SageMaker Managed Spot Instances.\n",
    "4.  **Live Log Streaming:** Training job logs are streamed directly into the notebook for real-time monitoring.\n",
    "5.  **Model Registry Integration:** A full workflow to register the best model from an HPO job into the SageMaker Model Registry."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Global Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# All user-configurable parameters are defined in this cell. Adjust these values to match your environment and training requirements.\n",
    "\n",
    "# Path to the tiny dataset for Free Tier validation\n",
    "S3_DATA_PATH = \"s3://miqa-data/ssv2-tiny\"\n",
    "\n",
    "S3_OUTPUT_PREFIX = 'videomae-free-tier-test'\n",
    "\n",
    "PROJECT_NAME = 'videomae-free-tier-test'\n",
    "\n",
    "# Critical: Use a Free Tier eligible instance type\n",
    "INSTANCE_TYPE = 'ml.m5.xlarge'\n",
    "\n",
    "INSTANCE_COUNT = 1\n",
    "\n",
    "FRAMEWORK_VERSION = '2.0'\n",
    "\n",
    "PYTHON_VERSION = 'py310'\n",
    "\n",
    "SOURCE_DIR = '../scripts'\n",
    "\n",
    "ENTRY_POINT = 'train.py'\n",
    "\n",
    "# Spot instances are not used for this test to ensure it runs\n",
    "USE_SPOT_INSTANCES = False\n",
    "\n",
    "# Limit run time to 1 hour to stay within budget\n",
    "MAX_RUN_SECONDS = 3600\n",
    "\n",
    "MAX_WAIT_SECONDS = 3600 * 2 # Not used when spot is false, but set for safety\n",
    "\n",
    "# These keys MUST match the command-line arguments in the entry_point script (train.py).\n",
    "# For SageMaker, it's a best practice to use kebab-case for hyperparameters.\n",
    "HYPERPARAMETERS = {\n",
    "    'total-epochs': 1,\n",
    "    'learning-rate': 1.5e-4,\n",
    "    'batch-size': 2,\n",
    "    'warmup-epochs': 0\n",
    "}\n",
    "\n",
    "# HPO settings are no longer used\n",
    "# HPO_MAX_JOBS = 10\n",
    "# HPO_MAX_PARALLEL_JOBS = 2\n",
    "\n",
    "# Model Registry settings are no longer used\n",
    "# MODEL_PACKAGE_GROUP_NAME = PROJECT_NAME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Session Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This section initializes the SageMaker and Boto3 sessions and retrieves the necessary execution role. This notebook assumes it is being run from a SageMaker environment (like SageMaker Studio or a Notebook Instance) where the execution role is automatically configured.\n",
    "import sagemaker\n",
    "import boto3\n",
    "import os\n",
    "import time\n",
    "from sagemaker.pytorch import PyTorch\n",
    "# Tuner is no longer used\n",
    "# from sagemaker.tuner import HyperparameterTuner, IntegerParameter, ContinuousParameter\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = sagemaker_session.boto_region_name\n",
    "s3_output_path = f\"s3://{sagemaker_session.default_bucket()}/{S3_OUTPUT_PREFIX}\"\n",
    "\n",
    "print(f\"SageMaker SDK Version: {sagemaker.__version__}\")\n",
    "print(f\"Region: {region}\")\n",
    "print(f\"IAM Role: {role}\")\n",
    "print(f\"S3 Data Input Path: {S3_DATA_PATH}\")\n",
    "print(f\"S3 Model Output Path: {s3_output_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Define the SageMaker PyTorch Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This estimator is the core component that defines our training environment. \n",
    "# It will be used for our single training job.\n",
    "\n",
    "# Define metrics that SageMaker will parse from the training job logs.\n",
    "metric_definitions = [\n",
    "    {'Name': 'train:loss', 'Regex': 'Training-Loss: ([0-9\\\\.]+)'}\n",
    "]\n",
    "\n",
    "estimator = PyTorch(\n",
    "    entry_point=ENTRY_POINT,\n",
    "    source_dir=SOURCE_DIR,\n",
    "    role=role,\n",
    "    instance_count=INSTANCE_COUNT,\n",
    "    instance_type=INSTANCE_TYPE,\n",
    "    framework_version=FRAMEWORK_VERSION,\n",
    "    py_version=PYTHON_VERSION,\n",
    "    hyperparameters=HYPERPARAMETERS,\n",
    "    output_path=s3_output_path,\n",
    "    metric_definitions=metric_definitions,\n",
    "    use_spot_instances=USE_SPOT_INSTANCES,\n",
    "    max_run=MAX_RUN_SECONDS,\n",
    "    max_wait=MAX_WAIT_SECONDS if USE_SPOT_INSTANCES else None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Launch a Single Training Job (Free Tier Test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The HPO and Model Registration steps have been removed for this cost-effective test.\n",
    "# We will now launch a single, standard training job using the estimator defined above.\n",
    "import time\n",
    "\n",
    "training_job_name = f\"{PROJECT_NAME}-test-{int(time.time())}\"\n",
    "print(f\"Launching a single training job on a Free Tier instance: {training_job_name}\")\n",
    "\n",
    "try:\n",
    "    # Use the 'estimator' you configured in the cell above\n",
    "    estimator.fit({'training': S3_DATA_PATH}, job_name=training_job_name, wait=True)\n",
    "    \n",
    "    print(\"\\nJob completed successfully!\")\n",
    "    print(f\"Model artifacts saved to: {estimator.model_data}\")\n",
    "\n",
    "except Exception as e:\n",
    "    print(f\"\\nError launching training job: {e}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (PyTorch 2.0 Python 3.10 CPU Optimized)",
   "language": "python",
   "name": "pytorch-2.0-cpu-py310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


## 2. Session Setup

This section initializes the SageMaker and Boto3 sessions and retrieves the necessary execution role. This notebook assumes it is being run from a SageMaker environment (like SageMaker Studio or a Notebook Instance) where the execution role is automatically configured.

In [None]:
import sagemaker
import boto3
import os
import time
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import HyperparameterTuner, IntegerParameter, ContinuousParameter

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
region = sagemaker_session.boto_region_name
s3_output_path = f"s3://{sagemaker_session.default_bucket()}/{S3_OUTPUT_PREFIX}"

print(f"SageMaker SDK Version: {sagemaker.__version__}")
print(f"Region: {region}")
print(f"IAM Role: {role}")
print(f"S3 Data Input Path: {S3_DATA_PATH}")
print(f"S3 Model Output Path: {s3_output_path}")

## 3. Define the SageMaker PyTorch Estimator

This estimator is the core component that defines our training environment. It will be used for both single training jobs and as the base for our hyperparameter tuning job.

In [None]:
# Define metrics that SageMaker will parse from the training job logs.
metric_definitions = [
    {'Name': 'train:loss', 'Regex': 'Training-Loss: ([0-9\.]+)'}
]

estimator = PyTorch(
    entry_point=ENTRY_POINT,
    source_dir=SOURCE_DIR,
    role=role,
    instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    framework_version=FRAMEWORK_VERSION,
    py_version=PYTHON_VERSION,
    hyperparameters=HYPERPARAMETERS,
    output_path=s3_output_path,
    metric_definitions=metric_definitions,
    use_spot_instances=USE_SPOT_INSTANCES,
    max_run=MAX_RUN_SECONDS,
    max_wait=MAX_WAIT_SECONDS if USE_SPOT_INSTANCES else None
)

## 4. Launch a Hyperparameter Tuning Job

To find the best hyperparameters automatically, we use SageMaker's `HyperparameterTuner`.

In [None]:
# The keys in this dictionary must match the hyperparameter names defined in the estimator.
hyperparameter_ranges = {
    'learning-rate': ContinuousParameter(1e-5, 1e-3),
    'batch-size': IntegerParameter(4, 16)
}

tuner = HyperparameterTuner(
    estimator=estimator,
    objective_metric_name='train:loss',
    hyperparameter_ranges=hyperparameter_ranges,
    objective_type='Minimize',
    max_jobs=HPO_MAX_JOBS,
    max_parallel_jobs=HPO_MAX_PARALLEL_JOBS,
    base_tuning_job_name=f"{PROJECT_NAME}-hpo"
)

tuning_job_name = f"{PROJECT_NAME}-hpo-{int(time.time())}"
print(f"Launching hyperparameter tuning job: {tuning_job_name}")
print(f"Check progress: https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/hyper-tuning-jobs/{tuning_job_name}")

try:
    tuner.fit({'training': S3_DATA_PATH}, job_name=tuning_job_name, wait=True)
except Exception as e:
    print(f"\nError launching tuning job: {e}")

## 5. Analyze Tuning Results and Register the Best Model

Once the tuning job is complete, we find the best-performing training job and register its model artifacts into the SageMaker Model Registry for versioning and deployment.

In [None]:
try:
    tuner_analyzer = sagemaker.analytics.HyperparameterTuningJobAnalytics(tuning_job_name)
    best_job_name = tuner_analyzer.best_training_job()['TrainingJobName']
    print(f"Best training job found: {best_job_name}")

    best_estimator = PyTorch.attach(best_job_name)
    model_artifacts = best_estimator.model_data
    print(f"Model artifacts for the best job are at: {model_artifacts}")

    sm_client = sagemaker_session.sagemaker_client
    try:
        sm_client.create_model_package_group(
            ModelPackageGroupName=MODEL_PACKAGE_GROUP_NAME,
            ModelPackageGroupDescription=f"Models for {PROJECT_NAME}"
        )
        print(f"Created Model Package Group: {MODEL_PACKAGE_GROUP_NAME}")
    except sm_client.exceptions.ClientError as e:
        if 'Name already exists' in str(e):
            print(f"Model Package Group '{MODEL_PACKAGE_GROUP_NAME}' already exists.")
        else:
            raise

    model_package = best_estimator.register(
        content_types=["application/x-video"],
        response_types=["application/json"],
        inference_instances=["ml.g4dn.xlarge"],
        transform_instances=["ml.m5.large"],
        model_package_group_name=MODEL_PACKAGE_GROUP_NAME,
        approval_status="PendingManualApproval"
    )
    print(f"\nSuccessfully registered model version: {model_package.model_package_arn}")

except Exception as e:
    print(f"Could not analyze or register model. It may have failed. Error: {e}")