In [None]:
! pip install -U sagemaker

In [None]:
import sagemaker

session = sagemaker.Session()
session_bucket = session.default_bucket()
role = sagemaker.get_execution_role()

In [None]:
from sagemaker.huggingface import HuggingFace

git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.6.1'}  # v4.6.1 is the most recent Transformers version to be supported natively by SageMaker (at time of writing) 
s3_prefix = 'xsum-dataset' # S3 key where data files are stored 
pytorch_version = '1.7.1'
python_version  ='py36'
# for Data Parallel training 
# distribution = {"smdistributed": { "dataparallel": { "enabled": True } } }

hyperparameters={
     'train_file':'/opt/ml/input/data/train/train.csv',
     'validation_file':'/opt/ml/input/data/validation/validation.csv',
     'summary_column':'summary',
     'text_column':'text',
     'per_device_train_batch_size': 2,
     'per_device_eval_batch_size': 2,
     'model_name_or_path':'google/pegasus-large', # Pre-trained Pegasus model
     'do_train':True,
     'do_eval':True,
     'output_dir':'/opt/ml/model',
     'num_train_epochs': 2,
     'learning_rate': 5e-5,
     'seed': 7,
     'max_source_length': 512 # This is the maximum sequence length supported by Pegasus
 }
 
huggingface_estimator = HuggingFace(entry_point='run_summarization.py',
                                    source_dir='./examples/pytorch/summarization',
                                    git_config=git_config,
                                    instance_type='ml.p3.2xlarge',     # ml.p3.16xlarge needed for DDP
                                    #distribution=distribution 
                                    volume_size = 200,
                                    instance_count=1,
                                    role=role,
                                    pytorch_version=pytorch_version,
                                    transformers_version='4.6.1',
                                    py_version=python_version,
                                    hyperparameters = hyperparameters
                                   )

huggingface_estimator.fit({'train':f's3://{session_bucket}/{s3_prefix}/train.csv',
                           'validation':f's3://{session_bucket}/{s3_prefix}/validation.csv'})