In [None]:
# https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.html

import gzip
import os
import random
import time

import numpy as np
import sagemaker
import torchvision
from sagemaker.pytorch import PyTorch

experiment_name = "amazon-sagemaker-pytorch-mnist"
sagemaker_session = sagemaker.Session()
s3_bucket = sagemaker_session.default_bucket()
iam_role_irn = sagemaker.get_execution_role()

In [None]:
# Get the data
torchvision.datasets.MNIST(
    "data/",
    download=True,
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.transforms.ToTensor(),
            torchvision.transforms.transforms.Normalize((0.1307,), (0.3081,)),
        ]
    ),
)

In [None]:
# Upload the data to S3
data_s3_uri = sagemaker_session.upload_data(
    path="data/", bucket=s3_bucket, key_prefix=f"{experiment_name}-data"
)
print(data_s3_uri)

In [None]:
# Train
estimator = PyTorch(
    source_dir="src/",
    entry_point="main.py",
    role=iam_role_irn,
    py_version="py310",
    framework_version="2.0.0",
    instance_count=2,
    instance_type="ml.c5.2xlarge",
    hyperparameters={"epochs": 1, "backend": "gloo"},
)
estimator.fit(
    inputs={"training": data_s3_uri},
    job_name=f"{experiment_name}-job-{int(time.time())}",
)

In [None]:
# Deploy
predictor = estimator.deploy(
    initial_instance_count=1,
    instance_type="ml.m5.xlarge",
    model_name=f"{experiment_name}-model-{int(time.time())}",
    endpoint_name=f"{experiment_name}-endpoint-{int(time.time())}",
)

In [None]:
# Evaluate
data_dir = "data/MNIST/raw"
path = os.path.join(data_dir, "t10k-images-idx3-ubyte.gz")
with gzip.open(path, "rb") as f:
    images = (
        np.frombuffer(f.read(), np.uint8, offset=16)
        .reshape(-1, 28, 28)
        .astype(np.float32)
    )

# Randomly select some of the test images
mask = random.sample(range(len(images)), 16)
mask = np.array(mask, dtype=np.int_)
data = images[mask]

response = predictor.predict(np.expand_dims(data, axis=1))
print("Raw prediction result:", response)

labeled_predictions = list(zip(range(10), response[0]))
print("Labeled predictions:", labeled_predictions)

labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print("Most likely answer:", labeled_predictions[0])

In [None]:
# Clean up
sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)