<a href="https://colab.research.google.com/github/dansojo/Medical_CV/blob/main/mask%EC%83%9D%EC%84%B1(EfficientNetUNet).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import timm
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
class Config:
    DATA_DIR = "/content/drive/MyDrive/Medical_CV/피부암 분류 및 Segmentation/main(A,B)/A(train)/images"
    MASKS_DIR = "/content/drive/MyDrive/Medical_CV/피부암 분류 및 Segmentation/main(A,B)/A(train)/masks"
    METADATA_DIR = "/content/drive/MyDrive/Medical_CV/피부암 분류 및 Segmentation/HAM10000_metadata"
    SAVE_MODEL_DIR = "/content/drive/MyDrive/Medical_CV//피부암 분류 및 Segmentation/part3_datasets"
    SAVE_MASKS_DIR = "/content/drive/MyDrive/Medical_CV/피부암 분류 및 Segmentation/main(A,B)/B(test)/mask이미지(Segmentation_EfficientNetUNet)"
    TEST_DIR = "/content/drive/MyDrive/Medical_CV/피부암 분류 및 Segmentation/main(A,B)/B(test)/images"
    BATCH_SIZE = 16
    IMAGE_SIZE = (224, 224)
    NUM_CLASSES = 1  # Binary Segmentation
    EPOCHS = 20
    LR = 1e-4
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, mask_dir, metadata, image_transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.metadata = metadata
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # 이미지 경로 설정
        img_name = self.metadata.iloc[idx]['image_id']
        image_path = os.path.join(self.image_dir, img_name + ".jpg")
        mask_path = os.path.join(self.mask_dir, img_name + "_segmentation.png")

        # 이미지와 마스크 로드
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Grayscale로 로드

        # 마스크 이진화 (NumPy 배열로 변환 후 0과 1로 구성)
        mask = np.array(mask)
        mask = (mask > 128).astype(np.float32)  # 이진화 수행

        # 변환 적용
        if self.image_transform:
            image = self.image_transform(image)

        if self.mask_transform:
            mask = self.mask_transform(Image.fromarray(mask * 255))  # NumPy 배열을 PIL 이미지로 변환 후 적용

        # 마스크 채널 차원 추가 (필요한 경우)
        if len(mask.shape) == 2:
            mask = torch.unsqueeze(torch.tensor(mask), dim=0)

        return image, mask

In [None]:
def split_data(metadata_path, image_dir):
    # 메타데이터 로드
    metadata = pd.read_csv(metadata_path)

    # 실제 파일과 매칭
    image_files = set([f.split('.')[0] for f in os.listdir(image_dir)])
    metadata = metadata[metadata['image_id'].isin(image_files)]

    # 데이터 분할 (7:1.5:1.5)
    train_data, temp_data = train_test_split(metadata, test_size=0.3, random_state=42, stratify=metadata['dx'])
    val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, stratify=temp_data['dx'])

    return train_data, val_data, test_data

In [None]:
def get_data_transforms():
    # 이미지에만 적용할 변환 (정규화 포함)
    image_transform = T.Compose([
        T.Resize(Config.IMAGE_SIZE),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])

    # 마스크에만 적용할 변환 (정규화 없음)
    mask_transform = T.Compose([
        T.Resize(Config.IMAGE_SIZE),
        T.ToTensor()  # 이진 마스크는 [0, 1] 범위로 변환됨
    ])

    return image_transform, mask_transform

