In [1]:
import boto3
import pandas as pd
import numpy as np
import json
from urllib.parse import urlparse
import shortuuid
import larry as lry


cognito = boto3.client("cognito-idp")
sagemaker = boto3.client("sagemaker")
bucket_name = "a9poc"

# Data prep
We'll read the data in from Excel, format it, and then group the ASINs by query string. The generated data structure will be our input to Ground Truth.

In [2]:
df = pd.read_excel("Sample+for+SMGT.xls")
records = df.to_dict("records")
records = [{k.split(":")[1]: v for k, v in record.items() if k.startswith("INPUT") and not (isinstance(v, float) and np.isnan(v))} for record in records]
print(f"Found {len(records)} records")
records[0]

Found 500 records


{'asin': 'B07SX5RM12',
 'detail_page_url': 'https://www.amazon.com/dp/B07SX5RM12',
 'image': 'https://m.media-amazon.com/images/I/71P8X7JIc-L._AC_UY879_.jpg',
 'query_string': 'guess jacket men',
 'search_alias': 'aps',
 'search_on_google_url': 'https://www.google.com/search?q=guess+jacket+men',
 'search_page_url': 'https://www.amazon.com/s?k=guess+jacket+men',
 'title': "GUESS Men's Wind & Water Resistant Hooded Puffer Jacket with Side Stretch Panels, Light Grey, XX-Lar"}

In [3]:
queries = {}
for record in records:
    qs = record["query_string"]
    asin = {k: v for k, v in record.items() if k not in ["query_string", "search_alias", "search_on_google_url", "search_page_url"]}
    if qs in queries:
        queries[qs]["asins"].append(asin)
    else:
        query = {k: v for k, v in record.items() if k in ["query_string", "search_alias", "search_on_google_url", "search_page_url"]}
        query["marketplace"] = "amazon.com"
        queries[qs] = {
            "query": query,
            "asins": [asin]
        }
queries = list(queries.values())
print(f"Consolidated into {len(queries)} queries")
queries[0]

Consolidated into 373 queries


{'query': {'query_string': 'guess jacket men',
  'search_alias': 'aps',
  'search_on_google_url': 'https://www.google.com/search?q=guess+jacket+men',
  'search_page_url': 'https://www.amazon.com/s?k=guess+jacket+men',
  'marketplace': 'amazon.com'},
 'asins': [{'asin': 'B07SX5RM12',
   'detail_page_url': 'https://www.amazon.com/dp/B07SX5RM12',
   'image': 'https://m.media-amazon.com/images/I/71P8X7JIc-L._AC_UY879_.jpg',
   'title': "GUESS Men's Wind & Water Resistant Hooded Puffer Jacket with Side Stretch Panels, Light Grey, XX-Lar"},
  {'asin': 'B07SW1D6HX',
   'detail_page_url': 'https://www.amazon.com/dp/B07SW1D6HX',
   'image': 'https://m.media-amazon.com/images/I/411Sli+HdhL.jpg',
   'title': "GUESS Men's Color Block Hooded Puffer Jacket, Navy, Medium"},
  {'asin': 'B08KRNDL3P',
   'detail_page_url': 'https://www.amazon.com/dp/B08KRNDL3P',
   'image': 'https://m.media-amazon.com/images/I/71JJtROtlBL._AC_UY679_.jpg',
   'title': "GUESS Men's Arctic Puffer Jacket, Reflective Prism, 

# Configure IAM roles
We'll need to IAM roles to use for this workflow, one for SageMaker itself, and a second for use by the Lambdas when executing.

In [4]:
# role and policy names
sagemaker_s3_policy = 'A9POCS3Policy'
sagemaker_role = 'A9POCSageMakerRole'
lambda_role = 'A9POCLambdaRole'

In [5]:
s3_policy_arn = lry.iam.create_or_update_policy(
    sagemaker_s3_policy,
    {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Action": [
                    "s3:GetObject",
                    "s3:PutObject",
                    "s3:DeleteObject",
                    "s3:ListBucket",
                    "sagemaker:*"
                ],
                "Resource": [
                    "arn:aws:s3:::*"
                ]
            }
        ]
    }
)
sagemaker_role_arn = lry.iam.create_or_update_service_role(sagemaker_role, 'sagemaker', [lry.iam.aws_policies.AmazonSageMakerFullAccess, s3_policy_arn])
print('Created role: {}'.format(sagemaker_role_arn))

Created role: arn:aws:iam::981332165467:role/A9POCSageMakerRole


The Lambda role will simply provide the access needed to execute the Lambdas and access S3. 
Use the following to create it if it doesn't exist.

In [6]:
lambda_role_arn = lry.iam.create_or_update_service_role(lambda_role, 'lambda', lry.iam.aws_policies.AWSLambdaExecute)
print('Created role: {}'.format(lambda_role_arn))

