## Introduction

Text Classification can be used to solve various use-cases like sentiment analysis, spam detection, hashtag prediction etc. This notebook demonstrates the use of Amazon Comprehend to provide text classification. 

## Setup

Let's start by specifying:

- The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting. If you don't specify a bucket, SageMaker SDK will create a default bucket following a pre-defined naming convention in the same region. 
- The IAM role ARN used to give SageMaker access to your data. It can be fetched using the **get_execution_role** method from sagemaker python SDK. **Note: This role should have AmazonComprehendFullAccess, so it can create and run custom classification jobs**

In [1]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3
import time
import pytz
from datetime import datetime

sess = sagemaker.Session()
role = get_execution_role()

print(role) # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch, Comprehend) on your behalf
bucket='comprehend-demolm' # customize to your bucket
prefix = 'dbpedia/' #Replace with the prefix under which you want to store the data if needed
region = 'us-east-1'

arn:aws:iam::625941629713:role/service-role/AmazonSageMaker-ExecutionRole-20190415T173068


### Data Preparation

Now we'll download a dataset from the web on which we want to train the text classification model. BlazingText expects a single preprocessed text file with space separated tokens and each line of the file should contain a single sentence and the corresponding label(s) prefixed by "\__label\__".

In this example, let us train the text classification model on the [DBPedia Ontology Dataset](https://wiki.dbpedia.org/services-resources/dbpedia-data-set-2014#2) as done by [Zhang et al](https://arxiv.org/pdf/1509.01626.pdf). The DBpedia ontology dataset is constructed by picking 14 nonoverlapping classes from DBpedia 2014. It has 560,000 training samples and 70,000 testing samples. The fields we used for this dataset contain title and abstract of each Wikipedia article. 

In [2]:
!wget https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz

--2019-07-20 18:28:53--  https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz
Resolving github.com (github.com)... 192.30.253.113
Connecting to github.com (github.com)|192.30.253.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz [following]
--2019-07-20 18:28:53--  https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.200.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.200.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68431223 (65M) [application/octet-stream]
Saving to: ‘dbpedia_csv.tar.gz’


2019-07-20 18:28:55 (87.8 MB/s) - ‘dbpedia_csv.tar.gz’ saved [68431223/68431223]



In [3]:
!tar -xzvf dbpedia_csv.tar.gz

dbpedia_csv/
dbpedia_csv/test.csv
dbpedia_csv/classes.txt
dbpedia_csv/train.csv
dbpedia_csv/readme.txt


Let us inspect the dataset and the classes to get some understanding about how the data and the label is provided in the dataset. 

In [4]:
!head dbpedia_csv/train.csv -n 3

1,"E. D. Abbott Ltd"," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."
1,"Schwan-Stabilo"," Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."
1,"Q-workshop"," Q-workshop is a Polish company located in Poznań that specializes in designand production of polyhedral dice and dice accessories for use in various games (role-playing gamesboard games and tabletop wargames). They also run an online retail store and maintainan active forum community.Q-workshop was established in 2001 by Patryk Strzelewicz – a student from Poznań. Initiallythe company sold its products via online auction services but in 2005 a website and online store wereestablis

As can be seen from the above output, the CSV has 3 fields - Label index, title and abstract. Let us first create a label index to label name mapping and then proceed to preprocess the dataset for ingestion by BlazingText.

Next we will print the labels file (`classes.txt`) to see all possible labels followed by creating an index to label mapping.

In [5]:
!cat dbpedia_csv/classes.txt

Company
EducationalInstitution
Artist
Athlete
OfficeHolder
MeanOfTransportation
Building
NaturalPlace
Village
Animal
Plant
Album
Film
WrittenWork


The following code creates the mapping from integer indices to class label which will later be used to retrieve the actual class name during inference. 

In [6]:
index_to_label = {} 
with open("dbpedia_csv/classes.txt") as f:
    for i,label in enumerate(f.readlines()):
        index_to_label[str(i+1)] = label.strip()
print(index_to_label)

{'1': 'Company', '2': 'EducationalInstitution', '3': 'Artist', '4': 'Athlete', '5': 'OfficeHolder', '6': 'MeanOfTransportation', '7': 'Building', '8': 'NaturalPlace', '9': 'Village', '10': 'Animal', '11': 'Plant', '12': 'Album', '13': 'Film', '14': 'WrittenWork'}


## Data Preprocessing
We need to preprocess the training data into **space separated tokenized text** format which can be consumed by Amazon Comprehend. Also, as mentioned previously, the class label(s) will be mapped from the classes.txt into the training data.

In [7]:
def transform_instance(row):
    cur_row = ''
    cur_row = index_to_label[row] 
    return cur_row


The `transform_instance` will be applied to each data instance in parallel using python's multiprocessing module

In [8]:
def preprocess(input_file, output_file, testfile=1):
    all_rows = ''
    with open(input_file, 'r') as csvinfile:
        #csv_reader = csv.reader(csvinfile, delimiter='\n')
        count = 0
        for row in csvinfile:
            if (testfile == 0):
                count += 1;
                if (count == 200):
                    break
            category = row.split(',')[0]
            title = row.split(',')[1]
            document = row.split(title+',')[1]
            all_rows += transform_instance(category) + ',' + document
    
        with open(output_file, 'w') as csvoutfile:
            csvoutfile.write(all_rows)
            
def preprocesstest(input_file, output_file, testfile=1):
    all_rows = ''
    with open(input_file, 'r') as csvinfile:
        #csv_reader = csv.reader(csvinfile, delimiter='\n')
        count = 0
        for row in csvinfile:
            if (testfile == 0):
                count += 1;
                if (count == 200):
                    break
            title = row.split(',')[1]
            document = row.split(title+',')[1]
            all_rows += document
    
        with open(output_file, 'w') as csvoutfile:
            csvoutfile.write(all_rows)

In [9]:
%%time

# Preparing the training dataset

preprocess('dbpedia_csv/train.csv', 'dbpedia.train')
        
# Preparing the test dataset        
preprocesstest('dbpedia_csv/test.csv', 'dbpedia.test')

CPU times: user 2.58 s, sys: 632 ms, total: 3.21 s
Wall time: 3.18 s


In [10]:
!head dbpedia.train -n 3

Company," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."
Company," Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."
Company," Q-workshop is a Polish company located in Poznań that specializes in designand production of polyhedral dice and dice accessories for use in various games (role-playing gamesboard games and tabletop wargames). They also run an online retail store and maintainan active forum community.Q-workshop was established in 2001 by Patryk Strzelewicz – a student from Poznań. Initiallythe company sold its products via online auction services but in 2005 a website and online store wereestablished."


In [11]:
def upload_to_s3(channel, file):
    s3 = boto3.resource('s3')
    data = open(file, "rb")
    key = channel + '/' + file
    s3.Bucket(bucket).put_object(Key=key, Body=data)


# caltech-256
s3_train_key = "dbpedia/train"
s3_test_key = "dbpedia/test"

upload_to_s3(s3_train_key, 'dbpedia.train')
upload_to_s3(s3_test_key, 'dbpedia.test')

The data preprocessing cell might take a minute to run. After the data preprocessing is complete, we need to upload it to S3 so that it can be consumed by SageMaker to execute training jobs. We'll use Python SDK to upload these two files to the bucket and prefix location that we have set above.   

Next we need to setup an output location at S3, where the model artifact will be dumped. These artifacts are also the output of the algorithm's traning job.

In [12]:
s3_output_location = 's3://{}/{}output'.format(bucket, prefix)
s3_train_location = 's3://{}/{}train'.format(bucket, prefix)+'/'+'dbpedia.train'
s3_test_location = 's3://{}/{}test'.format(bucket, prefix)+'/'+'dbpedia.test'

In [13]:

print (s3_output_location)
print (s3_train_location)
print (s3_train_key)

s3://comprehend-demolm/dbpedia/output
s3://comprehend-demolm/dbpedia/train/dbpedia.train
dbpedia/train


## Training Comprehend for custom classification

Create Policy for Comprehend Service role

In [15]:
iam = boto3.client("iam")
policy_name = "Comprehendpolicy"
policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "VisualEditor0",
            "Effect": "Allow",
            "Action": [
                "s3:PutObject",
                "s3:GetObject",
                "s3:ListBucket",
                "s3:DeleteObject"
            ],
            "Resource": [
                "arn:aws:s3:::*Comprehend*",
                "arn:aws:s3:::*comprehend*"
            ]
        }
    ]
}
create_policy_response = iam.create_policy(
    PolicyName = policy_name,
    PolicyDocument = json.dumps(policy_document),
    Description='Comprehend Policy'
)
PolicyArn=create_policy_response["Policy"]["Arn"]
print(PolicyArn)

