# 🏗️ Training U-Net với EfficientNet Backbone

Notebook này thực hiện training model **U-Net với EfficientNet backbone** cho bài toán phân đoạn khối u da:

## 🎯 Model: U-Net + EfficientNet
- **Architecture**: U-Net với EfficientNet-B0 encoder
- **Backbone**: EfficientNet-B0 (pre-trained trên ImageNet)
- **Parameters**: ~5.3M
- **Ưu điểm**: Cân bằng tốt giữa hiệu suất và tốc độ, efficient architecture
- **Learning Rate Scheduler**: ReduceLROnPlateau (adaptive)

## 📊 Training Configuration:
- **Epochs**: 25
- **Learning Rate**: 1e-4
- **Batch Size**: 8
- **Optimizer**: Adam với weight decay 1e-4
- **Loss Function**: Combined Loss (BCE + Dice)
- **Metrics**: Dice Coefficient, Jaccard Index (IoU)
- **Scheduler**: ReduceLROnPlateau (patience=3, factor=0.5)

## 1. Import thư viện và setup

In [None]:
# Cài đặt các thư viện cần thiết
!pip install segmentation-models-pytorch timm torch torchvision
!pip install albumentations opencv-python-headless
!pip install matplotlib seaborn scikit-learn pillow tqdm

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Segmentation models
import segmentation_models_pytorch as smp

# Timm for backbone
import timm

# Albumentations for augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Metrics
from sklearn.metrics import jaccard_score, f1_score

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Tạo thư mục models nếu chưa có
os.makedirs('models', exist_ok=True)
print("✅ Setup hoàn tất!")

ModuleNotFoundError: No module named 'segmentation_models_pytorch'

## 2. Dataset và Data Augmentation

In [None]:
class ISICDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None, target_size=(512, 512)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.target_size = target_size
        
        # Lấy danh sách file images
        self.image_files = sorted([f for f in os.listdir(images_dir) 
                                 if f.endswith(('.jpg', '.jpeg', '.png'))])
        
        print(f"Found {len(self.image_files)} images in {images_dir}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Find corresponding mask
        base_name = os.path.splitext(img_name)[0]
        possible_mask_names = [
            f"{base_name}_segmentation.png",
            f"{base_name}_mask.png",
            f"{base_name}.png"
        ]
        
        mask = None
        for mask_name in possible_mask_names:
            mask_path = os.path.join(self.masks_dir, mask_name)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                break
        
        if mask is None:
            # Create dummy mask if not found
            mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
        
        # Resize
        image = cv2.resize(image, self.target_size)
        mask = cv2.resize(mask, self.target_size)
        
        # Convert mask to binary
        mask = (mask > 127).astype(np.float32)
        
        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Convert to tensor if not already
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        
        if not isinstance(mask, torch.Tensor):
            mask = torch.from_numpy(mask).float().unsqueeze(0)
        
        return image, mask

# Define augmentations
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

print("✅ Dataset class và augmentations đã được định nghĩa!")

In [None]:
# Create datasets
train_dataset = ISICDataset(
    images_dir='data/train/images',
    masks_dir='data/train/ground_truth',
    transform=train_transform,
    target_size=(512, 512)
)

val_dataset = ISICDataset(
    images_dir='data/val/images',
    masks_dir='data/val/ground_truth',
    transform=val_transform,
    target_size=(512, 512)
)

# Create data loaders
batch_size = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"📊 Dataset Summary:")
print(f"   - Training samples: {len(train_dataset)}")
print(f"   - Validation samples: {len(val_dataset)}")
print(f"   - Batch size: {batch_size}")
print(f"   - Training batches: {len(train_loader)}")
print(f"   - Validation batches: {len(val_loader)}")

## 3. U-Net EfficientNet Model Definition

