Skip to content

Commit

Permalink
feat(sagemaker): add pytorch mnist (#10354)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongbo-miao committed Aug 13, 2023
1 parent eccac3a commit 0668f58
Show file tree
Hide file tree
Showing 19 changed files with 2,032 additions and 13 deletions.
8 changes: 4 additions & 4 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ indent_style = space
indent_size = 2
indent_style = space

[{*.ipynb,*.py}]
indent_size = 4
indent_style = space

[*.java]
indent_size = 2
indent_style = space
Expand All @@ -34,10 +38,6 @@ indent_style = space
indent_size = 4
indent_style = space

[*.py]
indent_size = 4
indent_style = space

[*.rego]
indent_size = 4
indent_style = tab
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/.static-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ jobs:
- name: Static type check Python
run: |
poetry run poe static-type-check-python -- --package=api-python
poetry run poe static-type-check-python -- --package=aws.amazon-sagemaker.pytorch-mnist
poetry run poe static-type-check-python -- --package=chatbot
poetry run poe static-type-check-python -- --package=convolutional-neural-network
poetry run poe static-type-check-python -- --package=data-distribution-service
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ poetry-cache-clear:

clean-jupyter-notebook:
poetry run poe clean-jupyter-notebook -- aws/amazon-emr/studio/hm-studio/hm-workspace.ipynb
poetry run poe clean-jupyter-notebook -- aws/amazon-sagemaker/pytorch-mnist/notebook.ipynb
lint-ansible:
poetry run poe lint-ansible
lint-matlab:
Expand Down Expand Up @@ -227,6 +228,7 @@ lint-yaml:
poetry run poe lint-yaml
static-type-check-python:
poetry run poe static-type-check-python -- --package=api-python
poetry run poe static-type-check-python -- --package=aws.amazon-sagemaker.pytorch-mnist
poetry run poe static-type-check-python -- --package=chatbot
poetry run poe static-type-check-python -- --package=convolutional-neural-network
poetry run poe static-type-check-python -- --package=data-distribution-service
Expand Down
8 changes: 0 additions & 8 deletions aws/amazon-emr/studio/hm-studio/hm-workspace.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,6 @@
"%%sql\n",
"select * from hm_motor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c562d57",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@
}
}
],
"serviceExecutionRoleArn": "arn:aws:iam::xxx:role/xxx"
"serviceExecutionRoleArn": "arn:aws:iam::xxxxxxxxxxxx:role/xxx"
}
7 changes: 7 additions & 0 deletions aws/amazon-sagemaker/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
sagemaker-create-notebook-instance:
aws sagemaker create-notebook-instance \
--notebook-instance-name=hm-sagemaker-notebook \
--instance-type=ml.m5.xlarge \
--role-arn=arn:aws:iam::xxxxxxxxxxxx:role/service-role/xxx \
--platform-identifier=notebook-al2-v2 \
--root-access=Enabled
23 changes: 23 additions & 0 deletions aws/amazon-sagemaker/pytorch-mnist/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
poetry-env-use:
poetry env use 3.10
poetry-update-lock-file:
poetry lock --no-update
poetry-install:
poetry install --no-root
poetry-add:
poetry add xxx
poetry-add-dev:
poetry add xxx --group=dev

poetry-run-dev:
poetry run poe dev


zip-pytorch-mnist:
cd .. && \
zip -r pytorch-mnist.zip pytorch-mnist \
-x 'pytorch-mnist/.venv/*'
# In SageMaker Noteebok instance's JupyterLab terminal
unzip-pytorch-mnist:
cd SageMaker/
unzip pytorch-mnist.zip
163 changes: 163 additions & 0 deletions aws/amazon-sagemaker/pytorch-mnist/notebook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.html\n",
"\n",
"# Set up\n",
"import sagemaker\n",
"\n",
"sagemaker_session = sagemaker.Session()\n",
"aws_region = sagemaker_session.boto_region_name\n",
"s3_bucket = sagemaker_session.default_bucket()\n",
"s3_key_prefix = \"amazon-sagemaker/pytorch-mnist\"\n",
"iam_role_irn = sagemaker.get_execution_role()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the data\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"\n",
"MNIST.mirrors = [\n",
" f\"https://sagemaker-example-files-prod-{aws_region}.s3.amazonaws.com/datasets/image/MNIST/\"\n",
"]\n",
"MNIST(\n",
" \"data\",\n",
" download=True,\n",
" transform=transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Upload the data to S3\n",
"data_s3_uri = sagemaker_session.upload_data(\n",
" path=\"data\", bucket=s3_bucket, key_prefix=s3_key_prefix\n",
")\n",
"print(data_s3_uri)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train\n",
"from sagemaker.pytorch import PyTorch\n",
"\n",
"estimator = PyTorch(\n",
" source_dir=\"src/\",\n",
" entry_point=\"main.py\",\n",
" role=iam_role_irn,\n",
" py_version=\"py310\",\n",
" framework_version=\"2.0.0\",\n",
" instance_count=2,\n",
" instance_type=\"ml.c5.2xlarge\",\n",
" hyperparameters={\"epochs\": 1, \"backend\": \"gloo\"},\n",
")\n",
"estimator.fit(\n",
" inputs={\"training\": data_s3_uri},\n",
" job_name=\"amazon-sagemaker-pytorch-mnist-job\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Deploy\n",
"predictor = estimator.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=\"ml.m5.xlarge\",\n",
" model_name=\"amazon-sagemaker-pytorch-mnist-model\",\n",
" endpoint_name=\"amazon-sagemaker-pytorch-mnist-endpoint\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Evaluate\n",
"import gzip\n",
"import numpy as np\n",
"import random\n",
"import os\n",
"\n",
"data_dir = \"data/MNIST/raw\"\n",
"with gzip.open(os.path.join(data_dir, \"t10k-images-idx3-ubyte.gz\"), \"rb\") as f:\n",
" images = (\n",
" np.frombuffer(f.read(), np.uint8, offset=16)\n",
" .reshape(-1, 28, 28)\n",
" .astype(np.float32)\n",
" )\n",
"\n",
"# Randomly select some of the test images\n",
"mask = random.sample(range(len(images)), 16)\n",
"mask = np.array(mask, dtype=np.int_)\n",
"data = images[mask]\n",
"\n",
"response = predictor.predict(np.expand_dims(data, axis=1))\n",
"print(\"Raw prediction result:\", response)\n",
"\n",
"labeled_predictions = list(zip(range(10), response[0]))\n",
"print(\"Labeled predictions:\", labeled_predictions)\n",
"\n",
"labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])\n",
"print(\"Most likely answer:\", labeled_predictions[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Cleanup\n",
"sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (PyTorch 1.13 Python 3.9 CPU Optimized)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.13-cpu-py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2"
},
"notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 0668f58

Please sign in to comment.