In [79]:
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from cn_clip import clip
from cn_clip.clip.utils import image_transform, tokenize

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [92]:
def text_tokenize_int32(batch_text):
    return tokenize(batch_text).int().numpy()

def image_batch_transform(batch_images):
    batch_transformed_images = []
    for inp in batch_images:
        # 默认preprocess 返回 tensor类型，需要转为numpy
        batch_transformed_images.append(image_transform()(inp).numpy())
    return np.array(batch_transformed_images)


# 预处理方法集合
input_preprocess = {
    "image": image_batch_transform,
    "text": text_tokenize_int32
}

# 模型推理类

In [93]:
class TritonHttpClient:
    def __init__(self, server, preprocess={}):
        """
        server: triton service ip:port
        preprocess: preprocess function, key value pair, key:input_name, value: preprocess function
        """
        triton_client = httpclient.InferenceServerClient(url=server)
        self.input_preprocess = preprocess

    def inf(self, model_name, 
            input_name, input_batch, input_data_type,
            output_name, is_binary_data=False):
        input_features = input_batch
        # 输入
        if self.input_preprocess.get(input_name) is not None:
            input_features = self.input_preprocess.get(input_name)(input_batch)
        input = httpclient.InferInput(input_name, input_features.shape, datatype=input_data_type)
        input.set_data_from_numpy(input_features, binary_data=is_binary_data)
        
        #输出
        output = httpclient.InferRequestedOutput(output_name, binary_data=is_binary_data)

        # 请求
        result = triton_client.infer(model_name=model_name, inputs=[input], outputs=[output])
        return result.as_numpy(output_name)

# 处理图像信息

In [127]:

image_batch = [Image.open('pokemon.jpeg')]

client = TritonHttpClient("10.208.62.27:8000", preprocess=input_preprocess)
image_features = client.inf("clip_vitb16_img", "image", image_batch, "FP32", "unnorm_image_features", is_binary_data=True)
image_features.shape

(1, 512)

# 处理文本信息

In [128]:
#input_batch=["皮卡丘"]
input_batch=["杰尼龟"]

client = TritonHttpClient("10.208.62.27:8000", preprocess=input_preprocess)
text_features=client.inf("clip_vitb16_txt", "text", input_batch, "INT32", "unnorm_text_features")
text_features.shape

(1, 512)

# 相关性计算

In [129]:
image_features = torch.tensor(image_features)
text_features = torch.tensor(text_features)

In [130]:
from torch import nn

image_features /= image_features.norm(dim=-1, keepdim=True) 
text_features /= text_features.norm(dim=-1, keepdim=True)    

def get_similarity(image_features, text_features):
    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()
    return logits_per_image, logits_per_text

logits_per_image, logits_per_text = get_similarity(image_features, text_features)
probs = logits_per_image.softmax(dim=-1).detach().numpy()

In [131]:
probs

array([[1.]], dtype=float32)

In [133]:
text_features

tensor([[ 1.1745e-01, -1.6243e-02,  3.9839e-03, -2.1405e-02,  3.3110e-03,
          3.6907e-02,  5.6607e-02,  5.2882e-03,  5.6576e-02,  5.9038e-03,
         -4.9567e-02, -7.6655e-02, -2.1674e-02, -1.4318e-02, -2.5652e-02,
          2.7704e-02, -1.0576e-02, -6.5353e-02,  9.4903e-02,  4.6283e-02,
         -4.4279e-03,  5.7618e-02, -3.5802e-02,  1.0868e-02,  1.1973e-02,
         -3.0766e-02,  2.9638e-03,  9.8581e-03, -9.6056e-03,  2.2479e-02,
          2.3540e-03,  5.3474e-03,  8.2954e-03,  8.1328e-02,  2.7357e-02,
          3.1066e-02, -3.8379e-03, -5.0451e-02,  4.3411e-03, -3.0829e-02,
          2.1137e-02, -2.0348e-02, -2.4373e-02,  1.8864e-02,  5.8525e-03,
         -2.7743e-03,  5.5976e-02,  3.8422e-02,  2.5099e-02, -7.8021e-03,
          2.1468e-02,  4.7073e-02, -5.0325e-02, -1.2802e-02, -4.9093e-02,
         -3.7349e-02, -8.4485e-02,  1.6370e-02,  1.7854e-02, -2.4014e-03,
          1.4775e-02, -3.0387e-03, -9.4019e-02, -3.1698e-02, -2.3817e-03,
         -6.1311e-02, -4.2684e-02,  3.