In [None]:
class UNetEfficientNet(nn.Module):
    def __init__(self, backbone_name="efficientnet-b0", num_classes=1, pretrained=True):
        super().__init__()
        
        # Sử dụng segmentation_models_pytorch để tạo U-Net với EfficientNet backbone
        self.model = smp.Unet(
            encoder_name=backbone_name,
            encoder_weights="imagenet" if pretrained else None,
            in_channels=3,
            classes=num_classes,
            activation=None  # We'll apply sigmoid manually
        )
        
        print(f"✅ U-Net EfficientNet model created:")
        print(f"   - Backbone: {backbone_name}")
        print(f"   - Pre-trained: {pretrained}")
        print(f"   - Number of classes: {num_classes}")
        
    def forward(self, x):
        # Forward pass through U-Net
        logits = self.model(x)
        
        # Apply sigmoid for binary segmentation
        return torch.sigmoid(logits)

# Initialize model
print("🏗️ Initializing U-Net EfficientNet model...")
model = UNetEfficientNet(
    backbone_name="efficientnet-b0",
    num_classes=1,
    pretrained=True
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"📊 Model Statistics:")
print(f"   - Total parameters: {total_params:,}")
print(f"   - Trainable parameters: {trainable_params:,}")
print(f"   - Model size: ~{total_params/1e6:.1f}M parameters")

## 4. Loss Functions và Metrics

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
        
    def forward(self, pred, target):
        return self.alpha * self.bce(pred, target) + (1 - self.alpha) * self.dice(pred, target)

def calculate_dice_batch(pred, target, threshold=0.5):
    """Calculate Dice coefficient for a batch"""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    intersection = (pred_binary * target_binary).sum()
    dice = (2. * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-6)
    
    return dice.item()

def calculate_jaccard_batch(pred, target):
    """Calculate Jaccard Index (IoU) for a batch"""
    intersection = (pred & target).float().sum()
    union = (pred | target).float().sum()
    
    jaccard = intersection / (union + 1e-6)
    return jaccard.item()

print("✅ Loss functions và metrics đã được định nghĩa!")

## 5. Training Function với ReduceLROnPlateau

In [None]:
def train_unet_efficientnet(model, train_loader, val_loader, num_epochs=25, lr=1e-4):
    """
    Training function for U-Net EfficientNet with ReduceLROnPlateau scheduler
    """
    # Loss function and optimizer
    criterion = CombinedLoss(alpha=0.5)  # Combination of BCE and Dice loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # ReduceLROnPlateau scheduler (adaptive)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, factor=0.5, verbose=True
    )
    
    # Training history
    history = {
        'train_losses': [],
        'val_losses': [],
        'train_dice': [],
        'val_dice': [],
        'train_jaccard': [],
        'val_jaccard': [],
        'learning_rates': []
    }
    
    best_val_loss = float('inf')
    
    print(f"🚀 Bắt đầu training U-Net EfficientNet...")
    print(f"📊 Configuration:")
    print(f"   - Epochs: {num_epochs}")
    print(f"   - Learning Rate: {lr}")
    print(f"   - Scheduler: ReduceLROnPlateau (patience=3, factor=0.5)")
    print(f"   - Loss: Combined (BCE + Dice)")
    print("=" * 70)
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        train_jaccard = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training')
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            # Calculate metrics
            with torch.no_grad():
                pred_masks = outputs > 0.5
                dice = calculate_dice_batch(outputs, masks)
                jaccard = calculate_jaccard_batch(pred_masks, masks.bool())
                train_dice += dice
                train_jaccard += jaccard
            
            train_loss += loss.item()
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}', 
                'Dice': f'{dice:.3f}', 
                'Jaccard': f'{jaccard:.3f}'
            })
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        val_jaccard = 0.0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation')
            for images, masks in val_pbar:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                # Calculate metrics
                pred_masks = outputs > 0.5
                dice = calculate_dice_batch(outputs, masks)
                jaccard = calculate_jaccard_batch(pred_masks, masks.bool())
                val_dice += dice
                val_jaccard += jaccard
                
                val_loss += loss.item()
                val_pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}', 
                    'Dice': f'{dice:.3f}', 
                    'Jaccard': f'{jaccard:.3f}'
                })
        
        # Calculate averages
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        train_dice /= len(train_loader)
        val_dice /= len(val_loader)
        train_jaccard /= len(train_loader)
        val_jaccard /= len(val_loader)
        
        # Update learning rate
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)  # ReduceLROnPlateau uses validation loss
        new_lr = optimizer.param_groups[0]['lr']
        
        # Store history
        history['train_losses'].append(train_loss)
        history['val_losses'].append(val_loss)
        history['train_dice'].append(train_dice)
        history['val_dice'].append(val_dice)
        history['train_jaccard'].append(train_jaccard)
        history['val_jaccard'].append(val_jaccard)
        history['learning_rates'].append(current_lr)
        
        # Print epoch summary
        if abs(current_lr - new_lr) > 1e-8:
            print(f"  Learning rate changed: {current_lr:.6f} -> {new_lr:.6f}")
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, Train Jaccard: {train_jaccard:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val Jaccard: {val_jaccard:.4f}")
        print(f"  LR: {current_lr:.6f}")
        print("-" * 70)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'models/unet_efficientnet_model_best.pth')
            print(f"New best validation loss: {val_loss:.4f}")
    
    return history

