# On a SM Docker, Train a BERT Model with Tensorflow
- 스크립트 모드를 사용하기 위해서 아래의 API 문서 참고 하세요
- Script Mode Ref:
    - https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#train-a-model-with-tensorflow

##  스크립트 (Local Mode) 학습 <a class="anchor" id="LocalModeTraining">

SageMaker에서 로컬 모드는, 여러분이 작성한 코드를 SageMaker에서 관리되는 보다 강력한 클러스터에서 실행하기 전에, 여러분의 코드가 기대한 방식으로 동작하는 지 로컬에서 확인할 수 있는 편리한 방식입니다. 로컬모드 학습을 위해서는 docker-compose 또는 nvidia-docker-compose (GPU 인스턴스인 경우)의 설치가 필요합니다. 다음 셀의 명령은 본 노트북환경에 docker-compose 또는 nvidia-docker-compose를 설치하고 구성합니다. 

In [7]:
%store -r

In [8]:
import os
import sagemaker
import boto3
from sagemaker.tensorflow import TensorFlow

sess   = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

In [9]:
s3_input_train_data = sagemaker.s3_input(s3_data=processed_train_data_s3_uri, 
                                         distribution='ShardedByS3Key') 
s3_input_validation_data = sagemaker.s3_input(s3_data=processed_validation_data_s3_uri, 
                                              distribution='ShardedByS3Key')
s3_input_test_data = sagemaker.s3_input(s3_data=processed_test_data_s3_uri, 
                                        distribution='ShardedByS3Key')

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.
's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.
's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.


{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-2-057716757052/sagemaker-scikit-learn-2020-06-27-06-07-25-298/output/bert-train', 'S3DataDistributionType': 'ShardedByS3Key'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-2-057716757052/sagemaker-scikit-learn-2020-06-27-06-07-25-298/output/bert-validation', 'S3DataDistributionType': 'ShardedByS3Key'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-2-057716757052/sagemaker-scikit-learn-2020-06-27-06-07-25-298/output/bert-test', 'S3DataDistributionType': 'ShardedByS3Key'}}}


In [10]:
import uuid

checkpoint_s3_prefix = 'checkpoints/{}'.format(str(uuid.uuid4()))
checkpoint_s3_uri = 's3://{}/{}/'.format(bucket, checkpoint_s3_prefix)

print(checkpoint_s3_uri)

s3://sagemaker-us-east-2-057716757052/checkpoints/c70a8018-7619-4b06-88ec-b24dd629fa63/


In [11]:
metrics_definitions = [
     {'Name': 'train:loss', 'Regex': 'loss: ([0-9\\.]+)'},
     {'Name': 'train:accuracy', 'Regex': 'accuracy: ([0-9\\.]+)'},
     {'Name': 'validation:loss', 'Regex': 'val_loss: ([0-9\\.]+)'},
     {'Name': 'validation:accuracy', 'Regex': 'val_accuracy: ([0-9\\.]+)'},
]

In [12]:
epochs=1
train_steps_per_epoch=1000

learning_rate=0.00001
epsilon=0.00000001
train_batch_size=128
validation_batch_size=128
test_batch_size=128


# train_steps_per_epoch=10
validation_steps=100
test_steps=100

train_instance_count=2 # modified by gonsoo
train_instance_type='ml.p3.2xlarge'
train_volume_size=1024

use_xla=True
use_amp=True
freeze_bert_layer=False
enable_checkpointing=True
input_mode='Pipe'

In [13]:
ecr_image = "057716757052.dkr.ecr.us-east-2.amazonaws.com/bert2tweet:latest"

In [14]:
from sagemaker.estimator import Estimator


estimator = Estimator( image_name = ecr_image,
                       role=sagemaker.get_execution_role(),
                       train_instance_count=train_instance_count, # Make sure you have at least this number of input files or the ShardedByS3Key distibution strategy will fail the job due to no data available
                       train_instance_type=train_instance_type,
                       train_volume_size=train_volume_size,
                       checkpoint_s3_uri=checkpoint_s3_uri, # Not support in local mode
                       hyperparameters={'epochs': epochs,
                                        'learning_rate': learning_rate,
                                        'epsilon': epsilon,
                                        'train_batch_size': train_batch_size,
                                        'validation_batch_size': validation_batch_size,
                                        'test_batch_size': test_batch_size,                                             
                                        'train_steps_per_epoch': train_steps_per_epoch,
                                        'validation_steps': validation_steps,
                                        'test_steps': test_steps,
                                        'use_xla': use_xla,
                                        'use_amp': use_amp,                                             
                                        'max_seq_length': max_seq_length,
                                        'freeze_bert_layer': freeze_bert_layer,
                                        'enable_checkpointing': enable_checkpointing
                                        },
                       input_mode=input_mode,
                       metric_definitions=metrics_definitions
                      )



In [15]:
train_dir = 'data/output/bert/train'
validation_dir = 'data/output/bert/validation'
test_dir = 'data/output/bert/test'


In [16]:
inputs={'train': s3_input_train_data, 
        'validation': s3_input_validation_data,
         'test': s3_input_test_data
              }

estimator.fit(inputs,wait=False)         

In [19]:
latest_training_job = estimator.latest_training_job

In [22]:
latest_training_job.describe()

{'TrainingJobName': 'bert2tweet-2020-06-27-07-11-20-159',
 'TrainingJobArn': 'arn:aws:sagemaker:us-east-2:057716757052:training-job/bert2tweet-2020-06-27-07-11-20-159',
 'ModelArtifacts': {'S3ModelArtifacts': 's3://sagemaker-us-east-2-057716757052/bert2tweet-2020-06-27-07-11-20-159/output/model.tar.gz'},
 'TrainingJobStatus': 'Completed',
 'SecondaryStatus': 'Completed',
 'HyperParameters': {'enable_checkpointing': 'True',
  'epochs': '1',
  'epsilon': '1e-08',
  'freeze_bert_layer': 'False',
  'learning_rate': '1e-05',
  'max_seq_length': '128',
  'test_batch_size': '128',
  'test_steps': '100',
  'train_batch_size': '128',
  'train_steps_per_epoch': '1000',
  'use_amp': 'True',
  'use_xla': 'True',
  'validation_batch_size': '128',
  'validation_steps': '100'},
 'AlgorithmSpecification': {'TrainingImage': '057716757052.dkr.ecr.us-east-2.amazonaws.com/bert2tweet:latest',
  'TrainingInputMode': 'Pipe',
  'MetricDefinitions': [{'Name': 'train:loss', 'Regex': 'loss: ([0-9\\.]+)'},
   {'N

In [23]:
training_job_name = estimator.latest_training_job.name
print('Training Job Name:  {}'.format(training_job_name))

Training Job Name:  bert2tweet-2020-06-27-07-11-20-159


In [24]:
%store training_job_name

Stored 'training_job_name' (str)
