論文<br>
https://arxiv.org/abs/2112.10003<br>
<br>
GitHub<br>
https://github.com/timojl/clipseg<br>

# 環境セットアップ

## GitHubからCode Clone

In [None]:
%cd /content
!git clone https://github.com/timojl/clipseg.git

# Commits on Sep 27, 2022
%cd /content/clipseg
!git checkout 515ca6ec2d066d447240c1dd79f3bbbee685bd29

## ライブラリのインストール

In [None]:
!pip install git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1

## ライブラリのインポート

In [None]:
%cd /content/clipseg

import os

import torch
import requests

from models.clipseg import CLIPDensePredT
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt

device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

# 学習済みモデルのセットアップ

In [None]:
%cd /content/clipseg
!mkdir pretrained

if not os.path.exists('pretrained/weights.zip'):
  !wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O pretrained/weights.zip
  !unzip -d pretrained/weights -j pretrained/weights.zip

# モデルのロード

In [None]:
%cd /content/clipseg

# load model
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
model.eval();

model.load_state_dict(torch.load('pretrained/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False);

# データの前処理

In [None]:
!wget -c https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg \
      -O test_01.jpg

In [None]:
%cd /content/clipseg

# 画像のロード
input_image = Image.open('test_01.jpg')

# Normalize
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((352, 352)),
])
img = transform(input_image).unsqueeze(0)

input_image

# Segmentation

In [None]:
prompts = ['a lemon', 'a girl', 'wood']
num_of_p = len(prompts)

# predict
with torch.no_grad():
  preds = model(img.repeat(num_of_p,1,1,1), prompts)[0]

# visualize prediction
_, ax = plt.subplots(1, 5, figsize=(15, num_of_p))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(input_image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(num_of_p)];
[ax[i+1].text(0, -15, prompts[i]) for i in range(num_of_p)];