In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import cv2
from PIL import Image

# Import custom modules
from architectures.efficientnet_unet import build_efficientnet_v2_s_unet
from dataloaders.segmentation_dataset import SegmentationDataset, get_transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Configuration
CONFIG = {
    'data_dir': r'path/to/your/dataset',  # Thay đổi path này
    'image_size': 256,
    'batch_size': 8,
    'num_epochs': 100,
    'learning_rate': 1e-4,
    'num_classes': 1,
    'model_save_path': 'efficientnet_v2s_segmentation.pth',
    'test_output_dir': 'test_predictions'
}

# Data paths
train_image_dir = os.path.join(CONFIG['data_dir'], 'Train', 'Image')
train_mask_dir = os.path.join(CONFIG['data_dir'], 'Train', 'Mask')
val_image_dir = os.path.join(CONFIG['data_dir'], 'Val', 'Image')
val_mask_dir = os.path.join(CONFIG['data_dir'], 'Val', 'Mask')
test_image_dir = os.path.join(CONFIG['data_dir'], 'Test', 'Image')

print("Data directories:")
print(f"Train images: {train_image_dir}")
print(f"Train masks: {train_mask_dir}")
print(f"Val images: {val_image_dir}")
print(f"Val masks: {val_mask_dir}")
print(f"Test images: {test_image_dir}")

In [None]:
# Create datasets
train_transform = get_transforms(CONFIG['image_size'], is_train=True)
val_transform = get_transforms(CONFIG['image_size'], is_train=False)
test_transform = get_transforms(CONFIG['image_size'], is_train=False)

train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, train_transform)
val_dataset = SegmentationDataset(val_image_dir, val_mask_dir, val_transform)
test_dataset = SegmentationDataset(test_image_dir, transform=test_transform, is_test=True)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Create model
model = build_efficientnet_v2_s_unet(
    num_classes=CONFIG['num_classes'], 
    pretrained=True
).to(device)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5)

print("Model created successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def dice_coefficient(pred, target, smooth=1e-6):
    """Calculate Dice coefficient"""
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.item()

def iou_score(pred, target, smooth=1e-6):
    """Calculate IoU score"""
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_dice = 0
    total_iou = 0
    num_batches = len(loader)
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        dice = dice_coefficient(outputs, masks)
        iou = iou_score(outputs, masks)
        
        total_loss += loss.item()
        total_dice += dice
        total_iou += iou
        
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Dice': f'{dice:.4f}',
            'IoU': f'{iou:.4f}'
        })
    
    return total_loss / num_batches, total_dice / num_batches, total_iou / num_batches

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    total_dice = 0
    total_iou = 0
    num_batches = len(loader)
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation')
        for batch in pbar:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            dice = dice_coefficient(outputs, masks)
            iou = iou_score(outputs, masks)
            
            total_loss += loss.item()
            total_dice += dice
            total_iou += iou
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Dice': f'{dice:.4f}',
                'IoU': f'{iou:.4f}'
            })
    
    return total_loss / num_batches, total_dice / num_batches, total_iou / num_batches

In [None]:
# Training loop
train_losses = []
val_losses = []
train_dices = []
val_dices = []
train_ious = []
val_ious = []

best_val_dice = 0.0

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 50)
    
    # Train
    train_loss, train_dice, train_iou = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_dice, val_iou = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_dices.append(train_dice)
    val_dices.append(val_dice)
    train_ious.append(train_iou)
    val_ious.append(val_iou)
    
    print(f"Train - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}, IoU: {train_iou:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}, IoU: {val_iou:.4f}")
    print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save best model
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_dice': best_val_dice,
            'config': CONFIG
        }, CONFIG['model_save_path'])
        print(f"New best model saved! Val Dice: {best_val_dice:.4f}")

print(f"\nTraining completed! Best Val Dice: {best_val_dice:.4f}")

In [None]:
# Plot training metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Loss
axes[0].plot(train_losses, label='Train Loss')
axes[0].plot(val_losses, label='Val Loss')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True)

# Dice
axes[1].plot(train_dices, label='Train Dice')
axes[1].plot(val_dices, label='Val Dice')
axes[1].set_title('Dice Coefficient')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Dice')
axes[1].legend()
axes[1].grid(True)

# IoU
axes[2].plot(train_ious, label='Train IoU')
axes[2].plot(val_ious, label='Val IoU')
axes[2].set_title('IoU Score')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('IoU')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Load best model for testing
checkpoint = torch.load(CONFIG['model_save_path'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Create output directory
os.makedirs(CONFIG['test_output_dir'], exist_ok=True)

print("Generating predictions on test set...")

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Predicting'):
        images = batch['image'].to(device)
        filenames = batch['filename']
        
        outputs = model(images)
        predictions = torch.sigmoid(outputs)
        predictions = (predictions > 0.5).float()
        
        for i, filename in enumerate(filenames):
            # Convert prediction to numpy
            pred_mask = predictions[i].cpu().numpy().squeeze()
            
            # Convert to 0-255 range
            pred_mask = (pred_mask * 255).astype(np.uint8)
            
            # Save prediction
            output_filename = filename.replace('.jpg', '.png')
            output_path = os.path.join(CONFIG['test_output_dir'], output_filename)
            cv2.imwrite(output_path, pred_mask)

print(f"Predictions saved to: {CONFIG['test_output_dir']}")

In [None]:
# Visualize some predictions
def visualize_predictions(num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 5))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Get some test samples
    test_iter = iter(test_loader)
    
    with torch.no_grad():
        for i in range(num_samples):
            try:
                batch = next(test_iter)
                image = batch['image'].to(device)
                filename = batch['filename'][0]
                
                # Get prediction
                output = model(image)
                prediction = torch.sigmoid(output)
                prediction = (prediction > 0.5).float()
                
                # Convert to numpy for visualization
                image_np = image[0].cpu().numpy().transpose(1, 2, 0)
                # Denormalize
                image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                image_np = np.clip(image_np, 0, 1)
                
                pred_np = prediction[0].cpu().numpy().squeeze()
                
                # Plot
                axes[i, 0].imshow(image_np)
                axes[i, 0].set_title(f'Original: {filename}')
                axes[i, 0].axis('off')
                
                axes[i, 1].imshow(pred_np, cmap='gray')
                axes[i, 1].set_title('Prediction')
                axes[i, 1].axis('off')
                
                axes[i, 2].imshow(image_np)
                axes[i, 2].imshow(pred_np, alpha=0.5, cmap='Reds')
                axes[i, 2].set_title('Overlay')
                axes[i, 2].axis('off')
                
            except StopIteration:
                break
    
    plt.tight_layout()
    plt.savefig('test_predictions_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()

visualize_predictions(5)

In [None]:
# Training summary
print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Model: EfficientNet-V2-S UNet")
print(f"Dataset: Custom Segmentation")
print(f"Total epochs: {CONFIG['num_epochs']}")
print(f"Best validation Dice: {best_val_dice:.4f}")
print(f"Image size: {CONFIG['image_size']}x{CONFIG['image_size']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Model saved at: {CONFIG['model_save_path']}")
print(f"Test predictions saved at: {CONFIG['test_output_dir']}")
print("="*60)