### Introduction

The following code is meant to be run in Sagemaker Notebook Instance. We will train a Convolution Neural Network on the MNIST dataset and migrate the data from the run to Comet. 

### Install Comet 

In [None]:
!pip install comet_ml

### Fetch Sagemaker Credentials 

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-pytorch-mnist"

role = sagemaker.get_execution_role()

###  Fetch the Data

In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms

MNIST.mirrors = ["https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/"]

MNIST(
    "data",
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)

inputs = sagemaker_session.upload_data(path="data", bucket=bucket, key_prefix=prefix)

### Set Training Parameters

In [None]:
AWS_INSTANCE_TYPE = "ml.c5.2xlarge"
AWS_INSTANCE_COUNT = 2

HYPERPARAMETERS = {
    "epochs": 5,
    "batch-size": 32,
    "log-interval": 1,
    "backend": "gloo" 
}

### Setup the Sagemaker Estimator

In [13]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="mnist.py",
    role=role,
    py_version="py38",
    framework_version="1.11.0",
    instance_count=AWS_INSTANCE_COUNT,
    instance_type=AWS_INSTANCE_TYPE,
    hyperparameters=HYPERPARAMETERS,
    metric_definitions=[
        {'Name':'train:loss', 'Regex':'Train Loss: (.*?);'},
        {'Name':'test:loss', 'Regex':'Test Average Loss: (.*?);'},
        {'Name':'test:accuracy', 'Regex':'Test Accuracy: (.*?)%;'}
    ]
)

### Run the Training Job

In [None]:
estimator.fit({"training": inputs})

### Initialize Comet

In [None]:
import comet_ml

COMET_WORKSPACE = "team-comet-ml"
COMET_PROJECT_NAME = "sagemaker-pytorch-mnist"

comet_ml.init(workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)

### Log the Sagemaker Run to Comet using the Estimator

In [None]:
from comet_ml.integration.sagemaker import log_sagemaker_training_job_v1, log_sagemaker_training_job_by_name_v1

COMET_API_KEY = comet_ml.config.get_config()["comet.api_key"]
COMET_WORKSPACE = comet_ml.config.get_config()["comet.workspace"]
COMET_PROJECT_NAME =  comet_ml.config.get_config()["comet.project_name"]

log_sagemaker_training_job_v1(
    estimator, 
    api_key=COMET_API_KEY, 
    workspace=COMET_WORKSPACE, 
    project_name=COMET_PROJECT_NAME
)

### Log the Sagemaker Run to Comet using the Job Name

You can also save runs to Comet using the job name, which can be found in the `Training Jobs` section of the Sagemaker UI or accessed through the Estimator object

In [None]:
log_sagemaker_training_job_by_name_v1(
    estimator.latest_training_job.job_name, 
    api_key=COMET_API_KEY, 
    workspace=COMET_WORKSPACE, 
    project_name=COMET_PROJECT_NAME
)