In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, UnidentifiedImageError
import numpy as np
import os
import faiss
import matplotlib.pyplot as plt
import cv2

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1.  유사도 모델

In [None]:
# 예시 폴더 경로
folder_path = r""

# EfficientNet 모델 로드 및 임베딩 벡터 추출 레이어 설정
base_model = models.efficientnet_b0(pretrained=True)
base_model.classifier[1] = nn.Identity()  # Remove the classification layer

# 이미지 전처리 함수
preprocess = transforms.Compose([
    transforms.Resize((256, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_image_embedding(image_path, model, device):
    try:
        # 이미지 로드 및 전처리
        img = Image.open(image_path).convert('RGB')
        img_tensor = preprocess(img).unsqueeze(0).to(device)

        # 임베딩 벡터 생성
        model.eval()
        with torch.no_grad():
            embedding_vector = model(img_tensor).cpu().numpy().flatten()
        return embedding_vector
    except UnidentifiedImageError:
        print(f"Cannot identify image file {image_path}. Skipping.")
        return None

def get_embeddings_from_folder(folder_path, model, device):
    embeddings = []
    image_paths = []
    for file_name in os.listdir(folder_path):
        if file_name.endswith(('jpg', 'jpeg', 'png')):
            image_path = os.path.join(folder_path, file_name)
            embedding = get_image_embedding(image_path, model, device)
            if embedding is not None:
                embeddings.append(embedding)
                image_paths.append(image_path)
    return np.array(embeddings), image_paths

def plot_image_groups(groups):
    for idx, group in enumerate(groups):
        if idx < 200:
            print("=================group {}=================".format(idx))
            plt.figure(figsize=(10,10))
            for i, image_path in enumerate(group):
                plt.subplot(1, len(group), i + 1)
                img = Image.open(image_path)
                plt.imshow(img)
                plt.axis('off')
            plt.show()


# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = base_model.to(device)

# 폴더 내 모든 이미지에 대한 임베딩 벡터 생성
embeddings, image_paths = get_embeddings_from_folder(folder_path, base_model, device)

# FAISS를 이용한 코사인 유사도 측정
index = faiss.IndexFlatIP(embeddings.shape[1])
faiss.normalize_L2(embeddings)
index.add(embeddings)

D, I = index.search(embeddings, k=len(embeddings))  # 모든 이미지에 대해 유사도 측정

# 유사도 0.6 이상인 그룹 생성
threshold = 0.775
groups = []
visited = set()

for i in range(len(embeddings)):
    if i in visited:
        continue
    group = [image_paths[i]]
    visited.add(i)
    for j in range(1, len(I[i])):
        if D[i][j] >= threshold and I[i][j] not in visited:
            group.append(image_paths[I[i][j]])
            visited.add(I[i][j])
    groups.append(group)

# 결과 출력
print(device)
plot_image_groups(groups)

# 2. blurring 모델

In [None]:
# Define your model architecture
class MobileNetClassifier(nn.Module):
    def __init__(self):
        super(MobileNetClassifier, self).__init__()
        self.model = models.mobilenet_v2(pretrained=True)
        self.model.classifier[1] = nn.Linear(self.model.last_channel, 2)  # Assuming 2 classes: blur and sharp
    
    def forward(self, x):
        x = self.model(x)
        return x

# Initialize your model
blur_model = MobileNetClassifier()

# mobilenet_centered_softblurred_scene_laplacian
model_path = './mobilenet_blurred.pth'

# Load the saved blur_model state_dict
blur_model.load_state_dict(torch.load(model_path))

# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Device 설정 (GPU 또는 CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blur_model.to(device)  # 모델을 GPU 또는 CPU로 전송

def predict_blur(blur_model, image_path):
    # 이미지 로드 및 전처리
    image = Image.open(image_path).convert('RGB')  # RGB로 변환
    image = transform(image).unsqueeze(0).to(device)  # 이미지를 GPU 또는 CPU로 전송

    # 모델 예측
    blur_model.eval()
    with torch.no_grad():
        output = blur_model(image)
        _, predicted = torch.max(output, 1)
        
    # 결과 출력
    classes = ['sharp','blur']
    predicted_class = classes[predicted.item()]
    if predicted_class == 'blur':
        # blur 사진 테스트용 -> 바로 사진 print
        # print(image_path)
        # img = cv2.imread(image_path,cv2.IMREAD_ANYCOLOR)
        # image_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # plt.imshow(image_rgb)
        # plt.axis('off')  # 축 제거
        # plt.show()

        # blur된 사진 경로 return
        return image_path
    elif predicted_class == 'sharp':
        return None


path = r"C:\Users\ben81\zflip_camera"
blurred_files = []

for i in os.listdir(path):
    if i.endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(path, i)
        try:
            blurred_path = predict_blur(blur_model, image_path)
            if blurred_path not None:
                blurred_files.append(blurred_path)
        except:
            print('error')

print(blurred_files)


# 3. Eye closing 모델