11일차 실습: U-Net으로 의료영상 분할

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import os
from sklearn.metrics import jaccard_score
import time

# GPU 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
device = torch.device('cpu')
import os
import sys
if sys.platform.startswith('win'):
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)

print("Windows 환경 설정 완료")

# CPU 강제 사용 (가장 안전)
device = torch.device('cpu')
print(f"Device: {device} (Windows 안전 모드)")

1. U-Net 모델 구현

In [None]:
class DoubleConv(nn.Module):
    """U-Net의 기본 블록: Conv -> ReLU -> Conv -> ReLU"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling: MaxPool -> DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling: ConvTranspose -> Concat -> DoubleConv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        
        if bilinear:
            # Bilinear upsampling 사용
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            # Transposed convolution 사용 (원논문 방식)
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        # x1: 이전 레이어에서 올라온 특징
        # x2: Skip connection에서 오는 특징
        
        x1 = self.up(x1)
        
        # 크기가 안 맞을 경우 패딩 (입력 크기가 2의 배수가 아닐 때)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        # Skip connection: 채널 차원으로 concatenate
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    """완전한 U-Net 아키텍처"""
    def __init__(self, n_channels=1, n_classes=2, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder (Contracting path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder (Expansive path)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        # Output layer
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)      # 64 channels
        x2 = self.down1(x1)   # 128 channels  
        x3 = self.down2(x2)   # 256 channels
        x4 = self.down3(x3)   # 512 channels
        x5 = self.down4(x4)   # 1024 channels (bottleneck)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)  # 512 channels
        x = self.up2(x, x3)   # 256 channels
        x = self.up3(x, x2)   # 128 channels  
        x = self.up4(x, x1)   # 64 channels
        
        # Output
        logits = self.outc(x)
        return logits

2. 의료영상 데이터셋 클래스

In [None]:
class LungCTDataset(Dataset):
    """폐 CT 데이터셋"""
    def __init__(self, image_dir, mask_dir, transform=None, is_train=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.is_train = is_train
        
        # 실제로는 파일 리스트를 읽어옴
        # 여기서는 가상의 데이터셋 크기 설정
        self.length = 1000 if is_train else 200
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        # 실제 환경에서는 실제 파일을 읽음
        # 여기서는 가상 데이터 생성 (512x512 크기)
        
        # 가상 CT 이미지 생성 (폐 모양)
        image = self.generate_fake_ct_image()
        
        # 가상 마스크 생성 (폐 영역)
        mask = self.generate_fake_lung_mask()
        
        # PIL Image로 변환
        image = Image.fromarray((image * 255).astype(np.uint8))
        mask = Image.fromarray((mask * 255).astype(np.uint8))
        
        if self.transform:
            image = self.transform(image)
            # 마스크는 interpolation 없이 변환
            mask_transform = transforms.Compose([
                transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.NEAREST),
                transforms.ToTensor()
            ])
            mask = mask_transform(mask)
        
        # 마스크를 0, 1로 확실히 변환
        mask = (mask > 0.5).long().squeeze(0)  # [H, W], 값은 0 또는 1
        
        return image, mask
    
    def generate_fake_ct_image(self):
        """가상 CT 이미지 생성"""
        img = np.random.randn(512, 512) * 0.1 + 0.3
        
        # 폐 모양 추가
        center_y, center_x = 256, 256
        y, x = np.ogrid[:512, :512]
        
        # 왼쪽 폐
        left_lung = ((x - 180)**2 + (y - 200)**2) < 80**2
        img[left_lung] += 0.4
        
        # 오른쪽 폐  
        right_lung = ((x - 330)**2 + (y - 200)**2) < 75**2
        img[right_lung] += 0.4
        
        # 노이즈 추가
        img += np.random.randn(512, 512) * 0.05
        
        return np.clip(img, 0, 1)
    
    def generate_fake_lung_mask(self):
        """가상 폐 마스크 생성"""
        mask = np.zeros((512, 512))
        
        # 왼쪽 폐 마스크
        y, x = np.ogrid[:512, :512]
        left_lung = ((x - 180)**2 + (y - 200)**2) < 80**2
        mask[left_lung] = 1
        
        # 오른쪽 폐 마스크
        right_lung = ((x - 330)**2 + (y - 200)**2) < 75**2
        mask[right_lung] = 1
        
        return mask

3. 데이터 전처리 및 로더

In [None]:
def get_transforms():
    """데이터 전처리 파이프라인"""
    train_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])  # [-1, 1] 정규화
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    return train_transform, val_transform

def create_data_loaders(batch_size=4):
    """데이터 로더 생성"""
    train_transform, val_transform = get_transforms()
    
    # 실제로는 실제 데이터 경로 사용
    train_dataset = LungCTDataset(
        image_dir="train/images", 
        mask_dir="train/masks",
        transform=train_transform,
        is_train=True
    )
    
    val_dataset = LungCTDataset(
        image_dir="val/images",
        mask_dir="val/masks", 
        transform=val_transform,
        is_train=False
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader

4. 손실함수 및 평가 지표

In [None]:
class DiceLoss(nn.Module):
    """Dice Loss - 의료영상 분할에서 많이 사용"""
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        # predictions: [B, C, H, W]
        # targets: [B, H, W]
        
        # 타겟 값 범위 확인
        targets = torch.clamp(targets, 0, predictions.shape[1] - 1)
        
        # Softmax 적용
        predictions = F.softmax(predictions, dim=1)
        
        # 각 클래스별로 Dice 계산
        dice_scores = []
        
        for c in range(predictions.shape[1]):
            pred_c = predictions[:, c, :, :]  # [B, H, W]
            target_c = (targets == c).float()  # [B, H, W]
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        return 1.0 - torch.mean(torch.stack(dice_scores))

class CombinedLoss(nn.Module):
    """Cross Entropy + Dice Loss 조합"""
    def __init__(self, ce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.ce_loss = nn.CrossEntropyLoss()
        self.dice_loss = DiceLoss()
    
    def forward(self, predictions, targets):
        ce = self.ce_loss(predictions, targets)
        dice = self.dice_loss(predictions, targets)
        return self.ce_weight * ce + self.dice_weight * dice

def calculate_iou(pred, target, num_classes=2):
    """IoU (Intersection over Union) 계산"""
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    
    for c in range(num_classes):
        pred_c = (pred == c)
        target_c = (target == c)
        
        intersection = (pred_c & target_c).sum().float()
        union = (pred_c | target_c).sum().float()
        
        if union == 0:
            iou = 1.0  # 해당 클래스가 없는 경우
        else:
            iou = intersection / union
        
        ious.append(iou.item())
    
    return ious

5. 학습 및 검증 함수

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """한 에포크 학습"""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (images, masks) in enumerate(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # 마스크 값 범위 확인 및 보정
        masks = torch.clamp(masks, 0, 1)  # 0과 1 사이로 제한
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
            # 디버깅 정보
            print(f'  Image range: [{images.min():.3f}, {images.max():.3f}]')
            print(f'  Mask range: [{masks.min():.0f}, {masks.max():.0f}]')
            print(f'  Mask unique values: {masks.unique()}')
    
    return total_loss / num_batches

def validate_epoch(model, dataloader, criterion, device):
    """한 에포크 검증"""
    model.eval()
    total_loss = 0
    total_iou = [0, 0]  # [background, lung]
    num_batches = 0
    
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # 예측값 계산
            pred_masks = torch.argmax(outputs, dim=1)
            
            # IoU 계산
            batch_iou = calculate_iou(pred_masks, masks)
            total_iou[0] += batch_iou[0]
            total_iou[1] += batch_iou[1]
            
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_iou = [iou / num_batches for iou in total_iou]
    
    return avg_loss, avg_iou

6. 시각화 함수

In [None]:
def visualize_predictions(model, dataloader, device, num_samples=4):
    """예측 결과 시각화"""
    model.eval()
    
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            pred_masks = torch.argmax(outputs, dim=1)
            
            # CPU로 이동 및 시각화
            images = images.cpu()
            masks = masks.cpu()
            pred_masks = pred_masks.cpu()
            
            fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
            
            for i in range(min(num_samples, images.shape[0])):
                # 원본 이미지
                img = images[i, 0].numpy()
                img = (img + 1) / 2  # [-1,1] -> [0,1]
                axes[i, 0].imshow(img, cmap='gray')
                axes[i, 0].set_title('Original CT Image')
                axes[i, 0].axis('off')
                
                # 정답 마스크
                true_mask = masks[i].numpy()
                axes[i, 1].imshow(true_mask, cmap='jet', alpha=0.8)
                axes[i, 1].set_title('Ground Truth Mask')
                axes[i, 1].axis('off')
                
                # 예측 마스크
                pred_mask = pred_masks[i].numpy()
                axes[i, 2].imshow(pred_mask, cmap='jet', alpha=0.8)
                axes[i, 2].set_title('Predicted Mask')
                axes[i, 2].axis('off')
            
            plt.tight_layout()
            plt.show()
            break

def plot_training_history(train_losses, val_losses, val_ious):
    """학습 과정 시각화"""
    epochs = range(1, len(train_losses) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss 그래프
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # IoU 그래프
    lung_ious = [iou[1] for iou in val_ious]  # 폐 영역 IoU만
    ax2.plot(epochs, lung_ious, 'g-', label='Lung IoU')
    ax2.set_title('Validation IoU (Lung)')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('IoU')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

7. 메인 학습 루프

In [None]:
def main():
    """메인 함수"""
    print("=== U-Net 의료영상 분할 실습 시작 ===")
    
    # 하이퍼파라미터
    BATCH_SIZE = 4
    NUM_EPOCHS = 20
    LEARNING_RATE = 1e-4
    
    # 데이터 로더 생성
    print("데이터 로더 생성 중...")
    train_loader, val_loader = create_data_loaders(BATCH_SIZE)
    
    # 모델 생성
    print("U-Net 모델 생성 중...")
    model = UNet(n_channels=1, n_classes=2, bilinear=True)
    model = model.to(device)
    
    # 모델 파라미터 수 출력
    total_params = sum(p.numel() for p in model.parameters())
    print(f"모델 파라미터 수: {total_params:,}")
    
    # 손실함수 및 옵티마이저
    criterion = CombinedLoss(ce_weight=0.4, dice_weight=0.6)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    
    # 학습 기록
    train_losses = []
    val_losses = []
    val_ious = []
    
    best_iou = 0.0
    
    print("학습 시작!")
    for epoch in range(NUM_EPOCHS):
        start_time = time.time()
        
        # 학습
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # 검증
        val_loss, val_iou = validate_epoch(model, val_loader, criterion, device)
        
        # 학습률 스케줄러
        scheduler.step(val_loss)
        
        # 기록 저장
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_ious.append(val_iou)
        
        # 최고 성능 모델 저장
        if val_iou[1] > best_iou:
            best_iou = val_iou[1]
            torch.save(model.state_dict(), 'best_unet_model.pth')
        
        epoch_time = time.time() - start_time
        
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}]')
        print(f'  Train Loss: {train_loss:.4f}')
        print(f'  Val Loss: {val_loss:.4f}')
        print(f'  Val IoU - Background: {val_iou[0]:.4f}, Lung: {val_iou[1]:.4f}')
        print(f'  Time: {epoch_time:.2f}s')
        print('-' * 50)
    
    print(f"학습 완료! 최고 Lung IoU: {best_iou:.4f}")
    
    # 결과 시각화
    plot_training_history(train_losses, val_losses, val_ious)
    
    # 예측 결과 시각화
    print("예측 결과 시각화...")
    model.load_state_dict(torch.load('best_unet_model.pth'))
    visualize_predictions(model, val_loader, device)
    
    return model

# 실습 실행
if __name__ == "__main__":
    # Windows multiprocessing 지원을 위해 필수
    if os.name == 'nt':  # Windows
        import multiprocessing
        multiprocessing.set_start_method('spawn', force=True)
    
    model = main()
    
    print("\n=== 실습 완료 ===")
    print("주요 학습 포인트:")
    print("1. U-Net의 Skip Connection 효과")
    print("2. 의료영상에서 Dice Loss의 중요성") 
    print("3. IoU를 통한 분할 성능 평가")
    print("4. 데이터 증강의 필요성")