Let’s start by specifying:

The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting. If you don’t specify a bucket, SageMaker SDK will create a default bucket following a pre-defined naming convention in the same region.

The IAM role ARN used to give SageMaker access to your data. It can be fetched using the get_execution_role method from sagemaker python SDK.

In [1]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3

sess = sagemaker.Session()

role = get_execution_role()
print(role) # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf

bucket = sess.default_bucket() # Replace with your own bucket name if needed
print(bucket)
prefix = 'blazingtext/supervised' #Replace with the prefix under which you want to store the data if needed

Couldn't call 'get_role' to get Role ARN from role name AmazonSageMaker-ExecutionRole-20210510T182202 to get Role path.
Assuming role was created in SageMaker AWS console, as the name contains `AmazonSageMaker-ExecutionRole`. Defaulting to Role ARN with service-role in path. If this Role ARN is incorrect, please add IAM read permissions to your role or supply the Role Arn directly.


arn:aws:iam::674581170738:role/service-role/AmazonSageMaker-ExecutionRole-20210510T182202
sagemaker-us-east-2-674581170738


Data Preparation

Now we’ll download a dataset from the web on which we want to train the text classification model. BlazingText expects a single preprocessed text file with space separated tokens and each line of the file should contain a single sentence and the corresponding label(s) prefixed by “_label_”.

In this example, let us train the text classification model on the DBPedia Ontology Dataset as done by Zhang et al. The DBpedia ontology dataset is constructed by picking 14 nonoverlapping classes from DBpedia 2014. It has 560,000 training samples and 70,000 testing samples. The fields we used for this dataset contain title and abstract of each Wikipedia article.

In [2]:
!wget https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz

--2021-05-10 23:44:53--  https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz [following]
--2021-05-10 23:44:53--  https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68431223 (65M) [application/octet-stream]
Saving to: ‘dbpedia_csv.tar.gz’


2021-05-10 23:44:55 (67.9 MB/s) - ‘dbpedia_csv.tar.gz’ saved [68431223/68431223]



In [3]:
!tar -xzvf dbpedia_csv.tar.gz

dbpedia_csv/
dbpedia_csv/test.csv
tar: dbpedia_csv/test.csv: Cannot change ownership to uid 1000, gid 1000: Operation not permitted
dbpedia_csv/classes.txt
tar: dbpedia_csv/classes.txt: Cannot change ownership to uid 1000, gid 1000: Operation not permitted
dbpedia_csv/train.csv
tar: dbpedia_csv/train.csv: Cannot change ownership to uid 1000, gid 1000: Operation not permitted
dbpedia_csv/readme.txt
tar: dbpedia_csv/readme.txt: Cannot change ownership to uid 1000, gid 1000: Operation not permitted
tar: dbpedia_csv: Cannot change ownership to uid 1000, gid 1000: Operation not permitted
tar: Exiting with failure status due to previous errors


In [4]:
!head dbpedia_csv/train.csv -n 3

1,"E. D. Abbott Ltd"," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."
1,"Schwan-Stabilo"," Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."
1,"Q-workshop"," Q-workshop is a Polish company located in Poznań that specializes in designand production of polyhedral dice and dice accessories for use in various games (role-playing gamesboard games and tabletop wargames). They also run an online retail store and maintainan active forum community.Q-workshop was established in 2001 by Patryk Strzelewicz – a student from Poznań. Initiallythe company sold its products via online auction services but in 2005 a website and online store wereestablishe

In [5]:
!cat dbpedia_csv/classes.txt

Company
EducationalInstitution
Artist
Athlete
OfficeHolder
MeanOfTransportation
Building
NaturalPlace
Village
Animal
Plant
Album
Film
WrittenWork


In [6]:
index_to_label = {}
with open("dbpedia_csv/classes.txt") as f:
    for i,label in enumerate(f.readlines()):
        index_to_label[str(i+1)] = label.strip()
print(index_to_label)

{'1': 'Company', '2': 'EducationalInstitution', '3': 'Artist', '4': 'Athlete', '5': 'OfficeHolder', '6': 'MeanOfTransportation', '7': 'Building', '8': 'NaturalPlace', '9': 'Village', '10': 'Animal', '11': 'Plant', '12': 'Album', '13': 'Film', '14': 'WrittenWork'}


Data Preprocessing

We need to preprocess the training data into space separated tokenized text format which can be consumed by BlazingText algorithm. Also, as mentioned previously, the class label(s) should be prefixed with __label__ and it should be present in the same line along with the original sentence. We’ll use nltk library to tokenize the input sentences from DBPedia dataset.

Download the nltk tokenizer and other libraries

In [7]:
from random import shuffle
import multiprocessing
from multiprocessing import Pool
import csv
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [8]:
def transform_instance(row):
    cur_row = []
    label = "__label__" + index_to_label[row[0]]  #Prefix the index-ed label with __label__
    cur_row.append(label)
    cur_row.extend(nltk.word_tokenize(row[1].lower()))
    cur_row.extend(nltk.word_tokenize(row[2].lower()))
    return cur_row

In [9]:
#The transform_instance will be applied to each data instance in parallel using python’s multiprocessing module

def preprocess(input_file, output_file, keep=1):
    all_rows = []
    with open(input_file, 'r') as csvinfile:
        csv_reader = csv.reader(csvinfile, delimiter=',')
        for row in csv_reader:
            all_rows.append(row)
    shuffle(all_rows)
    all_rows = all_rows[:int(keep*len(all_rows))]
    pool = Pool(processes=multiprocessing.cpu_count())
    transformed_rows = pool.map(transform_instance, all_rows)
    pool.close()
    pool.join()

    with open(output_file, 'w') as csvoutfile:
        csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
        csv_writer.writerows(transformed_rows)

In [10]:
%%time