Created role: arn:aws:iam::981332165467:role/A9POCLambdaRole


# Create handler lambdas
Ground Truth labeling jobs include lambda functions that run before and after each individual labeling task to 
pre/post-process the data. For the POC we are just going to do assign simple passthrough functions since we'll
be doing most of the data management external to Ground Truth. Note that because Ground Truth doesn't allow
nested JSON objects as input data, we'll be writing them to strings before passing them through. We'll then 
decode them within the pre-lambda.

In [7]:
def pre_ground_truth_handler(event, context):
    print(event)
    data = event['dataObject']
    if "json" in data:
        mapped_data = json.loads(data["json"])
    elif isinstance(data, str):
        mapped_data = json.loads(data)
    elif "source" in data:
        mapped_data = json.loads(data["source"])
    response = {
        "taskInput": mapped_data
    }
    print(response)
    return response

def post_ground_truth_handler(event, context):
    """
    Groups all of the annotations provided by workers into a consolidated dataset.
    """
    payload = get_payload(event)
    
    consolidated_response = []
    
    # Iterate through the items to be consolidated
    for dataset in payload:
        results = []
        
        # Iterate through the annotations provided for each item
        for annotation in dataset['annotations']:
            result = json.loads(annotation['annotationData']['content'])
            result = result.get('annotatedResult', result)
            results.append({
                'workerId': annotation['workerId'],
                'annotation': result
            })

        # Add the results to the consolidated response
        consolidated_response.append({
            'datasetObjectId': dataset['datasetObjectId'],
            'consolidatedAnnotation' : {
                'content': {
                    event['labelAttributeName']: {
                        'responses': results
                    }
                }
            }
        })
    return consolidated_response


def get_payload(event):
    """
    Returns the annotation payload from S3 if it's present. The payload can also be passed
    directly for testing by providing it as a 'content' key within the payload.
    """
    payload = event.get('payload', {})
    if 's3Uri' in payload:
        s3 = boto3.resource('s3')
        parsed_url = urlparse(payload['s3Uri'])
        return json.loads(s3.Bucket(parsed_url.netloc).Object(parsed_url.path[1:]).get()['Body'].read())
    else:
        return payload.get('content',[])

In [8]:
# Note that 'SageMaker' must be in the name for the SageMaker role to be able to access these
pre_lambda = 'SageMakerA9POC-Pre'
post_lambda = 'SageMakerA9POC-Post'

pre_lambda_obj = lry.lmbda.create_or_update(pre_lambda, 
                                            *lry.lmbda.package_function(pre_ground_truth_handler),
                                            lambda_role_arn,
                                            runtime='python3.9',
                                            timeout=60,
                                            memory_size=128)
pre_lambda_arn = pre_lambda_obj.arn
print('Created Pre Lambda: {}'.format(pre_lambda_arn))

post_lambda_obj = lry.lmbda.create_or_update(post_lambda, 
                                            *lry.lmbda.package_function(
                                                post_ground_truth_handler, 
                                                imports=['json','boto3','urllib.parse>urlparse'],
                                                functions=[get_payload],
                                            ),
                                            lambda_role_arn,
                                            runtime='python3.9',
                                            timeout=60,
                                            memory_size=128)
post_lambda_arn = post_lambda_obj.arn
print('Created Post Lambda: {}'.format(post_lambda_arn))

Created Pre Lambda: arn:aws:lambda:us-west-2:981332165467:function:SageMakerA9POC-Pre
Created Post Lambda: arn:aws:lambda:us-west-2:981332165467:function:SageMakerA9POC-Post


# Configure the Template
I've taken the SDAT task interface and updated it to use the templating language used by Ground Truth. The
preview below is generated using that template and one of the data items.

In [9]:
with open("template.html") as fp:
    template_html = fp.read()
template_uri = lry.s3.write(template_html, bucket_name, "template.html")

In [10]:
lry.sagemaker.labeling.display_task_preview(
    template_uri, 
    sagemaker_role_arn, 
    pre_lambda=pre_lambda_arn,
    lambda_input={"json": json.dumps(queries[0])}
)

# Setup a work team
Work teams define which set of workers can view and submit a task. For this POC, let's assume that we are
members of a vendor (or internal) workteam that we'll be assigning tasks to. In the cells below we'll define
a team and then add users to it.

In [11]:
# existing private workforce in this account
cognito_user_pool = "us-west-2_S1fy4adPH"
cognito_app_client = "5ot8srmhefiisljfddolj6a0h9"
labeling_portal_url = "https://2llam9evkp.labeling.us-west-2.sagemaker.aws"

# The name of our work group/team
group_name = "A9POC"

