In [None]:
import time

train_dataset_name = 'jan23_d_20_15'
eval_dataset_name = 'jan23_d_20_15'
dataset_hyphened_name = train_dataset_name.replace('_', '-') + f'-{time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())}'
layers = 5

### Client & SM sessions

In [None]:
import sagemaker, boto3

sess = boto3.Session()
sm   = sess.client('sagemaker')
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session(boto_session=sess)

### Create experiment

In [None]:
from smexperiments.experiment import Experiment
from smexperiments.trial import Trial
from smexperiments.trial_component import TrialComponent

training_experiment = Experiment.create(experiment_name = f'{dataset_hyphened_name}', 
                                        description     = f'{dataset_hyphened_name}', 
                                        sagemaker_boto_client=sm)

### Create trial

In [None]:
training_trial = Trial.create(trial_name = f'{dataset_hyphened_name}', 
                              experiment_name = training_experiment.experiment_name,
                              sagemaker_boto_client = sm,)
training_trial_comp_name = f'{dataset_hyphened_name}'
experiment_config = {"ExperimentName": training_experiment.experiment_name, 
                       "TrialName": training_trial.trial_name,
                       "TrialComponentDisplayName": training_trial_comp_name}

### Run training job & visualize results

In [None]:
from sagemaker.tensorflow import TensorFlow

dataframe_dir = f's3://obstacles-classification/{train_dataset_name}'

hyperparams={'epochs'       : 1,
             'learning-rate': 0.00005,
             'batch-size'   : 32,
             'optimizer'    : 'adam',
             #'from_chp'     : 's3://obstacles-classification-model-checkpoints/jan23_d_20_15/2023-01-17-21-16-16/epoch-30/',
             'dataframe_dir': dataframe_dir
            }

bucket_name = sagemaker_session.default_bucket()
output_path = f's3://{bucket_name}/obstacles_classification/jobs/{train_dataset_name}'

metric_definitions = [
    {'Name': 'auc', 'Regex': 'auc: ([0-9\\.]+)'},
    {'Name': 'recall', 'Regex': 'recall: ([0-9\\.]+)'},
    {'Name': 'specifity', 'Regex': 'specifity: ([0-9\\.]+)'},
    {'Name': 'accuracy', 'Regex': 'accuracy: ([0-9\\.]+)'},
    {'Name': 'loss', 'Regex': 'loss: ([0-9\\.]+)'},
    {'Name': 'validation auc', 'Regex': 'val_auc: ([0-9\\.]+)'},
    {'Name': 'validation recall', 'Regex': 'val_recall: ([0-9\\.]+)'},
    {'Name': 'validation specifity', 'Regex': 'val_specifity: ([0-9\\.]+)'},
    {'Name': 'validation accuracy', 'Regex': 'val_categorical_accuracy: ([0-9\\.]+)'},
    {'Name': 'validation loss', 'Regex': 'val_loss: ([0-9\\.]+)'},
    {'Name': 'test auc', 'Regex': 'test_auc: ([0-9\\.]+)'},
    {'Name': 'test recall', 'Regex': 'test_recall: ([0-9\\.]+)'},
    {'Name': 'test specifity', 'Regex': 'test_specifity: ([0-9\\.]+)'},
    {'Name': 'test accuracy', 'Regex': 'test_accuracy: ([0-9\\.]+)'},
    {'Name': 'test loss', 'Regex': 'test_loss: ([0-9\\.]+)'},
    {'Name': 'epoch', 'Regex': 'Epoch ([0-9]+)'},
]

tf_estimator = TensorFlow(entry_point          = 'train-with-cpoint.py',
                          output_path          = f'{output_path}/',
                          code_location        = output_path,
                          role                 = role,
                          train_instance_count = 1, 
                          train_instance_type  = 'ml.c5.xlarge',
                          framework_version    = '2.3', 
                          py_version           = 'py37',
                          script_mode          = True,
                          metric_definitions   = metric_definitions,
                          debugger_hook_config = False,
                          sagemaker_session    = sagemaker_session,
                          hyperparameters      = hyperparams)

job_name=f'obstacles-classification-{time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())}'
training_dataset = f's3://obstacles-classification/{train_dataset_name}/train'
eval_dataset = f's3://obstacles-classification/{train_dataset_name}/eval'
tf_estimator.fit({'training'  : training_dataset,
                  'validation': training_dataset,
                  'eval'      : eval_dataset},   
                  job_name = job_name,
                  experiment_config=experiment_config)