In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import classification_report
import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import wandb
# 환경 변수 설정: CUDA 비동기 오류 디버깅을 위해 설정
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
# 데이터셋 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.image_paths = []
        self.labels = []

        normal_dir = os.path.join(data_dir, 'NORMAL')
        for img_name in os.listdir(normal_dir):
            self.image_paths.append(os.path.join(normal_dir, img_name))
            self.labels.append(0)

        pneumonia_dir = os.path.join(data_dir, 'PNEUMONIA')
        for img_name in os.listdir(pneumonia_dir):
            self.image_paths.append(os.path.join(pneumonia_dir, img_name))
            self.labels.append(1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('L')
        label = self.labels[idx]
        return np.array(image), label
    
# transform을 적용한 새로운 데이터셋 클래스 정의
class TransformedDataset(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = Subset(dataset, indices)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

In [None]:
# 모델 정의
class EnhancedCNN(nn.Module): 
    def __init__(self): 
        super(EnhancedCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128*28*28, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, x): 
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128*28*28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# 모델 학습 함수 정의
def train():
    # wandb 설정 초기화
    wandb.init()
    config = wandb.config  # 하이퍼파라미터 설정

    # 데이터셋 로드 및 전처리
    full_dataset = CustomDataset(data_dir='chest_xray/train')
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_indices, val_indices = random_split(full_dataset, [train_size, val_size])

    train_transform = A.Compose([
        A.Resize(224, 224),
        A.Rotate(limit=5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=0, p=0.5),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
        A.Normalize(mean=(0.5,), std=(0.5,)),
        ToTensorV2()
    ])

    val_transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=(0.5,), std=(0.5,)),
        ToTensorV2()
    ])

    train_dataset = TransformedDataset(full_dataset, train_indices.indices, transform=train_transform)
    val_dataset = TransformedDataset(full_dataset, val_indices.indices, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

    # 모델, 손실 함수, 옵티마이저 설정
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EnhancedCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    early_stop_counter = 0
    patience = 7

    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0
        all_train_labels = []
        all_train_predictions = []
        correct = 0
        total = 0

        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            all_train_labels.extend(labels.cpu().numpy())
            all_train_predictions.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_accuracy = 100 * correct / total
        train_report = classification_report(all_train_labels, all_train_predictions, target_names=['NORMAL', 'PNEUMONIA'], output_dict=True)
        train_recall = train_report['weighted avg']['recall']
        print(f'Epoch [{epoch+1}/{config.epochs}] - Loss: {running_loss / len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%, recall: {train_recall:.4f}')

        wandb.log({"Train Loss": running_loss / len(train_loader), "Train Accuracy": train_accuracy, "Train Recall": train_recall})

        model.eval()
        val_loss = 0.0
        all_val_labels = []
        all_val_predictions = []
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                all_val_labels.extend(labels.cpu().numpy())
                all_val_predictions.extend(predicted.cpu().numpy())
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        val_report = classification_report(all_val_labels, all_val_predictions, target_names=['NORMAL', 'PNEUMONIA'], output_dict=True)
        val_recall = val_report['weighted avg']['recall']
        print(f'Validation Loss: {val_loss / len(val_loader):.4f}, Valid Accuracy: {val_accuracy:.2f}%, recall: {val_recall:.4f}')

        wandb.log({"Validation Loss": val_loss / len(val_loader), "Validation Accuracy": val_accuracy, "Validation Recall": val_recall})

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # 테스트 데이터셋 준비
    test_transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=(0.5,), std=(0.5,)),
        ToTensorV2()
    ])
    test_dataset = CustomDataset(data_dir='chest_xray/test')
    test_dataset = TransformedDataset(test_dataset, range(len(test_dataset)), transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

    # 모델 평가 모드로 설정
    model.eval()
    all_labels = []
    all_predictions = []
    total = 0
    correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Classification Report 출력 및 Weighted Average Recall 추출
    report = classification_report(all_labels, all_predictions, target_names=['NORMAL', 'PNEUMONIA'], output_dict=True)
    weighted_avg_recall = report['weighted avg']['recall']
    print("Classification Report:")
    print(classification_report(all_labels, all_predictions, target_names=['NORMAL', 'PNEUMONIA']))
    print(f"Weighted Average Recall: {weighted_avg_recall:.4f}")

    # 정확도 출력
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.4f}%')

    # wandb에 최종 성능 로그
    wandb.log({"Test Accuracy": accuracy, "Test Weighted Avg Recall": weighted_avg_recall})



In [None]:
# wandb sweep 설정
sweep_config = {
    'method': 'bayes',  # 하이퍼파라미터 검색 방법: grid, random, bayes
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},  # 최적화할 메트릭
    'parameters': {
        'learning_rate': {'min': 0.0001, 'max': 0.01},  # 학습률 검색 범위
        'batch_size': {'values': [8, 16, 32]},  # 배치 크기 후보
        'epochs': {'values': [10, 20, 30]},  # 에포크 수 후보
        'optimizer': {'values': ['adam', 'sgd']}  # 최적화 알고리즘 후보
    },
    # Early Stopping 설정 추가
    'early_terminate': {
        'type': 'hyperband',  # 조기 종료 방법
        'min_iter': 5         # 최소 반복 수
    },
    # Sweep 이름 추가
    'name': 'my-optimized-sweep-experiment'
}

# sweep 생성
sweep_id = wandb.sweep(sweep_config, project="pneumonia-detection")

In [None]:
# 에이전트 실행
wandb.agent(sweep_id, function=train)