# Production-Grade SageMaker Training Launcher for VideoMAE

This notebook provides a sophisticated and robust interface for launching and managing VideoMAE pre-training jobs on Amazon SageMaker. It is designed for production-level workflows, incorporating best practices for configuration, cost management, error handling, and MLOps.

**Key Features:**

1.  **Centralized Configuration:** All parameters are defined in a single block for easy management.
2.  **Cost-Effective Spot Training:** Integrated support for SageMaker Managed Spot Instances.
3.  **Live Log Streaming:** Training job logs are streamed directly into the notebook for real-time monitoring.
4.  **Model Registry Integration:** A full workflow to register the trained model into the SageMaker Model Registry for versioning and deployment.
5.  **Robust Error Handling & Recovery:** Jobs are launched within `try...except` blocks, and examples are provided for attaching to existing jobs.
6.  **Framework Synchronization:** PyTorch and Python versions are aligned with the project's dependencies.

## 1. Setup and Session Configuration

In [None]:
import sagemaker
import boto3
import os
import time
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import HyperparameterTuner, IntegerParameter, ContinuousParameter
from sagemaker.inputs import TrainingInput
from sagemaker.model_metrics import ModelMetrics, MetricsSource
from sagemaker.workflow.parameters import ParameterString

print(f"SageMaker SDK Version: {sagemaker.__version__}")

In [None]:
sagemaker_session = sagemaker.Session()
boto_session = boto3.Session()

region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()
account_id = sagemaker_session.boto_principal_arn.split(':')[4]

role = sagemaker.get_execution_role()

print(f"Region: {region}")
print(f"Account ID: {account_id}")
print(f"IAM Role: {role}")
print(f"S3 Bucket: {bucket}")

## 2. Training Job Configuration

All job-related parameters are defined here. This centralized approach simplifies customization.

In [None]:
project_name = 'videomae-pretraining-production'

# --- S3 Paths ---
# Your dataset should be uploaded here before running the notebook.
s3_data_path = f's3://{bucket}/{project_name}/data'
s3_output_path = f's3://{bucket}/{project_name}/output'

# --- Training Script ---
source_dir = '../scripts/'
entry_point = 'train.py'

# --- Framework and Instance ---
# mmcv==2.1.0 is compatible with torch==2.0 and cuda 11.8
framework_version = '2.0'
py_version = 'py310'
instance_type = 'ml.p3.2xlarge'
input_mode = 'File' # Use 'File', 'Pipe', or 'FastFile' depending on dataset size

# --- Cost-Saving: Spot Instance Configuration ---
use_spot_instances = True
max_run_seconds = 3600 * 4  # Max training time (4 hours)
max_wait_seconds = 3600 * 6  # Max time to wait for a spot instance (6 hours)

# --- Model Registry Configuration ---
model_package_group_name = project_name

### IAM Role Permissions

Ensure the SageMaker Execution Role (`{role}`) has the following permissions:
- `AmazonS3FullAccess` (or scoped-down access to the specific `s3_data_path` and `s3_output_path`).
- `AmazonSageMakerFullAccess`.
- `iam:PassRole` on the role itself.
- If using the Model Registry, permissions like `sagemaker:CreateModelPackageGroup`, `sagemaker:CreateModelPackage`, etc.

## 3. Launch a Single, Monitored Training Job

In [None]:
hyperparameters = {
    'epochs': 15,
    'lr': 1.5e-4,
    'batch-size': 8,
    'warmup-epochs': 2
}

base_job_name = f"{project_name}-single"
training_job_name = f"{base_job_name}-{int(time.time())}"

metric_definitions = [
    {'Name': 'train:loss', 'Regex': 'Training-Loss: ([0-9\.]+)'}
]

estimator = PyTorch(
    entry_point=entry_point,
    source_dir=source_dir,
    role=role,
    instance_count=1,
    instance_type=instance_type,
    framework_version=framework_version,
    py_version=py_version,
    hyperparameters=hyperparameters,
    output_path=s3_output_path,
    metric_definitions=metric_definitions,
    input_mode=input_mode,
    use_spot_instances=use_spot_instances,
    max_run=max_run_seconds,
    max_wait=max_wait_seconds if use_spot_instances else None
)

