In [None]:
!pip install -U sagemaker

In [None]:
import sys

import boto3
import sagemaker
from sagemaker.workflow.pipeline_context import PipelineSession

sagemaker_session = sagemaker.session.Session()
region = sagemaker_session.boto_region_name
role = sagemaker.get_execution_role()
pipeline_session = PipelineSession()
default_bucket = sagemaker_session.default_bucket()
model_package_group_name = f"Cifar10ModelPackageGroupName"

In [None]:
!mkdir -p data

In [None]:
# Download and upload CIFAR-10 data to S3
import os
import urllib.request
import tarfile
from pathlib import Path

base_uri = f"s3://{default_bucket}/cifar10"

# Ensure data directory exists
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

# Download CIFAR-10 if not present locally
cifar_local_path = data_dir / "cifar-10-batches-py"
if not cifar_local_path.exists():
    print("Downloading CIFAR-10 dataset...")
    cifar_url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    tar_path = data_dir / "cifar-10-python.tar.gz"

    # Download the tar file
    urllib.request.urlretrieve(cifar_url, tar_path)
    print("Downloaded CIFAR-10 tar file")

    # Extract the tar file
    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extractall(data_dir)
    print("Extracted CIFAR-10 data")

    # Clean up the tar file
    tar_path.unlink()

# Verify the data exists and upload to S3
if cifar_local_path.exists():
    print("Uploading CIFAR-10 data to S3...")
    input_data_uri = sagemaker.s3.S3Uploader.upload(
        local_path=str(cifar_local_path),
        desired_s3_uri=base_uri,
    )
    print(f"CIFAR-10 data uploaded to: {input_data_uri}")
else:
    raise FileNotFoundError(
        "CIFAR-10 data not found even after download attempt. Please check your internet connection.")

In [None]:
# Create batch inference data placeholder for CIFAR-10
batch_data_uri = f"{base_uri}/batch-inference-data"
print(f"Batch data URI placeholder: {batch_data_uri}")

In [None]:
from sagemaker.workflow.parameters import (
    ParameterInteger,
    ParameterString,
    ParameterFloat,
)

processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
instance_type = ParameterString(name="TrainingInstanceType", default_value="ml.m5.xlarge")
model_approval_status = ParameterString(
    name="ModelApprovalStatus", default_value="PendingManualApproval"
)
input_data = ParameterString(
    name="InputData",
    default_value=input_data_uri,
)
batch_data = ParameterString(
    name="BatchData",
    default_value=batch_data_uri,
)
# New parameter for dataset size control
dataset_percentage = ParameterFloat(name="DatasetPercentage", default_value=0.5)
# Updated threshold for classification accuracy instead of regression MSE  
accuracy_threshold = ParameterFloat(name="AccuracyThreshold", default_value=0.7)

In [None]:
!mkdir -p code

In [None]:
%%writefile code/preprocessing.py
import argparse
import os
import pickle
import numpy as np
import random
from pathlib import Path

# CIFAR-10 class names
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]


def unpickle(file):
    """Unpickle CIFAR-10 data files."""
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def load_cifar10_data(data_dir):
    """Load CIFAR-10 dataset from pickle files."""
    # Load training batches
    train_data = []
    train_labels = []

    for i in range(1, 6):
        batch_file = os.path.join(data_dir, f'data_batch_{i}')
        batch_dict = unpickle(batch_file)
        train_data.append(batch_dict[b'data'])
        train_labels.extend(batch_dict[b'labels'])

    # Combine all training data
    train_data = np.vstack(train_data)
    train_labels = np.array(train_labels)

    # Load test data
    test_batch = unpickle(os.path.join(data_dir, 'test_batch'))
    test_data = test_batch[b'data']
    test_labels = np.array(test_batch[b'labels'])

    return train_data, train_labels, test_data, test_labels


def process_cifar10_data(images, labels, dataset_percentage=1.0):
    """Process and reshape CIFAR-10 data."""
    # Reshape images from flat to 32x32x3
    images = images.reshape(-1, 3, 32, 32)
    # Convert from CIFAR format (CHW) to standard format (HWC)
    images = images.transpose(0, 2, 3, 1)

    # Apply dataset percentage sampling
    if dataset_percentage < 1.0:
        n_samples = len(images)
        sample_size = max(1, int(n_samples * dataset_percentage))
        indices = np.random.choice(n_samples, sample_size, replace=False)
        images = images[indices]
        labels = labels[indices]
        print(f"📊 Sampled {len(images)} images ({dataset_percentage * 100:.1f}%) from {n_samples} total")

    return images, labels


