<a href="https://colab.research.google.com/github/cedro3/others/blob/master/CLIP_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セットアップ

In [None]:
# Pytorchバージョン変更
! pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html 

# CLIP関連コードのコピー
! git clone https://github.com/openai/CLIP.git
%cd /content/CLIP/

# CLIPのモデル化
! pip install ftfy regex
import clip
model, preprocess = clip.load('ViT-B/32', jit=True)  
model = model.eval()  

# サンプル画像ダウンロード
import gdown
gdown.download('https://drive.google.com/uc?id=1xIYYYzw9aZhjhyjMM12nz4XjnWUzpp6v', 'img.zip', quiet=False)
! unzip img.zip

# 検索する画像の読み込み


In [None]:
# --- 画像の前処理 ----
import torch
import numpy as np
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
import glob
from tqdm import tqdm

# 前処理設定
preprocess = Compose([
    Resize(224, interpolation=Image.BICUBIC),
    CenterCrop(224),
    ToTensor()
])
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()


# 画像の読み込み
images =[]
files = glob.glob('./img/*.png')
files.sort()

for i, file in enumerate(tqdm(files)):
      image = preprocess(Image.open(file).convert("RGB"))
      images.append(image)

image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]

print('image_input.shape = ', image_input.shape)

# 検索テキストの入力


In [None]:
text = 'She is a charming woman with blonde hair and blue eyes'
text_input = clip.tokenize(text)
text_input = text_input.cuda()

print('text_input = ', text_input)
print('text_input.shape = ', text_input.shape)

# 画像とテキストのcos類似度の計算



In [None]:
# --- 画像とテキストのCOS類似度の計算 ----

# 特徴ベクトルを抽出
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_input).float()
    text_features /= text_features.norm(dim=-1, keepdim=True) 

# COS類似度を計算
text_probs = torch.cosine_similarity(image_features, text_features)

print('image_features.shape = ', image_features.shape)
print('text_features.shape = ', text_features.shape)
print('text_probs.shape = ', text_probs.shape)

# 検索結果の表示

In [None]:
# --- 検索結果の表示 ---

import matplotlib.pyplot as plt

# 検索テキスト表示
print('text = ', text)
print()

# COS類似度の高い順にインデックスをソート
x = np.argsort(-text_probs.cpu(), axis=0)

# COS類似度TOP３を表示
fig = plt.figure(figsize=(30, 40))
for i in range(3):
    name = str(x[i].item()).zfill(6)+'.png'
    img = Image.open('./img/'+name)    
    images = np.asarray(img)
    ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
    image_plt = np.array(images)
    ax.imshow(image_plt)
    cos_value = round(text_probs[x[i].item()].item(), 3)
    ax.set_xlabel(cos_value, fontsize=12)               
plt.show()
plt.close()  