In [None]:
chinese clip 地址：https://github.com/OFA-Sys/Chinese-CLIP

In [1]:
!python -m torch.utils.collect_env

Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 14 2022, 12:59:47)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-3.10.0-957.el7.x86_64-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudn

In [2]:
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from cn_clip import clip
from cn_clip.clip.utils import image_transform, tokenize

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
def text_tokenize_int32(batch_text):
    return tokenize(batch_text).int().numpy()

def image_batch_transform(batch_images):
    batch_transformed_images = []
    for inp in batch_images:
        # 默认preprocess 返回 tensor类型，需要转为numpy
        batch_transformed_images.append(image_transform()(inp).numpy())
    return np.array(batch_transformed_images)


# 预处理方法集合
input_preprocess = {
    "image": image_batch_transform,
    "text": text_tokenize_int32
}

# 模型推理类

In [5]:
class TritonHttpClient:
    def __init__(self, server, preprocess={}):
        """
        server: triton service ip:port
        preprocess: preprocess function, key value pair, key:input_name, value: preprocess function
        """
        self.triton_client = httpclient.InferenceServerClient(url=server)
        self.input_preprocess = preprocess

    def inf(self, model_name, 
            input_name, input_batch, input_data_type,
            output_name, is_binary_data=False):
        input_features = input_batch
        # 输入
        if self.input_preprocess.get(input_name) is not None:
            input_features = self.input_preprocess.get(input_name)(input_batch)
        input = httpclient.InferInput(input_name, input_features.shape, datatype=input_data_type)
        input.set_data_from_numpy(input_features, binary_data=is_binary_data)
        
        #输出
        output = httpclient.InferRequestedOutput(output_name, binary_data=is_binary_data)

        # 请求
        result = self.triton_client.infer(model_name=model_name, inputs=[input], outputs=[output])
        return result.as_numpy(output_name)

In [6]:
client = TritonHttpClient("10.208.62.27:8000", preprocess=input_preprocess)

# 处理图像信息

In [67]:
image_batch = [Image.open('pokemon.jpeg')]

image_features = client.inf("clip_vitb16_img", "image", image_batch, "FP32", "unnorm_image_features", is_binary_data=True)
image_features.shape

(1, 512)

In [68]:
normed_image_features = image_features/np.linalg.norm(image_features, axis=image_features.ndim-1, keepdims=True)
normed_image_features.shape

(1, 512)

# 处理文本信息

In [69]:
#input_batch=["皮卡丘"]
input_batch=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘" ]

text_features=client.inf("clip_vitb16_txt", "text", input_batch, "INT32", "unnorm_text_features")
text_features.shape

(4, 512)

In [70]:
normed_text_features = text_features/np.linalg.norm(text_features, axis=text_features.ndim-1, keepdims=True)
normed_text_features.shape

(4, 512)

# 相关性计算

In [71]:
normed_image_features = torch.tensor(normed_image_features)
normed_text_features = torch.tensor(normed_text_features)

In [72]:
from torch import nn
  

def get_similarity(image_features, text_features):
    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()
    return logits_per_image, logits_per_text

logits_per_image, logits_per_text = get_similarity(normed_image_features, normed_text_features)
probs = logits_per_image.softmax(dim=-1).detach().numpy()
probs

array([[0.16164693, 0.2764256 , 0.1477481 , 0.41417935]], dtype=float32)