# Draw-and-Learn Custom Model Training with SageMaker

This notebook demonstrates how to train a custom drawing recognition model using AWS SageMaker.

## Overview
- Load and preprocess drawing data
- Set up SageMaker training job
- Train custom model
- Deploy model for inference

## Setup and Imports

In [None]:
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.processing import SKLearnProcessor
import os
import json
import pandas as pd
import matplotlib.pyplot as plt

# Import custom utilities
import sys
sys.path.append('../src')
from sagemaker_utils import SageMakerHelper

## Initialize SageMaker Session

In [None]:
# Initialize SageMaker session
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()
region = sagemaker_session.boto_region_name

print(f"SageMaker role: {role}")
print(f"S3 bucket: {bucket}")
print(f"AWS region: {region}")

# Initialize helper
helper = SageMakerHelper('../config/sagemaker_config.json')

## Data Preparation

In [None]:
# Upload your data to S3 (replace with your data path)
local_data_path = '../data/raw'

if os.path.exists(local_data_path):
    input_data_uri = helper.upload_data_to_s3(local_data_path, 'data/raw')
    print(f"Data uploaded to: {input_data_uri}")
else:
    print("Please add your training data to the ../data/raw directory")
    input_data_uri = None

## Data Preprocessing with SageMaker Processing

In [None]:
# Set up SageMaker processing job
if input_data_uri:
    processor = SKLearnProcessor(
        framework_version="0.23-1",
        instance_type="ml.m5.large",
        instance_count=1,
        role=role
    )
    
    # Run preprocessing
    processor.run(
        code="../scripts/processing/preprocess.py",
        inputs=[
            ProcessingInput(source=input_data_uri, destination="/opt/ml/processing/input")
        ],
        outputs=[
            ProcessingOutput(output_name="processed_data", source="/opt/ml/processing/output")
        ],
        arguments=["--image-size", "224", "224"]
    )
    
    # Get processed data location
    preprocessing_job_description = processor.jobs[-1].describe()
    processed_data_uri = preprocessing_job_description['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']
    print(f"Processed data available at: {processed_data_uri}")
else:
    print("Skipping preprocessing - no input data available")

## Model Training

In [None]:
# Set up training job
if 'processed_data_uri' in locals():
    # Training hyperparameters
    hyperparameters = {
        'epochs': 10,
        'batch-size': 32,
        'learning-rate': 0.001,
        'num-classes': 10  # Adjust based on your dataset
    }
    
    # Create PyTorch estimator
    estimator = PyTorch(
        entry_point='train.py',
        source_dir='../scripts/training',
        role=role,
        instance_type='ml.m5.xlarge',
        instance_count=1,
        framework_version='1.12',
        py_version='py38',
        hyperparameters=hyperparameters,
        output_path=f's3://{bucket}/draw-learn-model/output/',
        code_location=f's3://{bucket}/draw-learn-model/code/'
    )
    
    # Start training
    train_data_uri = f"{processed_data_uri}/train"
    val_data_uri = f"{processed_data_uri}/validation"
    
    estimator.fit({
        'training': train_data_uri,
        'validation': val_data_uri
    })
    
    print("Training completed!")
    print(f"Model artifacts: {estimator.model_data}")
else:
    print("Skipping training - no processed data available")

## Model Deployment

In [None]:
# Deploy the trained model
if 'estimator' in locals():
    predictor = estimator.deploy(
        initial_instance_count=1,
        instance_type='ml.m5.large',
        endpoint_name='draw-learn-endpoint'
    )
    
    print(f"Model deployed to endpoint: {predictor.endpoint_name}")
else:
    print("Skipping deployment - no trained model available")

## Test Inference

In [None]:
# Test the deployed model (example)
if 'predictor' in locals():
    import numpy as np
    
    # Create dummy test data (replace with real test image)
    test_data = np.random.rand(1, 3, 224, 224).tolist()
    
    # Make prediction
    result = predictor.predict({'instances': test_data})
    print(f"Prediction result: {result}")
else:
    print("Skipping inference test - no deployed model available")

## Cleanup

In [None]:
# Clean up resources (uncomment when ready to delete)
# if 'predictor' in locals():
#     predictor.delete_endpoint()
#     print("Endpoint deleted")

## Next Steps

1. **Data Collection**: Add your drawing/image data to the `data/raw` directory
2. **Model Customization**: Modify the model architecture in `src/models/custom_model.py`
3. **Hyperparameter Tuning**: Use SageMaker Hyperparameter Tuning for optimization
4. **Model Monitoring**: Set up CloudWatch monitoring for your deployed model
5. **Batch Transform**: Use SageMaker Batch Transform for batch predictions

## Additional Resources

- [SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/)
- [SageMaker Python SDK](https://sagemaker.readthedocs.io/)
- [PyTorch on SageMaker](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/index.html)