In [None]:
import os
import sys
import logging

import numpy as np
import pandas as pd
from sagemaker.local import LocalSession
import sagemaker
from sagemaker.pytorch import PyTorch

In [None]:
LOCAL = True

In [None]:
# Set up logging
logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.getLevelName("INFO"),
    handlers=[logging.StreamHandler(sys.stdout)],
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)

In [None]:
if LOCAL:
    session = LocalSession()
    session.config = {"local": {"local_code": True}}
    bucket = "."
    role = "arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001"
    region = "local"
    train_input_path = "file://./data/train"
    val_input_path = "file://./data/val"
    test_input_path = "file://./data/test"

else:

    session = sagemaker.Session()
    # sagemaker session bucket -> used for uploading data, models and logs
    # sagemaker will automatically create this bucket if it not exists
    bucket = "quantsagemaker"
    if bucket is None and sess is not None:
        # set to default bucket if a bucket name is not given
        bucket = session.default_bucket()

    role = sagemaker.get_execution_role()
    session = sagemaker.Session(default_bucket=sagemaker_session_bucket)
    region = session.boto_region_name

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {region}")

## Training

In [None]:
hyperparameters = {
    "epochs": 1,
    "train_batch_size": 32,
    "model_name": "distilbert-base-uncased",
}
estimator = PyTorch(
    entry_point="train.py",
    source_dir="./code",
    role=role,
    framework_version="1.7.1",
    py_version="py3",
    instance_count=1,
    instance_type="local",
    hyperparameters=hyperparameters,
)
estimator.fit({'training': train_input_path, 'validating': val_input_path, 'testing': test_input_path})

## Inference

In [None]:
model_data = estimator.model_data
print(model_data)

In [None]:
from sagemaker.pytorch.model import PyTorchModel 

pytorch_model = PyTorchModel(model_data=model_data,
                             role=role,
                             framework_version="1.7.1",
                             source_dir="code",
                             py_version="py3",
                             entry_point="inference.py")

predictor = pytorch_model.deploy(initial_instance_count=1, instance_type="local")

In [None]:
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.StringDeserializer()

In [None]:
predictor = estimator.deploy(initial_instance_count=1, instance_type="local")

In [None]:
predictor.predict(['this is a very good movie'])