In [3]:
ls

 [0m[01;34mdata[0m/               [01;34mruns_multiscale_resnet[0m/        [01;34m'각종 ipynb'[0m/
 [01;34mgradcam_results[0m/    [01;34mruns_severity_classification[0m/   [01;34m김주형[0m/
 nih_train.ipynb     [01;34mruns_severity_regression[0m/      [01;34m'이전 버전'[0m/
 [01;34mruns_independent[0m/   [01;34mruns_simplified[0m/


In [7]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from glob import glob


class NIHChestXrayDataset(Dataset):
    """NIH Chest X-ray Dataset"""
    
    # 15가지 질병 라벨
    DISEASE_LABELS = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
        'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax',
        'Consolidation', 'Edema', 'Emphysema', 'Fibrosis',
        'Pleural_Thickening', 'Hernia', 'COVID-19'
    ]
    
    def __init__(self, data_root, metadata_path, split_file_path, transform=None):
        """
        Args:
            data_root: 'data/nih-chest-xrays' 경로
            metadata_path: 'total_metadata.csv' 경로
            split_file_path: 'total_train.txt' 등의 경로
            transform: 이미지 변환
        """
        self.data_root = data_root
        self.transform = transform
        
        # 메타데이터 로드
        self.metadata = pd.read_csv(metadata_path)
        self.metadata.set_index('Image Index', inplace=True)
        
        # Split 파일 로드 (train/val/test)
        with open(split_file_path, 'r') as f:
            self.image_files = [line.strip() for line in f if line.strip()]
        
        # 모든 이미지 경로를 미리 매핑
        print("Building image path mapping...")
        self.image_path_dict = self._build_image_paths()
        print(f"Loaded {len(self.image_files)} images")
        
    def _build_image_paths(self):
        """data_root 하위의 모든 png 파일 찾아서 딕셔너리로 매핑"""
        image_paths = {}
        
        # images_001 ~ images_012 탐색
        for i in range(1, 13):
            dir_name = f"images_{i:03d}"
            pattern = os.path.join(self.data_root, dir_name, "images", "*.png")
            for path in glob(pattern):
                filename = os.path.basename(path)
                image_paths[filename] = path
        
        # row_png 탐색
        pattern = os.path.join(self.data_root, "row_png", "*.png")
        for path in glob(pattern):
            filename = os.path.basename(path)
            image_paths[filename] = path
        
        print(f"Found {len(image_paths)} total images in directories")
        return image_paths
    
    def _parse_labels(self, label_string):
        """라벨 문자열을 multi-hot 벡터로 변환"""
        labels = torch.zeros(len(self.DISEASE_LABELS), dtype=torch.float32)
        
        if pd.isna(label_string) or label_string == 'No Finding':
            return labels
        
        # '|'로 구분된 라벨들 파싱
        diseases = label_string.split('|')
        for disease in diseases:
            disease = disease.strip()
            if disease in self.DISEASE_LABELS:
                idx = self.DISEASE_LABELS.index(disease)
                labels[idx] = 1.0
        
        return labels
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        
        # 이미지 로드
        if image_file in self.image_path_dict:
            img_path = self.image_path_dict[image_file]
        else:
            print(f"Warning: {image_file} not found in directories")
            # 검은 이미지 반환
            image = Image.new('RGB', (224, 224), color='black')
            labels = torch.zeros(len(self.DISEASE_LABELS), dtype=torch.float32)
            if self.transform:
                image = self.transform(image)
            return image, labels
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (224, 224), color='black')
        
        # Transform 적용
        if self.transform:
            image = self.transform(image)
        
        # 라벨 추출
        if image_file in self.metadata.index:
            label_string = self.metadata.loc[image_file, 'Finding Labels']
            labels = self._parse_labels(label_string)
        else:
            print(f"Warning: {image_file} not in metadata")
            labels = torch.zeros(len(self.DISEASE_LABELS), dtype=torch.float32)
        
        return image, labels


def get_transforms(is_training=True, img_size=224):
    """기본 Transform"""
    if is_training:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])


def create_dataloaders(data_root, metadata_path, 
                       train_split, val_split, test_split,
                       batch_size=32, num_workers=4, img_size=224):
    """DataLoader 생성"""
    
    train_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=train_split,
        transform=get_transforms(is_training=True, img_size=img_size)
    )
    
    val_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=val_split,
        transform=get_transforms(is_training=False, img_size=img_size)
    )
    
    test_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=test_split,
        transform=get_transforms(is_training=False, img_size=img_size)
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader


# 사용 예제
if __name__ == "__main__":
    train_loader, val_loader, test_loader = create_dataloaders(
        data_root="data/nih-chest-xrays",
        metadata_path="data/nih-chest-xrays/total_metadata.csv",
        train_split="data/nih-chest-xrays/total_train.txt",
        val_split="data/nih-chest-xrays/total_val.txt",
        test_split="data/nih-chest-xrays/total_test.txt",
        batch_size=32,
        num_workers=4
    )
    
    print(f"\nDataset sizes:")
    print(f"Train: {len(train_loader.dataset)}")
    print(f"Val: {len(val_loader.dataset)}")
    print(f"Test: {len(test_loader.dataset)}")
    
    # 샘플 확인
    images, labels = next(iter(train_loader))
    print(f"\nBatch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Sample labels: {labels[0]}")

Building image path mapping...
Found 105215 total images in directories
Loaded 72571 images
Building image path mapping...
Found 105215 total images in directories
Loaded 17731 images
Building image path mapping...
Found 105215 total images in directories
Loaded 26513 images

Dataset sizes:
Train: 72571
Val: 17731
Test: 26513


Batch shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32, 15])
Sample labels: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [5]:
from pathlib import Path

file_path = Path('data/nih-chest-xrays/row_png/108115246579239728.png')

if file_path.exists():
    print(f"✅ 파일이 존재합니다: {file_path}")
else:
    print(f"❌ 파일이 존재하지 않습니다: {file_path}")

# 파일인지 디렉토리인지 확인
if file_path.is_file():
    print("이것은 파일입니다.")
elif file_path.is_dir():
    print("이것은 디렉토리입니다.")

❌ 파일이 존재하지 않습니다: data/nih-chest-xrays/row_png/108115246579239728.png


In [10]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from torchvision import transforms
from glob import glob
import torchxrayvision as xrv
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


class NIHChestXrayDataset(Dataset):
    """NIH + Brixia(COVID-19) Chest X-ray Dataset"""
    
    # 15가지 질병 라벨 (COVID-19 추가!)
    DISEASE_LABELS = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
        'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax',
        'Consolidation', 'Edema', 'Emphysema', 'Fibrosis',
        'Pleural_Thickening', 'Hernia', 'COVID-19'
    ]
    
    def __init__(self, data_root, metadata_path, split_file_path, transform=None):
        """
        Args:
            data_root: 'data/nih-chest-xrays' 경로
            metadata_path: 'total_metadata.csv' 경로
            split_file_path: 'total_train.txt' 등의 경로
            transform: 이미지 변환
        """
        self.data_root = data_root
        self.transform = transform
        
        # 메타데이터 로드
        self.metadata = pd.read_csv(metadata_path, low_memory=False)
        self.metadata.set_index('Image Index', inplace=True)
        
        # 모든 이미지 경로를 미리 매핑
        print("Building image path mapping...")
        self.image_path_dict = self._build_image_paths()
        print(f"Found {len(self.image_path_dict)} total images in directories")
        
        # Split 파일 로드 후 실제 존재하는 파일만 필터링
        with open(split_file_path, 'r') as f:
            all_files = [line.strip() for line in f if line.strip()]
        
        # 실제 존재하는 파일만 남김
        self.image_files = [f for f in all_files if f in self.image_path_dict]
        
        missing = len(all_files) - len(self.image_files)
        if missing > 0:
            print(f"Warning: {missing} files from split not found in directories")
        
        print(f"Loaded {len(self.image_files)} valid images")
        
    def _build_image_paths(self):
        """data_root 하위의 모든 png 파일 찾아서 딕셔너리로 매핑"""
        image_paths = {}
        
        # images_001 ~ images_012 탐색
        for i in range(1, 13):
            dir_name = f"images_{i:03d}"
            pattern = os.path.join(self.data_root, dir_name, "images", "*.png")
            for path in glob(pattern):
                filename = os.path.basename(path)
                image_paths[filename] = path
        
        # row_png 탐색 (Brixia COVID-19)
        pattern = os.path.join(self.data_root, "row_png", "*.png")
        for path in glob(pattern):
            filename = os.path.basename(path)
            image_paths[filename] = path
        
        return image_paths
    
    def _parse_labels(self, label_string):
        """라벨 문자열을 multi-hot 벡터로 변환"""
        labels = torch.zeros(len(self.DISEASE_LABELS), dtype=torch.float32)
        
        if pd.isna(label_string) or label_string == 'No Finding':
            return labels
        
        # '|'로 구분된 라벨들 파싱
        diseases = label_string.split('|')
        for disease in diseases:
            disease = disease.strip()
            if disease in self.DISEASE_LABELS:
                idx = self.DISEASE_LABELS.index(disease)
                labels[idx] = 1.0
        
        return labels
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        
        # 이미지 로드 (GRAYSCALE for TorchXRayVision!)
        img_path = self.image_path_dict[image_file]
        
        try:
            image = Image.open(img_path).convert('L')  # Grayscale!
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('L', (224, 224), color=0)
        
        # Transform 적용
        if self.transform:
            image = self.transform(image)
        
        # 라벨 추출
        if image_file in self.metadata.index:
            label_string = self.metadata.loc[image_file, 'Finding Labels']
            labels = self._parse_labels(label_string)
        else:
            # 메타데이터에 없으면 0 벡터
            labels = torch.zeros(len(self.DISEASE_LABELS), dtype=torch.float32)
        
        return image, labels


def get_transforms(is_training=True, img_size=224):
    """Grayscale Transform for TorchXRayVision"""
    if is_training:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            # TorchXRayVision expects normalized images
        ])
    else:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])


def create_dataloaders(data_root, metadata_path, 
                       train_split, val_split, test_split,
                       batch_size=32, num_workers=4, img_size=224):
    """DataLoader 생성"""
    
    print("\n=== Creating Train Dataset ===")
    train_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=train_split,
        transform=get_transforms(is_training=True, img_size=img_size)
    )
    
    print("\n=== Creating Val Dataset ===")
    val_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=val_split,
        transform=get_transforms(is_training=False, img_size=img_size)
    )
    
    print("\n=== Creating Test Dataset ===")
    test_dataset = NIHChestXrayDataset(
        data_root=data_root,
        metadata_path=metadata_path,
        split_file_path=test_split,
        transform=get_transforms(is_training=False, img_size=img_size)
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader


class XRayMultiLabelClassifier(nn.Module):
    """TorchXRayVision 기반 다중라벨 분류 모델"""
    
    def __init__(self, num_classes=15, pretrained=True, model_name='densenet121-res224-all'):
        """
        Args:
            num_classes: 출력 클래스 수 (NIH 14 + COVID-19 = 15)
            pretrained: TorchXRayVision 사전학습 가중치 사용 여부
            model_name: 모델 종류
        """
        super().__init__()
        
        if pretrained:
            # TorchXRayVision 사전학습 모델 로드
            self.backbone = xrv.models.DenseNet(weights=model_name)
            print(f"Loaded pretrained model: {model_name}")
            print(f"Original output classes: {self.backbone.classifier.out_features}")
        else:
            # 사전학습 없이 초기화
            self.backbone = xrv.models.DenseNet(weights=None)
        
        # 기존 classifier 교체 (15 classes)
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(in_features, num_classes)
        self.backbone.op_threshs = None
        
        print(f"Model created with {num_classes} output classes")
    
    def forward(self, x):
        return self.backbone(x)


def calculate_metrics(outputs, labels, threshold=0.5):
    """정확도 계산"""
    preds = (torch.sigmoid(outputs) > threshold).float()
    correct = (preds == labels).float()
    
    # Per-sample accuracy (모든 라벨이 정확히 맞은 비율)
    exact_match = (correct.sum(dim=1) == labels.size(1)).float().mean()
    
    # Per-label accuracy (각 라벨별 정확도 평균)
    per_label_acc = correct.mean()
    
    return exact_match.item(), per_label_acc.item()


def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device):
    """1 epoch 학습 with AMP"""
    model.train()
    total_loss = 0.0
    total_exact_match = 0.0
    total_per_label_acc = 0.0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # Mixed Precision Training
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Backward with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Metrics
        exact_match, per_label_acc = calculate_metrics(outputs, labels)
        
        total_loss += loss.item()
        total_exact_match += exact_match
        total_per_label_acc += per_label_acc
        
        # Progress bar 업데이트
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{per_label_acc:.4f}'
        })
    
    avg_loss = total_loss / len(train_loader)
    avg_exact_match = total_exact_match / len(train_loader)
    avg_per_label_acc = total_per_label_acc / len(train_loader)
    
    return avg_loss, avg_exact_match, avg_per_label_acc


def validate(model, val_loader, criterion, device):
    """검증"""
    model.eval()
    total_loss = 0.0
    total_exact_match = 0.0
    total_per_label_acc = 0.0
    
    pbar = tqdm(val_loader, desc="Validation", leave=False)
    
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Metrics
            exact_match, per_label_acc = calculate_metrics(outputs, labels)
            
            total_loss += loss.item()
            total_exact_match += exact_match
            total_per_label_acc += per_label_acc
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{per_label_acc:.4f}'
            })
    
    avg_loss = total_loss / len(val_loader)
    avg_exact_match = total_exact_match / len(val_loader)
    avg_per_label_acc = total_per_label_acc / len(val_loader)
    
    return avg_loss, avg_exact_match, avg_per_label_acc


def plot_training_history(history, save_path='training_history.png'):
    """학습 히스토리 시각화"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training & Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy plot
    axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
    axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training & Validation Accuracy (Per-Label)')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Training history plot saved to: {save_path}")
    plt.close()


class EarlyStopping:
    """Early Stopping"""
    def __init__(self, patience=5, min_delta=0.0, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False
        
        if self.mode == 'min':
            improved = score < (self.best_score - self.min_delta)
        else:
            improved = score > (self.best_score + self.min_delta)
        
        if improved:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                return True
        
        return False


def main():
    # ==================== 설정 ====================
    DATA_ROOT = "data/nih-chest-xrays"
    METADATA_PATH = "data/nih-chest-xrays/total_metadata.csv"
    TRAIN_SPLIT = "data/nih-chest-xrays/total_train.txt"
    VAL_SPLIT = "data/nih-chest-xrays/total_val.txt"
    TEST_SPLIT = "data/nih-chest-xrays/total_test.txt"
    
    BATCH_SIZE = 64  # V100 32GB -> 배치 크기 증가
    NUM_WORKERS = 8
    IMG_SIZE = 224
    NUM_EPOCHS = 50
    LEARNING_RATE = 0.001
    NUM_CLASSES = 15  # NIH 14 + COVID-19 1
    EARLY_STOPPING_PATIENCE = 7
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # ==================== DataLoader 생성 ====================
    train_loader, val_loader, test_loader = create_dataloaders(
        data_root=DATA_ROOT,
        metadata_path=METADATA_PATH,
        train_split=TRAIN_SPLIT,
        val_split=VAL_SPLIT,
        test_split=TEST_SPLIT,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        img_size=IMG_SIZE
    )
    
    print(f"\n{'='*60}")
    print(f"Dataset Summary:")
    print(f"  Train: {len(train_loader.dataset):,} images ({len(train_loader)} batches)")
    print(f"  Val:   {len(val_loader.dataset):,} images ({len(val_loader)} batches)")
    print(f"  Test:  {len(test_loader.dataset):,} images ({len(test_loader)} batches)")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"{'='*60}\n")
    
    # ==================== 모델 생성 ====================
    print("=== Creating Model ===")
    model = XRayMultiLabelClassifier(
        num_classes=NUM_CLASSES,
        pretrained=True,
        model_name='densenet121-res224-all'
    )
    model = model.to(device)
    
    # ==================== Loss & Optimizer ====================
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Mixed Precision Scaler
    scaler = GradScaler()
    
    # Early Stopping
    early_stopping = EarlyStopping(patience=EARLY_STOPPING_PATIENCE, mode='min')
    
    # ==================== 학습 ====================
    print("\n" + "="*60)
    print("=== Training Start ===")
    print("="*60 + "\n")
    
    best_val_loss = float('inf')
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_exact_match': [], 'val_exact_match': []
    }
    
    for epoch in range(NUM_EPOCHS):
        print(f"\n{'='*60}")
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
        print(f"{'='*60}")
        
        # Train
        train_loss, train_exact, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler, device
        )
        
        # Validation
        val_loss, val_exact, val_acc = validate(
            model, val_loader, criterion, device
        )
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # History 저장
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_exact_match'].append(train_exact)
        history['val_exact_match'].append(val_exact)
        
        # 결과 출력
        print(f"\n✓ Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Exact Match: {train_exact:.4f}")
        print(f"✓ Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Exact Match: {val_exact:.4f}")
        print(f"✓ LR: {current_lr:.6f}")
        
        # 모델 저장
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, 'best_model.pth')
            print(f"★ Best model saved! (Val Loss: {val_loss:.4f})")
        
        # Early Stopping 체크
        if early_stopping(val_loss):
            print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
            break
    
    # ==================== 학습 완료 ====================
    print("\n" + "="*60)
    print("=== Training Completed ===")
    print(f"Best Val Loss: {best_val_loss:.4f}")
    print("="*60)
    
    # 학습 히스토리 플롯
    plot_training_history(history)
    
    # ==================== Test 평가 ====================
    print("\n=== Testing on Best Model ===")
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    
    test_loss, test_exact, test_acc = validate(model, test_loader, criterion, device)
    
    print(f"\n{'='*60}")
    print("Test Results:")
    print(f"  Loss: {test_loss:.4f}")
    print(f"  Per-Label Accuracy: {test_acc:.4f}")
    print(f"  Exact Match Accuracy: {test_exact:.4f}")
    print("="*60)
    
    # 최종 결과 요약
    print(f"\n{'='*60}")
    print("Final Summary:")
    print(f"  Best Val Loss: {best_val_loss:.4f}")
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.4f}")
    print("="*60 + "\n")


if __name__ == "__main__":
    # 질병 라벨 출력
    print("\n" + "="*60)
    print("Disease Labels (15 classes):")
    for i, label in enumerate(NIHChestXrayDataset.DISEASE_LABELS):
        print(f"  {i:2d}. {label}")
    print("="*60 + "\n")
    
    main()


Disease Labels (15 classes):
   0. Atelectasis
   1. Cardiomegaly
   2. Effusion
   3. Infiltration
   4. Mass
   5. Nodule
   6. Pneumonia
   7. Pneumothorax
   8. Consolidation
   9. Edema
  10. Emphysema
  11. Fibrosis
  12. Pleural_Thickening
  13. Hernia
  14. COVID-19

Using device: cuda
GPU: Tesla V100-SXM2-32GB
GPU Memory: 34.07 GB

=== Creating Train Dataset ===
Building image path mapping...
Found 105215 total images in directories
Loaded 65716 valid images

=== Creating Val Dataset ===
Building image path mapping...
Found 105215 total images in directories
Loaded 15784 valid images

=== Creating Test Dataset ===
Building image path mapping...
Found 105215 total images in directories
Loaded 23715 valid images

Dataset Summary:
  Train: 65,716 images (1026 batches)
  Val:   15,784 images (247 batches)
  Test:  23,715 images (371 batches)
  Batch Size: 64

=== Creating Model ===
Loaded pretrained model: densenet121-res224-all
Original output classes: 18
Model created with 15 o

                                                                                      


✓ Train - Loss: 0.1430, Acc: 0.9579, Exact Match: 0.5799
✓ Val   - Loss: 0.1288, Acc: 0.9597, Exact Match: 0.5920
✓ LR: 0.001000
★ Best model saved! (Val Loss: 0.1288)

Epoch [2/50]


                                                                                      


✓ Train - Loss: 0.1294, Acc: 0.9588, Exact Match: 0.5827
✓ Val   - Loss: 0.1264, Acc: 0.9599, Exact Match: 0.5910
✓ LR: 0.001000
★ Best model saved! (Val Loss: 0.1264)

Epoch [3/50]


                                                                                      


✓ Train - Loss: 0.1276, Acc: 0.9589, Exact Match: 0.5835
✓ Val   - Loss: 0.1280, Acc: 0.9598, Exact Match: 0.5902
✓ LR: 0.001000

Epoch [4/50]


                                                                                      


✓ Train - Loss: 0.1263, Acc: 0.9590, Exact Match: 0.5834
✓ Val   - Loss: 0.1277, Acc: 0.9600, Exact Match: 0.5933
✓ LR: 0.001000

Epoch [5/50]


                                                                                      


✓ Train - Loss: 0.1252, Acc: 0.9592, Exact Match: 0.5848
✓ Val   - Loss: 0.1252, Acc: 0.9599, Exact Match: 0.5908
✓ LR: 0.001000
★ Best model saved! (Val Loss: 0.1252)

Epoch [6/50]


                                                                                      


✓ Train - Loss: 0.1241, Acc: 0.9593, Exact Match: 0.5853
✓ Val   - Loss: 0.1285, Acc: 0.9597, Exact Match: 0.5907
✓ LR: 0.001000

Epoch [7/50]


                                                                                      


✓ Train - Loss: 0.1234, Acc: 0.9593, Exact Match: 0.5851
✓ Val   - Loss: 0.1311, Acc: 0.9595, Exact Match: 0.5889
✓ LR: 0.001000

Epoch [8/50]


                                                                                      


✓ Train - Loss: 0.1224, Acc: 0.9596, Exact Match: 0.5869
✓ Val   - Loss: 0.1244, Acc: 0.9596, Exact Match: 0.5872
✓ LR: 0.001000
★ Best model saved! (Val Loss: 0.1244)

Epoch [9/50]


                                                                                      


✓ Train - Loss: 0.1218, Acc: 0.9595, Exact Match: 0.5858
✓ Val   - Loss: 0.1256, Acc: 0.9597, Exact Match: 0.5883
✓ LR: 0.001000

Epoch [10/50]


                                                                                      


✓ Train - Loss: 0.1211, Acc: 0.9596, Exact Match: 0.5865
✓ Val   - Loss: 0.1293, Acc: 0.9596, Exact Match: 0.5887
✓ LR: 0.001000

Epoch [11/50]


                                                                                      


✓ Train - Loss: 0.1205, Acc: 0.9598, Exact Match: 0.5878
✓ Val   - Loss: 0.1234, Acc: 0.9596, Exact Match: 0.5880
✓ LR: 0.001000
★ Best model saved! (Val Loss: 0.1234)

Epoch [12/50]


                                                                                      


✓ Train - Loss: 0.1199, Acc: 0.9598, Exact Match: 0.5876
✓ Val   - Loss: 0.1236, Acc: 0.9597, Exact Match: 0.5889
✓ LR: 0.001000

Epoch [13/50]


                                                                                      


✓ Train - Loss: 0.1192, Acc: 0.9599, Exact Match: 0.5876
✓ Val   - Loss: 0.1255, Acc: 0.9595, Exact Match: 0.5860
✓ LR: 0.001000

Epoch [14/50]


                                                                                      


✓ Train - Loss: 0.1185, Acc: 0.9600, Exact Match: 0.5884
✓ Val   - Loss: 0.1357, Acc: 0.9596, Exact Match: 0.5909
✓ LR: 0.001000

Epoch [15/50]


                                                                                      


✓ Train - Loss: 0.1180, Acc: 0.9601, Exact Match: 0.5893
✓ Val   - Loss: 0.1260, Acc: 0.9598, Exact Match: 0.5897
✓ LR: 0.000500

Epoch [16/50]


                                                                                      


✓ Train - Loss: 0.1150, Acc: 0.9605, Exact Match: 0.5916
✓ Val   - Loss: 0.1242, Acc: 0.9597, Exact Match: 0.5874
✓ LR: 0.000500

Epoch [17/50]


                                                                                      


✓ Train - Loss: 0.1142, Acc: 0.9608, Exact Match: 0.5926
✓ Val   - Loss: 0.1246, Acc: 0.9597, Exact Match: 0.5880
✓ LR: 0.000500

Epoch [18/50]


                                                                                      


✓ Train - Loss: 0.1134, Acc: 0.9608, Exact Match: 0.5938
✓ Val   - Loss: 0.1255, Acc: 0.9596, Exact Match: 0.5879
✓ LR: 0.000500

⚠ Early stopping triggered at epoch 18

=== Training Completed ===
Best Val Loss: 0.1234

✓ Training history plot saved to: training_history.png

=== Testing on Best Model ===


                                                                                      


Test Results:
  Loss: 0.1941
  Per-Label Accuracy: 0.9303
  Exact Match Accuracy: 0.3792

Final Summary:
  Best Val Loss: 0.1234
  Test Loss: 0.1941
  Test Accuracy: 0.9303

