### Install Comet

In [None]:
!pip install comet_ml

### Initialize Comet

In [None]:
import comet_ml
 
PROJECT_NAME = "comet-example-sagemaker-tensorflow-custom-mnist"
comet_ml.init(project_name=PROJECT_NAME)

### Fetch Sagemaker Credentials 

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-tensorflow-mnist"

role = sagemaker.get_execution_role()

### Fetch The Data

In [None]:
import os
import keras
import numpy as np
from keras.datasets import mnist
(x_train, y_train), (x_val, y_val) = mnist.load_data()

os.makedirs("./data", exist_ok = True)

np.savez('./data/training', image=x_train, label=y_train)
np.savez('./data/validation', image=x_val, label=y_val)

prefix = 'keras-mnist'

training_input_path   = sagemaker_session.upload_data('data/training.npz', key_prefix=prefix+'/training')
validation_input_path = sagemaker_session.upload_data('data/validation.npz', key_prefix=prefix+'/validation')

### Set Training Parameters

In [None]:
AWS_INSTANCE_TYPE = "ml.c5.2xlarge"
AWS_INSTANCE_COUNT = 1

HYPERPARAMETERS = {
    "epochs": 1,
    "batch-size": 32
}

### Setup Sagemaker Estimator

In [None]:
from sagemaker.tensorflow import TensorFlow

COMET_API_KEY = comet_ml.config.get_config()["comet.api_key"]
COMET_PROJECT_NAME =  comet_ml.config.get_config()["comet.project_name"]

estimator = TensorFlow(
    source_dir="src",
    entry_point="mnist.py",
    role=role,
    instance_count=AWS_INSTANCE_COUNT,
    instance_type=AWS_INSTANCE_TYPE,
    hyperparameters=HYPERPARAMETERS,
    framework_version="2.2",
    py_version="py37",
    environment={
        "COMET_API_KEY": COMET_API_KEY,
        "COMET_PROJECT_NAME": COMET_PROJECT_NAME
    }
)

### Run the Training Job

In [None]:
estimator.fit({'training': training_input_path, 'validation': validation_input_path})