In [None]:
class EfficientNetUNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=1, encoder_name='efficientnet_b0'):
        super(EfficientNetUNet, self).__init__()

        # EfficientNet 인코더
        self.encoder = timm.create_model(
            encoder_name,
            pretrained=True,
            features_only=True,
            in_chans=input_channels
        )

        # 인코더의 채널 수 가져오기
        encoder_channels = self.encoder.feature_info.channels()

        # 디코더 레이어
        self.decoder4 = self._decoder_block(encoder_channels[-1], encoder_channels[-2], 512)
        self.decoder3 = self._decoder_block(512, encoder_channels[-3], 256)
        self.decoder2 = self._decoder_block(256, encoder_channels[-4], 128)
        self.decoder1 = self._decoder_block(128, encoder_channels[-5], 64)

        # 최종 출력 레이어
        self.final_conv = nn.Conv2d(64, output_channels, kernel_size=1)

    def _decoder_block(self, in_channels, skip_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # 인코더 통과
        features = self.encoder(x)

        # 디코더 통과 (스킵 커넥션 활용)
        x = self.decoder4(torch.cat([features[-1], F.interpolate(features[-2], size=features[-1].shape[2:], mode='bilinear', align_corners=False)], dim=1))
        x = self.decoder3(torch.cat([x, F.interpolate(features[-3], size=x.shape[2:], mode='bilinear', align_corners=False)], dim=1))
        x = self.decoder2(torch.cat([x, F.interpolate(features[-4], size=x.shape[2:], mode='bilinear', align_corners=False)], dim=1))
        x = self.decoder1(torch.cat([x, F.interpolate(features[-5], size=x.shape[2:], mode='bilinear', align_corners=False)], dim=1))

        # 최종 출력
        x = self.final_conv(x)
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

In [None]:
def train_model(model, train_loader, val_loader, device, num_epochs=20, learning_rate=1e-4, patience=5):
    """
    모델 학습을 수행하는 함수

    Args:
        model: 학습할 모델
        train_loader: 학습 데이터 로더
        val_loader: 검증 데이터 로더
        device: 학습에 사용할 디바이스 (cuda/cpu)
        num_epochs: 학습 에포크 수
        learning_rate: 학습률
        patience: Early stopping을 위한 patience 값

    Returns:
        model: 학습된 모델
        best_model_state: 가장 좋은 성능을 보인 모델의 가중치
        history: 학습 히스토리
    """
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCEWithLogitsLoss()  # 손실 함수: BCEWithLogitsLoss
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

    # 학습 관련 변수 초기화
    best_val_loss = float('inf')
    best_model_state = None
    early_stopping_counter = 0
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_dice': [],
        'val_dice': []
    }

    print("Starting training process...")
    print(f"Training on device: {device}")

    for epoch in range(num_epochs):
        # ----------------------
        # Training phase
        # ----------------------
        model.train()
        train_loss = 0.0
        train_intersection = 0
        train_union = 0
        train_batches = 0

        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Training]')

        for images, masks in progress_bar:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass
            loss.backward()
            optimizer.step()

            # 손실값 및 Dice coefficient 계산
            train_loss += loss.item()
            predicted_masks = (outputs > 0).float()  # BCEWithLogitsLoss 사용 시 0 기준
            intersection = (predicted_masks * masks).sum()
            union = predicted_masks.sum() + masks.sum()
            train_intersection += intersection.item()
            train_union += union.item()
            train_batches += 1

            # 프로그레스바 업데이트
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{(2 * intersection / (union + 1e-8)).item():.4f}'
            })

        avg_train_loss = train_loss / train_batches
        avg_train_dice = (2 * train_intersection) / (train_union + 1e-8)

        # ----------------------
        # Validation phase
        # ----------------------
        model.eval()
        val_loss = 0.0
        val_intersection = 0
        val_union = 0
        val_batches = 0

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Validation]')

            for images, masks in progress_bar:
                images, masks = images.to(device), masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)

                val_loss += loss.item()
                predicted_masks = (outputs > 0).float()
                intersection = (predicted_masks * masks).sum()
                union = predicted_masks.sum() + masks.sum()
                val_intersection += intersection.item()
                val_union += union.item()
                val_batches += 1

                progress_bar.set_postfix({
                    'val_loss': f'{loss.item():.4f}',
                    'val_dice': f'{(2 * intersection / (union + 1e-8)).item():.4f}'
                })

        avg_val_loss = val_loss / val_batches
        avg_val_dice = (2 * val_intersection) / (val_union + 1e-8)

        # 히스토리 업데이트
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_dice'].append(avg_train_dice)
        history['val_dice'].append(avg_val_dice)

        # 결과 출력
        print(f"\nEpoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}")

        # 학습률 스케줄러 호출
        scheduler.step(avg_val_loss)

        # 최적 모델 저장
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            early_stopping_counter = 0
            print(f"Best model updated at epoch {epoch+1}")
        else:
            early_stopping_counter += 1

        # Early stopping 체크
        if early_stopping_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

        # 중간 체크포인트 저장 (5 에포크마다)
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(
                Config.SAVE_MODEL_DIR,
                f'checkpoint_epoch_{epoch+1}.pth'
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'train_dice': avg_train_dice,
                'val_dice': avg_val_dice,
                'history': history
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}")

    print("Training completed!")
    return model, best_model_state, history


