# 简单的 CLIP 模型

In [1]:
# !python -m pip install --upgrade pip setuptools -i https://mirrors.aliyun.com/pypi/simple/
# !python -m pip install git+https://github.com/openai/CLIP.git -i https://mirrors.aliyun.com/pypi/simple/

In [2]:
import torch
import clip
from PIL import Image
import logging
import time

import utils

In [3]:
MODEL_PATH = 'model'
DATA_PATH = 'data'

model_path = utils.gen_abspath(directory='./', rel_path=MODEL_PATH)

# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

INFO:__main__:Using device: cuda


In [4]:
# 下载模型
model, preprocess = clip.load("ViT-B/32", device=device, download_root=model_path)

In [5]:
def generate_image_embedding(image_path):
    # 加载示例图像并处理
    try:
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    except FileNotFoundError:
        logger.error(f"The image file {image_path} was not found.")
        raise
    except Exception as e:
        logger.error(f"An error occurred while opening the image: {e}")
        raise

    # 使用模型生成图像的 Embedding
    with torch.no_grad():
        image_features = model.encode_image(image)
    
    # 将 Embedding 转换为标准化的向量
    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features

def generate_text_embedding(text_list):
    # 将文本转化为 tokens
    text_tokens = clip.tokenize(text_list).to(device)
    
    # 使用模型生成文本的 Embedding
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
    
    # 将 Embedding 转换为标准化的向量
    text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

In [12]:
# 获取文本 Embedding
# text_features = generate_text_embedding(["A photo of a cat", "A photo of a dog"])
text_features = generate_text_embedding(["一只猫", "一条狗"])
text_features.shape, text_features.device

(torch.Size([2, 512]), device(type='cuda', index=0))

In [13]:
# 获取图片 Embedding
img_path = utils.gen_abspath(directory=DATA_PATH, rel_path='cat.JPG')
image_features = generate_image_embedding(img_path)
image_features.shape, image_features.device

(torch.Size([1, 512]), device(type='cuda', index=0))

In [14]:
# 计算图像与文本的相似性分数
eventual_similarity = torch.matmul(image_features, text_features.T)
eventual_similarity.cpu().numpy()

array([[0.281 , 0.2598]], dtype=float16)