def save_data_splits(images, labels, base_dir, split_ratios=(0.7, 0.15, 0.15)):
    """Save data splits as numpy arrays."""
    n_samples = len(images)

    # Calculate split indices
    train_end = int(n_samples * split_ratios[0])
    val_end = train_end + int(n_samples * split_ratios[1])

    # Create splits
    train_images = images[:train_end]
    train_labels = labels[:train_end]

    val_images = images[train_end:val_end]
    val_labels = labels[train_end:val_end]

    test_images = images[val_end:]
    test_labels = labels[val_end:]

    # Save splits
    for split_name, split_images, split_labels in [
        ('train', train_images, train_labels),
        ('validation', val_images, val_labels),
        ('test', test_images, test_labels)
    ]:
        split_dir = os.path.join(base_dir, split_name)
        os.makedirs(split_dir, exist_ok=True)

        np.save(os.path.join(split_dir, 'images.npy'), split_images)
        np.save(os.path.join(split_dir, 'labels.npy'), split_labels)

        print(f"Saved {split_name}: {len(split_images)} samples")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-percentage", type=float, default=1.0,
                        help="Percentage of dataset to use (0.1 = 10%, 1.0 = 100%)")
    args = parser.parse_args()

    base_dir = "/opt/ml/processing"
    input_dir = os.path.join(base_dir, "input")

    # Set random seed for reproducibility
    np.random.seed(42)
    random.seed(42)

    print(f"Processing CIFAR-10 data with {args.dataset_percentage * 100:.1f}% of dataset...")

    # Load CIFAR-10 data
    train_data, train_labels, test_data, test_labels = load_cifar10_data(input_dir)

    # Combine all data for processing
    all_images = np.vstack([train_data, test_data])
    all_labels = np.hstack([train_labels, test_labels])

    print(f"Loaded {len(all_images)} total CIFAR-10 images")

    # Process data with dataset percentage
    images, labels = process_cifar10_data(all_images, all_labels, args.dataset_percentage)

    # Shuffle data
    indices = np.random.permutation(len(images))
    images = images[indices]
    labels = labels[indices]

    # Save processed splits
    save_data_splits(images, labels, base_dir)

    print("✅ CIFAR-10 preprocessing completed!")

In [None]:
from sagemaker.sklearn.processing import SKLearnProcessor

framework_version = "1.2-1"

sklearn_processor = SKLearnProcessor(
    framework_version=framework_version,
    instance_type="ml.m5.xlarge",
    instance_count=processing_instance_count,
    base_job_name="sklearn-abalone-process",
    role=role,
    sagemaker_session=pipeline_session,
)

In [None]:
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.workflow.steps import ProcessingStep

from sagemaker.workflow.functions import Join

processor_args = sklearn_processor.run(
    inputs=[
        ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
    ],
    outputs=[
        ProcessingOutput(output_name="train", source="/opt/ml/processing/train"),
        ProcessingOutput(output_name="validation", source="/opt/ml/processing/validation"),
        ProcessingOutput(output_name="test", source="/opt/ml/processing/test"),
    ],
    code="code/preprocessing.py",
    arguments=["--dataset-percentage", Join(on="", values=[dataset_percentage])],
)

step_process = ProcessingStep(name="Cifar10Process", step_args=processor_args)

In [None]:
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import TrainingInput

model_path = f"s3://{default_bucket}/Cifar10Train"

# Use PyTorch for image classification
pytorch_estimator = PyTorch(
    entry_point="train.py",
    source_dir="code",
    role=role,
    instance_type=instance_type,
    instance_count=1,
    framework_version="2.7",
    py_version="py312",
    output_path=model_path,
    sagemaker_session=pipeline_session,
    hyperparameters={
        "epochs": 5,
        "batch-size": 64,
        "learning-rate": 0.001,
        "dataset-percentage": Join(on="", values=[dataset_percentage]),
    }
)