In [None]:
def save_model(save_dir, model_state, file_name="best_model.pth"):
    """
    모델 가중치를 저장하는 함수.
    """
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, file_name)
    torch.save(model_state, save_path)
    print(f"Model saved to {save_path}")

In [None]:
# 1. 데이터 분리
train_data, val_data, test_data = split_data(Config.METADATA_DIR, Config.DATA_DIR)
full_data = pd.concat([train_data, val_data, test_data])

# 2. 데이터 로드 및 변환
image_transform, mask_transform = get_data_transforms()

train_dataset = ImageDataset(Config.DATA_DIR, Config.MASKS_DIR, train_data, image_transform, mask_transform)
val_dataset = ImageDataset(Config.DATA_DIR, Config.MASKS_DIR, val_data, image_transform, mask_transform)
test_dataset = ImageDataset(Config.DATA_DIR, Config.MASKS_DIR, test_data, image_transform, mask_transform)
full_dataset = ImageDataset(Config.DATA_DIR,Config.MASKS_DIR, full_data, image_transform, mask_transform)

train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
full_loader = DataLoader(full_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

In [None]:
for images, masks in train_loader:
    print(f"Image batch shape: {images.shape}")  # 이미지 형태 출력
    print(f"Mask batch shape: {masks.shape}")    # 마스크 형태 출력
    break  # 첫 번째 배치만 출력하고 종료

Image batch shape: torch.Size([16, 3, 224, 224])
Mask batch shape: torch.Size([16, 1, 224, 224])


In [None]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [None]:
# 3. 모델 학습
model = EfficientNetUNet(input_channels=3, output_channels=1)
trained_model, best_model_state, history = train_model(model, train_loader, val_loader, Config.DEVICE)

# 최적 모델 저장
save_path = os.path.join(Config.SAVE_MODEL_DIR, "best_model_Atten.pth")
torch.save(best_model_state, save_path)
print(f"Best model saved to {save_path}")



Starting training process...
Training on device: cuda


Epoch 1/20 [Training]: 100%|██████████| 395/395 [1:40:33<00:00, 15.27s/it, loss=-2635.4277, dice=1.9693]
Epoch 1/20 [Validation]:  62%|██████▏   | 53/85 [14:30<08:00, 15.01s/it, val_loss=-2439.6760, val_dice=1.9659]

In [None]:
# Train vs Validation Loss
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train vs Validation Loss')
plt.legend()
plt.show()

# Train vs Validation Dice Coefficient
plt.figure(figsize=(10, 5))
plt.plot(history['train_dice'], label='Train Dice')
plt.plot(history['val_dice'], label='Val Dice')
plt.xlabel('Epochs')
plt.ylabel('Dice Coefficient')
plt.title('Train vs Validation Dice Coefficient')
plt.legend()
plt.show()

In [None]:
class TestImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        image_id = self.image_files[idx].split('.')[0]

        if self.transform:
            image = self.transform(image)

        return image, image_id

In [None]:
def generate_and_save_masks(model, data_loader, save_dir, device):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Generating masks"):
            images, image_ids = batch
            images = images.to(device)
            outputs = model(images)
            predictions = torch.sigmoid(outputs) > 0.5  # sigmoid 적용 후 이진화

            for pred, img_id in zip(predictions, image_ids):
                mask = pred.squeeze().cpu().numpy() * 255
                mask_path = os.path.join(save_dir, f"{img_id}_generated.png")
                cv2.imwrite(mask_path, mask.astype(np.uint8))

    print(f"Masks saved to {save_dir}")

In [None]:
# 테스트 이미지 경로
test_image_dir = Config.TEST_DIR

# 변환 정의 (학습에 사용한 것과 동일하게 적용)
test_transform = T.Compose([
    T.Resize(Config.IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

# 테스트 데이터셋 및 데이터 로더 생성
test_dataset = TestImageDataset(test_image_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2,
                        pin_memory=True)

# 학습된 모델로 테스트 데이터에 대한 마스크 생성 및 저장
generate_and_save_masks(trained_model, test_loader, Config.SAVE_MASKS_DIR, Config.DEVICE)