In [None]:
import os
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import OxfordIIITPet
from transformers import CLIPModel, CLIPProcessor
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np

# 모델 로드
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using device: {device}")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

# Oxford-IIIT Pet 로드
pet_dataset = OxfordIIITPet(
    root=os.path.expanduser("~/.cache"),
    split='test',
    target_types='category',
    download=True
)

# 무작위 이미지 선택
indices = torch.randperm(len(pet_dataset))[:2000].tolist()
images = [pet_dataset[i][0] for i in indices]
gt_labels = [pet_dataset.classes[pet_dataset[i][1]] for i in indices]

# 1. SINGLE PROMPT (BASELINE)
print("\nSINGLE PROMPT EVALUATION")
single_template = ["a photo of a {}"]
single_texts_per_class = [[t.format(c) for t in single_template] for c in pet_dataset.classes]
single_flat_texts = [t for texts in single_texts_per_class for t in texts]

# Single prompt CLIP 입력
single_inputs = processor(
    images=images,
    text=single_flat_texts,
    return_tensors="pt",
    padding=True
).to(device)

with torch.no_grad():
    single_outputs = model(**single_inputs)
    single_image_features = single_outputs.image_embeds
    single_text_features = single_outputs.text_embeds

# Single 정규화
single_image_features = single_image_features / single_image_features.norm(dim=-1, keepdim=True)
single_text_features = single_text_features / single_text_features.norm(dim=-1, keepdim=True)

# Single 클래스별 평균 (37, 1, 512) → (37, 512)
num_classes = len(pet_dataset.classes)
single_text_features = single_text_features.view(num_classes, 1, -1).mean(dim=1)
single_similarity = single_image_features @ single_text_features.T

# Single 성능 계산
similarity_topk = 2
single_values, single_indices_pred = single_similarity.topk(similarity_topk, dim=1)

single_correct = 0
single_margins = []
for img_idx in range(len(images)):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    if single_indices_pred[img_idx, 0] == gt_idx:
        single_correct += 1
    if similarity_topk > 1:
        margin = (single_values[img_idx, 0] - single_values[img_idx, 1]).item()
        single_margins.append(margin)

single_acc = single_correct / len(images) * 100
single_avg_margin = np.mean(single_margins) if single_margins else 0

# 2. 템플릿 앙상블
print("\nTEMPLATE ENSEMBLE EVALUATION")
# Oxford-IIIT Pet 및 동물 분류에 최적화된 10대 템플릿
templates = [
    "a photo of a {}, a type of pet.",               # 가장 강력한 기본형
    "a photo of the {}, a type of cat or dog.",      # 대분류(개/고양이) 명시
    "a photo of a {}, a breed of dog.",              # 품종 맥락 추가
    "a close-up photo of a {}.",                     # 근접 촬영 대응
    "a photo of a sitting {}.",                      # 자세 정보 추가
    "a pet portrait of a {}.",                       # 인물화 형식의 구도
    "the {} is shown in the image.",                 # 객체 중심 설명
    "a blurry photo of a {}.",                       # 저화질/노이즈 대응
    "a photo of a {} looking at the camera.",        # 시선 처리 대응
    "a high quality photo of a {}."                  # 고화질 특징 강조
]

num_templates = len(templates)
texts_per_class = [[template.format(c) for template in templates] for c in pet_dataset.classes]
flat_texts = [t for texts in texts_per_class for t in texts]

# Ensemble CLIP 입력
inputs = processor(
    images=images, 
    text=flat_texts, 
    return_tensors="pt", 
    padding=True
).to(device)

with torch.no_grad():
    outputs = model(**inputs)
    image_features = outputs.image_embeds
    text_features = outputs.text_embeds

# Ensemble 정규화 & 평균
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
text_features = text_features.view(num_classes, num_templates, -1).mean(dim=1)
ensemble_similarity = image_features @ text_features.T

# Ensemble 성능 계산
ensemble_values, ensemble_indices_pred = ensemble_similarity.topk(similarity_topk, dim=1)
ensemble_correct = 0
ensemble_margins = []
for img_idx in range(len(images)):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    if ensemble_indices_pred[img_idx, 0] == gt_idx:
        ensemble_correct += 1
    if similarity_topk > 1:
        margin = (ensemble_values[img_idx, 0] - ensemble_values[img_idx, 1]).item()
        ensemble_margins.append(margin)

ensemble_acc = ensemble_correct / len(images) * 100
ensemble_avg_margin = np.mean(ensemble_margins) if ensemble_margins else 0

#최종 비교 결과
print("\n" + "="*80)
print(f"SINGLE PROMPT vs {len(templates)} TEMPLATE ENSEMBLE COMPARISON")
print("="*80)
print(f"{'Metric':<25} {'Single':<12} {'Ensemble':<12} {'Improvement':<12}")
print("-"*80)
print(f"Top-1 Accuracy     : {single_acc:6.1f}% ({single_correct:3d})  {ensemble_acc:6.1f}% ({ensemble_correct:3d})  {ensemble_acc-single_acc:+6.1f}%")
print(f"Avg Margin         : {single_avg_margin:8.3f}       {ensemble_avg_margin:8.3f}       {ensemble_avg_margin-single_avg_margin:+7.3f}")
print(f"Improvement Rate   : {'':<18} {'':<12} {((ensemble_correct/single_correct-1)*100 if single_correct>0 else 0):+6.1f}%")
print("="*80)

# 상세 결과 (첫 10개)
print("\nDETAILED RESULTS (first 5 images):")
for img_idx in range(min(5, len(images))):
    gt_idx = pet_dataset.class_to_idx[gt_labels[img_idx]]
    single_correct = "O" if single_indices_pred[img_idx, 0] == gt_idx else "X"
    ensemble_correct = "O" if ensemble_indices_pred[img_idx, 0] == gt_idx else "X"
    
    print(f"\nImage {img_idx+1} | Answer: {gt_labels[img_idx]}")
    print(f"  Single:  {pet_dataset.classes[single_indices_pred[img_idx, 0]]} {single_correct} ({single_values[img_idx, 0]:.3f})")
    print(f"  Ensemble:{pet_dataset.classes[ensemble_indices_pred[img_idx, 0]]} {ensemble_correct} ({ensemble_values[img_idx, 0]:.3f})")

# 이미지 시각화
n_show = min(8, len(images))
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.ravel()

for idx, (ax, img, gt) in enumerate(zip(axes, images[:n_show], gt_labels[:n_show])):
    ax.imshow(img)
    gt_idx = pet_dataset.class_to_idx[gt]
    single_pred = pet_dataset.classes[single_indices_pred[idx, 0]]
    ensemble_pred = pet_dataset.classes[ensemble_indices_pred[idx, 0]]
    single_status = "O" if single_indices_pred[idx, 0] == gt_idx else "X"
    ensemble_status = "O" if ensemble_indices_pred[idx, 0] == gt_idx else "X"
    
    title = f"Answer: {gt}\nSingle: {single_pred} {single_status}\nEnsemble: {ensemble_pred} {ensemble_status}"
    ax.set_title(title, fontsize=14)
    ax.axis("off")

plt.suptitle(f'Single Prompt vs {len(templates)} Prompt', fontsize=17)
plt.tight_layout()