training_input = TrainingInput(s3_data=s3_data_path, content_type='application/x-video')

print(f"Launching single training job: {training_job_name}")

try:
    estimator.fit({'training': training_input}, job_name=training_job_name, wait=False)
    print(f"Job launched successfully. You can view it in the SageMaker console.")
    print(f"Streaming logs for job '{training_job_name}':")
    sagemaker_session.logs_for_job(job_name=training_job_name, wait=True)
except Exception as e:
    print(f"\nError launching or monitoring training job: {e}")
    print(f"Check CloudWatch logs for job '{training_job_name}' for more details.")

## 4. Launch a Hyperparameter Tuning Job

In [None]:
hyperparameter_ranges = {
    'lr': ContinuousParameter(1e-5, 1e-3)
}

tuner = HyperparameterTuner(
    estimator=estimator, # Re-uses the same estimator configuration
    objective_metric_name='train:loss',
    hyperparameter_ranges=hyperparameter_ranges,
    objective_type='Minimize',
    max_jobs=10,
    max_parallel_jobs=2,
    base_tuning_job_name=f"{project_name}-hpo"
)

tuning_job_name = f"{project_name}-hpo-{int(time.time())}"
print(f"Launching hyperparameter tuning job: {tuning_job_name}")

try:
    tuner.fit({'training': training_input}, job_name=tuning_job_name, wait=True)
except Exception as e:
    print(f"\nError launching tuning job: {e}")

## 5. Analyze Tuning Results and Register the Best Model

In [None]:
try:
    tuner_analyzer = sagemaker.analytics.HyperparameterTuningJobAnalytics(tuning_job_name)
    best_job_name = tuner_analyzer.best_training_job()['TrainingJobName']
    print(f"Best training job found: {best_job_name}")

    # Attach to the best estimator
    best_estimator = PyTorch.attach(best_job_name)
    model_artifacts = best_estimator.model_data
    print(f"Model artifacts for the best job are at: {model_artifacts}")

    # Create a SageMaker Model Package Group for versioning
    sm_client = sagemaker_session.sagemaker_client
    try:
        sm_client.create_model_package_group(
            ModelPackageGroupName=model_package_group_name,
            ModelPackageGroupDescription=f"Models for {project_name}"
        )
        print(f"Created Model Package Group: {model_package_group_name}")
    except sm_client.exceptions.ClientError as e:
        if e.response['Error']['Code'] == 'ValidationException':
            print(f"Model Package Group '{model_package_group_name}' already exists.")
        else:
            raise

    # Register the model
    model_package = best_estimator.register(
        content_types=["application/x-video"],
        response_types=["application/json"],
        inference_instances=["ml.g4dn.xlarge"], # Example instance for deployment
        transform_instances=["ml.m5.large"],
        model_package_group_name=model_package_group_name,
        approval_status="PendingManualApproval"
    )
    print(f"\nSuccessfully registered model version: {model_package.model_package_arn}")

except Exception as e:
    print(f"Could not analyze or register model. It may have failed. Error: {e}")

## 6. Cleanup (Optional)

To avoid incurring costs, you can delete the resources created. If you deployed an endpoint, delete it first. Here, we show how to delete the model versions and the model group.

In [None]:
# sm_client = sagemaker_session.sagemaker_client
# try:
#     # List all versions in the model group
#     model_packages = sm_client.list_model_packages(ModelPackageGroupName=model_package_group_name)['ModelPackageSummaryList']
#     for mp in model_packages:
#         print(f"Deleting model package: {mp['ModelPackageArn']}")
#         sm_client.delete_model_package(ModelPackageName=mp['ModelPackageArn'])
    
#     # Delete the model group itself
#     print(f"Deleting model package group: {model_package_group_name}")
#     sm_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name)
#     print("Cleanup complete.")
# except sm_client.exceptions.ClientError as e:
#     print(f"Could not perform cleanup. Error: {e}")