## PyKEEN on Amazon SageMaker (no container)

In [None]:
# uncomment If you don't have sagemaker python sdk installed
# !pip install sagemaker

In [None]:
# Setup
from sagemaker import get_execution_role
import sagemaker
import json

sagemaker_session = sagemaker.Session()

# This role retrieves the SageMaker-compatible role used by the notebook instance, 
# only works if you run this notebook on sagemaker notebook instance
# role = get_execution_role()

# If you run the notebook locally, specify the SageMaker execution role manually
role = 'arn:aws:iam::{}:role/service-role/AmazonSageMaker-ExecutionRole-{}'.format(aws_account_id, execution_role_id)

In [None]:
print(sagemaker_session.default_bucket())

## Train
I demonstrate an Pykeen example [`pykeen-vanilla-run.py`] migrated to Amazon SageMaker. First, put the data to Amazon S3. Then, create a [PyTorch estimator](https://sagemaker.readthedocs.io/en/stable/sagemaker.pytorch.html#pytorch-estimator). The training will be invoked by the `fit` method (in parallel here). 

### Upload pykeen config file to S3

In [None]:
# specify the location of the training config file in your machine
input_subdir = "input/public-datasets/"
run_config_fname = "20220903_train_configs_biokg.json"

# alternatively, specify the location of the HyperParameter Optimization (HPO) config file in your machine
# run_config_fname = "20220903_hpo_configs_biokg.json"

In [None]:
# check the content of the config file
with open(input_subdir + run_config_fname) as file:
    train_config = json.load(file)
print(train_config)

In [None]:
input_data = sagemaker_session.upload_data(path=input_subdir,
            key_prefix='data/pykeen-biokg')
print(input_data)

### initialize PyTorch estimator and start the training job

In [None]:
# setup SageMaker PyTorch estimator
from sagemaker.pytorch.estimator import PyTorch

pytorch_estimator = PyTorch(entry_point='pykeen-vanilla-run.py',
                            source_dir="src",
                            framework_version='1.10.0', 
                            py_version='py38', 
                            role=role,
                            max_run=1200, # 172800, 
                            instance_count=1,
                            instance_type= 'ml.g4dn.2xlarge', # alternatively: 'ml.p3.2xlarge'
                            hyperparameters={
                                # 'data-version': load_data_version, 
                                'config-fname': run_config_fname, 
                            })

In [None]:
pytorch_estimator.fit(input_data)