In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import numpy as np

# -----------------------------
# 1. 환경 설정 및 모델 로딩
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ResNet34 아키텍처 정의
model = models.resnet34(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("models/best_resnet34_addval.pth", map_location=device))
model.to(device)
model.eval()

# -----------------------------
# 2. 이미지 전처리 정의
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet
        std=[0.229, 0.224, 0.225]
    )
])

# -----------------------------
# 3. 데이터 불러오기 및 추론
# -----------------------------
test_dir = "test"
class_to_idx = {"NORMAL": 0, "PNEUMONIA": 1}

y_true = []
y_score = []

for label in ["NORMAL", "PNEUMONIA"]:
    folder_path = os.path.join(test_dir, label)
    for img_name in os.listdir(folder_path):
        img_path = os.path.join(folder_path, img_name)
        try:
            image = Image.open(img_path).convert("RGB")
            input_tensor = transform(image).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(input_tensor)
                prob = torch.softmax(output, dim=1)[0, 1].item()  # pneumonia 확률

            y_true.append(class_to_idx[label])
            y_score.append(prob)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")

# -----------------------------
# 4. ROC 커브 계산 및 저장
# -----------------------------
fpr, tpr, thresholds = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve (Test Set)")
plt.legend(loc="lower right")
plt.grid(True)

# 저장
os.makedirs("roc", exist_ok=True)
plt.savefig("roc/roc_curve.png", dpi=300)
plt.close()

print("ROC 커브 저장 완료 → 'roc/roc_curve.png'")

FileNotFoundError: [Errno 2] No such file or directory: 'model.pth'