train_args = pytorch_estimator.fit(
    inputs={
        "train": TrainingInput(
            s3_data=step_process.properties.ProcessingOutputConfig.Outputs["train"].S3Output.S3Uri,
            content_type="application/x-npy",
        ),
        "validation": TrainingInput(
            s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
                "validation"
            ].S3Output.S3Uri,
            content_type="application/x-npy",
        ),
    }
)

In [None]:
%%writefile code/train.py
import argparse
import json
import logging
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ImprovedCNN(nn.Module):
    """Improved CNN for CIFAR-10 classification with modern architecture patterns."""

    def __init__(self, num_classes=10, dropout_rate=0.3):
        super(ImprovedCNN, self).__init__()

        # First block - more filters, batch norm
        self.conv1a = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1a = nn.BatchNorm2d(64)
        self.conv1b = nn.Conv2d(64, 64, 3, padding=1)
        self.bn1b = nn.BatchNorm2d(64)

        # Second block
        self.conv2a = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2a = nn.BatchNorm2d(128)
        self.conv2b = nn.Conv2d(128, 128, 3, padding=1)
        self.bn2b = nn.BatchNorm2d(128)

        # Third block
        self.conv3a = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3a = nn.BatchNorm2d(256)
        self.conv3b = nn.Conv2d(256, 256, 3, padding=1)
        self.bn3b = nn.BatchNorm2d(256)

        # Fourth block (optional - can comment out for simpler model)
        self.conv4a = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4a = nn.BatchNorm2d(512)
        self.conv4b = nn.Conv2d(512, 512, 3, padding=1)
        self.bn4b = nn.BatchNorm2d(512)

        # Global average pooling instead of large FC layer
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        # Smaller FC layers
        self.fc1 = nn.Linear(512, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, num_classes)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights using Xavier/He initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1a(self.conv1a(x)))
        x = F.relu(self.bn1b(self.conv1b(x)))
        x = F.max_pool2d(x, 2)

        # Block 2
        x = F.relu(self.bn2a(self.conv2a(x)))
        x = F.relu(self.bn2b(self.conv2b(x)))
        x = F.max_pool2d(x, 2)

        # Block 3
        x = F.relu(self.bn3a(self.conv3a(x)))
        x = F.relu(self.bn3b(self.conv3b(x)))
        x = F.max_pool2d(x, 2)

        # Block 4
        x = F.relu(self.bn4a(self.conv4a(x)))
        x = F.relu(self.bn4b(self.conv4b(x)))

        # Global average pooling
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)

        # FC layers
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = self.fc2(x)

        return x  # Return logits directly (use CrossEntropyLoss)


def load_data(data_dir):
    """Load preprocessed CIFAR-10 data."""
    train_images = np.load(os.path.join(data_dir, "train", "images.npy"))
    train_labels = np.load(os.path.join(data_dir, "train", "labels.npy"))

    val_images = np.load(os.path.join(data_dir, "validation", "images.npy"))
    val_labels = np.load(os.path.join(data_dir, "validation", "labels.npy"))

    return train_images, train_labels, val_images, val_labels


def create_data_loaders(train_images, train_labels, val_images, val_labels, batch_size):
    """Create PyTorch data loaders."""
    # Convert to tensors and normalize
    train_images = torch.tensor(train_images, dtype=torch.float32) / 255.0
    train_labels = torch.tensor(train_labels, dtype=torch.long)
    val_images = torch.tensor(val_images, dtype=torch.float32) / 255.0
    val_labels = torch.tensor(val_labels, dtype=torch.long)

    # Reshape to NCHW format for PyTorch
    train_images = train_images.permute(0, 3, 1, 2)
    val_images = val_images.permute(0, 3, 1, 2)

    # Create datasets and loaders
    train_dataset = TensorDataset(train_images, train_labels)
    val_dataset = TensorDataset(val_images, val_labels)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


