### AWS Account Info

In [None]:
import sagemaker
sess = sagemaker.Session()

role = 'arn:aws:iam::062336837736:role/Developer'
account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name
print(region)
print(account)
print(role)
print(sess.boto_session.profile_name)

### Build & Push Docker Image

#### Variables for Docker Image

In [None]:
image = 'cog_verse'
bucket_name   = sess.default_bucket()
base_job_name = 'cog-verse-training'
%env image {image}
%env account {account}
%env region {region}
%env bucket_name {bucket_name}
%env base_job_name = {base_job_name}

#### Build Image

In [None]:
%%sh 
bash ./build_and_push.sh $image

#### Push Image to ECR

In [None]:
!docker push $account.dkr.ecr.$region.amazonaws.com/${image}:latest

### Training

#### Local Test

In [None]:
# Training setup
base_job_name = 'lstm-training'
output_path = f"s3://{bucket_name}/{image}/output"
image_name = f"{account}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
train_input_path_local = 'file://'+'./input/data'+'/training'

estimator = sagemaker.estimator.Estimator(image_uri=image_name,
                       base_job_name=base_job_name,
                       role=role, 
                       instance_count=1, 
                       output_path=output_path,
                       instance_type='local')
estimator.set_hyperparameters(bucket_name=bucket_name, source_dir=f"{image}/src/")
estimator.fit({'training': train_input_path_local})

# Verification
print(f"output_path: {output_path}")
print(f"image_name: {image_name}")

#### AWS Run

In [None]:
# Training setup
output_path = f"s3://{bucket_name}/{image}/output"
image_name = f"{account}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
tag_name = [{'Key': 'lavo', 'Value': 'lstm-training'}]

estimator = sagemaker.estimator.Estimator(image_uri=image_name,
                       base_job_name=base_job_name,
                       role=role, 
                       instance_count=1, 
                       instance_type='ml.m5.large',
                       volume_size=16,
                       tags=tag_name,
                       source_dir='.cogment_verse',
                       output_path=output_path,
                       sagemaker_session=sess)
hyperparameters = {
    'main_args': "ppo_atari_pz/pong_pz"
}
estimator.fit(hyperparameters=hyperparameters)

# Verification
print(f"output_path: {output_path}")
print(f"image_name: {image_name}")
