# Deploy a Trained PyTorch Model from Checkpoint file


This notebook is an example implementation of the deployment of our MalConv model in AWS SageMaker. It is designed to be implemented in the SageMaker Notebooks' Jupyter environment, but alternative utilizations via AWSCLI and Boto3 are also possible (though not covered in this example).

In the very first step, we need to create what SageMaker calls "Model Artifact Object" - a TARGZ (like .zip, but older and more compressed) file that follows a predefined structure. For PyTorch models, the root folder of the TARGZ file should have the model checkpoint file (model.pt) or the model paths file (model.pth) saved during training of the model. This means that when uncompressed, the model file should be extracted to the same folder where the TARGZ file is (i.e., the model file should not be in a folder). I understand that this seems like overemphasis on a simple matter, but trust me: more than half of all questions and support requests for SageMaker are due to incorrect structuring of this .TAR.GZ file.

Below is a sample Python code for creating this GZ file, assuming that your model file was uploaded to the SageMaker Jupyter Notebook's home directory:

In [1]:
import tarfile
with tarfile.open('model.tar.gz', 'w:gz') as f:
    f.add('model.pt')
f.close()

In [14]:
# with tarfile.open('model.tar.gz', 'w:gz') as tar:  
#     tar.add('model', arcname=os.path.basename('model'))

In [15]:
# import tarfile
# import os

# def create_tar_gz_of_directory(directory_path, output_archive):
#     with tarfile.open(output_archive, "w:gz") as tar:
#         # Walk through the directory
#         for root, dirs, files in os.walk(directory_path):
#             for file in files:
#                 # Create the path to your file
#                 file_path = os.path.join(root, file)
#                 # Calculate the arcname (name within the archive)
#                 arcname = os.path.relpath(file_path, directory_path)
#                 # Add the file to the archive; arcname controls the name inside the archive
#                 tar.add(file_path, arcname=arcname)

# # Example usage
# directory_path = 'model'  # The directory to tar.gz
# output_archive = 'model.tar.gz'  # The output archive path
# create_tar_gz_of_directory(directory_path, output_archive)

# print(f"Archive created at: {output_archive}")

In [2]:
# setups

import os
import json

import boto3
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session

sess = Session()

role = get_execution_role()

%store -r pt_malconv_model_data

try:
    pt_malconv_model_data
except NameError:
    import json

    # copy a pretrained model from a public public to your default bucket
    s3 = boto3.client("s3")
    
    # upload to default bucket
    pt_malconv_model_data = sess.upload_data(
        path="model.tar.gz", bucket=sess.default_bucket(), key_prefix="model/pytorch"
    )

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
no stored variable or alias pt_malconv_model_data


In [3]:
print(pt_malconv_model_data)

s3://sagemaker-us-east-1-992382615242/model/pytorch/model.tar.gz


## PyTorch Model Object

The `PyTorchModel` class allows you to define an environment for making inference using your
model artifact. Like the `PyTorch` class discussed 
[in this notebook for training an PyTorch model](get_started_mnist_train.ipynb), it is a high level API used to set up a docker image for your model hosting service.

Once it is properly configured, it can be used to create a SageMaker
endpoint on an EC2 instance. The SageMaker endpoint is a containerized environment that uses your trained model 
to make inference on incoming data via RESTful API calls. 

Some common parameters used to initiate the `PyTorchModel` class are:
- `entry_point`: A user defined python file to be used by the inference image as handlers of incoming requests
- `source_dir`: The directory of the `entry_point`
- `role`: An IAM role to make AWS service requests
- `model_data`: the S3 location of the compressed model artifact. It can be a path to a local file if the endpoint 
is to be deployed on the SageMaker instance you are using to run this notebook (local mode)
- `framework_version`: version of the PyTorch package to be used
- `py_version`: python version to be used

We elaborate on the `entry_point` below.



In [4]:
model = PyTorchModel(
    entry_point="inference.py",
    source_dir="code",
    role=role,
    model_data=pt_malconv_model_data,
    framework_version="1.5.0",
    py_version="py3",
)

### Entry Point for the Inference Image

Your model artifacts pointed by `model_data` is pulled by the `PyTorchModel` and it is decompressed and saved in
in the docker image it defines. They become regular model checkpoint files that you would produce outside SageMaker. This means in order to use your trained model for serving, 
you need to tell `PyTorchModel` class how to a recover a PyTorch model from the static checkpoint.

Also, the deployed endpoint interacts with RESTful API calls, you need to tell it how to parse an incoming 
request to your model. 

These two instructions needs to be defined as two functions in the python file pointed by `entry_point`.

By convention, we name this entry point file `inference.py` and we put it in the `code` directory.

To tell the inference image how to load the model checkpoint, you need to implement a function called 
`model_fn`. This function takes one positional argument 