def train_epoch(model, device, train_loader, optimizer, scheduler, epoch, criterion=None, grad_clip=None):
    """
    Train for one epoch with improved CNN.

    Args:
        model: The neural network model
        device: Device to run on (cuda/cpu)
        train_loader: DataLoader for training data
        optimizer: Optimizer (Adam, SGD, etc.)
        scheduler: Learning rate scheduler
        epoch: Current epoch number
        criterion: Loss function (defaults to CrossEntropyLoss)
        grad_clip: Gradient clipping value (optional)

    Returns:
        tuple: (average_loss, accuracy)
    """
    model.train()

    # Use CrossEntropyLoss for improved CNN (returns logits)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    total_loss = 0
    correct = 0
    num_batches = len(train_loader)

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Calculate loss (CrossEntropyLoss expects logits, not log_softmax)
        loss = criterion(output, target)

        # Backward pass
        loss.backward()

        # Gradient clipping (optional, helps with training stability)
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        # Update weights
        optimizer.step()

        # Statistics
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

        # Log progress
        if batch_idx % 100 == 0:
            current_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
            logger.info(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                        f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                        f'Loss: {loss.item():.6f}\tLR: {current_lr:.6f}')

    # Update learning rate
    if scheduler:
        scheduler.step()

    # Calculate epoch metrics
    avg_loss = total_loss / num_batches
    accuracy = 100. * correct / len(train_loader.dataset)

    current_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
    logger.info(f'Train Epoch: {epoch} Complete - '
                f'Average Loss: {avg_loss:.6f}, '
                f'Accuracy: {accuracy:.2f}%, '
                f'Learning Rate: {current_lr:.6f}')

    return avg_loss, accuracy


def validate(model, device, val_loader, criterion=None):
    """
    Validate model on validation set.

    Args:
        model: The neural network model
        device: Device to run on (cuda/cpu)
        val_loader: DataLoader for validation data
        criterion: Loss function (defaults to CrossEntropyLoss)

    Returns:
        tuple: (average_loss, accuracy)
    """
    model.eval()

    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    total_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # Calculate loss
            loss = criterion(output, target)
            total_loss += loss.item()

            # Calculate accuracy
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)

    logger.info(f'Validation - Average Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')

    return avg_loss, accuracy


def save_model(model, model_dir):
    """Save the trained model."""
    logger.info("Saving the model.")
    path = os.path.join(model_dir, "model.pth")
    torch.save(model.state_dict(), path)


def model_fn(model_dir):
    """Load model for inference."""
    model = ImprovedCNN()
    model_path = os.path.join(model_dir, "model.pth")
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # SageMaker specific arguments
    parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
    parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
    parser.add_argument("--validation", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION"))

    # Hyperparameters
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--learning-rate", type=float, default=0.001)
    parser.add_argument("--dataset-percentage", type=float, default=1.0)

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Load data
    logger.info("Loading data...")
    data_dir = args.train.replace("/train", "") if args.train else "."
    train_images, train_labels, val_images, val_labels = load_data(data_dir)

    logger.info(f"Train set: {len(train_images)} samples")
    logger.info(f"Validation set: {len(val_images)} samples")

    # Create data loaders
    train_loader, val_loader = create_data_loaders(
        train_images, train_labels, val_images, val_labels, args.batch_size
    )

    # Initialize model
    model = ImprovedCNN().to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=1e-4,  # L2 regularization
        betas=(0.9, 0.999)
    )
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    # Training loop
    best_accuracy = 0
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train_epoch(
            model=model,
            device=device,
            train_loader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            criterion=criterion,
            grad_clip=1.0
        )
        val_loss, val_acc = validate(model, device, val_loader, criterion)

        if val_acc > best_accuracy:
            best_accuracy = val_acc
            if args.model_dir:
                save_model(model, args.model_dir)

    logger.info(f"Training completed. Best validation accuracy: {best_accuracy:.2f}%")


In [None]:
from sagemaker.inputs import TrainingInput
from sagemaker.workflow.steps import TrainingStep

step_train = TrainingStep(
    name="Cifar10Train",
    step_args=train_args,
)

In [None]:
%%writefile code/evaluation.py
import json
import pathlib
import tarfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report


