# RVAI TensorRT Inference Server Example

In this example we will reuse the ImageClassifierCell from the `mnist_training.ipynb` tutorial and expand it with TensorRT Inference Server capabilities.

The first part of this tutorial will duplicate quite some steps from the training notebook.

**Important note:** This tutorial requires TensorRT Inference Server to be installed on your system. You can use the tutorial docker environment for this. If you want to be able to use your GPU in this environment, you need at least version 440.xx for your hosts NVIDIA drivers. When this is not the case, you will only be able to execute this tutorial on CPU

## Prerequisites
First, let's install all the prerequisites:

In [None]:
!pip install -qqq rvai==1.1.0rc51 pygraphviz

In [None]:
# some global notebook configuration
%matplotlib inline
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

## Create a TRTISCell
Let us create a Cell now. A Cell represents the smalles building block in RVAI. Since our Cell should have TensorRT Inference Server support, we select the `TRTISCell` base class. The basic skeleton of a `TRTISCell` can be found in the [docs [1]](https://base.rvai.dev/rvai.base.trtis.html#rvai.base.trtis.cell.TRTISCell). The `TRTISCell` extends the `TrainableCell` by providing a `convert_to_trtis_model` and `trtis_predict` methods. These methods convert a model (loaded by `load_model`) to a model that's compatible with TensorRT Inference Server or perform a prediction step using a `TRTISClient` respectively.

- [1] https://base.rvai.dev/rvai.base.trtis.html#rvai.base.trtis.cell.TRTISCell

### Cell IO
We can reuse the IO of the `ImageClassificationCell` from the `mnist_training.ipynb` notebook. For details of the code in this section, we refer to that notebook.

In [None]:
from dataclasses import dataclass
from typing import Optional
from rvai.base.data import Inputs, Outputs, Samples, Annotations, Parameters, Metrics
from rvai.types import Float, Image, Integer

In [None]:
# Inference mode IO

@dataclass
class ImageClassificationInputs(Inputs):
    image: Image = Inputs.field(
        name="Image", description="The image to be classified.")

@dataclass
class ImageClassificationOutputs(Outputs):
    label: Integer = Outputs.field(
        name="Class", description="The class of the image.")

# Training mode IO
        
@dataclass
class ImageClassificationSamples(Samples, ImageClassificationInputs):
    """Inherits from ImageClassificationInputs because the Samples this Cell expects during training are the same as its inputs."""

@dataclass
class ImageClassificationAnnotations(Annotations, ImageClassificationOutputs):
    """Inherits from ImageClassificationOutputs because the Annotations this Cell expects during training are the same as its outputs."""
    
@dataclass
class ImageClassificationMetrics(Metrics):
    acc: Float = Metrics.field(name="Accuracy", short_name="acc", performance=True)
    loss: Float = Metrics.field(name="Loss")
    val_acc: Optional[Float] = Metrics.field(
        name="Validation Accuracy", default=None
    )
    val_loss: Optional[Float] = Metrics.field(
        name="Validation Loss", default=None
    )
    
# Parameters

@dataclass
class ImageClassificationParameters(Parameters):
    epochs: Integer = Parameters.field(default=Integer(2), name="Epochs", description="The amount of times the training loop should process the data.")
    batch_size: Integer = Parameters.field(default=Integer(4), name="Batch Size", description="SGD mini-batch size.")

### Cell Body
Now, let's actually create the Cell! Also this section contains quite some duplication from the `mnist_training.ipynb` notebook. The updated parts will be clearly marked.

In [None]:
# necessary RVAI imports:
from rvai.base.cell import cell # used as a decorator to register a cell in RVAI
# ==================
# BEGIN TRTIS UPDATE
from rvai.base.trtis.cell import TRTISCell # base class, defines main functionality
from rvai.base.trtis.model import TRTISModel # wrapper for TRTIS compatible model
from rvai.base.trtis.client import TRTISClient # client that allows us to use the deployed models
from rvai.base.trtis.utils import keras_to_trtismodel # helper function to convert a keras model
# END TRTIS UPDATE
# ================


# used for typing:
from rvai.base.cell import CellMode # enum, defines what mode the cell is running in
from rvai.base.data import Example, Dataset, DatasetConfig, Metrics
from rvai.base.context import InferenceContext, ModelContext, ParameterContext, TestContext, TrainingContext # required argument for most cell methods
from rvai.base.training import Model, ModelConfig, ModelPath
from typing import Optional, Tuple, Sequence

# used in implementation:
# ==================
# BEGIN TRTIS UPDATE
import cv2  # image manipulation
import tempfile  # needed to create temporary conversion folders
# END TRTIS UPDATE
# ================
import numpy as np
import tensorflow as tf
tf.autograph.set_verbosity(1)
tf.logging.set_verbosity(tf.logging.ERROR)
from rvai.base import compat
from rvai.base.training import TrainingSession
from rvai.base.test import TestSession

In [None]:
@cell
class ImageClassificationCell(TRTISCell):
        
    # ==================
    # BEGIN TRTIS UPDATE
    @classmethod
    def load_model(
        cls,
        context: ModelContext,
        parameters: ImageClassificationParameters,
        model_path: Optional[str],
        dataset_config: Optional[DatasetConfig],
    ) -> Tuple[tf.keras.models.Model, ModelConfig]:
        # When we load the model for TRTIS mode, we try to avoid using a GPU
        # because the model will only be used for conversion
        if context.trtis_mode:
            with tf.device('/cpu:0'):
                model = cls._do_load_model(model_path=model_path)
        else:
            model = cls._do_load_model(model_path=model_path)
        return model, None
    
    @classmethod
    def _do_load_model(cls, model_path: Optional[str]) -> tf.keras.models.Model:
        if model_path is not None:
            return tf.keras.models.load_model(model_path)
        else:
            model = tf.keras.Sequential()
            model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)))
            model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
            model.add(tf.keras.layers.Dropout(0.3))
            model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
            model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
            model.add(tf.keras.layers.Dropout(0.3))
            model.add(tf.keras.layers.Flatten())
            model.add(tf.keras.layers.Dense(256, activation='relu'))
            model.add(tf.keras.layers.Dropout(0.5))
            model.add(tf.keras.layers.Dense(10, activation='softmax'))
            model.compile(loss='categorical_crossentropy',
                          optimizer='adam',
                          metrics=['accuracy'])

        return model
    # END TRTIS UPDATE
    # ================

    @classmethod
    def _unpack_example(
        cls,
        example: Example[ImageClassificationSamples, ImageClassificationAnnotations],
    ) -> Tuple[np.ndarray, int]:

        samples: ImageClassificationSamples = example[0]
        annotations: ImageClassificationAnnotations = example[1]

        # standardize image input
        image = np.atleast_3d(samples.image)
        label = int(annotations.label)

        return image, label

    @classmethod
    def _collate_batch(
        cls,
        examples: Sequence[Tuple[np.ndarray, np.ndarray]],
    ) -> Tuple[np.ndarray, np.ndarray]:

        x, y = zip(*examples)

        images: np.ndarray = np.stack(arrays=x, axis=0)
        labels: np.ndarray = tf.keras.utils.to_categorical(
            y=y, num_classes=10, dtype=np.float32
        )

        return images, labels

    @classmethod
    def train(
        cls,
        context: TrainingContext,
        parameters: ImageClassificationParameters,
        model: tf.keras.models.Model,
        model_config: ModelConfig,
        train_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
        validation_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
        dataset_config: Optional[DatasetConfig]
    ) -> TrainingSession[ImageClassificationMetrics]:
        # Integer -> int
        batch_size = int(parameters.batch_size)

        train_generator = compat.keras.as_generator(
            train_dataset,
            batch_size=batch_size,
            process_example=cls._unpack_example,
            process_batch=cls._collate_batch,
        )

        validation_generator = compat.keras.as_generator(
            validation_dataset,
            batch_size=batch_size,
            process_example=cls._unpack_example,
            process_batch=cls._collate_batch,
        )

        nb_epochs = int(parameters.epochs)
        nb_training_batches = int(len(train_dataset) // batch_size)
        nb_validation_batches = int(len(validation_dataset) // batch_size)

        model.fit_generator(
            generator=train_generator,
            steps_per_epoch=nb_training_batches,
            validation_data=validation_generator,
            validation_steps=nb_validation_batches,
            epochs=nb_epochs,
            verbose=0,
            callbacks=[compat.keras.training_update_callback(
                context=context,
                metrics=ImageClassificationMetrics,
            )],
        )

        model_path = context.get_model_path()

        tf.keras.models.save_model(model=model, filepath=model_path)

        return model_path

    @classmethod
    def test(
        cls,
        context: TestContext,
        parameters: ImageClassificationParameters,
        model: tf.keras.models.Model,
        model_config: ModelConfig,
        test_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
        dataset_config: Optional[DatasetConfig]
    ) -> TestSession[ImageClassificationMetrics]:
        raise NotImplementedError

    # ==================
    # BEGIN TRTIS UPDATE 
    @classmethod
    def predict(
        cls,
        context: InferenceContext,
        parameters: ImageClassificationParameters,
        model: tf.keras.models.Model,
        model_config: ModelConfig,
        inputs: ImageClassificationInputs,
    ) -> ImageClassificationOutputs:
        image = inputs.image
        # Get the required input size
        input_shape = model.inputs[0].get_shape().as_list()
        h, w = input_shape[1], input_shape[2]
        # Resize the input image
        inp = cv2.resize(image, (w, h))
        # Make sure that the input still has 3 dimensions
        inp = np.atleast_3d(inp)
        # Perform prediction
        predictions = model.predict(np.array([inp]))
        print(f'Prediction done')
        # Get the label
        label = predictions[0].argmax()
        # Return output
        return ImageClassificationOutputs(label=Integer(label))
        
    @classmethod
    def convert_to_trtis_model(
        cls,
        context: ModelContext,
        parameters: ImageClassificationParameters,
        model: tf.keras.models.Model,
        model_config: ModelConfig,
    ) -> TRTISModel:
        # Create a temp directory where to put the converted model
        outfolder = tempfile.mkdtemp()
        return keras_to_trtismodel(
            model=model,  # keras model to convert
            model_path=outfolder,  # output folder where to converted model data can be stored
            max_batch_size=16  # maximum allowed batch size for the model
        )

    @classmethod
    def trtis_predict(
        cls,
        context: InferenceContext,
        parameters: ImageClassificationParameters,
        trtis_client: TRTISClient,
        model_config: ModelConfig,
        inputs: ImageClassificationInputs,
    ) -> ImageClassificationOutputs:
        image = inputs.image
        # Get the input and output layer of the model
        input_layer = trtis_client.get_model_spec().input_layers[0]
        output_layer = trtis_client.get_model_spec().output_layers[0]
        # Get the input shape
        h, w = input_layer.dims[0], input_layer.dims[1]  # dims don't have batch dimension
        # Resize the input and convert to required data format
        inp = cv2.resize(image, (w, h)).astype(input_layer.data_type.to_np())
        # Make sure that the input still has 3 dimensions
        inp = np.atleast_3d(inp)
        # Perform inference. This takes a mapping from layer name to input data
        result = trtis_client.infer({input_layer.name: inp})
        # Get the output data
        predictions = result.get(output_layer.name)
        # Get the label
        label = predictions.argmax()
        # Return output
        return ImageClassificationOutputs(label=Integer(label))
    # END TRTIS UPDATE
    # ================

Let's discuss.

#### `convert_to_trtis_model`
TensorRT Inference Server supports several data model types, including Tensorflow savedmodel, Tensorflow frozen graphs, onnx, pytorch, ... Next to the actual model data, the inference server also needs a configuration file specifying some details of the model. The `convert_to_trtis_model` function is used to convert your normal model, loaded via `load_model`, to a model that can be used by the inference server.

In the `rvai.base.trtis.utils`, we provide some convenience methods to facilitate this conversion for you, for example for Keras models or Tensorflow frozen graphs. In this tutorial we use the `keras_to_trtismodel` convenience method. This method is given a keras model, a model path where the converted model can be stored and a maximum batch size, and returns the converted model.


#### `trtis_predict`
The `trtis_predict` method has the same functionality as the `predict` method, but uses a `TRTISClient`, connected to your converted model on a TensorRT Inference Server to do the predictions, rather than your plain model.

We use the `infer` method of the `TRTISClient` to perform the inference. This method takes a mapping between input layer names (can be fetched from the `TRTISClient` as seen in this tutorial) and it's input data as an argument. The `infer` call returns a mapping between output layer names and the resulting output data. Similarly, the `infer_batch` method can be used to perform predictions on batches of inputs.

## Creating a Pipeline
Also this looks exactly like in the `mnist_training.ipynb` notebook.

In [None]:
from rvai.base.pipeline import DeclarativePipeline, PipelineCells, DeclarativeTrainingPipeline, pipeline

### Training Pipeline

In [None]:
class TrainingCells(PipelineCells):
    classifier: ImageClassificationCell

@pipeline
class ImageClassificationTrainingPipeline(DeclarativeTrainingPipeline):
    cells = TrainingCells
    train = cells.classifier
    samples = [cells.classifier.samples.image]
    annotations = [cells.classifier.annotations.label]

### Inference Pipeline

In [None]:
class InferenceCells(PipelineCells):
    classifier: ImageClassificationCell

@pipeline
class ImageClassificationPipeline(DeclarativePipeline):
    cells = InferenceCells
    inputs = {"image": cells.classifier.inputs.image}
    outputs = {"label": cells.classifier.outputs.label}
    training_pipelines = {
        cells.classifier: ImageClassificationTrainingPipeline
    }

In [None]:
inference_pipeline = ImageClassificationPipeline.build()
training_pipeline = inference_pipeline.get_training_pipeline("classifier")

## Training
The training part is exactly the same as in the `mnist_training.ipynb` notebook.

In [None]:
# required RVAI base class
from rvai.base.data import Dataset

# used for typing
from rvai.types import Image, Integer
from typing import Sequence, Tuple
import numpy as np

# actual data
from tensorflow.keras.datasets import fashion_mnist

# some imports for displaying data
from IPython.display import display, HTML
import PIL

In [None]:
class FashionMNISTDataset(
    Dataset[ImageClassificationSamples, ImageClassificationAnnotations]
):
    def __init__(
        self, images: Sequence[np.ndarray], labels: Sequence[np.ndarray]
    ):
        self.images = images
        self.labels = labels

    def __getitem__(
        self, index
    ) -> Tuple[ImageClassificationSamples, ImageClassificationAnnotations]:
        return (
            ImageClassificationSamples(image=Image(self.images[index])),
            ImageClassificationAnnotations(label=Integer(self.labels[index])),
        )

    def __len__(self):
        return len(self.images)
    
# Class names for FashionMNIST
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

train_dataset, validation_dataset = (FashionMNISTDataset(images=images, labels=labels) for images, labels in fashion_mnist.load_data())

# display an example image and its label
samples, annotations = train_dataset[0]
display(PIL.Image.fromarray(samples.image)); print(class_names[annotations.label])

In [None]:
from rvai.base.runtime import init, Training, Inference
from rvai.base.training import Tensorboard

In [None]:
# create a runtime, we choose the debug runtime
runtime = init("debug")

# generate a training pipeline
training_pipeline = inference_pipeline.get_training_pipeline("classifier")

# configure a training task
training = Training(
    pipeline=training_pipeline,
    models={}, # no previous models yet
    parameters={"classifier": ImageClassificationParameters()}, # defaults are fine for us 
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
)


training_loop = runtime.start_training(training)

print('Starting training')
for update in training_loop.updates():
    print(f"\r[{update.progress * 100:.3}%] - accuracy: {update.metrics.acc}", end='')    
trained_model_path = training_loop.result()
print('\nTraining done. Model can be found at:', trained_model_path)
# Stop the training process
training_loop.stop()

## Inference
Now that we have a trained model, we can start doing inference tasks.

### Normal Inference (no TensorRT Inference Server)

In [None]:
# create inference task
inference = Inference(
    pipeline=inference_pipeline, 
    models={"classifier": trained_model_path},  # use the trained model
    parameters={"classifier": ImageClassificationParameters()},  # defaults are fine for us 
)
proc = runtime.start_inference(inference)

In [None]:
# Get a random sample from the validation dataset. Run this cell multiple times to get different samples
idx = np.random.randint(len(validation_dataset))
sample, annotation = validation_dataset[idx]
# Perform inference and get the result
pred = proc.predict({"image": sample.image})
result = pred.result()

# Display the image and the label vs prediction
display(PIL.Image.fromarray(sample.image))
print(f'Label: {class_names[annotation.label]}, prediction: {class_names[result.get("label")]}')

In [None]:
# Stop the inference process
proc.stop()

### TensorRT Inference Server Mode
This works exactly the same as normal inference tasks, with the exception that the `Inference` task is started with the `trtis_enabled` flag.

In [None]:
from rvai.base.trtis.resources import TRTISResources

# create inference task, similar as before, but enable TRTIS by providing an TRTISResources object with required cpus and gpus
inference = Inference(
    pipeline=inference_pipeline, 
    models={"classifier": trained_model_path},  # use the trained model
    parameters={"classifier": ImageClassificationParameters()},  # defaults are fine for us 
    trtis_resources=TRTISResources(gpus=1.0),  # enable TensorRT Inference Server mode
)
proc = runtime.start_inference(inference)

In [None]:
# Get a random sample from the validation dataset. Run this cell multiple times to get different samples
idx = np.random.randint(len(validation_dataset))
sample, annotation = validation_dataset[idx]
# Perform inference and get the result
pred = proc.predict({"image": sample.image})
result = pred.result()

# Display the image and the label vs prediction
display(PIL.Image.fromarray(sample.image))
print(f'Label: {class_names[annotation.label]}, prediction: {class_names[result.get("label")]}')

In [None]:
# Stop the inference process and runtime
proc.stop()
runtime.stop()