- `model_dir`: the directory of the static model checkpoints in the inference image.

The return of `model_fn` is a PyTorch model. In this example, the `model_fn`
looks like:

```python
def model_fn(model_dir): 
    model = Net()   
    with open(os.path.join(model_dir, "model.pth"), "rb") as f:
        model.load_state_dict(torch.load(f))
    model.to(device).eval()
    return model
```

Next, you need to tell the hosting service how to handle the incoming data. This includes:

* How to parse the incoming request
* How to use the trained model to make inference
* How to return the prediction to the caller of the service


You do it by implementing 3 functions:

#### `input_fn` function

The SageMaker PyTorch model server will invoke the `input_fn` function in your inference entry point. This function handles data decoding. The `input_fn` have the following signature:
```python
def input_fn(request_body, request_content_type)
```
The two positional arguments are:
- `request_body`: the payload of the incoming request
- `request_content_type`: the content type of the incoming request

The return of `input_fn` is an object that can be passed to `predict_fn`

In this example, the `input_fn` looks like:
```python
def input_fn(request_body, request_content_type):
    assert request_content_type=='application/json'
    data = json.loads(request_body)['inputs']
    data = torch.tensor(data, dtype=torch.float32, device=device)
    return data
```
It requires the request payload is encoded as a json string and
it assumes the decoded payload contains a key `inputs`
that maps to the input data to be consumed by the model.



#### `predict_fn` 
After the inference request has been deserialized by `input_fn`, the SageMaker PyTorch model server invokes `predict_fn` on the return value of `input_fn`.

The `predict_fn` function has the following signature:
```python
def predict_fn(input_object, model)
```
The two positional arguments are:
- `input_object`: the return value from `input_fn`
- `model`: the return value from `model_fn`

The return of `predict_fn` is the first argument to be passed to `output_fn`

In this example, the `predict_fn` function looks like

```python
def predict_fn(input_object, model):
    with torch.no_grad():
        prediction = model(input_object)
    return prediction
```

Note that we directly feed the return of `input_fn` to `predict_fn`.
This means you should invoke the SageMaker PyTorch model server with data that
can be readily consumed by the model, i.e. normalized and has batch and channel dimension. 


#### `output_fn` 
After invoking `predict_fn`, the model server invokes `output_fn` for data post-process.
The `output_fn` has the following signature:

```python
def output_fn(prediction, content_type)
```

The two positional arguments are:
- `prediction`: the return value from `predict_fn`
- `content_type`: the content type of the response

The return of `output_fn` should be a byte array of data serialized to `content_type`.

In this example, the `output_fn` function looks like

```python
def output_fn(predictions, content_type):
    assert content_type == 'application/json'
    res = predictions.cpu().numpy().tolist()
    return json.dumps(res)
```

After the inference, the function uses `content_type` to encode the 
prediction into the content type of the response. In this example,
the function requires the caller of the service to accept json string. 

For more info on handler functions, check the [SageMaker Python SDK document](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#process-model-output)

## Execute the inference container
Once the `PyTorchModel` class is initiated, we can call its `deploy` method to run the container for the hosting
service. Some common parameters needed to call `deploy` methods are:

- `initial_instance_count`: the number of SageMaker instances to be used to run the hosting service.
- `instance_type`: the type of SageMaker instance to run the hosting service. Set it to `local` if you want to run the hosting service on the local SageMaker instance. Local mode is typically used for debugging. 
- `serializer`: A python callable used to serialize (encode) the request data.
- `deserializer`: A python callable used to deserialize (decode) the response data.

Commonly used serializers and deserializers are implemented in `sagemaker.serializers` and `sagemaker.deserializers`
submodules of the SageMaker Python SDK. 

Since in the `transform_fn` we declared that the incoming requests are json-encoded, we need to use a `json serializer`,
to encode the incoming data into a json string. 
Also, we declared the return content type to be json string, we need to use a `json deserializer` to parse the response into an integer, in this case, representing the predicted hand-written digit. 

<span style="color:red"> Note: local mode is not supported in SageMaker Studio </span>

In [5]:
from sagemaker.serializers import IdentitySerializer
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    #serializer=JSONSerializer(),
    serializer=IdentitySerializer(content_type='application/octet-stream'),  # Updated
    deserializer=JSONDeserializer(),
)

------!

The `predictor` we get above can be used to make prediction requests against a SageMaker endpoint. 
For more information, check [the API reference for SageMaker Predictor](
https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.predictor)

## (Optional) Clean up 

If you do not plan to use the endpoint, you should delete it to free up some computation 
resource. If you use local, you will need to manually delete the docker container bounded
at port 8080 (the port that listens to the incoming request).


In [8]:
# import os

# if not local_mode:
#    predictor.delete_endpoint()
# else:
#    os.system("docker container ls | grep 8080 | awk '{print $1}' | xargs docker container rm -f")