# Running LLaMA model with SageMaker Inference endpoint
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

Please review [LLaMA License](https://github.com/facebookresearch/llama/blob/main/LICENSE) before running this example.

## Step 1: Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker boto3 awscli --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

Here, we would assume your LLaMA model is stored in a folder of a S3 bucket.

**Note: We do expect the LLaMA model came from its original form. Please make sure you have the original checkpoints.**

```
aws s3 ls s3://bucket/some_llama_model
- tokenizer.model
- consolidated.00.pth
- consolidated.01.pth
- consolidated.02.pth
- consolidated.03.pth
```

The above one showed a 4-checkpoint LLaMA model that we expect to see

In [None]:
model_artifact = "s3://bucket/some_llama_model/"
print(f"You can set option.s3url={model_artifact}")

## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties: Defines the model server settings
- model.py: A python file to define the core inference logic
- requirements.txt: Any additional pip wheel need to install

**Remember to change your S3 path to the one you have the model artifacts**

**If your LLAMA model has 8 checkpoints, please change the tensor_parallel_degree below to 8**

In [None]:
%%writefile serving.properties
engine=DeepSpeed
option.tensor_parallel_degree=4
option.s3url=s3://bucket/some_llama_model/

In [None]:
%%writefile model.py
from typing import Tuple
import os
import sys
import torch
import time
import json

from pathlib import Path

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama import ModelArgs, Transformer, Tokenizer, LLaMA
from djl_python import Input, Output

def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(
    ckpt_dir: str,
    tokenizer_path: str,
    local_rank: int,
    world_size: int,
    max_seq_len: int,
    max_batch_size: int,
) -> LLaMA:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    assert world_size == len(
        checkpoints
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
    )
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator


def load_model(properties):
    if "model_id" in properties:
        ckpt_dir = properties["model_id"]
        tokenizer_path = os.path.join(properties["model_id"], "tokenizer.model")
    else:
        ckpt_dir = properties["ckpt_dir"]
        tokenizer_path = properties["tokenizer_path"]
    max_seq_len= 512
    max_batch_size = 32
    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, "w")

    return load(
        ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
    )

def infer(generator, inputs):
    temperature = 0.8
    top_p= 0.95
    prompts = inputs["prompts"]
    results = generator.generate(
        prompts, max_gen_len=256, temperature=temperature, top_p=top_p
    )
    return results

generator = None

def handle(inputs: Input) -> None:
    global generator
    if not generator:
        generator = load_model(inputs.get_properties())

    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None

    data = inputs.get_as_json()
    result = infer(generator, data)
    return Output().add(result)

In [None]:
%%sh
git clone https://github.com/facebookresearch/llama
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
mv llama/requirements.txt mymodel/
mv llama/llama mymodel/llama
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel llama

## Step 3: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI


In [None]:
image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.21.0"
    )

### Upload artifact on S3 and create SageMaker model

In [None]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

### 4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

For different LLaMA model, here is the recommendation:

|  Model | MP | Instance (min) | Instance (best)    |
|--------|----|------------|------------------------|
| 7B     | 1  | g5.2xlarge | p4d.24xlarge | 
| 13B    | 2  | g5.12xlarge| p4d.24xlarge | 
| 33B    | 4  | g5.12xlarge| p4d.24xlarge | 
| 65B    | 8  | g5.48xlarge| p4d.24xlarge | 

G5 is minimum requirement to host a LLaMA model. If you would like to run in long sequence (e.g 2048/4096). Please consider using a P4D instance.




In [None]:
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

## Step 5: Test and benchmark the inference

In [None]:
%%timeit -n3 -r1
predictor.predict(
    {"prompts": "Large model inference is"}
)

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()