![alt text](https://whylabs-public.s3.us-west-2.amazonaws.com/assets/whylabs-logo-night-blue.svg)

*Run AI with Certainty*

# **Using WhyLabs with Sagemaker** 

In [None]:
# Torch installed else outside the notebook
%pip install 'transformers[torch]' python-dotenv ipywidgets

## AWS Authentication

Set up the AWS authentication by preparing an execution role for Sagemaker and ensuring you can use the aws cli.

In [2]:
# Just storing sensitive stuff in a .env file.
from dotenv import load_dotenv

# Create a sagemaker.env file with these vars
# SAGEMAKER_ROLE=
# WHYLABS_API_KEY=
# WHYLABS_DEFAULT_DATASET_ID=
# BUCKET_ENV=

load_dotenv(dotenv_path='sagemaker.env')

True

In [3]:
import sagemaker
import os

# A sagemaker execution role that you previously created
aws_role = os.getenv("SAGEMAKER_ROLE")
aws_region = "us-west-2"
session = sagemaker.Session()

## Prepare model
For this example we'll package up an existing model, one of the resnet variants available in pytorch.

In [4]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import tarfile
import os

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model_path = 'model/'
code_path = 'code/'

if not os.path.exists(model_path):
    os.mkdir(model_path)
    
model.save_pretrained(save_directory=model_path)
tokenizer.save_vocabulary(save_directory=model_path)

bucket = os.getenv("BUCKET_NAME")
key_prefix = 'sagemaker_models/gpt2'

with tarfile.open(os.path.join(model_path, "model.tar.gz"), 'w:gz') as tar:
    tar.add(model_path)
    # tar.add(code_path)

upload_path = session.upload_data(path='model/model.tar.gz', bucket=bucket, key_prefix=key_prefix)
print(f"Model artifact uploaded to: {upload_path}")

## Prepare the requirements file
You'll need to install whylogs on the Sagemaker host. You do this by passing a requirements file with everything that you need. We'll create a dummy virtual env here just to export a requirements file for Sagemaker. 

In [6]:
# Just bundled with the requirements.txt file instead of creating dynamically

# ! mkdir -p code 
# ! bash -c "virtualenv ./code/.venv && source ./code/.venv/bin/activate && pip install whylogs[image,proc]==1.3.11 langkit[all] && pip freeze > code/requirements.txt"
# ! rm -rf ./code/.venv

## Create an inference.py file
The integration happens in the custom inference logic for the Sagemaker container. The important parts are captured below.  This cell will be written to a file and deployed along with the model further down. This happens to be logging image data but it works with other kinds of data as well.

In [7]:
%%writefile code/inference.py
import traceback
import json
import os
import logging

import whylogs as why
from whylogs.api.writer import Writer, Writers
from whylogs.api.logger.experimental.logger.actor.thread_rolling_logger import ThreadRollingLogger
from langkit import llm_metrics # alternatively use 'light_metrics'
from whylogs.api.logger.experimental.logger.actor.time_util import Schedule, TimeGranularity
from transformers import GPT2Tokenizer, TextGenerationPipeline, GPT2LMHeadModel


logging.basicConfig(level=logging.DEBUG)

# Initialize whylogs with your WhyLabs API key and target dataset ID. You can get an api key from the
# settings menu of you WhyLabs account. 
why.init() # This loads credentials from the env directly

def create_logger():
    logger = ThreadRollingLogger(
        # This should match the model type in WhyLabs. We're using a daily model here.
        aggregate_by=TimeGranularity.Day,
        # The profiles will be uploaded from the rolling logger to WhyLabs every 5 minutes. Data
        # will accumulates during that time.
        write_schedule=Schedule(cadence=TimeGranularity.Minute, interval=5),
        writers=[Writers.get('whylabs')],
        schema=llm_metrics.init(),
    )

    return logger


logger = create_logger()


def model_fn(model_dir):
    """
    Load the model for inference
    """

    # Load GPT2 tokenizer from disk.
    vocab_path = os.path.join(model_dir, 'model/vocab.json')
    merges_path = os.path.join(model_dir, 'model/merges.txt')
    
    tokenizer = GPT2Tokenizer(vocab_file=vocab_path,
                              merges_file=merges_path)

    # Load GPT2 model from disk.
    model_path = os.path.join(model_dir, 'model/')
    model = GPT2LMHeadModel.from_pretrained(model_path)

    return TextGenerationPipeline(model=model, tokenizer=tokenizer)


def input_fn(request_body, request_content_type):
    assert request_content_type == 'application/json'
    return json.loads(request_body)

def predict_fn(input_data, model):
    if 'flush' in input_data:
        # Utility for flushing the logger, which forces it to upload any pending profiles synchronously.
        print('>> flushing')
        logger.flush()
        return 'flushed'

    if 'close' in input_data:
        print('>> closing')
        logger.close()
        return 'closed'

    output = model.__call__(input_data, max_length=100)
    output = output[0]['generated_text']

    try:
        # Log image async with whylogs. This won't hold up predictions.
        row = {'prompt': input_data, 'response': output}
        print(f'Row: {row}')
        logger.log(row)
    except Exception as e:
        print(f"Failed to log image: {e}")
        print(traceback.format_exc())

    return output

def output_fn(prediction, content_type):
    return str(prediction)



Overwriting code/inference.py


# Create a pytorch deployment

In [8]:
from sagemaker.pytorch import PyTorchModel
from sagemaker import Predictor

instance_type='ml.m5.xlarge'
# instance_type='ml.p3.2xlarge'  # For gpu

sagemaker_model = PyTorchModel(
    source_dir='code',
    entry_point='inference.py',
    predictor_cls=Predictor,
    model_data=upload_path,
    framework_version='2.0',
    py_version='py310',
    role=aws_role,
    env={
        'WHYLABS_API_KEY': os.environ['WHYLABS_API_KEY'],
        'WHYLABS_DEFAULT_DATASET_ID': os.environ['WHYLABS_DEFAULT_DATASET_ID']
    },
)

In [9]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import StringDeserializer

predictor = sagemaker_model.deploy(initial_instance_count=1, instance_type=instance_type)
predictor.serializer = JSONSerializer()
predictor.deserializer = StringDeserializer()

-----------!

# Make predictions

In [10]:
def predict(text):
    query_response = predictor.predict(
        text,
        {
            "ContentType": "application/json",
            "Accept": "application/json",
        },
    )
    return query_response

In [18]:
predict("LLMs are... ")

"LLMs are... \xa0I'm not sure if they're the same as the ones I've seen in the past, but I'm sure they're the same.\nI'm not sure if they're the same as the ones I've seen in the past, but I'm sure they're the same.\nI'm not sure if they're the same as the ones I've seen in the past, but I'm sure they're the same.\nI'm not sure if they're"

## Force the logger to upload

> ⚠️ These things really only reliably work if you have a single instance behind your prediction endpoint. Otherwise you'll have to make sure these requests get to each endpoint individually.

This forces the logger to upload (see the inference.py code) which uploads any remaining data in the logger before we close down the Sagemaker endpoint. The rolling logger typically uploads data on a predefined interval so you can do something like this to make sure you don't clip your profile uploads before shutting things down. Sagemaker doesn't provide any "on close" hooks to make this transparent.


In [19]:
predict('flush')

'flushed'

In addition, you can close the logger. This also forces an upload but results in the logger no longer being active, so you should only do this before you're about to tear down the endpoint since it let's you synchronously wait for any pending uploads to finish.

Remember, these don't automatically work. They only work because we set up the `inference.py`` file to check for these payloads and call the right methods on the logger.

In [13]:
# predict('close')

# Clean up endpoint

In [14]:
# predictor.delete_endpoint()