class SimpleCNN(nn.Module):
    """Simple CNN for CIFAR-10 classification - same as training script."""

    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class ImprovedCNN(nn.Module):
    """Improved CNN for CIFAR-10 classification with modern architecture patterns."""

    def __init__(self, num_classes=10, dropout_rate=0.3):
        super(ImprovedCNN, self).__init__()

        # First block - more filters, batch norm
        self.conv1a = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1a = nn.BatchNorm2d(64)
        self.conv1b = nn.Conv2d(64, 64, 3, padding=1)
        self.bn1b = nn.BatchNorm2d(64)

        # Second block
        self.conv2a = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2a = nn.BatchNorm2d(128)
        self.conv2b = nn.Conv2d(128, 128, 3, padding=1)
        self.bn2b = nn.BatchNorm2d(128)

        # Third block
        self.conv3a = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3a = nn.BatchNorm2d(256)
        self.conv3b = nn.Conv2d(256, 256, 3, padding=1)
        self.bn3b = nn.BatchNorm2d(256)

        # Fourth block (optional - can comment out for simpler model)
        self.conv4a = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4a = nn.BatchNorm2d(512)
        self.conv4b = nn.Conv2d(512, 512, 3, padding=1)
        self.bn4b = nn.BatchNorm2d(512)

        # Global average pooling instead of large FC layer
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        # Smaller FC layers
        self.fc1 = nn.Linear(512, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, num_classes)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights using Xavier/He initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1a(self.conv1a(x)))
        x = F.relu(self.bn1b(self.conv1b(x)))
        x = F.max_pool2d(x, 2)

        # Block 2
        x = F.relu(self.bn2a(self.conv2a(x)))
        x = F.relu(self.bn2b(self.conv2b(x)))
        x = F.max_pool2d(x, 2)

        # Block 3
        x = F.relu(self.bn3a(self.conv3a(x)))
        x = F.relu(self.bn3b(self.conv3b(x)))
        x = F.max_pool2d(x, 2)

        # Block 4
        x = F.relu(self.bn4a(self.conv4a(x)))
        x = F.relu(self.bn4b(self.conv4b(x)))

        # Global average pooling
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)

        # FC layers
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = self.fc2(x)

        return x  # Return logits directly (use CrossEntropyLoss)


def load_model(model_path):
    """Load the trained PyTorch model."""
    model = ImprovedCNN()

    # Extract model from tar.gz
    with tarfile.open(model_path) as tar:
        tar.extractall(path=".")

    # Load model state dict
    model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')))
    model.eval()

    return model


def load_test_data(test_path):
    """Load test data."""
    test_images = np.load(f"{test_path}/images.npy")
    test_labels = np.load(f"{test_path}/labels.npy")

    return test_images, test_labels


