## Minor changes to the TensorFlow estimator

To enable use of Horovod for distributed training, you simply pass a larger train_instance_count and provide a value for the distributions parameter enabling MPI and defining the number of Horovod processes per host.


In [None]:
from sagemaker.tensorflow import TensorFlow
import logging

hvd_instance_type = 'ml.p3.8xlarge'
hvd_processes_per_host = 1  #make it single threaded to avoid MobileNetV2 download corruption issue #4
hvd_instance_count = 1

serve_instance_type = 'ml.c5.xlarge'

distributions = {'mpi': {
                    'enabled': True,
                    'processes_per_host': hvd_processes_per_host
                        }
                }

hyperparameters = {'initial_epochs': 10, 'tuning_epochs': 50, 
                   'data_dir': '/opt/ml/input/data',
                   'dropout': 0.8, 'num_fully_connected_layers': 1}

metric_definitions=[{'Name' : 'validation:acc', 
                     'Regex': '.*step.* - val_acc: (.*$)'},
                    {'Name' : 'validation:loss', 
                     'Regex': '- val_loss: (.*?) '},
                   {'Name' : 'acc', 
                     'Regex': '.*step.* - acc: (.*?) '},
                    {'Name' : 'loss', 
                     'Regex': '.*step.* - loss: (.*?) '}]

estimator = TensorFlow(entry_point='train-mobilenet-horovod.py',
                       source_dir='code',
                       train_instance_type=hvd_instance_type,
                       train_instance_count=hvd_instance_count,
                       hyperparameters=hyperparameters,
                       metric_definitions=metric_definitions,
                       role=sagemaker.get_execution_role(), # Pass notebook role to container
                       framework_version='1.13.1', 
                       py_version='py3',
                       base_job_name=JOB_PREFIX,
                       script_mode=True)#,
#                       distributions=distributions)

# 1.12 version gives an error when loading pretrained MobileNetV2 model, complaining about image sizes.
# works fine on 1.13.1
# however, batch fails on 1.13.1 w GPU's