# DeepSeek OCR Pipeline on SageMaker Training Jobs

This notebook runs a three-stage OCR pipeline using SageMaker Training Jobs:

1. **Extract** – Run DeepSeek OCR over a dataset, save Markdown and crop detected figures
2. **Describe** – Generate captions for extracted figures  
3. **Assemble** – Enrich Markdown with figure captions

This is the SageMaker equivalent of the HuggingFace Jobs pipeline. It uses SageMaker ModelTrainer V3
with a vLLM container to run GPU-accelerated inference.

**Key difference from HF Jobs:** This notebook saves datasets to S3 instead of HuggingFace Hub.

## Prerequisites

- AWS credentials configured
- SageMaker execution role with S3 access
- HuggingFace token for accessing source models and datasets
- SageMaker SDK V3 installed (`pip install sagemaker --upgrade`)

In [None]:
!pip3 install sagemaker --upgrade --quiet

In [None]:
!pip install -U "datasets>=4.0.0" "s3fs" "fsspec"

In [None]:
import os
import json
import shutil
import tempfile
import time
from pathlib import Path

import boto3
import sagemaker
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode, Compute, StoppingCondition, OutputDataConfig
from sagemaker.core.helper.session_helper import Session, get_execution_role

## Configuration

In [None]:
# Initialize SageMaker session
sagemaker_session = Session()
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker-dlcs')['Role']['Arn']
region = sagemaker_session.boto_region_name
account_id = boto3.client("sts").get_caller_identity()["Account"]

print(f"Region: {region}")
print(f"Account: {account_id}")
print(f"Role: {role}")

In [None]:
# Pipeline Configuration
PROJECT_NAME = "deepseek-ocr-sagemaker"
BUCKET_NAME = sagemaker_session.default_bucket()
S3_PREFIX = f"{PROJECT_NAME}"

# S3 output path (single location for all stages - dataset gets updated in place)
S3_OUTPUT_URI = f"s3://{BUCKET_NAME}/{S3_PREFIX}"

# vLLM Container - use SageMaker vLLM DLC
TRAINING_IMAGE = f"763104351884.dkr.ecr.{region}.amazonaws.com/vllm:0.12.0-gpu-py312-cu129-ubuntu22.04-sagemaker-v1.0"

# Instance configuration
INSTANCE_TYPE = "ml.g5.2xlarge"  # Single A10G GPU
# INSTANCE_TYPE = "ml.p4d.24xlarge"  # 8x A100 GPUs for larger scale
VOLUME_SIZE_GB = 100
MAX_RUNTIME_SECONDS = 3 * 60 * 60  # 3 hours

# Source dataset (from HuggingFace)
SOURCE_DATASET = "HuggingFaceM4/FineVision"
SOURCE_CONFIG = "olmOCR-mix-0225-documents"
MAX_SAMPLES = 20  # Start small for testing

# HuggingFace token for accessing source datasets
HF_TOKEN = os.environ.get("HF_TOKEN", "")

print(f"S3 Bucket: s3://{BUCKET_NAME}/{S3_PREFIX}")
print(f"S3 Output URI: {S3_OUTPUT_URI}")
print(f"Instance: {INSTANCE_TYPE}")
print(f"Source: {SOURCE_DATASET}/{SOURCE_CONFIG} ({MAX_SAMPLES} samples)")

## Bundle Pipeline Code

SageMaker automatically uploads this bundle to S3 and makes it available at `/opt/ml/input/data/code`.

In [None]:
# Paths to pipeline code
CODE_PATHS = [
    Path("entry.sh"),
    Path("sm_job_runner.py"),
    Path("../llm_ocr"),
]

# Create a source directory bundle
source_dir = Path(tempfile.mkdtemp(prefix="sm-ocr-code-"))

for path in CODE_PATHS:
    src = Path.cwd() / path if not path.is_absolute() else path
    if src.is_dir():
        shutil.copytree(src, source_dir / path.name, dirs_exist_ok=True)
    else:
        shutil.copy2(src, source_dir / path.name)

print(f"Source directory: {source_dir}")
print(f"Contents: {list(source_dir.iterdir())}")

In [None]:
# Dependencies are declared in sm_job_runner.py inline metadata (PEP 723)
# entry.sh installs uv and runs: uv run sm_job_runner.py
# This automatically installs all dependencies

## Define Base Environment Variables

In [None]:
# Base environment variables for all stages
# All configuration is passed via environment variables (same as HF Jobs)
BASE_ENV = {
    # vLLM configuration
    "MODEL_ID": "deepseek-ai/DeepSeek-OCR",
    "SERVED_MODEL_NAME": "deepseek-ocr",
    "HOST": "0.0.0.0",
    "PORT": "8000",
    "MAX_MODEL_LEN": "8192",
    "GPU_MEMORY_UTILIZATION": "0.90",
    "TENSOR_PARALLEL_SIZE": "1",
    
    # HuggingFace authentication (for source datasets)
    "HF_TOKEN": HF_TOKEN,
    "HF_HUB_ENABLE_HF_TRANSFER": "1",
    
    # Prompts
    "DOC_PROMPT": "<image>\n<|grounding|>Convert this document to Markdown.",
    "DOC_MAX_TOKENS": "4096",
    "DOC_TEMPERATURE": "0.1",
    "FIGURE_PROMPT": "<image>\nDescribe this image in detail.",
    "FIGURE_MAX_TOKENS": "512",
    "FIGURE_TEMPERATURE": "0.6",
}