def evaluate_model(model, test_images, test_labels, batch_size=128):
    """Evaluate the model on test data."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Convert to tensors and normalize
    test_images = torch.tensor(test_images, dtype=torch.float32) / 255.0
    test_labels = torch.tensor(test_labels, dtype=torch.long)

    # Reshape to NCHW format for PyTorch
    test_images = test_images.permute(0, 3, 1, 2)

    # Create data loader
    test_dataset = TensorDataset(test_images, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Evaluate
    all_predictions = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target, reduction='sum')
            total_loss += loss.item()

            pred = output.argmax(dim=1)
            all_predictions.extend(pred.cpu().numpy())
            all_labels.extend(target.cpu().numpy())

    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted')

    # Per-class metrics
    precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average=None
    )

    avg_loss = total_loss / len(test_loader.dataset)

    return {
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'avg_loss': float(avg_loss),
        'per_class_precision': precision_per_class.tolist(),
        'per_class_recall': recall_per_class.tolist(),
        'per_class_f1': f1_per_class.tolist()
    }


if __name__ == "__main__":
    model_path = "/opt/ml/processing/model/model.tar.gz"
    test_path = "/opt/ml/processing/test"

    print("Loading model...")
    model = load_model(model_path)

    print("Loading test data...")
    test_images, test_labels = load_test_data(test_path)
    print(f"Test set: {len(test_images)} samples")

    print("Evaluating model...")
    metrics = evaluate_model(model, test_images, test_labels)

    # Create evaluation report
    report_dict = {
        "classification_metrics": {
            "accuracy": {"value": metrics['accuracy']},
            "precision": {"value": metrics['precision']},
            "recall": {"value": metrics['recall']},
            "f1_score": {"value": metrics['f1_score']},
            "avg_loss": {"value": metrics['avg_loss']},
        }
    }

    if 'per_class_precision' in metrics:
        report_dict["additional_metrics"] = {
            "per_class_precision": [float(x) for x in metrics['per_class_precision']],
            "per_class_recall": [float(x) for x in metrics['per_class_recall']],
            "per_class_f1_score": [float(x) for x in metrics['per_class_f1']]
        }

    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1_score']:.4f}")

    # Save evaluation results
    output_dir = "/opt/ml/processing/evaluation"
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    evaluation_path = f"{output_dir}/evaluation.json"
    with open(evaluation_path, "w") as f:
        json.dump(report_dict, f, indent=2)

    print(f"Evaluation complete! Results saved to {evaluation_path}")

In [None]:
from sagemaker.pytorch.processing import PyTorchProcessor

script_eval = PyTorchProcessor(
    framework_version="2.7",
    py_version="py312",
    instance_type="ml.m5.xlarge",
    instance_count=1,
    base_job_name="script-cifar10-eval",
    role=role,
    sagemaker_session=pipeline_session,
)

eval_args = script_eval.run(
    inputs=[
        ProcessingInput(
            source=step_train.properties.ModelArtifacts.S3ModelArtifacts,
            destination="/opt/ml/processing/model",
        ),
        ProcessingInput(
            source=step_process.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri,
            destination="/opt/ml/processing/test",
        ),
    ],
    outputs=[
        ProcessingOutput(output_name="evaluation", source="/opt/ml/processing/evaluation"),
    ],
    code="code/evaluation.py",
)

In [None]:
from sagemaker.workflow.properties import PropertyFile

evaluation_report = PropertyFile(
    name="EvaluationReport", output_name="evaluation", path="evaluation.json"
)
step_eval = ProcessingStep(
    name="Cifar10Eval",
    step_args=eval_args,
    property_files=[evaluation_report],
)

In [None]:
from sagemaker.pytorch.model import PyTorchModel

model = PyTorchModel(
    model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
    role=role,
    entry_point="train.py",
    source_dir="code",
    framework_version="2.6.0",
    py_version="py312",
    sagemaker_session=pipeline_session,
)

from sagemaker.inputs import CreateModelInput
from sagemaker.workflow.model_step import ModelStep

step_create_model = ModelStep(
    name="Cifar10CreateModel",
    step_args=model.create(instance_type="ml.m5.large"),
)

In [None]:
from sagemaker.model_metrics import MetricsSource, ModelMetrics


model_metrics = ModelMetrics(
    model_statistics=MetricsSource(
        s3_uri="{}/evaluation.json".format(
            step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
        ),
        content_type="application/json",
    )
)

register_args = model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=["ml.t2.medium", "ml.m5.large"],
    model_package_group_name=model_package_group_name,
    approval_status=model_approval_status,
    model_metrics=model_metrics,
)
step_register = ModelStep(name="Cifar10RegisterModel", step_args=register_args, depends_on=[step_eval] )

In [None]:
from sagemaker.workflow.fail_step import FailStep
from sagemaker.workflow.functions import Join

step_fail = FailStep(
    name="Cifar10AccuracyFail",
    error_message=Join(on=" ", values=["Execution failed due to accuracy <", accuracy_threshold]),
)

In [None]:
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.functions import JsonGet

cond_gte = ConditionGreaterThanOrEqualTo(
    left=JsonGet(
        step_name=step_eval.name,
        property_file=evaluation_report,
        json_path="classification_metrics.accuracy.value",
    ),
    right=accuracy_threshold,
)

step_cond = ConditionStep(
    name="Cifar10AccuracyCond",
    conditions=[cond_gte],
    if_steps=[step_register, step_create_model],
    else_steps=[step_fail],
)

In [None]:
from sagemaker.workflow.pipeline import Pipeline

pipeline_name = f"Cifar10Pipeline"
pipeline = Pipeline(
    name=pipeline_name,
    parameters=[
        processing_instance_count,
        instance_type,
        model_approval_status,
        input_data,
        batch_data,
        dataset_percentage,
        accuracy_threshold,
    ],
    steps=[step_process, step_train, step_eval, step_cond],
)



In [None]:
pipeline.upsert(role_arn=role)

In [None]:
execution = pipeline.start()

In [None]:
execution.describe()

In [None]:
execution.wait()


In [None]:
execution.list_steps()


In [None]:
import json
from pprint import pprint


evaluation_json = sagemaker.s3.S3Downloader.read_file(
    "{}/evaluation.json".format(
        step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
    )
)
pprint(json.loads(evaluation_json))