# KorGPT2 Lyric Fine-Tuning With Sagemaker
Author: https://github.com/MrBananaHuman/KorGPT2Tutorial

In [4]:
!pip -q install sagemaker sagemaker[local]
!pip -q install gdown

[33mYou are using pip version 19.0.3, however version 20.2b1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
[33mYou are using pip version 19.0.3, however version 20.2b1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [5]:
import sagemaker
import boto3

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/hunkim-kogpg2-data'
role = 'arn:aws:iam::294038372338:role/hunkimSagemaker'



In [6]:
!mkdir -p local_data/KorGPT-2SampleModel
!gdown -O ./local_data/KorGPT-2SampleModel/pytorch_model.bin --id 1kX_dB05dkLRgxJkqoHidrT2OFYHGYWPF

Downloading...
From: https://drive.google.com/uc?id=1kX_dB05dkLRgxJkqoHidrT2OFYHGYWPF
To: /Users/hunkim/work/sagemaker-aihub/kogpt2/local_data/KorGPT-2SampleModel/pytorch_model.bin
516MB [00:18, 28.5MB/s]


In [7]:
inputs = sagemaker_session.upload_data(path='local_data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

input spec (in this case, just an S3 path): s3://sagemaker-us-west-2-294038372338/sagemaker/hunkim-kogpg2-data


In [29]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='lyric_train.py',
                    source_dir='code',
                    role=role,
                    framework_version='1.5.0',
                    train_instance_count=1,
                    train_instance_type='ml.g4dn.xlarge',
                    hyperparameters={
                        'epochs': 5,
                        'batch-size': 4
                      })

In [32]:
estimator.fit({'training': inputs})

2020-06-15 00:00:35 Starting - Starting the training job...
2020-06-15 00:00:37 Starting - Launching requested ML instances............
2020-06-15 00:02:46 Starting - Preparing the instances for training...
2020-06-15 00:03:42 Downloading - Downloading input data
2020-06-15 00:03:42 Training - Downloading the training image........[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2020-06-15 00:05:05,529 sagemaker-containers INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2020-06-15 00:05:05,551 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2020-06-15 00:05:05,555 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2020-06-15 00:05:05,822 sagemaker-containers INFO     Module default_user_module_name does not provide a setup.py. [0m
[34mGenerating setup.py[0m
[34m2020-06-15 00:05:05,823 sage

In [33]:
training_job_name = estimator.latest_training_job.name
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
trained_model_location = desc['ModelArtifacts']['S3ModelArtifacts']
print(trained_model_location)
# s3://sagemaker-us-west-2-294038372338/pytorch-training-2020-06-15-00-00-33-479/output/model.tar.gz

s3://sagemaker-us-west-2-294038372338/pytorch-training-2020-06-15-00-00-33-479/output/model.tar.gz


In [54]:
from sagemaker.pytorch import PyTorchModel

training_job_name = estimator.latest_training_job.name
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
trained_model_location = desc['ModelArtifacts']['S3ModelArtifacts']
print(trained_model_location)
model = PyTorchModel(model_data=trained_model_location,
                     role=role,
                     framework_version='1.5.0',
                     entry_point='lyric_gen.py',
                     source_dir='code')

s3://sagemaker-us-west-2-294038372338/pytorch-training-2020-06-15-00-00-33-479/output/model.tar.gz


In [60]:
%%time
predictor = model.deploy(initial_instance_count=1, instance_type='ml.g4dn.xlarge')

----------------!CPU times: user 28.3 s, sys: 9.5 s, total: 37.8 s
Wall time: 9min 17s


In [61]:
# Get the end point
endpoint = predictor.endpoint
print(endpoint)
print("See the logs at", "https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=/aws/sagemaker/Endpoints/"+ endpoint)

pytorch-inference-2020-06-15-01-47-54-313
See the logs at https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=/aws/sagemaker/Endpoints/pytorch-inference-2020-06-15-01-47-54-313


In [64]:
%%time
import boto3
import json

client = boto3.client('sagemaker-runtime')

input = {
    'text': "하늘을 날자"
}
payload = json.dumps(input)

response = client.invoke_endpoint(
    EndpointName=endpoint, 
    ContentType="application/json",
    Accept="application/json" ,
    Body=payload
)

print(response['Body'].read())  

b'"<song><s> \\ud558\\ub298\\uc744 \\ub0a0\\uc790 \\uadf8 \\ub204\\uad6c\\ub3c4 \\uadf8 \\ub204\\uad6c\\ub3c4 \\uadf8 \\ub204\\uad6c\\ub3c4 \\uc774 \\uc138\\uc0c1\\uc758 \\ub204\\uad6c\\ub3c4 \\uc774 \\uc138\\uc0c1\\uc744 \\uc0ac\\ub791\\ud558\\uc9c0 \\uc54a\\uc744 \\uc218 \\uc5c6\\ub2e4."'
CPU times: user 21.1 ms, sys: 3.2 ms, total: 24.3 ms
Wall time: 914 ms


In [65]:
sagemaker_session.delete_endpoint(predictor.endpoint)