# Python Backend 入门

ONNX 只接受张量作为输入。如果希望给 Triton 提供图像或者文本，那就需要对输入做预处理，处理成张量后再传给 ONNX.

在 Triton 中做图像、文本预处理，需要用到 [Python Backend](https://github.com/triton-inference-server/python_backend).

In [1]:
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [2]:
import numpy as np
import requests

import utils

## 1. 准备模型存储库

存储库结构如下：

```
model_repository/
  └── simple_python_backend/
      ├── 1/
      │   └── model.py
      └── config.pbtxt
```

## 2. 编写模型文件

创建一个超简单数据处理逻辑，模拟模型的行为：

- 模型输入分别是 `(2, 3)` 和 `(3, 3)` 维的矩阵
- 模型输出是一个 `(3, 3)` 维的矩阵

In [3]:
def matrix_dot(matrix_a, matrix_b):
    """计算两个矩阵的点乘"""
    return np.dot(matrix_a, matrix_b)

a = [[1, 2, 3], [3, 4, 5]]
b = [[5, 6, 7], [5, 6, 7], [5, 6, 7]]
matrix_dot(a, b)

array([[30, 36, 42],
       [60, 72, 84]])

In [4]:
np.array(a).shape, np.array(b).shape, matrix_dot(a, b).shape

((2, 3), (3, 3), (2, 3))

**模型文件** simple_python_backend/1/model.py

```python
import json
import numpy as np
import triton_python_backend_utils as pb_utils

class TritonPythonModel:
    def initialize(self, args):

        model_config = json.loads(args["model_config"])
        output0_config = pb_utils.get_output_config_by_name(model_config, "dot_output")
        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])

    def execute(self, requests):
        output0_dtype = self.output0_dtype

        responses = []
        for request in requests:
            in_0 = pb_utils.get_input_tensor_by_name(request, "matrix_a_input")
            in_1 = pb_utils.get_input_tensor_by_name(request, "matrix_b_input")

            matrix_out = np.dot(in_0.as_numpy(), in_1.as_numpy())

            out_tensor_0 = pb_utils.Tensor("dot_output", matrix_out.astype(output0_dtype))
            inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0])
            responses.append(inference_response)
        return responses

    def finalize(self):
        print("Cleaning up...")
```


## 3. 编写配置文件

创建一个配置文件 `config.pbtxt`，用于描述模型的输入、输出、运行设备、动态批量等配置信息。

**模型配置** simple_python_backend/config.pbtxt

```
name: "simple_python_backend"
backend: "python"
max_batch_size: 256
input [
{
    name: "matrix_a_input"
    data_type: TYPE_FP32
    dims: [ 2, 3 ]
},
{
    name: "matrix_b_input"
    data_type: TYPE_FP32
    dims: [ 3, 3 ]
}
]

output [
{
    name: "dot_output"
    data_type: TYPE_FP32
    dims: [ 2, 3 ]
}
]

instance_group [
  {
    count: 2
    kind: KIND_CPU
  },
  {
    count: 1
    kind: KIND_GPU
    gpus: [ 0 ]
  }
]

dynamic_batching {
    max_queue_delay_microseconds: 100
}
```

## 4. 启动 Triton

用以下命令启动 triton。可以开启模型重载功能（`--model-control-mode=poll`），以便调试接口。

```bash
tritonserver --model-repository=/models --model-control-mode=poll --repository-poll-secs=20
```

如果成功启动，可以看见以下日志：

```
+-----------------------+---------+--------+
| Model                 | Version | Status |
+-----------------------+---------+--------+
| simple_python_backend | 1       | READY  |
+-----------------------+---------+--------+
```

## 5. 客户端调用

In [5]:
TRITON_URL = "http://localhost:8000"

PYTHON_MODEL = "simple_python_backend"

In [6]:
utils.check_triton_health()

Triton server is ready.


In [7]:
utils.check_model_health(model_name=PYTHON_MODEL)

Model 'simple_python_backend' is ready.


In [8]:
def get_dot_result(
        matrix_a,
        matrix_b,
        model_name,
        triton_url=TRITON_URL,
        model_version='1'):

    url = f"{triton_url}/v2/models/{model_name}/versions/{model_version}/infer"

    # 将输入数据转换为 JSON 格式
    input_data_json = {
        "inputs": [
            {
                "name": 'matrix_a_input',
                "shape": list(matrix_a.shape),
                "datatype": "FP32",
                "data": matrix_a.flatten().tolist()
            },
            {
                "name": 'matrix_b_input',
                "shape": list(matrix_b.shape),
                "datatype": "FP32",
                "data": matrix_b.flatten().tolist()
            }
        ]
    }

    # 发送 POST 请求
    response = requests.post(url, json=input_data_json)

    # 检查响应状态码
    if response.status_code != 200:
        raise Exception(f"Inference request failed with status code {response.status_code}: {response.text}")

    # 解析响应
    result = response.json()

    return result

In [9]:
# 示例输入数据
matrix_a = np.random.uniform(0, 99, (2, 2, 3)).astype(np.float32)
matrix_b = np.random.uniform(0, 99, (2, 3, 3)).astype(np.float32)

matrix_a, matrix_b

(array([[[76.198975 , 46.66827  ,  3.8670235],
         [37.84335  , 84.079155 , 18.604324 ]],
 
        [[36.68547  , 12.755604 , 81.027504 ],
         [19.381435 , 34.726177 , 90.25347  ]]], dtype=float32),
 array([[[54.96951 , 62.9187  , 33.921318],
         [50.967373, 66.616615, 53.998264],
         [46.455822, 23.68376 , 71.7238  ]],
 
        [[28.173761, 76.02462 , 97.00841 ],
         [76.86671 , 18.352785, 59.100372],
         [ 6.008605, 48.57268 , 78.99156 ]]], dtype=float32))

In [10]:
# 发送推理请求
output = get_dot_result(matrix_a,
                        matrix_b,
                        model_name=PYTHON_MODEL)
output

{'model_name': 'simple_python_backend',
 'model_version': '1',
 'outputs': [{'name': 'dot_output',
   'datatype': 'FP32',
   'shape': [2, 2, 2, 3],
   'data': [6746.8251953125,
    7994.80810546875,
    5382.1328125,
    5757.283203125,
    6837.322265625,
    10455.515625,
    7229.802734375,
    8422.7431640625,
    7158.197265625,
    7640.86328125,
    5323.775390625,
    10109.8173828125,
    6430.9013671875,
    5076.97314453125,
    7744.80078125,
    2500.911376953125,
    6958.8232421875,
    10713.1494140625,
    7028.08935546875,
    5670.33642578125,
    9005.9189453125,
    3757.63232421875,
    6494.6416015625,
    11061.7548828125]}]}

In [11]:
np.array(output['outputs'][0]['data']).reshape(-1, 2, 2, 3)

array([[[[ 6746.82519531,  7994.80810547,  5382.1328125 ],
         [ 5757.28320312,  6837.32226562, 10455.515625  ]],

        [[ 7229.80273438,  8422.74316406,  7158.19726562],
         [ 7640.86328125,  5323.77539062, 10109.81738281]]],


       [[[ 6430.90136719,  5076.97314453,  7744.80078125],
         [ 2500.91137695,  6958.82324219, 10713.14941406]],

        [[ 7028.08935547,  5670.33642578,  9005.91894531],
         [ 3757.63232422,  6494.64160156, 11061.75488281]]]])

In [12]:
np.dot(matrix_a, matrix_b)

array([[[[ 6746.825 ,  7994.808 ,  5382.133 ],
         [ 5757.283 ,  6837.3223, 10455.516 ]],

        [[ 7229.8027,  8422.743 ,  7158.1978],
         [ 7640.8633,  5323.7754, 10109.817 ]]],


       [[[ 6430.9014,  5076.973 ,  7744.8003],
         [ 2500.9114,  6958.8228, 10713.149 ]],

        [[ 7028.0894,  5670.3364,  9005.919 ],
         [ 3757.632 ,  6494.6416, 11061.755 ]]]], dtype=float32)