# Deploy SmolDocling on AWS SageMaker

This notebook demonstrates how to deploy DS4SD's SmolDocling model on Amazon SageMaker for document understanding and conversion.

## What is SmolDocling?
SmolDocling is a vision-language model designed for document understanding tasks:
- Convert document images to Markdown, HTML, or DocTags
- Extract structured content from PDFs, scans, and images
- Preserve document layout and formatting
- Support for tables, figures, and complex layouts

## Prerequisites
- AWS Account with SageMaker access
- IAM role with SageMaker permissions
- Document images for testing

## 1. Setup and Configuration

In [None]:
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFaceModel
import json
import base64
from pathlib import Path
import tarfile
import os

# Configuration
role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name

print(f"SageMaker role: {role}")
print(f"Region: {region}")

## 2. Prepare Custom Inference Code

In [None]:
# Create model artifact with custom inference code
code_dir = "../smoldocling_code"
model_artifact = "model.tar.gz"

print(f"Creating model artifact from {code_dir}...")
with tarfile.open(model_artifact, "w:gz") as tar:
    tar.add(code_dir, arcname="code")

# Upload to S3
s3_client = boto3.client("s3")
bucket = sess.default_bucket()
prefix = "smoldocling-model"
s3_path = f"s3://{bucket}/{prefix}/{model_artifact}"

print(f"Uploading to {s3_path}...")
s3_client.upload_file(model_artifact, bucket, f"{prefix}/{model_artifact}")

# Clean up local artifact
os.remove(model_artifact)
print("✓ Model artifact uploaded")

## 3. Deploy the Model

In [None]:
# Model configuration
hub = {
    "HF_MODEL_ID": "ds4sd/docling-project__SmolDocling-v1.0",
    "HF_TASK": "image-to-text",
}

# Create Hugging Face Model
huggingface_model = HuggingFaceModel(
    model_data=s3_path,
    role=role,
    transformers_version="4.37",
    pytorch_version="2.1",
    py_version="py310",
    env=hub,
)

print("Model configuration complete")

In [None]:
# Deploy the model (this takes 5-10 minutes)
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name="smoldocling-endpoint",
    container_startup_health_check_timeout=600,
)

print(f"✓ Endpoint deployed: {predictor.endpoint_name}")

## 4. Test the Endpoint

In [None]:
# Helper function to load and encode image
def encode_image(image_path):
    with open(image_path, "rb") as f:
        image_bytes = f.read()
    return base64.b64encode(image_bytes).decode("utf-8")

# Test with a document image
image_path = "your_document.png"  # Replace with your image path

if Path(image_path).exists():
    image_base64 = encode_image(image_path)
    
    payload = {
        "image": image_base64,
        "prompt": "Convert this page to docling.",
        "output_format": "markdown",  # Options: markdown, html, doctags
        "max_new_tokens": 8192
    }
    
    # Initialize runtime client
    runtime = boto3.client("sagemaker-runtime")
    
    response = runtime.invoke_endpoint(
        EndpointName=predictor.endpoint_name,
        ContentType="application/json",
        Body=json.dumps(payload)
    )
    
    result = json.loads(response["Body"].read().decode())
    print("\n=== Markdown Output ===")
    print(result.get("markdown", ""))
    print("\n=== DocTags ===")
    print(result.get("doctags", "")[:500] + "...")
else:
    print(f"Image not found: {image_path}")
    print("Please provide a document image to test.")

## 5. Document Processing Use Cases

In [None]:
# Example 1: Convert to HTML
payload = {
    "image": image_base64,
    "prompt": "Convert this page to docling.",
    "output_format": "html",
    "max_new_tokens": 8192
}

response = runtime.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    ContentType="application/json",
    Body=json.dumps(payload)
)

result = json.loads(response["Body"].read().decode())
print("\n=== HTML Output ===")
print(result.get("html", "")[:500] + "...")

In [None]:
# Example 2: Extract DocTags only (structured format)
payload = {
    "image": image_base64,
    "prompt": "Convert this page to docling.",
    "output_format": "doctags",
    "max_new_tokens": 8192
}

response = runtime.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    ContentType="application/json",
    Body=json.dumps(payload)
)

result = json.loads(response["Body"].read().decode())
print("\n=== DocTags Output ===")
print(result.get("doctags", ""))

## 6. Batch Processing Example

In [None]:
# Process multiple documents
document_dir = Path("documents")  # Directory with document images

if document_dir.exists():
    results = []
    
    for image_file in document_dir.glob("*.png"):
        print(f"Processing {image_file.name}...")
        
        image_base64 = encode_image(str(image_file))
        payload = {
            "image": image_base64,
            "prompt": "Convert this page to docling.",
            "output_format": "markdown",
            "max_new_tokens": 8192
        }
        
        response = runtime.invoke_endpoint(
            EndpointName=predictor.endpoint_name,
            ContentType="application/json",
            Body=json.dumps(payload)
        )
        
        result = json.loads(response["Body"].read().decode())
        results.append({
            "filename": image_file.name,
            "markdown": result.get("markdown", ""),
            "doctags": result.get("doctags", "")
        })
    
    print(f"\n✓ Processed {len(results)} documents")
else:
    print(f"Directory not found: {document_dir}")

## 7. Performance Monitoring

In [None]:
# Get endpoint metrics from CloudWatch
import datetime

cloudwatch = boto3.client('cloudwatch')

end_time = datetime.datetime.utcnow()
start_time = end_time - datetime.timedelta(hours=1)

metrics = cloudwatch.get_metric_statistics(
    Namespace='AWS/SageMaker',
    MetricName='ModelLatency',
    Dimensions=[
        {'Name': 'EndpointName', 'Value': predictor.endpoint_name},
        {'Name': 'VariantName', 'Value': 'AllTraffic'}
    ],
    StartTime=start_time,
    EndTime=end_time,
    Period=300,
    Statistics=['Average', 'Maximum']
)

print("Model Latency Metrics:")
for datapoint in metrics['Datapoints']:
    print(f"  Time: {datapoint['Timestamp']}")
    print(f"  Average: {datapoint.get('Average', 0):.2f}ms")
    print(f"  Maximum: {datapoint.get('Maximum', 0):.2f}ms")
    print()

## 8. Cleanup (Optional)

In [None]:
# Delete the endpoint to avoid charges
# predictor.delete_endpoint()
# print("Endpoint deleted")