print("✅ Training function đã được định nghĩa!")

## 6. Bắt đầu Training

In [None]:
# Train the U-Net EfficientNet model
print("🚀 Bắt đầu training U-Net EfficientNet...")
print("⏰ Thời gian dự kiến: ~1-2 giờ (tùy thuộc vào GPU)")
print()

unet_efficientnet_history = train_unet_efficientnet(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=25,
    lr=1e-4
)

# Save final model
torch.save(model.state_dict(), 'models/unet_efficientnet_model_final.pth')
print("\n🎉 U-Net EfficientNet training hoàn thành!")
print("💾 Model đã được lưu:")
print("   - models/unet_efficientnet_model_best.pth (best validation loss)")
print("   - models/unet_efficientnet_model_final.pth (final epoch)")

## 7. Visualization và Evaluation

In [None]:
def plot_training_history(history, model_name="U-Net EfficientNet"):
    """Plot training history"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_losses'], label='Train Loss', color='blue')
    axes[0, 0].plot(history['val_losses'], label='Val Loss', color='red')
    axes[0, 0].set_title(f'{model_name} - Training & Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Dice Coefficient
    axes[0, 1].plot(history['train_dice'], label='Train Dice', color='blue')
    axes[0, 1].plot(history['val_dice'], label='Val Dice', color='red')
    axes[0, 1].set_title(f'{model_name} - Dice Coefficient')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Dice Score')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Jaccard Index
    axes[1, 0].plot(history['train_jaccard'], label='Train Jaccard', color='blue')
    axes[1, 0].plot(history['val_jaccard'], label='Val Jaccard', color='red')
    axes[1, 0].set_title(f'{model_name} - Jaccard Index (IoU)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Jaccard Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Learning Rate
    axes[1, 1].plot(history['learning_rates'], label='Learning Rate', color='green')
    axes[1, 1].set_title(f'{model_name} - Learning Rate Schedule')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    axes[1, 1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()

# Plot training history
print("📊 Hiển thị training history:")
plot_training_history(unet_efficientnet_history, "U-Net EfficientNet")

In [None]:
def evaluate_model_on_samples(model, val_loader, num_samples=6):
    """Evaluate model on sample images"""
    model.eval()
    
    # Get some samples
    samples = []
    with torch.no_grad():
        for i, (images, masks) in enumerate(val_loader):
            if i >= num_samples // val_loader.batch_size + 1:
                break
                
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            
            for j in range(min(images.shape[0], num_samples - len(samples))):
                # Denormalize image for visualization
                img = images[j].cpu()
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                img = img * std + mean
                img = torch.clamp(img, 0, 1)
                
                samples.append({
                    'image': img.permute(1, 2, 0).numpy(),
                    'true_mask': masks[j, 0].cpu().numpy(),
                    'pred_mask': outputs[j, 0].cpu().numpy(),
                    'pred_binary': (outputs[j, 0] > 0.5).cpu().numpy().astype(np.uint8)
                })
                
                if len(samples) >= num_samples:
                    break
    
    # Visualize results
    fig, axes = plt.subplots(4, len(samples), figsize=(20, 16))
    
    for i, sample in enumerate(samples):
        # Original image
        axes[0, i].imshow(sample['image'])
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        # True mask
        axes[1, i].imshow(sample['true_mask'], cmap='gray')
        axes[1, i].set_title(f'True Mask {i+1}')
        axes[1, i].axis('off')
        
        # Predicted mask (probability)
        axes[2, i].imshow(sample['pred_mask'], cmap='hot', vmin=0, vmax=1)
        axes[2, i].set_title(f'Pred Prob {i+1}')
        axes[2, i].axis('off')
        
        # Predicted mask (binary)
        axes[3, i].imshow(sample['pred_binary'], cmap='gray')
        axes[3, i].set_title(f'Pred Binary {i+1}')
        axes[3, i].axis('off')
    
    plt.suptitle('U-Net EfficientNet - Prediction Results', fontsize=16)
    plt.tight_layout()
    plt.show()

# Evaluate model on samples
print("🔍 Đánh giá model trên validation samples:")
evaluate_model_on_samples(model, val_loader, num_samples=6)

## 8. Final Results Summary

In [None]:
# Print final results
print("🎯 U-NET EFFICIENTNET TRAINING RESULTS")
print("=" * 50)
print(f"📊 Final Metrics (Last Epoch):")
print(f"   - Training Loss: {unet_efficientnet_history['train_losses'][-1]:.4f}")
print(f"   - Validation Loss: {unet_efficientnet_history['val_losses'][-1]:.4f}")
print(f"   - Training Dice: {unet_efficientnet_history['train_dice'][-1]:.4f}")
print(f"   - Validation Dice: {unet_efficientnet_history['val_dice'][-1]:.4f}")
print(f"   - Training Jaccard: {unet_efficientnet_history['train_jaccard'][-1]:.4f}")
print(f"   - Validation Jaccard: {unet_efficientnet_history['val_jaccard'][-1]:.4f}")

print(f"\n🏆 Best Metrics:")
best_val_dice_idx = np.argmax(unet_efficientnet_history['val_dice'])
best_val_jaccard_idx = np.argmax(unet_efficientnet_history['val_jaccard'])
print(f"   - Best Validation Dice: {max(unet_efficientnet_history['val_dice']):.4f} (Epoch {best_val_dice_idx + 1})")
print(f"   - Best Validation Jaccard: {max(unet_efficientnet_history['val_jaccard']):.4f} (Epoch {best_val_jaccard_idx + 1})")

print(f"\n💾 Saved Models:")
print(f"   - models/unet_efficientnet_model_best.pth")
print(f"   - models/unet_efficientnet_model_final.pth")

print(f"\n📈 Model Performance:")
final_dice = unet_efficientnet_history['val_dice'][-1]
final_jaccard = unet_efficientnet_history['val_jaccard'][-1]

if final_dice > 0.85:
    print(f"   ✅ Excellent performance! Dice > 0.85")
elif final_dice > 0.80:
    print(f"   ✅ Good performance! Dice > 0.80")
elif final_dice > 0.75:
    print(f"   ⚠️  Acceptable performance. Dice > 0.75")
else:
    print(f"   ❌ Performance needs improvement. Dice < 0.75")

print(f"\n🚀 Next Steps:")
print(f"   1. Run 04_train_unet_vit.ipynb to train U-Net ViT")
print(f"   2. Run 05_train_deeplabv3_resnet.ipynb to train DeepLabV3+ ResNet")
print(f"   3. Compare all models' performance")

print("\n✅ U-Net EfficientNet training completed successfully!")