## Training Knowledge Graph Embedding using the DGL with MXNet backend
The **SageMaker Python SDK** makes it easy to train DGL models. In this example, we generate knowledge graph embedding using the [DMLC DGL API](https://github.com/dmlc/dgl.git) and FB15k dataset.

For more details about Knowledge Graph Embedding and this example please refer to https://github.com/dmlc/dgl/tree/master/apps/kg


### Setup
We need to define a few variables that will be needed later in the example.

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.session import Session

# Setup session
sess = sagemaker.Session()

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = sess.default_bucket()

# Location to put your custom code.
custom_code_upload_location = 'customcode'

# IAM execution role that gives SageMaker access to resources in your AWS account.
# We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

### Build KGE-MXNet docker image
AWS provides basic docker images in https://docs.aws.amazon.com/dlami/latest/devguide/deep-learning-containers-images.html. For both pytorch 1.3 and mxnet 1.6, DGL is preinstalled. As this example needs additional dependancies, we provide a dockerfile to build a new image. You should build a KGE specific docker image and push it into your ECR.

Note: Do change the KGE_mxnet.Dockerfile if you are in different region.

In [None]:
%%sh
# Build KGE docker image first
docker_name=sagemaker-dgl-kge-mxnet
$(aws ecr get-login --no-include-email --region ${region} --registry-ids 763104351884)
docker build -t $docker_name -f KGE_mxnet.Dockerfile .

account=$(aws sts get-caller-identity --query Account --output text)
echo $account
region=$(aws configure get region)
# Get the login command from ECR and execute it directly and upload kge docker image into private ECR
$(aws ecr get-login --region ${region} --no-include-email)

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${docker_name}:latest"
# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${docker_name}" > /dev/null 2>&1
if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${docker_name}" > /dev/null
fi

docker tag ${docker_name} ${fullname}
docker push ${fullname}

### SageMaker's  estimator class
The SageMaker Estimator allows us to run single machine in SageMaker, using CPU or GPU-based instances.

When we create the estimator, we pass in the filename of our training script, the name of our IAM execution role. We also provide a few other parameters. train_instance_count and train_instance_type determine the number and type of SageMaker instances that will be used for the training job. The hyperparameters parameter is a dict of values that will be passed to your training script as parameters that you can use argparse to parse them.

In [None]:
from sagemaker.mxnet.estimator import MXNet

ENTRY_POINT = 'train.py'
CODE_PATH = './'

account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name
docker_name = "sagemaker-dgl-kge-mxnet"
image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, docker_name)

params = {}
params['dataset'] = 'FB15k'
params['model'] = 'DistMult'
params['batch_size'] = 1024
params['neg_sample_size'] = 256
params['hidden_dim'] = 2000
params['gamma'] = 500.0
params['lr'] = 0.1
params['max_step'] = 100000
params['batch_size_eval'] = 16
params['valid'] = True
params['test'] = True
params['neg_adversarial_sampling'] = True

estimator = MXNet(entry_point=ENTRY_POINT,
                    source_dir=CODE_PATH,
                    role=role, 
                    train_instance_count=1, 
                    train_instance_type='ml.p3.2xlarge',
                    image_name=image,
                    hyperparameters=params,
                    sagemaker_session=sess)

### Running the Training Job
After we've constructed our Estimator object, we can fit it using sagemaker (The dataset will be automatically downloaded). Below we run SageMaker training on one channels: training-code, the code to run.

In [None]:
estimator.fit(logs=True, wait=True)

## Output
You can get the resulting embedding output from the Sagemaker Console by searching for the training task and looking for the address of 'S3 model artifact'