## Helper Functions

In [None]:
def launch_stage(stage: str, env: dict = None):
    """Launch a pipeline stage as a SageMaker Training Job.
    
    Args:
        stage: Pipeline stage (extract, describe, assemble)
        env: Stage-specific environment variables (optional)
        
    Returns:
        ModelTrainer object
    """
    import uuid
    
    job_name = f"{PROJECT_NAME}-{stage}-{uuid.uuid4().hex[:8]}"
    
    # Merge base env with stage-specific env
    full_env = {**BASE_ENV, "PIPELINE_STAGE": stage}
    if env:
        full_env.update(env)
    
    # Create trainer
    trainer = ModelTrainer(
        sagemaker_session=sagemaker_session,
        role=role,
        image_uri=TRAINING_IMAGE,
        training_mode="Heterogeneous",
        source_code=SourceCode(
            source_dir=str(source_dir),
            entry_script="entry.sh",
        ),
        compute=Compute(
            instance_type=INSTANCE_TYPE,
            instance_count=1,
            volume_size_in_gb=VOLUME_SIZE_GB,
        ),
        stopping_condition=StoppingCondition(
            max_runtime_in_seconds=MAX_RUNTIME_SECONDS,
        ),
        output_data_config=OutputDataConfig(
            s3_output_path=f"s3://{BUCKET_NAME}/{S3_PREFIX}/output/",
        ),
        base_job_name=job_name,
        environment=full_env,
    )
    
    print(f"Launching {stage} stage: {job_name}")
    trainer.train(wait=False)
    
    return trainer


def wait_for_job(trainer, poll_interval: int = 30, timeout: int = 10800):
    """Wait for a SageMaker Training Job to complete."""
    import time
    
    job_name = trainer._latest_job.name
    sm_client = boto3.client('sagemaker')
    start_time = time.time()
    
    print(f"Waiting for job {job_name}...")
    
    while time.time() - start_time < timeout:
        response = sm_client.describe_training_job(TrainingJobName=job_name)
        status = response['TrainingJobStatus']
        
        if status == 'Completed':
            print(f"  {job_name}: Completed ✓")
            return response
        elif status == 'Failed':
            print(f"  {job_name}: Failed ✗")
            print(f"  Reason: {response.get('FailureReason', 'Unknown')}")
            return response
        elif status == 'Stopped':
            print(f"  {job_name}: Stopped")
            return response
        else:
            print(f"  {job_name}: {status}...")
        
        time.sleep(poll_interval)
    
    raise TimeoutError(f"Job {job_name} did not complete within {timeout}s")


# Import IO and rendering utilities from llm_ocr
import sys
sys.path.insert(0, '..')  # Add parent directory for llm_ocr imports
from llm_ocr.sm_io import load_dataset_from_s3
from llm_ocr.document import render_sample_markdown, display_markdown


def display_samples(dataset, num_samples: int = 2):
    """Display a few samples from the dataset."""
    from IPython.display import display
    
    print(f"Dataset: {len(dataset)} samples")
    print(f"Columns: {list(dataset.column_names)}")
    print()
    
    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        print(f"=== Sample {i}: {sample['sample_id']} ===")
        
        if sample.get('source_image'):
            print("Source image:")
            display(sample['source_image'])
        
        md = sample.get('document_markdown') or sample.get('document_markdown_text', '')
        if md:
            print(f"\nMarkdown preview ({len(md)} chars):")
            print(md[:500] + '...' if len(md) > 500 else md)
        
        final_md = sample.get('document_final_markdown') or sample.get('document_final_markdown_text', '')
        if final_md:
            print(f"\nFinal markdown preview ({len(final_md)} chars):")
            print(final_md[:500] + '...' if len(final_md) > 500 else final_md)
        
        figures = sample.get('extracted_figures', [])
        if figures:
            print(f"\nExtracted figures: {len(figures)}")
            for fig in figures[:2]:
                display(fig)
        print()


## Stage 1: Extract

Run OCR on the source dataset to extract markdown and figures.
Output is saved to S3 (not HF Hub).

In [None]:
# Stage 1: Extract
# Output dataset will be saved to S3
stage1_env = {
    # Source dataset (from HuggingFace)
    "DATASET_NAME": SOURCE_DATASET,
    "DATASET_CONFIG": SOURCE_CONFIG,
    "DATASET_SPLIT": "train",
    "MAX_SAMPLES": str(MAX_SAMPLES),
    
    # Local output directory
    "OUTPUT_DIR": "./outputs",
    
    # Batch settings
    "EXTRACT_BATCH_SIZE": "16",
    "EXTRACT_MAX_CONCURRENCY": "4",
    
    # S3 output (single location for all stages)
    "S3_OUTPUT_URI": S3_OUTPUT_URI,
}

