# ASP-VMUNet Training for Image Segmentation
This notebook trains an ASP-VMUNet model for segmentation tasks. The model is trained using a custom dataset with training, validation, and test splits.

In [None]:
# Import necessary libraries
import os
import copy
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tensorboardX import SummaryWriter

from sklearn.metrics import confusion_matrix
from PIL import Image

# Adjust path to import modules from the project
import sys
sys.path.append('./') 

# Import model and utility functions
from models.Atrous.atrous_UL_CNN import atrous_ULPSR_basev3_CNN
from utils import save_imgs, get_optimizer, get_scheduler, set_seed, get_logger, cal_params_flops

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
# Configure paths and settings
class Config:
    # Data paths
    data_path = '../data/'  # Path to data folder containing train, val, test
    work_dir = '../results/aspvmunet_run/'  # Path to save results
    
    # Data settings
    input_size = 256  # Input image size
    train_bs = 8      # Training batch size
    val_bs = 4        # Validation batch size
    test_bs = 1       # Test batch size
    num_workers = 4   # Number of workers for data loading
    seed = 42         # Random seed
    
    # Training settings
    epochs = 100       # Number of training epochs
    lr = 1e-4          # Learning rate
    weight_decay = 1e-4 # Weight decay
    val_interval = 1   # Validate every n epochs
    save_interval = 10 # Save images every n iterations
    print_interval = 10 # Print logs every n iterations
    threshold = 0.5    # Threshold for binary segmentation
    
    # Model configuration
    network = 'atrous_UL_CNN'
    criterion = nn.BCELoss()
    optimizer_name = 'AdamW'  # AdamW or SGD
    scheduler_name = 'CosineAnnealingLR'
    datasets = 'custom'  # Dataset name
    gpu_id = 0  # GPU ID
    
    # Model specific parameters
    model_config = {
        'num_classes': 1,
        'input_channels': 3,
        'c_list': [8, 16, 24, 32, 48, 64],
        'd_conv': 3,
        'split_att': 'fc',
        'bridge': True,
        'if_shifted_round': False,
        'if_ss2d': True,
        'forward_type': 'v1',
        'encoder_atrous_step': [[2, 2], [2, 2], [2, 2], [2, 2, 2, 2, 2, 2], [2, 2]],
        'decoder_atrous_step': [[2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
        'if_CNN': True,
        'if_SE': True,
        'if_SK': True,
    }

config = Config()

# Create directories
os.makedirs(config.work_dir, exist_ok=True)
os.makedirs(os.path.join(config.work_dir, 'log'), exist_ok=True)
os.makedirs(os.path.join(config.work_dir, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(config.work_dir, 'outputs'), exist_ok=True)
os.makedirs(os.path.join(config.work_dir, 'test_predictions'), exist_ok=True)

# Setup logger and tensorboard writer
logger = get_logger('train', os.path.join(config.work_dir, 'log'))
writer = SummaryWriter(config.work_dir + 'summary')

# Set random seed
set_seed(config.seed)

## Create Custom Dataset and DataLoaders
We need to create custom datasets to load and preprocess the images and masks.

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir=None, transform=None, is_test=False):
        """
        Args:
            img_dir (string): Directory with all the images.
            mask_dir (string, optional): Directory with all the masks.
            transform (callable, optional): Optional transform to be applied on a sample.
            is_test (bool): Whether this is a test dataset (without masks)
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.is_test = is_test
        self.img_names = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
        
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert("RGB")
        
        if self.is_test:
            # For test data without masks
            if self.transform:
                image = self.transform(image)
            return image, img_name
        else:
            # For training/validation data with masks
            mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '.png'))
            mask = Image.open(mask_path).convert("L")  # Load as grayscale
            
            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
            
            # Normalize mask to 0-1
            mask = (mask > 0).float()
            
            return image, mask

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((config.input_size, config.input_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((config.input_size, config.input_size)),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((config.input_size, config.input_size)),
    transforms.ToTensor(),
])

# Create datasets
train_dataset = SegmentationDataset(
    img_dir=os.path.join(config.data_path, 'train/images'),
    mask_dir=os.path.join(config.data_path, 'train/masks'),
    transform=train_transform
)

val_dataset = SegmentationDataset(
    img_dir=os.path.join(config.data_path, 'val_images'),
    mask_dir=os.path.join(config.data_path, 'val_masks'),
    transform=val_transform
)

test_dataset = SegmentationDataset(
    img_dir=os.path.join(config.data_path, 'test/images'),
    transform=test_transform,
    is_test=True
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.train_bs,
    shuffle=True,
    pin_memory=True,
    num_workers=config.num_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.val_bs,
    shuffle=False,
    pin_memory=True,
    num_workers=config.num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.test_bs,
    shuffle=False,
    pin_memory=True,
    num_workers=config.num_workers
)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

## Initialize the Model, Loss Function, Optimizer, and Scheduler

In [None]:
# Initialize model
model = atrous_ULPSR_basev3_CNN(
    num_classes=config.model_config['num_classes'],
    input_channels=config.model_config['input_channels'],
    c_list=config.model_config['c_list'],
    d_conv=config.model_config['d_conv'],
    split_att=config.model_config['split_att'],
    bridge=config.model_config['bridge'],
    if_shifted_round=config.model_config['if_shifted_round'],
    if_ss2d=config.model_config['if_ss2d'],
    forward_type=config.model_config['forward_type'],
    encoder_atrous_step=config.model_config['encoder_atrous_step'],
    decoder_atrous_step=config.model_config['decoder_atrous_step'],
    if_CNN=config.model_config['if_CNN'],
    if_SE=config.model_config['if_SE'],
    if_SK=config.model_config['if_SK'],
)

# Move model to device
model = model.to(device)

# Display model size and FLOPS
cal_params_flops(copy.deepcopy(model), config.input_size, logger)

# Define loss function, optimizer, and scheduler
criterion = config.criterion
optimizer = get_optimizer(config, model)
scheduler = get_scheduler(config, optimizer)

## Training and Validation Functions

In [None]:
def train_one_epoch(loader, model, criterion, optimizer, scheduler, epoch, step, logger, config, writer):
    """Train model for one epoch"""
    model.train()
    loss_list = []
    
    for iter, (images, masks) in enumerate(loader):
        step += 1
        optimizer.zero_grad()
        
        images = images.to(device).float()
        masks = masks.to(device).float()
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        loss_list.append(loss.item())
        now_lr = optimizer.state_dict()['param_groups'][0]['lr']
        
        writer.add_scalar('loss', loss, global_step=step)
        
        if iter % config.print_interval == 0:
            log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
            print(log_info)
            logger.info(log_info)
            
    scheduler.step()
    return step

def val_one_epoch(loader, model, criterion, epoch, logger, config):
    """Evaluate model on validation data"""
    model.eval()
    preds = []
    gts = []
    loss_list = []
    
    with torch.no_grad():
        for img, msk in tqdm(loader):
            img = img.to(device).float()
            msk = msk.to(device).float()
            
            out = model(img)
            loss = criterion(out, msk)
            
            loss_list.append(loss.item())
            gts.append(msk.squeeze(1).cpu().detach().numpy())
            out = out.squeeze(1).cpu().detach().numpy()
            preds.append(out)
    
    # Calculate metrics
    if epoch % config.val_interval == 0:
        preds = np.concatenate([p.flatten() for p in preds])
        gts = np.concatenate([g.flatten() for g in gts])
        
        y_pre = np.where(preds >= config.threshold, 1, 0)
        y_true = np.where(gts >= 0.5, 1, 0)
        
        confusion = confusion_matrix(y_true, y_pre)
        TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
        
        accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
        sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
        specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
        f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
        miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
        
        log_info = f' val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou:.4f}, f1_or_dsc: {f1_or_dsc:.4f}, accuracy: {accuracy:.4f}, \
                specificity: {specificity:.4f}, sensitivity: {sensitivity:.4f}, confusion_matrix: {confusion}'
        print(log_info)
        logger.info(log_info)
    else:
        log_info = f' val epoch: {epoch}, loss: {np.mean(loss_list):.4f}'
        print(log_info)
        logger.info(log_info)
    
    return np.mean(loss_list), f1_or_dsc

## Train the Model

In [None]:
# Training loop
min_loss = 999
max_dsc = -1
start_epoch = 1
min_epoch = 1
step = 0

# Resume from checkpoint if exists
resume_model = os.path.join(config.work_dir, 'checkpoints', 'latest.pth')
if os.path.exists(resume_model):
    print('#----------Resume Model and Other params----------#')
    checkpoint = torch.load(resume_model, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    saved_epoch = checkpoint['epoch']
    start_epoch = saved_epoch + 1
    min_loss, min_epoch, loss, max_dsc = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'], checkpoint['max_dsc']
    
    log_info = f'Resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, max_dsc: {max_dsc:.4f}'
    print(log_info)
    logger.info(log_info)

print('#----------Training Started----------#')
for epoch in range(start_epoch, config.epochs + 1):
    t = time.time()
    torch.cuda.empty_cache()
    
    # Train for one epoch
    step = train_one_epoch(
        train_loader,
        model,
        criterion,
        optimizer,
        scheduler,
        epoch,
        step,
        logger,
        config,
        writer
    )
    print(f'Training time: {time.time() - t:.2f}s')
    
    # Validate
    loss, dsc = val_one_epoch(
        val_loader,
        model,
        criterion,
        epoch,
        logger,
        config
    )
    print(f'Total epoch time: {time.time() - t:.2f}s')
    
    # Save best model
    if dsc > max_dsc:
        torch.save(model.state_dict(), os.path.join(config.work_dir, 'checkpoints', 'best.pth'))
        min_loss = loss
        min_epoch = epoch
        max_dsc = dsc
        
    print(f'Best model at epoch {min_epoch} with DSC {max_dsc:.4f}')
    
    # Save latest model
    torch.save(
        {
            'epoch': epoch,
            'min_loss': min_loss,
            'min_epoch': min_epoch,
            'loss': loss,
            'max_dsc': max_dsc,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, os.path.join(config.work_dir, 'checkpoints', 'latest.pth')
    )

print('#----------Training Completed----------#')

## Test and Generate Predictions

In [None]:
def save_prediction_mask(mask, filename, save_path):
    """Save the predicted mask as a PNG file"""
    # Convert to binary mask (0-255)
    binary_mask = (mask > config.threshold).astype(np.uint8) * 255
    # Save as PNG
    Image.fromarray(binary_mask).save(os.path.join(save_path, filename))

# Load best model for testing
print('#----------Testing----------#')
best_model_path = os.path.join(config.work_dir, 'checkpoints', 'best.pth')
if os.path.exists(best_model_path):
    best_weight = torch.load(best_model_path, map_location=device)
    model.load_state_dict(best_weight)
    model.eval()
    
    # Create directory for saving predictions
    prediction_dir = os.path.join(config.work_dir, 'test_predictions')
    os.makedirs(prediction_dir, exist_ok=True)
    
    # Generate and save predictions
    with torch.no_grad():
        for i, (img, img_name) in enumerate(tqdm(test_loader)):
            img = img.to(device).float()
            out = model(img)
            
            # Convert output to numpy
            pred_mask = out.squeeze().cpu().numpy()
            
            # Save prediction
            save_prediction_mask(pred_mask, img_name[0].replace('.jpg', '.png'), prediction_dir)
            
            # Optionally visualize some predictions
            if i % config.save_interval == 0:
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(img.squeeze().permute(1, 2, 0).cpu().numpy())
                plt.title('Original Image')
                plt.axis('off')
                
                plt.subplot(1, 2, 2)
                plt.imshow(pred_mask, cmap='gray')
                plt.title('Predicted Mask')
                plt.axis('off')
                
                plt.savefig(os.path.join(config.work_dir, 'outputs', f'test_prediction_{i}.png'))
                plt.close()
    
    print(f"Predictions saved to {prediction_dir}")
else:
    print("Best model not found. Cannot generate predictions.")

## Analyze Results

After generating predictions, you can analyze the results:

1. **Visualize some predictions**: Review the images saved in the outputs directory to see how well your model is performing.
2. **Prepare Kaggle submission**: The predictions saved in the `test_predictions` folder are ready to be submitted to Kaggle.
3. **Evaluate model metrics**: Review the validation metrics from the training process to understand the model's performance.