# 필요 라이브러리

In [7]:
# %% 표준 라이브러리
import os
import time
import random
from datetime import datetime, timedelta
# %% 수치·데이터 처리
import numpy as np
import pandas as pd
from collections import Counter

# %% 시각화
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance, ImageFilter
import cv2
# %% PyTorch 및 관련
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import DataLoader, Dataset, Subset, WeightedRandomSampler
from torch.optim.lr_scheduler import StepLR, OneCycleLR, ReduceLROnPlateau
from torch.cuda.amp import GradScaler
from torch.amp import autocast


# %% torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights, efficientnet_b0, EfficientNet_B0_Weights


from transformers import CLIPProcessor, CLIPModel

# %% 유틸리티
from torchinfo import summary
from tqdm import tqdm

# %% 데이터 분할
from sklearn.model_selection import StratifiedShuffleSplit


plt.rcParams['font.family'] = 'Malgun Gothic'  # Windows의 경우, 한글 지원 폰트로 설정
plt.rcParams['axes.unicode_minus'] = False       # 음수 기호가 깨지지 않도록 설정
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

# 파라미터

In [8]:
IMAGE_SIZE = (224, 224)

# 모델정의

In [None]:
# 6개 질환 모델과 전처리 함수
pizi_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\biddem_B0_compressed.pt'
# pizi_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\biddem_compressed.pt' # b3

talmo_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\talmo_B0_compressed.pt'
mosa_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\mosa_B0_compressed.pt'
mono_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\mono_B0_compressed.pt'
# mono_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\mono_compressed.pt' # b3
mise_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\mise_B0_compressed.pt'
biddem_model_path = r'C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\result\model\compressed\biddem_B0_compressed.pt'

def load_model(model_path):
    model = efficientnet_b0()  
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Sequential(nn.Linear(in_features, 512),
                                        nn.BatchNorm1d(512),       # 배치 정규화 추가
                                        nn.ReLU(),                 # 혹은 nn.ReLU()
                                        nn.Linear(512, 3))         # 최종 클래스 수
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model = model.to(DEVICE)
    model.eval()
    return model

def load_model_b3(model_path):
    model = efficientnet_b3()  
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Sequential(nn.Linear(in_features, 512),
                                        nn.BatchNorm1d(512),       # 배치 정규화 추가
                                        nn.ReLU(),                 # 혹은 nn.ReLU()
                                        nn.Linear(512, 3))         # 최종 클래스 수
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model = model.to(DEVICE)
    model.eval()
    return model


pizi_model    = load_model(pizi_model_path)
talmo_model   = load_model(talmo_model_path)
mosa_model    = load_model(mosa_model_path)
mono_model    = load_model(mono_model_path)
mise_model    = load_model(mise_model_path)
biddem_model  = load_model(biddem_model_path)

disease_models = [
    mise_model, pizi_model, mosa_model, mono_model, biddem_model, talmo_model  # 이미 load된 상태 (eval)
]

# 미세각질
class UnsharpMaskTransform(object):
    def __call__(self, img):
        return img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))

class LaplacianEnhanceTransform(object):
    def __call__(self, img):
        img_np = np.array(img)
        gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
        lap = cv2.Laplacian(gray, cv2.CV_64F, ksize=3)
        lap = np.clip(lap, 0, 255).astype(np.uint8)
        lap_rgb = cv2.cvtColor(lap, cv2.COLOR_GRAY2RGB)
        # 원본과 Laplacian 결과를 합성
        enhanced = cv2.addWeighted(img_np, 0.8, lap_rgb, 0.5, 0)
        return Image.fromarray(enhanced)

