### Test SageMaker endpoint


In [None]:
import json
import re
from io import StringIO
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn

import boto3
import sagemaker

from aws_profiles import UserProfiles
profiles = UserProfiles()

#### Pick the AWS profile that you want to use


In [None]:
profile = "prod"
profile_id = profiles.get_profile_id(profile)

session = boto3.Session(profile_name=profile)
dev_s3_client = session.client("s3")
sm_client = session.client('sagemaker-runtime')

In [None]:
sm_session = sagemaker.Session(boto_session=session)
default_bucket = sm_session.default_bucket()

iam = session.client("iam")
role_arn = iam.get_role(RoleName=f"{profile_id}-sagemaker-exec")["Role"]["Arn"]

### 1. Run Inference on deployed endpoint
#### 1.1 Load eval data from AWS S3 bucket


In [None]:
s3_client = session.client("s3")
file_path = "data/val.csv"

s3_object = s3_client.get_object(Bucket=default_bucket, Key=file_path)
body = s3_object["Body"]
csv_string = body.read().decode("utf-8")
df = pd.read_csv(StringIO(csv_string), index_col=0)
df.reset_index(drop=True, inplace=True)
df.head()

#### 1.2 Show examples


In [None]:
inputs = df.transcription.tolist()
targets = df.medical_specialty.tolist()

n_prints = 10
for t, i in zip(targets[:n_prints], inputs[:n_prints]):
    text_block = re.sub("(.{120})", "\\1\n", i, 0, re.DOTALL)
    print(f"'{t}': \n {text_block[:500]} ... \n")

#### 1.3 Run Prediction on endpoint


In [None]:
endpoint_name = f"{profile_id}-endpoint"
CONTENT_TYPE_JSON = "application/json"
payload = json.dumps({"instances": inputs})

response = sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=CONTENT_TYPE_JSON,
    Accept=CONTENT_TYPE_JSON,
    Body=payload,
)

prediction = json.loads(response["Body"].read())["prediction"]

results = pd.DataFrame()
results["pred"] = prediction
results["target"] = targets
results["correct"] = results.apply(lambda x: x.pred == x.target, axis=1)
results.head()

#### 1.4 Eval result

In [None]:
counts_tar = results['target'].value_counts()
counts_pred = results['pred'].value_counts()

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True)
ax1.barh(counts_tar.index, counts_tar.values)
ax1.set_title('target')
ax2.barh(counts_pred.index, counts_pred.values)
ax2.set_title('pred')

print(f"Accuracy: {results.correct.mean()*100:.3f}%")
plt.show()