# Preparing the training dataset

# Since preprocessing the whole dataset might take a couple of mintutes,
# we keep 20% of the training dataset for this demo.
# Set keep to 1 if you want to use the complete dataset
preprocess('dbpedia_csv/train.csv', 'dbpedia.train', keep=.2)

# Preparing the validation dataset
preprocess('dbpedia_csv/test.csv', 'dbpedia.validation')

CPU times: user 7.77 s, sys: 1.12 s, total: 8.89 s
Wall time: 1min 20s


In [11]:
%%time

train_channel = prefix + '/train'
validation_channel = prefix + '/validation'

sess.upload_data(path='dbpedia.train', bucket=bucket, key_prefix=train_channel)
sess.upload_data(path='dbpedia.validation', bucket=bucket, key_prefix=validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)

CPU times: user 473 ms, sys: 273 ms, total: 745 ms
Wall time: 1.27 s


In [12]:
s3_output_location = 's3://{}/{}/example_output'.format(bucket, prefix)

Training

Now that we are done with all the setup that is needed, we are ready to train our object detector. To begin, let us create a sageMaker.estimator.Estimator object. This estimator will launch the training job.

In [13]:
region_name = boto3.Session().region_name

In [28]:
container = sagemaker.image_uris.retrieve("blazingtext",region_name, "latest")
print('Using SageMaker BlazingText container: {} ({})'.format(container, region_name))

Defaulting to the only supported framework/algorithm version: 1. Ignoring framework/algorithm version: latest.


Using SageMaker BlazingText container: 825641698319.dkr.ecr.us-east-2.amazonaws.com/blazingtext:1 (us-east-2)


In [16]:
bt_model = sagemaker.estimator.Estimator(container,
                                         role,
                                         instance_count=1,
                                         instance_type='ml.c5.4xlarge',
                                         volume_size = 30,
                                         max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         hyperparameters = {
                                           "mode":"supervised",
                                           "epochs":1,
                                            "min_count":2,
                                            "learning_rate": 0.05,
                                            "vector_dim":10,
                                            "early_stopping":True,
                                            "patience":4,
                                            "min_epochs":5,
                                            "word_ngrams":2
                                         })

In [17]:
train_data = sagemaker.inputs.TrainingInput(s3_train_data, distribution='FullyReplicated',
                        content_type='text/plain', s3_data_type='S3Prefix')
validation_data = sagemaker.inputs.TrainingInput(s3_validation_data, distribution='FullyReplicated',
                             content_type='text/plain', s3_data_type='S3Prefix')
data_channels = {'train': train_data, 'validation': validation_data}

In [20]:
bt_model.fit(inputs=data_channels, logs=True)

2021-05-10 23:53:48 Starting - Starting the training job...
2021-05-10 23:54:11 Starting - Launching requested ML instancesProfilerReport-1620690827: InProgress
......
2021-05-10 23:55:12 Starting - Preparing the instances for training......
2021-05-10 23:56:13 Downloading - Downloading input data
2021-05-10 23:56:13 Training - Downloading the training image..[34mArguments: train[0m
[34m[05/10/2021 23:56:28 INFO 139828887017088] nvidia-smi took: 0.025152921676635742 secs to identify 0 gpus[0m
[34m[05/10/2021 23:56:28 INFO 139828887017088] Running single machine CPU BlazingText training using supervised mode.[0m
[34mNumber of CPU sockets found in instance is  1[0m
[34m[05/10/2021 23:56:28 INFO 139828887017088] Processing /opt/ml/input/data/train/dbpedia.train . File size: 35.03332328796387 MB[0m
[34m[05/10/2021 23:56:28 INFO 139828887017088] Processing /opt/ml/input/data/validation/dbpedia.validation . File size: 21.88746166229248 MB[0m
[34mRead 6M words[0m
[34mNumber of 

Hosting / Inference

Once the training is done, we can deploy the trained model as an Amazon SageMaker real-time hosted endpoint. This will allow us to make predictions (or inference) from the model. Note that we don’t have to host on the same type of instance that we used to train. Because instance endpoints will be up and running for long, it’s advisable to choose a cheaper instance for inference.

In [25]:
from sagemaker.serializers import JSONSerializer

text_classifier = bt_model.deploy(
    initial_instance_count = 1,
    instance_type = 'ml.m4.xlarge',
    serializer=JSONSerializer()
)

-------------!

In [26]:
sentences = ["Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft.",
            "Berwick secondary college is situated in the outer melbourne metropolitan suburb of berwick ."]

# using the same nltk tokenizer that we used during data preparation for training
tokenized_sentences = [' '.join(nltk.word_tokenize(sent)) for sent in sentences]

payload = {"instances" : tokenized_sentences}

response = text_classifier.predict(payload)

predictions = json.loads(response)
print(json.dumps(predictions, indent=2))

[
  {
    "label": [
      "__label__Artist"
    ],
    "prob": [
      0.4133935272693634
    ]
  },
  {
    "label": [
      "__label__EducationalInstitution"
    ],
    "prob": [
      0.42267635464668274
    ]
  }
]


In [27]:
payload = {"instances" : tokenized_sentences,
          "configuration": {"k": 2}}

response = text_classifier.predict(payload)

predictions = json.loads(response)
print(json.dumps(predictions, indent=2))

[
  {
    "label": [
      "__label__Artist",
      "__label__MeanOfTransportation"
    ],
    "prob": [
      0.4133935272693634,
      0.19562363624572754
    ]
  },
  {
    "label": [
      "__label__EducationalInstitution",
      "__label__Building"
    ],
    "prob": [
      0.42267635464668274,
      0.1869693249464035
    ]
  }
]


In [29]:
sess.delete_endpoint(text_classifier.endpoint)

The endpoint attribute has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