mise_preprocess = transforms.Compose([
    UnsharpMaskTransform(),
    LaplacianEnhanceTransform(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 피지과다
class RemoveReflection(object):
    def __call__(self, img):
        return img.filter(ImageFilter.MedianFilter(size=3))

class EnhanceContrastSebaceous(object):
    def __call__(self, img):
        enhancer = ImageEnhance.Contrast(img)
        return enhancer.enhance(1.5)

pizi_preprocess = transforms.Compose([
    RemoveReflection(),
    EnhanceContrastSebaceous(),
    transforms.Resize((300, 300)),  # 피지과다 이미지 크기 조정
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 모낭사이홍반
class ErythemaRednessEnhanceTransform:
    def __init__(self, apply_mask: bool=False):
        # a* 채널 CLAHE
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4,4))
        self.apply_mask = apply_mask

    def __call__(self, img: Image.Image) -> Image.Image:
        # 1) RGB → LAB → a* 채널만 CLAHE
        img_np    = np.array(img)
        lab       = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b   = cv2.split(lab)
        a_clahe   = self.clahe.apply(a)
        lab_clahe = cv2.merge((l, a_clahe, b))
        img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)

        if self.apply_mask:
            # 2) HSV 마스크로 붉은 영역 분리
            hsv    = cv2.cvtColor(img_clahe, cv2.COLOR_RGB2HSV)
            lower1 = (0,   50, 50); upper1 = (15, 255,255)
            lower2 = (160, 50, 50); upper2 = (180,255,255)
            m1     = cv2.inRange(hsv, lower1, upper1)
            m2     = cv2.inRange(hsv, lower2, upper2)
            mask   = cv2.bitwise_or(m1, m2)

            # 3) 배경은 회색조로 변환
            gray     = cv2.cvtColor(img_clahe, cv2.COLOR_RGB2GRAY)
            gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)

            # 4) 마스크 영역만 컬러, 나머지는 gray
            mask_bool = mask.astype(bool)[..., None]
            result    = np.where(mask_bool, img_clahe, gray_rgb)
            return Image.fromarray(result.astype(np.uint8))

        # apply_mask=False: CLAHE 처리된 컬러만 반환
        return Image.fromarray(img_clahe)
    
mosa_preprocess = transforms.Compose([
    ErythemaRednessEnhanceTransform(apply_mask=True),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 모낭농포
class EmphasizeRedTransform(object):
    """LAB a* 채널(JET colormap)로 염증(적색) 강조"""
    def __call__(self, img):
        img_np = np.array(img)
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        # a* 채널 정규화 후 pseudocolor
        a_norm = cv2.normalize(a, None, 0, 255, cv2.NORM_MINMAX)
        a_color = cv2.applyColorMap(a_norm, cv2.COLORMAP_JET)
        # cv2는 BGR 반환이므로 RGB 변환 필요
        a_color_rgb = cv2.cvtColor(a_color, cv2.COLOR_BGR2RGB)
        return Image.fromarray(a_color_rgb)
    
mono_preprocess = transforms.Compose([
    EmphasizeRedTransform(),             # 염증 강조 (a* pseudocolor)
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    # 시각화용이라면 Normalize 생략, 모델 추론용이면 Normalize 적용
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 비듬
class CLAHETransform(object):
    def __init__(self, clip_limit=0.4, grid_size=(4,4)):
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)

    def __call__(self, img):
        img_np = np.array(img)
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        l_clahe = self.clahe.apply(l)
        lab_clahe = cv2.merge((l_clahe, a, b))
        img_rgb = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
        return Image.fromarray(img_rgb)

# Sharpen transform
class SharpenTransform(object):
    def __call__(self, img):
        return img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
    
biddem_preprocess = transforms.Compose([
    CLAHETransform(clip_limit=0.4, grid_size=(4,4)),             # 염증 강조 (a* pseudocolor)
    SharpenTransform(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    # 시각화용이라면 Normalize 생략, 모델 추론용이면 Normalize 적용
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 탈모
class AlopeciaHLFTransform(object):
    def __init__(self, grid_size=8):
        self.grid_size = grid_size

    def __call__(self, img):
        # 1. CLAHE (contrast enhancement)
        img_np = np.array(img)
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4,4))
        l_clahe = clahe.apply(l)
        lab_clahe = cv2.merge((l_clahe, a, b))
        img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)

        # 2. Sharpening
        img_sharp = cv2.addWeighted(img_clahe, 1.5, cv2.GaussianBlur(img_clahe, (0,0), 3), -0.5, 0)

        return Image.fromarray(img_sharp)
    
talmo_preprocess = transforms.Compose([
    AlopeciaHLFTransform(grid_size=8),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    # 시각화용이라면 Normalize 생략, 모델 추론용이면 Normalize 적용
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

preprocess_funcs = [
    mise_preprocess, pizi_preprocess, mosa_preprocess, mono_preprocess, biddem_preprocess, talmo_preprocess  # 각 질환별 전처리
]

disease_names = ["미세각질", "피지과다", "모낭사이홍반", "모낭농포", "비듬", "탈모"]

# ────────────── 함수 정의 ──────────────

# ────────────── 2. 질환별 추론 ──────────────
def disease_inference(img_pil, disease_models, preprocess_funcs, disease_names):
    """
    질환별 전처리 & 모델 추론
    결과: {질환명: {'pred_class': int, 'probs': np.array}}
    """
    results = {}
    for model, preprocess, name in zip(disease_models, preprocess_funcs, disease_names):
        tensor = preprocess(img_pil).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            out = model(tensor)
            probs = torch.softmax(out, dim=1)[0].cpu().numpy()
            pred = int(probs.argmax())
        results[name] = {"pred_class": pred, "probs": probs}
        torch.cuda.empty_cache()
    return results

# ────────────── 3. 결과 깔끔하게 출력 ──────────────
def pretty_print_disease_result_top1(disease_result, class_labels):
    """
    Top-1 예측 결과를 콘솔에 깔끔히 출력
    class_labels: [[클래스명1, ...], ...]
    """
    print("========== 두피 질환별 예측 결과 ==========")
    for idx, (disease, res) in enumerate(disease_result.items()):
        pred = res["pred_class"]
        prob = res["probs"][pred]
        name = class_labels[idx][pred]
        print(f"● {disease}: {name} ({prob:.2%})")
    print("==========================================")

# ────────────── 6. 메인 파이프라인 ──────────────
def main_pipeline(img_path):
    img = Image.open(img_path).convert("RGB")
    img = img.resize(IMAGE_SIZE)

    print("질환 예측 시작...")
    res = disease_inference(img, disease_models, preprocess_funcs, disease_names)
    pretty_print_disease_result_top1(res, class_labels)
    

# ────────────── 실행 예시 ──────────────

# 클래스 라벨 설정
class_labels = [
    ["정상","경증","중증"],  # 미세각질
    ["정상","경증","중증"],  # 피지과다
    ["정상","경증","중증"],  # 모낭사이홍반
    ["정상","경증","중증"],  # 모낭농포
    ["정상","경증","중증"],  # 비듬
    ["정상","경증","중증"],  # 탈모
]

# 이미지 경로 지정
img_path = r"C:\Users\user1\Desktop\Code\Scalp_Disease_Classifier\data\pizi_org_3_preprocess\train\중증\0013_A2LEBJJDE00060O_1603257373184_2_TH.jpg"
main_pipeline(img_path)

질환 예측 시작...
● 미세각질: 경증 (77.32%)
● 피지과다: 중증 (43.49%)
● 모낭사이홍반: 경증 (84.21%)
● 모낭농포: 경증 (85.46%)
● 비듬: 경증 (86.81%)
● 탈모: 경증 (69.89%)