stage1_trainer = launch_stage("extract", stage1_env)

In [None]:
# Wait for Stage 1 to complete
stage1_result = wait_for_job(stage1_trainer)
print(f"Extract stage completed: {stage1_result['TrainingJobStatus']}")

In [None]:
# Load and display samples after Extract
ds_extract = load_dataset_from_s3(f"{S3_OUTPUT_URI}/dataset")
display_samples(ds_extract, num_samples=2)

## Stage 2: Describe

Generate captions for extracted figures.
Input is read from S3 (output of Stage 1), output is saved to S3.

In [None]:
# Stage 2: Describe
# Updates dataset in place (same location as extract)
stage2_env = {
    # Local output directory
    "OUTPUT_DIR": "./outputs",
    
    # Batch settings
    "DESCRIBE_BATCH_SIZE": "8",
    "DESCRIBE_MAX_CONCURRENCY": "4",
    
    # S3 input and output (same location - updates in place)
    "S3_INPUT_URI": f"{S3_OUTPUT_URI}/dataset",
    "S3_OUTPUT_URI": S3_OUTPUT_URI,
}

stage2_trainer = launch_stage("describe", stage2_env)

In [None]:
# Wait for Stage 2 to complete
stage2_result = wait_for_job(stage2_trainer)
print(f"Describe stage completed: {stage2_result['TrainingJobStatus']}")

In [None]:
# Load and display samples after Describe
ds_describe = load_dataset_from_s3(f"{S3_OUTPUT_URI}/dataset")
display_samples(ds_describe, num_samples=2)

## Stage 3: Assemble

Enrich markdown with figure captions to create the final dataset.
Input is read from S3 (output of Stage 2), output is saved to S3.

In [None]:
# Stage 3: Assemble
# Updates dataset in place + saves final markdown files
stage3_env = {
    # Local output directory
    "OUTPUT_DIR": "./outputs",
    
    # S3 input and output (same location - updates in place)
    "S3_INPUT_URI": f"{S3_OUTPUT_URI}/dataset",
    "S3_OUTPUT_URI": S3_OUTPUT_URI,
    
    # Assemble stage doesn't need GPU
    "SKIP_SERVER_LAUNCH": "true",
}

stage3_trainer = launch_stage("assemble", stage3_env)

In [None]:
# Wait for Stage 3 to complete
stage3_result = wait_for_job(stage3_trainer)
print(f"Assemble stage completed: {stage3_result['TrainingJobStatus']}")

In [None]:
# Load and display final samples after Assemble
ds_final = load_dataset_from_s3(f"{S3_OUTPUT_URI}/dataset")
display_samples(ds_final, num_samples=2)

## Pipeline Complete!

The OCR pipeline has finished. Your dataset is available in S3:

In [None]:
print(f"\n" + "="*60)
print("Pipeline Complete!")
print("="*60)
print(f"\nS3 Output Location: {S3_OUTPUT_URI}")
print(f"  - Dataset: {S3_OUTPUT_URI}/dataset/")
print(f"  - Files: {S3_OUTPUT_URI}/outputs/")
print(f"\nS3 Job Output: s3://{BUCKET_NAME}/{S3_PREFIX}/output/")
print("\nJob Summary:")
for i, (name, result) in enumerate([
    ("Extract", stage1_result),
    ("Describe", stage2_result),
    ("Assemble", stage3_result),
], 1):
    status = result["TrainingJobStatus"]
    print(f"  {i}. {name}: {status}")

## Load Final Dataset from S3

In [None]:
# Load the final assembled dataset from S3
from datasets import load_from_disk
import tempfile

# Download from S3
s3 = boto3.client('s3')
local_dataset_dir = Path(tempfile.mkdtemp(prefix="final-dataset-"))

# Parse S3 URI to get bucket and prefix
s3_parts = S3_OUTPUT_URI.replace('s3://', '').split('/', 1)
bucket = s3_parts[0]
prefix = f"{s3_parts[1]}/dataset/" if len(s3_parts) > 1 else "dataset/"

# List and download all files from the dataset
paginator = s3.get_paginator('list_objects_v2')

for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
    for obj in page.get('Contents', []):
        key = obj['Key']
        rel_path = key[len(prefix):]
        if rel_path:
            local_path = local_dataset_dir / rel_path
            local_path.parent.mkdir(parents=True, exist_ok=True)
            s3.download_file(bucket, key, str(local_path))

# Load the dataset
final_dataset = load_from_disk(str(local_dataset_dir))
print(f"Loaded {len(final_dataset)} samples from {S3_OUTPUT_URI}/dataset/")
print(final_dataset)

## Cleanup (Optional)

In [None]:
# Clean up temporary source directory
try:
    shutil.rmtree(source_dir)
    print(f"Cleaned up: {source_dir}")
except Exception as e:
    print(f"Could not clean up: {e}")