### Building Custom Triton Server Image

In [None]:
!docker build -f Dockerfile -t triton-server .

### Download model from Huggingfaces

In [None]:
!pip install -r requirements.txt
  

In [None]:
from optimum.intel.openvino import OVModelForCausalLM

print('Downloading and converting...')
ov_config = {'PERFORMANCE_HINT': 'LATENCY', 'NUM_STREAMS': '1'}
ov_model = OVModelForCausalLM.from_pretrained(
    'togethercomputer/RedPajama-INCITE-Chat-3B-v1',
    export=True,
    device='CPU',
    compile=False,
    ov_config=ov_config)

print('Saving to ./models ...')
ov_model.save_pretrained('./models/')
print('Done.')

In [None]:
!docker run -d \
    --shm-size=1g \
    --ulimit memlock=-1 \
    --ulimit stack=67108864 \
    -p 12337:12337 -p 12338:12338 \
    -v $(pwd)/models:/model:rw \
    -v $(pwd)/endpoints:/repository:rw \
    -e http_proxy=$http_proxy \
    -e https_proxy=$https_proxy \
    -e no_proxy=$no_proxy \
    triton-server \
    tritonserver \
        --model-repository /repository \
        --grpc-port 12337 \
        --http-port 12338

### Wait for endpoint to become ready
Model loading might take a while

In [None]:
!curl -i http://localhost:12338/v2/models/llm-streaming/ready

In [None]:
!pip install tritonclient\[grpc\] numpy gradio

In [None]:
import gradio as gr
import queue

import tritonclient.grpc as grpcclient

from functools import partial
from tritonclient.utils import *


TRITON_SERVER_URL = "localhost:12337"


###### Everything below is red-pajama specific #######
start_message = ""
history_template = "\n<human>:{user}\n<bot>:{assistant}"
current_message_template = "\n<human>:{user}\n<bot>:{assistant}"

def convert_history_to_text(history):
    text = start_message + "".join(["".join([
        history_template.format(user=item[0], assistant=item[1])
        ]) for item in history[:-1]
    ])

    text += "".join(["".join([
        current_message_template.format(user=history[-1][0], assistant=history[-1][1])
    ])])

    return text


def red_pajama_partial_text_processor(partial_text, new_text):
    if new_text == '<':
        return partial_text
    
    partial_text += new_text
    return partial_text.split('<bot>:')[-1]
################################################


def string_to_triton_inputs(payload):
    # TODO: Reuse inputs?
    inputs = []
    inputs.append(grpcclient.InferInput('text', [1], "BYTES"))
    inputs[0].set_data_from_numpy(np.array([payload], dtype=np.object_))
    return inputs


def triton_output_to_string(output):
    # b'abcd'
    return output[0].decode('utf-8', 'ignore')


def on_message_recv(message, history):
    print('Requesting question:', message)
    new_question_with_context = convert_history_to_text(history + [[message, ""]])

    with grpcclient.InferenceServerClient(url=TRITON_SERVER_URL, verbose=False) as triton_client:
        response_queue = queue.Queue()

        def callback(response_queue, result, error):
            response_queue.put(error if error else result)

        triton_client.start_stream(callback=partial(callback, response_queue))
        triton_client.async_stream_infer(
            model_name='llm-streaming',
            inputs=string_to_triton_inputs(new_question_with_context),
            request_id="0",  # TODO: Allow multiple chat users
            enable_empty_final_response=True)

        print("Waiting for result...")
        result = ""
        while True:
            response = response_queue.get()
            if type(response) == InferenceServerException:
                raise response
            if response.as_numpy("partial_response") is None:
                break
            partial_response = triton_output_to_string(response.as_numpy("partial_response"))
            print(partial_response, flush=True, end='')
            result = red_pajama_partial_text_processor(result, partial_response)
            yield result


gr.ChatInterface(on_message_recv).queue().launch(server_name="0.0.0.0", server_port=7779)


### Open the chatbot in browser
[http://localhost:7779](http://localhost:7779)