# Train Amazon Comprehend Custom Classifier Model

<img src="img/comprehend.png" width="80%" align="left">

## Note that Amazon Comprehend is currently only supported in a subset of regions: 

* US East (N. Virginia), US East (Ohio), US West (Oregon)
* Canada (Central)
* Europe (London), Europe (Ireland), Europe (Frankfurt)
* Asia Pacific (Mumbai), Asia Pacific (Seoul), Asia Pacific (Tokyo), Asia Pacific (Singapore), Asia Pacific (Sydney)

You can check https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/ for details and updates. 

In [1]:
import boto3
import sagemaker
import pandas as pd

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

from botocore.config import Config

config = Config(retries={"max_attempts": 10, "mode": "adaptive"})

iam = boto3.client("iam", config=config)
sm = boto3.Session().client(service_name="sagemaker", region_name=region)



### Check if you current regions supports Comprehend

In [2]:
if region in [
    "ap-south-1",
    "eu-west-2",
    "eu-west-1",
    "ap-northeast-2",
    "ap-northeast-1",
    "ca-central-1",
    "ap-southeast-1",
    "ap-southeast-2",
    "eu-central-1",
    "us-east-1",
    "us-east-2",
    "us-west-2",
]:
    print(" [OK] COMPREHEND IS SUPPORTED IN {}".format(region))
    print(" [OK] Please proceed with this notebook.")
else:
    print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print(" [ERROR] COMPREHEND IS NOT YET SUPPORTED IN {}.".format(region))
    print(" [INFO] This is OK. Skip this notebook and continue with the next use case.")
    print(" [INFO] This notebook is not required for the rest of this workshop.")
    print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

 [OK] COMPREHEND IS SUPPORTED IN us-east-1
 [OK] Please proceed with this notebook.


In [3]:
comprehend = boto3.client("comprehend")

### Retrieve S3 location of training data

In [4]:
%store -r comprehend_train_s3_uri

In [5]:
if not comprehend_train_s3_uri:
    print("****************************************************************************************")
    print("**************** PLEASE RE-RUN THE PREVIOUS DATA PREPARATION NOTEBOOK ******************")
    print("**************** THIS NOTEBOOK WILL NOT RUN PROPERLY ***********************************")
    print("****************************************************************************************")

In [6]:
print(comprehend_train_s3_uri)

s3://sagemaker-us-east-1-211125778552/data/amazon_reviews_us_Digital_Software_v1_00_comprehend.csv


In [7]:
!aws s3 ls $comprehend_train_s3_uri

2024-02-15 23:27:40     397534 amazon_reviews_us_Digital_Software_v1_00_comprehend.csv


## See our prepared training data which we use as input for Comprehend

In [8]:
!aws s3 cp $comprehend_train_s3_uri ./tmp/

download: s3://sagemaker-us-east-1-211125778552/data/amazon_reviews_us_Digital_Software_v1_00_comprehend.csv to tmp/amazon_reviews_us_Digital_Software_v1_00_comprehend.csv


In [9]:
import csv

df = pd.read_csv("./tmp/amazon_reviews_us_Digital_Software_v1_00_comprehend.csv", header=None)
df.head()

Unnamed: 0,0,1
0,3,Works great for the Federal. But I had to mail...
1,5,The bit defender antivirus plus is great!
2,2,"Basically, I paid for the ability to download ..."
3,3,Longtime user. Would give 5 stars to the 2002-...
4,2,"Quicken for Mac 2015 is a nice idea, but it's ..."


# Create Data Access Role for Comprehend

## Create Policy

In [10]:
assume_role_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {"Effect": "Allow", "Principal": {"Service": "comprehend.amazonaws.com"}, "Action": "sts:AssumeRole"}
    ],
}

## Create Role and Attach Policies

In [11]:
iam_comprehend_role_name = "DSOAWS_Comprehend"

In [12]:
import json
import time

from botocore.exceptions import ClientError

try:
    iam_role_comprehend = iam.create_role(
        RoleName=iam_comprehend_role_name,
        AssumeRolePolicyDocument=json.dumps(assume_role_policy_doc),
        Description="DSOAWS Comprehend Role",
    )
except ClientError as e:
    if e.response["Error"]["Code"] == "EntityAlreadyExists":
        iam_role_comprehend = iam.get_role(RoleName=iam_comprehend_role_name)
        print("Role already exists")
    else:
        print("Unexpected error: %s" % e)

time.sleep(30)

In [13]:
comprehend_s3_policy_doc = {
    "Version": "2012-10-17",
    "Statement": [
        {"Action": ["s3:GetObject"], "Resource": ["arn:aws:s3:::{}/*".format(bucket)], "Effect": "Allow"},
        {"Action": ["s3:ListBucket"], "Resource": ["arn:aws:s3:::{}".format(bucket)], "Effect": "Allow"},
        {"Action": ["s3:PutObject"], "Resource": ["arn:aws:s3:::{}/*".format(bucket)], "Effect": "Allow"},
    ],
}

print(comprehend_s3_policy_doc)

