In [None]:
import os
import glob
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pickle

class ImageSimilarityFinder:
    def __init__(self):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        
    def get_image_embedding(self, image_path):
        """이미지의 임베딩 벡터 생성"""
        image = Image.open(image_path).convert('RGB')
        inputs = self.processor(images=image, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            embedding = self.model.get_image_features(**inputs)
        
        return embedding.cpu().numpy()
    
    def build_index(self, folder_path):
        """폴더 내 모든 이미지의 임베딩 생성"""
        image_paths = []
        embeddings = []
        
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_paths.extend(glob.glob(f'{folder_path}/**/{ext}', recursive=True))
        
        print(f"총 {len(image_paths)}개 이미지 처리 중...")
        
        for i, path in enumerate(image_paths):
            try:
                embedding = self.get_image_embedding(path)
                embeddings.append(embedding)
                if (i + 1) % 100 == 0:
                    print(f"{i + 1}/{len(image_paths)} 완료")
            except Exception as e:
                print(f"오류 발생 ({path}): {e}")
                continue
        
        # 저장
        data = {
            'paths': image_paths,
            'embeddings': np.vstack(embeddings)
        }
        
        with open('image_index.pkl', 'wb') as f:
            pickle.dump(data, f)
        
        return data
    
    def find_similar(self, query_image_path, top_k=10):
        """유사한 이미지 찾기"""
        # 인덱스 로드
        with open('image_index.pkl', 'rb') as f:
            data = pickle.load(f)
        
        # 쿼리 이미지 임베딩
        query_embedding = self.get_image_embedding(query_image_path)
        
        # 유사도 계산
        similarities = cosine_similarity(query_embedding, data['embeddings'])[0]
        
        # 상위 k개 인덱스
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            results.append({
                'path': data['paths'][idx],
                'similarity': similarities[idx]
            })
        
        return results

# 사용 예시
if __name__ == "__main__":
    finder = ImageSimilarityFinder()
    
    # 1단계: 인덱스 구축 (처음 한 번만)
    finder.build_index('./PHOTO')
    
    # 2단계: 유사 이미지 검색
    results = finder.find_similar('mint.jpg', top_k=10)
    
    for i, result in enumerate(results):
        print(f"{i+1}. {result['path']} (유사도: {result['similarity']:.4f})")

총 250개 이미지 처리 중...
100/250 완료
200/250 완료
1. ./PHOTO\DSC00804.jpg (유사도: 0.8532)
2. ./PHOTO\DSC00819.jpg (유사도: 0.8442)
3. ./PHOTO\DSC01103.jpg (유사도: 0.8293)
4. ./PHOTO\DSC00165.jpg (유사도: 0.8258)
5. ./PHOTO\DSC01102.jpg (유사도: 0.8215)
6. ./PHOTO\DSC00772.jpg (유사도: 0.8128)
7. ./PHOTO\DSC01063.jpg (유사도: 0.8124)
8. ./PHOTO\DSC00813.jpg (유사도: 0.8123)
9. ./PHOTO\DSC00653.jpg (유사도: 0.8097)
10. ./PHOTO\DSC00580.jpg (유사도: 0.8075)