arn:aws:iam::625941629713:policy/Comprehendpolicy


In [18]:
role_name = "ComprehendRole"
assume_role_policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
          "Effect": "Allow",
          "Principal": {
            "Service": "comprehend.amazonaws.com"
          },
          "Action": "sts:AssumeRole"
        }
    ]
}
create_role_response = iam.create_role(
    RoleName = role_name,
    AssumeRolePolicyDocument = json.dumps(assume_role_policy_document),
    Description='Amazon Comprehend service role for classifier.'
)

iam.attach_role_policy(
    RoleName = role_name,
    PolicyArn = create_policy_response["Policy"]["Arn"]
)

time.sleep(30) # wait for a minute to allow IAM role policy attachment to propagate

role_arn = create_role_response["Role"]["Arn"]
print(role_arn)

arn:aws:iam::625941629713:role/ComprehendRole


We will create a custom classification training job using the Boto3 SDK

In [19]:
# Instantiate Boto3 SDK:
client = boto3.client('comprehend', region_name=(region))

# Create a document classifier
create_response = client.create_document_classifier(
    InputDataConfig={
        'S3Uri': (s3_train_location)
    },
    DataAccessRoleArn=(role_arn),
    DocumentClassifierName='dbpedia-classifier',
    LanguageCode='en'
)
print("Create response: %s\n", create_response)




