# PyTorch Native Transformer Text Gen Using TorchServe on SageMaker with Llama-2

## Contents

[GPT fast](https://github.com/pytorch-labs/gpt-fast) is a simple and efficient pytorch-native transformer text generation.

It features:
* Very low latency
* <1000 lines of python
* No dependencies other than PyTorch and sentencepiece
* int8/int4 quantization
* Speculative decoding
* Tensor parallelism
* Supports Nvidia and AMD GPUs

More details about gpt-fast can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-2/)

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
!python --version

In [None]:
# Install the latest aws cli v2 if it is not installed
!curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
!unzip awscliv2.zip
#Follow the instruction to install aws v2 on the terminal
!cat aws/README.md

In [None]:
# Conform it is aws-cli/2.xx.xx
!aws --version

In [None]:
#%pip install sagemaker pip --upgrade  --quiet
!pip install numpy
!pip install pillow
!pip install -U sagemaker
!pip install -U boto 
!pip install -U botocore
!pip install -U boto3
!pip install torch-model-archiver

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session(region_name="us-west-2")
smr = boto3.client('sagemaker-runtime-demo')
sm = boto3.client('sagemaker')
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess= sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account = sess.account_id()  # account_id of the current SageMaker Studio environment

# Configuration:
bucket_name = sess.default_bucket()
prefix = "torchserve"
output_path = f"s3://{bucket_name}/{prefix}"
print(f'account={account}, region={region}, role={role}, output_path={output_path}')

## Step 2: Build a BYOD TorchServe Docker container and push it to Amazon ECR

In [None]:
# Install our own dependencies
!cat workspace/docker/Dockerfile

In [2]:
%%capture build_output

baseimage = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"
reponame = "torchserve-gpt-fast"
versiontag = "0.1"

# Build our own docker image
!cd workspace/docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

In [None]:
if 'Error response from daemon' in str(build_output):
    print(build_output)
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    container = str(build_output).strip().split('\n')[-1]

print(container)

In [None]:
## Build model artifacts

In [1]:
# model configuration
!cat workspace/model-config.yaml

#frontend settings
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 10800
parallelType: "tp"
deviceType: "gpu"
continuousBatching: false

handler:
  model_name: "meta-llama/Llama-2-70b-chat-hf"
  converted_ckpt_dir: "checkpoints/meta-llama/Llama-2-70b-chat-hf/model.pth"
  draft_model_name: "meta-llama/Llama-2-7b-chat-hf"
  draft_checkpoint_dir: "checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
  quantization: int8
  max_new_tokens: 50
  compile: true
  speculate_k: 8

In [None]:
# create model artifact folder
!torch-model-archiver --model-name gpt-fast --version 1.0 --handler handler.py --config-file model-config.yaml --archive-format no-archive

## Step 4: Upload model artifacts to S3

In [None]:
!aws s3 cp gpt-fast {output_path}/gpt-fast --recursive

In [None]:
s3_uri = f"{output_path}/gpt-fast/"
print(s3_uri)

## Step 5: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [20]:
from datetime import datetime

instance_type = "ml.inf2.24xlarge"
endpoint_name = sagemaker.utils.name_from_base("ts-gpt-fast")
hf_token = "YOUR_TOKEN"

model = Model(
    name="torchserve-gpt-fast" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    # Enable SageMaker uncompressed model artifacts
    model_data={
        "S3DataSource": {
                "S3Uri": s3_uri,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
        }
    },
    image_uri=container,
    role=role,
    sagemaker_session=sess,
    env={"HUGGING_FACE_HUB_TOKEN": hf_token,},
)
print(model)

<sagemaker.model.Model object at 0x7fda64bf7040>


In [21]:
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    volume_size=512, # increase the size to store large model
    model_data_download_timeout=3600, # increase the timeout to download large model
    container_startup_health_check_timeout=600, # increase the timeout to load large model
)

Your model is not compiled. Please compile your model before using Inferentia.


---------------!

## Run the inference with streaming response

### SageMaker streaming response

In [22]:
import io

class Parser:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        data = self.buff.getvalue()
        
    def scan_lines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [23]:
import json

prompt = "The capital of France".encode('utf-8')
body = {
    "prompt": prompt,
    "max_new_tokens": 50,
}
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=body, ContentType="application/json")
event_stream = resp['Body']
parser = Parser()
for event in event_stream:
    parser.write(event['PayloadPart']['Bytes'])
    for line in parser.scan_lines():
        print(line.decode("utf-8"), end=' ')

Today the weather is really nice and I am planning on going to the beach. I am going to take my camera and take some pictures of the beach. I am going to take pictures of the sand, the water, and the people. I am also going to take pictures of the sunset. I am really excited to go to the beach and take pictures. The beach is a great place to take pictures. The sand, the water, and the people are all great subjects for pictures. The sunset is also a great subject for pictures 

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()