{'Version': '2012-10-17', 'Statement': [{'Action': ['s3:GetObject'], 'Resource': ['arn:aws:s3:::sagemaker-us-east-1-211125778552/*'], 'Effect': 'Allow'}, {'Action': ['s3:ListBucket'], 'Resource': ['arn:aws:s3:::sagemaker-us-east-1-211125778552'], 'Effect': 'Allow'}, {'Action': ['s3:PutObject'], 'Resource': ['arn:aws:s3:::sagemaker-us-east-1-211125778552/*'], 'Effect': 'Allow'}]}


# Attach Policy to Role

In [14]:
import time

response = iam.put_role_policy(
    RoleName=iam_comprehend_role_name,
    PolicyName="DSOAWS_ComprehendPolicyToS3",
    PolicyDocument=json.dumps(comprehend_s3_policy_doc),
)

print(response)

time.sleep(30)

{'ResponseMetadata': {'RequestId': '5ba62ed9-a914-4655-ba91-0158b126e274', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '5ba62ed9-a914-4655-ba91-0158b126e274', 'content-type': 'text/xml', 'content-length': '206', 'date': 'Thu, 15 Feb 2024 23:28:29 GMT'}, 'RetryAttempts': 0}}


# Train the Model

In [15]:
prefix = "models"

s3_output_job = "s3://{}/{}/{}".format(bucket, prefix, "comprehend/output")
print(s3_output_job)

s3://sagemaker-us-east-1-211125778552/models/comprehend/output


In [16]:
iam_role_comprehend_arn = iam_role_comprehend["Role"]["Arn"]

In [17]:
import datetime
import time

timestamp = str(datetime.datetime.now().strftime("%s"))

comprehend_training_job_name = "Amazon-Customer-Reviews-Classifier-{}".format(timestamp)

print(comprehend_training_job_name)

Amazon-Customer-Reviews-Classifier-1708039740


In [18]:
training_job = comprehend.create_document_classifier(
    DocumentClassifierName=comprehend_training_job_name,
    DataAccessRoleArn=iam_role_comprehend_arn,
    InputDataConfig={"S3Uri": comprehend_train_s3_uri},
    OutputDataConfig={"S3Uri": s3_output_job},
    LanguageCode="en",
)

time.sleep(30)

In [19]:
comprehend_training_job_arn = training_job["DocumentClassifierArn"]

print(comprehend_training_job_arn)

arn:aws:comprehend:us-east-1:211125778552:document-classifier/Amazon-Customer-Reviews-Classifier-1708039740


In [20]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/comprehend/v2/home?region={}#classifier-details/{}">Comprehend Training Job</a></b>'.format(
            region, comprehend_training_job_arn
        )
    )
)

# This Next Cell Will Take Some Time
# _Please be patient._

In [21]:
import time

max_time = time.time() + 3 * 60 * 60  # 3 hours
while time.time() < max_time:
    describe_custom_classifier = comprehend.describe_document_classifier(
        DocumentClassifierArn=comprehend_training_job_arn
    )
    status = describe_custom_classifier["DocumentClassifierProperties"]["Status"]
    print("Custom classifier: {}".format(status))

    if status == "TRAINED" or status == "IN_ERROR":
        print("")
        print("Status {}".format(status))
        print("")
        print(describe_custom_classifier["DocumentClassifierProperties"])
        break

    time.sleep(10)

Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: SUBMITTED
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom classifier: TRAINING
Custom cla

# _Please Wait Until the ^^ Classifier ^^ is Trained Above._

# [INFO] _Feel free to continue to the next workshop section while this notebook is running._

# Show Results of the Classifier

In [22]:
print(describe_custom_classifier["DocumentClassifierProperties"])