Create response: %s
 {'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:625941629713:document-classifier/dbpedia-classifier', 'ResponseMetadata': {'RequestId': '3dacfe45-41a1-4c43-8343-706050089613', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '3dacfe45-41a1-4c43-8343-706050089613', 'content-type': 'application/x-amz-json-1.1', 'content-length': '108', 'date': 'Sat, 20 Jul 2019 18:32:27 GMT'}, 'RetryAttempts': 0}}


In [20]:
status = None
max_time = time.time() + 3*60*60 # 3 hours
while time.time() < max_time:
    describe_response = client.describe_document_classifier(
    DocumentClassifierArn=create_response['DocumentClassifierArn'])
    status = describe_response["DocumentClassifierProperties"]["Status"]
    now = datetime.now(pytz.utc)
    elapsed = now - describe_response["DocumentClassifierProperties"]["SubmitTime"]
    print("DocumentClassifierProperties: {}   (elapsed = {})".format(status, elapsed))
    
    if status == "TRAINED" or status == "CREATE FAILED":
        break
        
    time.sleep(15)

DocumentClassifierProperties: SUBMITTED   (elapsed = 0:00:09.765044)
DocumentClassifierProperties: SUBMITTED   (elapsed = 0:00:24.881853)
DocumentClassifierProperties: TRAINING   (elapsed = 0:00:39.955981)
DocumentClassifierProperties: TRAINING   (elapsed = 0:00:55.020844)
DocumentClassifierProperties: TRAINING   (elapsed = 0:01:10.088692)
DocumentClassifierProperties: TRAINING   (elapsed = 0:01:25.163460)
DocumentClassifierProperties: TRAINING   (elapsed = 0:01:40.247051)
DocumentClassifierProperties: TRAINING   (elapsed = 0:01:55.284898)
DocumentClassifierProperties: TRAINING   (elapsed = 0:02:10.351946)
DocumentClassifierProperties: TRAINING   (elapsed = 0:02:25.430562)
DocumentClassifierProperties: TRAINING   (elapsed = 0:02:40.511459)
DocumentClassifierProperties: TRAINING   (elapsed = 0:02:55.589646)
DocumentClassifierProperties: TRAINING   (elapsed = 0:03:10.677079)
DocumentClassifierProperties: TRAINING   (elapsed = 0:03:25.719026)
DocumentClassifierProperties: TRAINING   (elap

Check the status of the classifier

##  Inference
Once the training is done, we can create a job to classify documents with the Amazon Comprehend custom classifier. We will run this against our test data

In [21]:
start_response = client.start_document_classification_job(
    InputDataConfig={
        'S3Uri': (s3_test_location),
        'InputFormat': 'ONE_DOC_PER_LINE'
    },
    OutputDataConfig={
        'S3Uri': (s3_output_location)
    },
    DataAccessRoleArn='arn:aws:iam::625941629713:role/service-role/AmazonComprehendServiceRole-dbpedia',
    DocumentClassifierArn='arn:aws:comprehend:us-east-1:625941629713:document-classifier/dbpedia-classifier'
)

print("Start response: %s\n", start_response)

Start response: %s
 {'JobId': '1d61584cb891b55c49c1864d2d356081', 'JobStatus': 'SUBMITTED', 'ResponseMetadata': {'RequestId': '93c08869-abc6-49c7-af20-eec1173c582f', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '93c08869-abc6-49c7-af20-eec1173c582f', 'content-type': 'application/x-amz-json-1.1', 'content-length': '68', 'date': 'Sat, 20 Jul 2019 18:50:02 GMT'}, 'RetryAttempts': 0}}


