In [14]:
!pip install git+https://github.com/openai/CLIP.git
!pip install torchvision


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-ikw2eys5
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-ikw2eys5
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [15]:
# 라이브러리 임포트
import torch
import clip
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


In [16]:
# 데이터 변환 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP 모델의 입력 크기에 맞게 조정
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),  # CLIP 입력 정규화 값
                         (0.26862954, 0.26130258, 0.27577711))
])

# 테스트 데이터셋 로드
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 41.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [17]:
# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)


100%|███████████████████████████████████████| 338M/338M [00:07<00:00, 49.5MiB/s]


In [18]:
# CIFAR-10 클래스 이름
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

# 프롬프트 템플릿 리스트
prompts = [
    'a photo of a {}.',
    'an image of a {}.',
    'a drawing of a {}.',
    'a sketch of a {}.',
    'a picture of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a cartoon {}.',
    'a painting of a {}.'
]


In [19]:
def evaluate_model(model, test_loader, classes, prompt_template):
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        text_inputs = [prompt_template.format(c) for c in classes]
        text_tokens = clip.tokenize(text_inputs).to(device)
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ text_features.T)
            probs = logits.softmax(dim=-1)
            predictions = torch.argmax(probs, dim=1)
            total += labels.size(0)
            correct += (predictions == labels).sum().item()

    accuracy = correct / total
    return accuracy


In [None]:
results = {}
for prompt in prompts:
    print(f'프롬프트: "{prompt}" 평가 중...')
    accuracy = evaluate_model(model, test_loader, classes, prompt)
    results[prompt] = accuracy
    print(f'정확도: {accuracy:.4f}\n')


프롬프트: "a photo of a {}." 평가 중...


  2%|▏         | 3/157 [00:33<28:27, 11.09s/it]

In [None]:
# 프롬프트와 정확도 리스트 생성
prompt_list = list(results.keys())
accuracy_list = [results[p] for p in prompt_list]

# 막대 그래프 그리기
plt.figure(figsize=(10, 6))
plt.barh(prompt_list, accuracy_list, color='skyblue')
plt.xlabel('Accuracy')
plt.title('CLIP Zero-shot Performance with Different Prompts on CIFAR-10')
plt.gca().invert_yaxis()  # 높은 정확도가 위로 오도록 순서 변경
plt.show()
