In [26]:
import shutil

import sagemaker
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

In [68]:
base_job_name = 'transformer-featureless'

# instance_type = 'local'
instance_type = 'ml.p3.2xlarge'

output_path = 's3://grab-aws-graphml-datadrop/transformer-models'
    
hyperparameters = {
    'input-dim': 16,
    'hid-dim': 16,
    'num-attn-head': 8,
    'gnn-n-layer': 2,
    'fanouts': 20,
    'epochs': 10,
    'batch-size': 512,
    'learning-rate': 0.0001,
    'dropout': 0.2,
    'graph-fname': 'v2_nodefeature-graph.bin',
    'embed-init': 'constant'
}

train_ds_location_local = f"file:///home/ec2-user/SageMaker/data_small/{hyperparameters['graph-fname']}"
train_ds_location_s3 = f"s3://grab-aws-graphml-datadrop/data_small/{hyperparameters['graph-fname']}"

if 'local' in instance_type:
    is_wait = True
    train_ds_location = train_ds_location_local 
else:
    is_wait = False
    train_ds_location = train_ds_location_s3 

if instance_type == 'local':
    shutil.copy('./requirements-cpu.txt', './requirements.txt')
elif instance_type == 'local_gpu':
    shutil.copy('./requirements-cu101.txt', './requirements.txt')
else:
    shutil.copy('./requirements-cu111.txt', './requirements.txt')

estimator = PyTorch(entry_point='train_transformer_featureless.py',
                    source_dir='./',
                    role=role,
                    py_version='py3',
                    framework_version='1.8.0',
                    instance_count=1,
                    instance_type=instance_type,
                    volume_size=500,
                    hyperparameters=hyperparameters,
                    base_job_name=base_job_name,
                    output_path=output_path,
                    disable_profiler=True,
                    debugger_hook_config=False
                   )

In [69]:
estimator.fit({'train': train_ds_location}, wait=is_wait)

In [70]:
print(estimator.latest_training_job.job_name)

transformer-featureless-2022-04-08-15-50-50-720