Check the status of the job

In [23]:
status = None
max_time = time.time() + 3*60*60 # 3 hours
while time.time() < max_time:
    describe_response = client.describe_document_classification_job(JobId=start_response['JobId'])
    status = describe_response["DocumentClassificationJobProperties"]["JobStatus"]
    now = datetime.now(pytz.utc)
    elapsed = now - describe_response["DocumentClassificationJobProperties"]["SubmitTime"]
    print("DocumentClassificationJobProperties: {}   (elapsed = {})".format(status, elapsed))
    
    if status == "COMPLETED" or status == "CREATE FAILED":
        break
        
    time.sleep(15)
output_location = describe_response["DocumentClassificationJobProperties"]["OutputDataConfig"]
outputs3=output_location["S3Uri"]

DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:00:08.044773)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:00:23.101194)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:00:38.145841)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:00:53.208978)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:01:08.278999)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:01:23.346907)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:01:38.421726)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:01:53.501449)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:02:08.559914)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:02:23.602743)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:02:38.673054)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:02:53.728880)
DocumentClassificationJobProperties: IN_PROGRESS   (elapsed = 0:

In [29]:
output_location = describe_response["DocumentClassificationJobProperties"]["OutputDataConfig"]
print(describe_response)
outputs3=output_location["S3Uri"]
print(outputs3)

{'DocumentClassificationJobProperties': {'JobId': '1d61584cb891b55c49c1864d2d356081', 'JobStatus': 'COMPLETED', 'SubmitTime': datetime.datetime(2019, 7, 20, 18, 50, 3, 398000, tzinfo=tzlocal()), 'EndTime': datetime.datetime(2019, 7, 20, 18, 54, 39, 268000, tzinfo=tzlocal()), 'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:625941629713:document-classifier/dbpedia-classifier', 'InputDataConfig': {'S3Uri': 's3://comprehend-demolm/dbpedia/test/dbpedia.test', 'InputFormat': 'ONE_DOC_PER_LINE'}, 'OutputDataConfig': {'S3Uri': 's3://comprehend-demolm/dbpedia/output/625941629713-CLN-1d61584cb891b55c49c1864d2d356081/output/output.tar.gz'}, 'DataAccessRoleArn': 'arn:aws:iam::625941629713:role/service-role/AmazonComprehendServiceRole-dbpedia'}, 'ResponseMetadata': {'RequestId': '67e90f01-ab2b-4824-a3d3-2956d9a26d38', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '67e90f01-ab2b-4824-a3d3-2956d9a26d38', 'content-type': 'application/x-amz-json-1.1', 'content-length': '626', 'date'

Once the classification job has run lets download and view the results

In [39]:
s3 = (outputs3)
!echo $dir_path

#!aws s3 cp $1 . --recursive

s3://comprehend-demolm/dbpedia/output/625941629713-CLN-1d61584cb891b55c49c1864d2d356081/output/output.tar.gz


In [None]:
!tar -xzvf output.tar.gz

In [None]:
!head predictions.jsonl