<a href="https://colab.research.google.com/github/cedro3/others/blob/master/CLIP_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セットアップ

In [None]:
# --- セットアップ ---

# 1.pytorchバージョン変更
! pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html #ftfy regex

# 2.GithubからCLIPをコピー
! git clone https://github.com/openai/CLIP.git
%cd CLIP/clip

# 3.CLIPモデルの重みをダウンロード
MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",    
}
! wget {MODELS["ViT-B/32"]} -O model.pt

# 4.simple_tokenizer インストール
! pip install ftfy regex
from simple_tokenizer import *
tokenizer = SimpleTokenizer()

# 5.サンプル画像ダウンロード
import gdown
gdown.download('https://drive.google.com/uc?id=1vcxH6JOtwh_-FoZ8SNXYlHF9qCi3YoDH', 'food_101.zip', quiet=False)
! unzip food_101.zip

# CLIPモデルの仕様確認

In [None]:
# --- CLIPモデルの仕様確認 ----

import numpy as np
import torch

model = torch.jit.load("model.pt").cuda().eval()
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

# simple_tokenizer の動作確認

In [None]:
# テキストをトークンへ変換1
index = tokenizer.encode('I ate an apple')
print(index)

In [None]:
# テキストをトークンへ変換2
index = tokenizer.encode('image segmentation')
print(index)

# 画像の前処理


In [None]:
# --- 画像の前処理 ----

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
import glob

# 設定
preprocess = Compose([
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])

image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()

# 前処理実行
images =[]
files = glob.glob('./food_101/*.jpg')
files.sort()
for file in files:
      image = preprocess(Image.open(file).convert("RGB"))
      images.append(image)

image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]

print('image_input.shape = ', image_input.shape)

# テキストの前処理

In [None]:
# --- テキストの前処理 ----

# 分類ラベルの設定
labels = ['takoyaki', 'susi', 'spagetti', 'ramen', 'pizza', 'omelette', 'humburger', 'gyoza']

# ラベルを文の形のトークンへ変換
text_descriptions = [f"This is a photo of a {label}" for label in labels]  
sot_token = tokenizer.encoder['<|startoftext|>']
eot_token = tokenizer.encoder['<|endoftext|>']
text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]
text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)

# トークンをテンソルに変換
for i, tokens in enumerate(text_tokens):
    text_input[i, :len(tokens)] = torch.tensor(tokens)

text_input = text_input.cuda()

In [None]:
# 各データの先頭を表示
print(text_descriptions[0]) 
print(text_tokens[0])
print(text_input[0])
print(text_input.shape)

# 画像とテキストのcos類似度を計算

In [None]:
# --- 画像とテキストのCOS類似度を計算 ----

# CLIPモデルで画像とテキストの特徴を抽出
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_input).float()
    text_features /= text_features.norm(dim=-1, keepdim=True) 

# 画像の特徴とテキストの特徴からCOS類似度を計算
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

print(image_features.shape)
print(text_features.shape)

In [None]:
# COS類似度の計算結果をそのまま表示
print(text_probs)

# 予測結果の表示

In [None]:
# --- 予測結果の表示 ---

import matplotlib.pyplot as plt

def pred_disp(i, image):
      plt.figure(figsize=(8, 4))
      plt.subplot(1, 2, 1)
      plt.imshow(image.permute(1, 2, 0))
      plt.axis("off")

      plt.subplot(1, 2, 2)
      y = np.arange(top_probs.shape[-1])
      plt.grid()
      plt.barh(y, top_probs[i])
      plt.gca().invert_yaxis()
      plt.gca().set_axisbelow(True)
      plt.yticks(y, [labels[index] for index in top_labels[i].numpy()])
      plt.xlabel("probability")

      plt.subplots_adjust(wspace=0.5)
      plt.show()

for i, image in enumerate(images):
     pred_disp(i, image)