In [None]:
from PIL import Image
import requests
import torch
from transformers import CLIPProcessor, CLIPModel
from IPython.display import display

# 모델과 프로세서
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# 이미지 불러오기
urls = [
    "https://images.pexels.com/photos/103123/pexels-photo-103123.jpeg",
    "https://images.pexels.com/photos/414712/pexels-photo-414712.jpeg",
    "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg",
    "https://images.pexels.com/photos/17474740/pexels-photo-17474740.jpeg",
    "https://images.pexels.com/photos/10875195/pexels-photo-10875195.jpeg",
    "https://images.pexels.com/photos/34026276/pexels-photo-34026276.jpeg"
]

images = []
for url in urls:
    img = Image.open(requests.get(url, stream=True).raw)
    images.append(img)

# 라벨과 템플릿 준비
templates = [
    "a photo of a {}",
    "a close-up photo of a {}",
    "a blurry photo of a {}",
    "a cropped photo of a {}",
    "a photo of a {} in the wild",
    "a photo of a {} outdoors",
]
labels = ["animal", "object", "banana", "bird", "person","flower","food","scenery"]

# 모든 텍스트 후보 만들기
texts = [t.format(l) for l in labels for t in templates]

# 텍스트 임베딩 미리 계산
# 배치처리 가능으로 인해 texts들을 한꺼번에 임베딩 할 수 있음
inputs_text = processor(text=texts, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
    text_embeds = model.get_text_features(**inputs_text)
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

num_templates = len(templates)

# 프롬프트 앙상블
text_embeds_ensemble = text_embeds.view(len(labels), num_templates, -1).mean(dim=1)  # [num_labels, dim]


# 이미지별 처리
topk = 3
thumb_size = (400, 500)

for idx, image in enumerate(images):
    # 이미지 임베딩
    # text와 다르게 이미지는 사이즈 등이 다르기 때문에 processer에서는 1:1로 다루려고 하므로 각각 하였다.
    inputs_image = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_embeds = model.get_image_features(**inputs_image)
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

    # 프롬프트 앙상블 미적용
    logits_templates = image_embeds @ text_embeds.T
    probs_templates = logits_templates.softmax(dim=1)
    
    # top-3
    values, indices = probs_templates.topk(topk, dim=1)
    print("\n프롬프트 앙상블 미적용: Top-3 템플릿")
    for rank, (v, i) in enumerate(zip(values[0], indices[0]), 1):
        print(f"{rank}. '{texts[i]}' - 확률: {v.item():.3f}")

    # 프롬프트 앙상블 적용
    logits_labels = image_embeds @ text_embeds_ensemble.T  # [1, num_labels]
    probs_labels = logits_labels.softmax(dim=1)

    # top-k 라벨 출력
    values, indices = probs_labels.topk(topk, dim=1)
    print(f"\n프롬프트 앙상블 적용: Top-{topk} 라벨")
    for rank, (v, i) in enumerate(zip(values[0], indices[0]), 1):
        print(f"{rank}. '{labels[i]}' - 확률: {v.item():.3f}")

    # 이미지 출력
    image_copy = image.copy()
    image_copy.thumbnail(thumb_size)
    display(image_copy)
