In [None]:
import sagemaker
import json
import boto3

In [None]:
sess = sagemaker.Session()

role = sagemaker.get_execution_role()
print(
    role
)  # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf

bucket = sess.default_bucket()  # Replace with your own bucket name if needed
print(bucket)
prefix = "blazingtext/supervised"  # Replace with the prefix under which you want to store the data if needed

In [None]:
! aws s3 cp s3://aws-mls-c01/sagemaker/blazingtext/dbpedia.train ./data/
! aws s3 cp s3://aws-mls-c01/sagemaker/blazingtext/dbpedia.validation ./data/

In [None]:
region_name = boto3.Session().region_name

In [None]:
container = sagemaker.image_uris.retrieve("blazingtext", region_name, version="latest")

In [None]:
train_channel = prefix + "/train"
validation_channel = prefix + "/validation"

sess.upload_data(path="./data/dbpedia.train", bucket=bucket, key_prefix=train_channel)
sess.upload_data(path="./data/dbpedia.validation", bucket=bucket, key_prefix=validation_channel)

s3_train_data = "s3://{}/{}".format(bucket, train_channel)
s3_validation_data = "s3://{}/{}".format(bucket, validation_channel)

In [None]:
s3_output_location = "s3://{}/{}/output".format(bucket, prefix)

In [None]:
bt_model = sagemaker.estimator.Estimator(
    container,
    role,
    instance_count=1,
    instance_type="ml.c4.4xlarge",
    volume_size=30,
    max_run=360000,
    input_mode="File",
    output_path=s3_output_location,
    hyperparameters={
        "mode": "supervised",
        "epochs": 1,
        "min_count": 2,
        "learning_rate": 0.05,
        "vector_dim": 10,
        "early_stopping": True,
        "patience": 4,
        "min_epochs": 5,
        "word_ngrams": 2,
    },
)

In [None]:
train_data = sagemaker.inputs.TrainingInput(
    s3_train_data,
    distribution="FullyReplicated",
    content_type="text/plain",
    s3_data_type="S3Prefix",
)
validation_data = sagemaker.inputs.TrainingInput(
    s3_validation_data,
    distribution="FullyReplicated",
    content_type="text/plain",
    s3_data_type="S3Prefix",
)
data_channels = {"train": train_data, "validation": validation_data}

In [None]:
%%time
bt_model.fit(inputs=data_channels, logs=True)

In [None]:
%%time

from sagemaker.serializers import JSONSerializer

text_classifier = bt_model.deploy(
    initial_instance_count=1, instance_type="ml.m5.xlarge", serializer=JSONSerializer()
)

In [None]:
import nltk
nltk.download("punkt")

In [None]:
sentences = [
    "Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft.",
    "Berwick secondary college is situated in the outer melbourne metropolitan suburb of berwick .",
    "Sparky is a dog of the canis familiaris family. He was lives in our house. Which is in Australia."
]

# using the same nltk tokenizer that we used during data preparation for training
tokenized_sentences = [" ".join(nltk.word_tokenize(sent)) for sent in sentences]

payload = {"instances": tokenized_sentences}

response = text_classifier.predict(payload)

predictions = json.loads(response)
print(json.dumps(predictions, indent=2))

In [None]:
text_classifier.delete_endpoint()