{'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:211125778552:document-classifier/Amazon-Customer-Reviews-Classifier-1708039740', 'LanguageCode': 'en', 'Status': 'TRAINED', 'SubmitTime': datetime.datetime(2024, 2, 15, 23, 29, 0, 280000, tzinfo=tzlocal()), 'EndTime': datetime.datetime(2024, 2, 16, 0, 7, 58, 325000, tzinfo=tzlocal()), 'TrainingStartTime': datetime.datetime(2024, 2, 15, 23, 33, 26, 926000, tzinfo=tzlocal()), 'TrainingEndTime': datetime.datetime(2024, 2, 16, 0, 6, 31, 808000, tzinfo=tzlocal()), 'InputDataConfig': {'DataFormat': 'COMPREHEND_CSV', 'S3Uri': 's3://sagemaker-us-east-1-211125778552/data/amazon_reviews_us_Digital_Software_v1_00_comprehend.csv'}, 'OutputDataConfig': {'S3Uri': 's3://sagemaker-us-east-1-211125778552/models/comprehend/output/211125778552-CLR-5615bca1c68c5c1c05bb416ca7baa6c1/output/output.tar.gz'}, 'ClassifierMetadata': {'NumberOfLabels': 5, 'NumberOfTrainedDocuments': 810, 'NumberOfTestDocuments': 90, 'EvaluationMetrics': {'Accuracy': 0.4444, 

In [23]:
model_arn = describe_custom_classifier["DocumentClassifierProperties"]["DocumentClassifierArn"]
print(model_arn)

arn:aws:comprehend:us-east-1:211125778552:document-classifier/Amazon-Customer-Reviews-Classifier-1708039740


In [24]:
import os

# Retrieve the S3URI from the model output and create jobkey variable.
job_output = describe_custom_classifier["DocumentClassifierProperties"]["OutputDataConfig"]["S3Uri"]
print(job_output)

path_prefix = "s3://{}/".format(bucket)

job_key = os.path.relpath(job_output, path_prefix)

print(job_key)

s3://sagemaker-us-east-1-211125778552/models/comprehend/output/211125778552-CLR-5615bca1c68c5c1c05bb416ca7baa6c1/output/output.tar.gz
models/comprehend/output/211125778552-CLR-5615bca1c68c5c1c05bb416ca7baa6c1/output/output.tar.gz


# Download Model Artifacts including Training Metrics

In [25]:
s3 = boto3.resource("s3")

s3.Bucket(bucket).download_file(job_key, "./output.tar.gz")



In [26]:
# Unpack the gzip file
!tar xvzf ./output.tar.gz

tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'
output/
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'
output/confusion_matrix.json


In [27]:
import json

with open("./output/confusion_matrix.json") as json_file:
    data = json.load(json_file)
print(json.dumps(data, indent=2, default=str))

{
  "confusion_matrix": [
    [
      7,
      10,
      2,
      0,
      0
    ],
    [
      2,
      15,
      0,
      1,
      0
    ],
    [
      1,
      13,
      2,
      2,
      0
    ],
    [
      0,
      4,
      3,
      8,
      3
    ],
    [
      0,
      4,
      0,
      5,
      8
    ]
  ],
  "labels": [
    "1",
    "2",
    "3",
    "4",
    "5"
  ],
  "type": "multi_class",
  "all_labels": [
    "1",
    "2",
    "3",
    "4",
    "5"
  ]
}


In [28]:
!pip install tabulate

Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0


In [29]:
from IPython.display import HTML, display
import tabulate

table = [
    ["", "1", "2", "3", "4", "5", "(Predicted)"],
    [
        "1",
        data["confusion_matrix"][0][0],
        data["confusion_matrix"][0][1],
        data["confusion_matrix"][0][2],
        data["confusion_matrix"][0][3],
        data["confusion_matrix"][0][4],
    ],
    [
        "2",
        data["confusion_matrix"][1][0],
        data["confusion_matrix"][1][1],
        data["confusion_matrix"][1][2],
        data["confusion_matrix"][1][3],
        data["confusion_matrix"][1][4],
    ],
    [
        "3",
        data["confusion_matrix"][2][0],
        data["confusion_matrix"][2][1],
        data["confusion_matrix"][2][2],
        data["confusion_matrix"][2][3],
        data["confusion_matrix"][2][4],
    ],
    [
        "4",
        data["confusion_matrix"][3][0],
        data["confusion_matrix"][3][1],
        data["confusion_matrix"][3][2],
        data["confusion_matrix"][3][3],
        data["confusion_matrix"][3][4],
    ],
    [
        "5",
        data["confusion_matrix"][4][0],
        data["confusion_matrix"][4][1],
        data["confusion_matrix"][4][2],
        data["confusion_matrix"][4][3],
        data["confusion_matrix"][4][4],
    ],
    ["(Actual)"],
]
display(HTML(tabulate.tabulate(table, tablefmt="html")))

0,1,2,3,4,5,6
,1.0,2.0,3.0,4.0,5.0,(Predicted)
1,7.0,10.0,2.0,0.0,0.0,
2,2.0,15.0,0.0,1.0,0.0,
3,1.0,13.0,2.0,2.0,0.0,
4,0.0,4.0,3.0,8.0,3.0,
5,0.0,4.0,0.0,5.0,8.0,
(Actual),,,,,,


# Deploy Endpoint

In [30]:
from time import gmtime, strftime, sleep

timestamp_suffix = strftime("%d-%H-%M-%S", gmtime())

comprehend_endpoint_name = "comprehend-inference-ep-" + timestamp_suffix

inference_endpoint_response = comprehend.create_endpoint(
    EndpointName=comprehend_endpoint_name, ModelArn=model_arn, DesiredInferenceUnits=1
)

In [31]:
comprehend_endpoint_arn = inference_endpoint_response["EndpointArn"]
print(comprehend_endpoint_arn)

arn:aws:comprehend:us-east-1:211125778552:document-classifier-endpoint/comprehend-inference-ep-16-00-08-02


# Pass Variables to the Next Notebook(s)

In [32]:
%store comprehend_training_job_arn

Stored 'comprehend_training_job_arn' (str)


In [33]:
%store comprehend_endpoint_arn

Stored 'comprehend_endpoint_arn' (str)


# Release Resources

In [34]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>

In [35]:
%%javascript

try {
    Jupyter.notebook.save_checkpoint();
    Jupyter.notebook.session.delete();
}
catch(err) {
    // NoOp
}

<IPython.core.display.Javascript object>