In [None]:
# define a work group/team (this is only run once to initialize the group)
create_group_response = cognito.create_group(
    GroupName=group_name,
    UserPoolId=cognito_user_pool
)
create_workteam_response = sagemaker.create_workteam(
    WorkteamName=group_name,
    MemberDefinitions=[
        {
            "CognitoMemberDefinition": {
                "UserPool": cognito_user_pool,
                "UserGroup": group_name,
                "ClientId": cognito_app_client
            }
        }
    ],
    Description="Members of the A9 SMGT POC"
)
workteam_arn = create_workteam_response["WorkteamArn"]

In [13]:
workteam_arn = "arn:aws:sagemaker:us-west-2:981332165467:workteam/private-crowd/A9POC"

## Add users

To keep things simple we just have a simple function that will create a user in cognito with the password
set to match the username. To do this we've turned off the default password rules for this user group which 
would normally require numbers, uppercase, etc. We then add the new user to the team.

In [14]:
def add_user(username, user_pool, group_name):
    if len(username) < 6:
        raise ValueError("Username must have at least six characters")
    if len(cognito.list_users(UserPoolId=cognito_user_pool, Filter=f'username="{username}"')["Users"]) == 0:
        cognito.admin_create_user(
            UserPoolId=cognito_user_pool,
            Username=username,
            MessageAction="SUPPRESS"
        )
    cognito.admin_set_user_password(
        UserPoolId=cognito_user_pool,
        Username=username,
        Password=username,
        Permanent=True
    )
    cognito.admin_add_user_to_group(
        UserPoolId=cognito_user_pool,
        Username=username,
        GroupName=group_name
    )

In [110]:
add_user("schultz", cognito_user_pool, group_name)

In [111]:
team_users = [
    "pasanh",
    "posolga",
    "zumucodo",
    "shjulaka"
]
for user in team_users:
    add_user(user, cognito_user_pool, group_name)

# Create a labeling job

In [15]:
# Grab a subset of the data
multi_asin_queries = [query for query in queries if len(query["asins"]) >= 3]
len(multi_asin_queries)

19

In [152]:
# Write it to S3 in the appropriate format
data_uri = lry.s3.write_as(
    [{"source": json.dumps(query)} for query in multi_asin_queries], 
    [dict], 
    bucket_name, 
    "multi_asin_queries.jsonl")
data_uri

's3://a9poc/multi_asin_queries.jsonl'

In [130]:
task_config = lry.sagemaker.labeling.build_human_task_config(
    template_uri=template_uri,
    pre_lambda=pre_lambda,
    consolidation_lambda=post_lambda,
    title="Perform ESCI annotation",
    description="Perform ESCI annotation",
    workers=1,
    workteam_arn=workteam_arn,
    time_limit=1800, # 30 minutes
    lifetime=259200) # 3 days

In [172]:
job_name = "a9poc-" + shortuuid.uuid()[:10]
print(f"Creating labeling job '{job_name}'")
output_uri = lry.s3.join_uri(bucket_name)
lry.sagemaker.labeling.create_job(name=job_name,
                                  manifest_uri=data_uri,
                                  output_uri=output_uri,
                                  role_arn=sagemaker_role_arn,
                                  task_config=task_config,
                                  label_attribute_name="esci")

Creating labeling job 'a9poc-a72g2mujoD'


'arn:aws:sagemaker:us-west-2:981332165467:labeling-job/a9poc-a72g2mujod'

In [57]:
lry.sagemaker.labeling.get_job_state(job_name)

'InProgress (4/19)'

# Retrieve Results

In [53]:
results = lry.sagemaker.labeling.get_results(output_uri, job_name)

In [54]:
# Filter for completed items
results = [result for result in results if "esci" in result]
len(results)

2

In [55]:
query_results = []
for result in results:
    query = json.loads(results[0]["source"])["query"]
    annotation = result["esci"]["responses"][0]["annotation"]
    asins = {}
    for t in ["complement[]", "exact[]", "irrelevant[]", "other[]", "substitute[]"]:
        for asin, selected in annotation[t].items():
            if selected:
                asins[asin] = t[:-2]
    query["asins"] = asins
    query_results.append(query)
query_results

[{'query_string': 'carpet rake',
  'search_alias': 'aps',
  'search_on_google_url': 'https://www.google.com/search?q=carpet+rake',
  'search_page_url': 'https://www.amazon.com/s?k=carpet+rake',
  'marketplace': 'amazon.com',
  'asins': {'B09B2L3L4Y': 'exact',
   'B09G68GNQ3': 'exact',
   'B09GXJ9HVX': 'exact'}},
 {'query_string': 'carpet rake',
  'search_alias': 'aps',
  'search_on_google_url': 'https://www.google.com/search?q=carpet+rake',
  'search_page_url': 'https://www.amazon.com/s?k=carpet+rake',
  'marketplace': 'amazon.com',
  'asins': {'B003WZKMQ2': 'exact',
   'B07VLLGHLS': 'exact',
   'B084JSJR6Q': 'irrelevant'}}]