In [1]:
import torch
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
import types
# -----------------------------------
# 1. 모델 구조 정의 및 가중치 로드
# -----------------------------------
print("1. Loading the fine-tuned model structure and weights...")

# --- 설정값  ---
MODEL_REPO = "facebookresearch/dinov2"
MODEL_NAME = "dinov2_vits14"

NUM_CLASSES = 3 
CLASS_NAMES = ['downy','healthy', 'powdery']
DEVICE = torch.device("cpu")
print(f"--> Using device: {DEVICE}")
# --- 모델 구조 만들기  ---
try:
    # torch.hub를 이용해 DINOv2 모델 구조 로드
    model = torch.hub.load(MODEL_REPO, MODEL_NAME, pretrained=False) # pretrained=False로 설정
    
    num_features = 384 # ViT-Small의 특징 벡터 크기
    model.head = nn.Linear(num_features, NUM_CLASSES)
    
    # --- 저장된 가중치 불러오기 ---
    model.load_state_dict(torch.load('dinov2_hub_finetuned_model.pth', map_location=DEVICE))
    
    model = model.to(DEVICE)
    model.eval()
    print("--> Model loaded successfully.")

except Exception as e:
    print(f"\nERROR: Failed to load the model.")
    print(f"--> Ensure 'dinov2_hub_finetuned_model.pth' is in the same directory.")
    print(f"--> Original Error: {e}")
    sys.exit(1)

1. Loading the fine-tuned model structure and weights...
--> Using device: cpu


Using cache found in C:\Users\51100/.cache\torch\hub\facebookresearch_dinov2_main


--> Model loaded successfully.


In [2]:
# ----------------------------------
# 2. 이미지 전처리
# ----------------------------------
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
# ----------------------------------
# 3. 몽키 패치를 이용한 어텐션 추출 및 시각화
# ----------------------------------

def predict_and_visualize_attention(image_path):
    try:
        original_img = cv2.imread(image_path)
        original_img = cv2.resize(original_img, (224, 224))

        pil_img = Image.open(image_path).convert("RGB")
        img_tensor = preprocess(pil_img).unsqueeze(0).to(DEVICE)

        # ✅ 1. 마지막 어텐션 블록의 원래 forward 함수를 따로 저장
        attn_block = model.blocks[-1].attn
        attn_forward_orig = attn_block.forward

        # ✅ 2. 어텐션 맵을 저장할 새로운 forward 함수 정의
        def new_attn_forward(self, x):
            # MemEffAttention의 내부 로직을 그대로 수행
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]

            # Q와 K를 곱하여 어텐션 스코어 계산
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            
            # --- 우리가 필요한 어텐션 맵을 self 객체에 저장 ---
            self.attention_map = attn
            
            # 원래 forward 함수의 나머지 로직 수행
            x = (attn @ v).transpose(1, 2).reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)
            return x
        
        # ✅ 3. 마지막 어텐션 블록의 forward 함수를 우리가 만든 새 함수로 교체
        attn_block.forward = types.MethodType(new_attn_forward, attn_block)

        print(f"\n3. Running inference and extracting attention via monkey-patching...")
        
        with torch.no_grad():
            # 모델을 실행하면, 교체된 new_attn_forward가 자동으로 실행됨
            outputs = model(img_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            top_prob, top_catid = torch.max(probabilities, 0)
            predicted_class = CLASS_NAMES[top_catid]
            confidence = top_prob.item()

        # ✅ 4. 어텐션 블록의 forward 함수를 원래대로 복원
        attn_block.forward = attn_forward_orig
        
        # --- 어텐션 맵 처리 ---
        # self.attention_map에 저장된 맵을 가져옴
        attn_map_tensor = attn_block.attention_map
        
        # (이하 시각화 코드는 모두 동일)
        attn_map = attn_map_tensor[0].mean(axis=0).cpu().numpy()
        cls_attn_map = attn_map[0, 1:]
        grid_size = int(np.sqrt(cls_attn_map.shape[0]))
        attention_grid = cls_attn_map.reshape(grid_size, grid_size)
        attention_heatmap = cv2.normalize(attention_grid, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        attention_heatmap = cv2.applyColorMap(attention_heatmap, cv2.COLORMAP_JET)
        attention_heatmap = cv2.resize(attention_heatmap, (224, 224))
        superimposed_img = cv2.addWeighted(original_img, 0.6, attention_heatmap, 0.4, 0)

        # --- 결과 출력 ---
        print("\n--- Inference Result ---")
        print(f"Predicted Class: {predicted_class}")
        print(f"Confidence: {confidence:.2%}")
        # ... (이하 OpenCV 출력 코드는 동일) ...
        cv2.imshow("Superimposed Image", superimposed_img)
        cv2.imshow("Original", original_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    except Exception as e:
        print(f"ERROR: An error occurred: {e}")
        # 만약 forward 함수를 복원하기 전에 오류가 나면 복원해주는 코드
        if 'attn_forward_orig' in locals():
            model.blocks[-1].attn.forward = attn_forward_orig

In [4]:
# --- 추론할 이미지 경로 지정 ---
image_to_predict = 'C:/blooming_AI/classification_dataset/test/powdery/306942_20210914_4_1_a4_3_2_12_2_0.jpg' # 추론할 이미지 경
predict_and_visualize_attention(image_to_predict)


3. Running inference and extracting attention via monkey-patching...

--- Inference Result ---
Predicted Class: powdery
Confidence: 100.00%
