# SAM2-UNet Training and Prediction

This notebook guides through:
1. Training the SAM2-UNet model for image segmentation
2. Evaluating the model on validation data
3. Making predictions on test data for Kaggle submission

In [None]:
# Import required libraries
import os
import argparse
import random
import numpy as np
import torch
import torch.optim as opt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
import shutil
import time

# Import our modules
from dataset import FullDataset, TestDataset
from SAM2UNet import SAM2UNet

## 1. Configuration Settings

Let's set up the configuration parameters for training.

In [None]:
# Configuration
class Config:
    # Model parameters
    hiera_path = "path/to/sam2/pretrained/hiera"  # Change this to your sam2 pretrained hiera path
    
    # Data paths - update these to your data locations
    train_image_path = "train/images/"
    train_mask_path = "train/masks/"
    val_image_path = "val_images/"
    val_mask_path = "val_masks/"
    test_image_path = "test/images/"
    
    # Output paths
    save_path = "checkpoints/"
    prediction_output_path = "predictions/"
    
    # Training parameters
    epoch = 20
    lr = 0.001
    batch_size = 8
    weight_decay = 5e-4
    image_size = 352
    
    # Seed for reproducibility
    seed = 1024
    
    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()

# Create directories if they don't exist
os.makedirs(config.save_path, exist_ok=True)
os.makedirs(config.prediction_output_path, exist_ok=True)

## 2. Set Random Seeds for Reproducibility

In [None]:
def seed_torch(seed=1024):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(config.seed)

## 3. Define Loss Function

In [None]:
def structure_loss(pred, mask):
    """
    Combined weighted BCE and weighted IoU loss for better segmentation results
    """
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
    
    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    
    return (wbce + wiou).mean()

## 4. Prepare Datasets and DataLoaders

In [None]:
# Data loading
train_dataset = FullDataset(
    config.train_image_path, 
    config.train_mask_path, 
    config.image_size, 
    mode='train'
)

val_dataset = FullDataset(
    config.val_image_path, 
    config.val_mask_path, 
    config.image_size, 
    mode='val'
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True, 
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## 5. Initialize Model, Optimizer, and Scheduler

In [None]:
# Initialize model
model = SAM2UNet(config.hiera_path)
model = model.to(config.device)

# Setup optimizer and learning rate scheduler
optimizer = opt.AdamW(
    [{"params": model.parameters(), "initial_lr": config.lr}], 
    lr=config.lr, 
    weight_decay=config.weight_decay
)

scheduler = CosineAnnealingLR(
    optimizer, 
    config.epoch, 
    eta_min=1.0e-7
)

## 6. Validation Function

In [None]:
def validate(model, val_loader, device):
    """
    Validate model on validation dataset
    """
    model.eval()
    val_loss = 0.0
    dice_score = 0.0
    iou_score = 0.0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            x = batch['image'].to(device)
            target = batch['label'].to(device)
            
            # Forward pass
            pred0, pred1, pred2 = model(x)
            
            # Calculate loss
            loss0 = structure_loss(pred0, target)
            loss1 = structure_loss(pred1, target)
            loss2 = structure_loss(pred2, target)
            loss = loss0 + loss1 + loss2
            
            # Accumulate validation loss
            val_loss += loss.item()
            
            # Calculate metrics using final prediction
            pred = torch.sigmoid(pred0) > 0.5
            target_binary = target > 0.5
            
            # Dice score
            intersection = (pred & target_binary).float().sum((1, 2, 3))
            union = pred.float().sum((1, 2, 3)) + target_binary.float().sum((1, 2, 3))
            dice = (2 * intersection) / (union + 1e-7)
            dice_score += dice.mean().item()
            
            # IoU score
            iou = intersection / (union - intersection + 1e-7)
            iou_score += iou.mean().item()
    
    # Calculate averages
    val_loss /= len(val_loader)
    dice_score /= len(val_loader)
    iou_score /= len(val_loader)
    
    return val_loss, dice_score, iou_score

## 7. Training and Validation Loop

In [None]:
# For tracking metrics
train_losses = []
val_losses = []
dice_scores = []
iou_scores = []
best_dice = 0.0

# Training loop
for epoch in range(config.epoch):
    # Training phase
    model.train()
    epoch_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epoch}")
    for i, batch in enumerate(progress_bar):
        x = batch['image'].to(config.device)
        target = batch['label'].to(config.device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        pred0, pred1, pred2 = model(x)
        
        # Calculate loss
        loss0 = structure_loss(pred0, target)
        loss1 = structure_loss(pred1, target)
        loss2 = structure_loss(pred2, target)
        loss = loss0 + loss1 + loss2
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    
    # Calculate average training loss
    epoch_loss /= len(train_loader)
    train_losses.append(epoch_loss)
    
    # Validation phase
    val_loss, dice_score, iou_score = validate(model, val_loader, config.device)
    val_losses.append(val_loss)
    dice_scores.append(dice_score)
    iou_scores.append(iou_score)
    
    # Update learning rate
    scheduler.step()
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{config.epoch}")
    print(f"Training Loss: {epoch_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Dice Score: {dice_score:.4f}, IoU: {iou_score:.4f}")
    
    # Save model if it has the best dice score
    if dice_score > best_dice:
        best_dice = dice_score
        torch.save(model.state_dict(), os.path.join(config.save_path, 'SAM2-UNet-best.pth'))
        print(f"Saved new best model with Dice score: {best_dice:.4f}")
    
    # Save checkpoint every 5 epochs or at the last epoch
    if (epoch+1) % 5 == 0 or (epoch+1) == config.epoch:
        torch.save(
            model.state_dict(), 
            os.path.join(config.save_path, f'SAM2-UNet-{epoch+1}.pth')
        )
        print(f"Saved checkpoint: SAM2-UNet-{epoch+1}.pth")

## 8. Plot Training and Validation Metrics

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(dice_scores, label='Dice Score')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.title('Validation Dice Score')

plt.subplot(1, 3, 3)
plt.plot(iou_scores, label='IoU Score')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.title('Validation IoU Score')

plt.tight_layout()
plt.savefig(os.path.join(config.save_path, 'training_metrics.png'))
plt.show()

## 9. Make Predictions on Test Data

In [None]:
def predict_test_set():
    # Load best model for predictions
    model.load_state_dict(torch.load(os.path.join(config.save_path, 'SAM2-UNet-best.pth')))
    model.eval()
    
    # Get all test images
    test_images = sorted([f for f in os.listdir(config.test_image_path) if f.endswith('.jpg') or f.endswith('.png')])
    
    # Create test dataset loader
    test_loader = TestDataset(
        config.test_image_path,
        config.image_size
    )
    
    # Create prediction directory
    prediction_dir = config.prediction_output_path
    os.makedirs(prediction_dir, exist_ok=True)
    
    # Make predictions
    print(f"Making predictions on {test_loader.size} test images...")
    
    with torch.no_grad():
        for i in tqdm(range(test_loader.size)):
            # Custom load_data for test set without ground truth
            image, _, name = test_loader.load_data()
            image = image.to(config.device)
            
            # Forward pass
            res, _, _ = model(image)
            
            # Convert to sigmoid probability and resize to original size
            res = torch.sigmoid(res)
            
            # Convert to numpy and scale to [0, 255]
            res = res.data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            res = (res * 255).astype(np.uint8)
            
            # Save prediction
            Image.fromarray(res).save(os.path.join(prediction_dir, name[:-4] + ".png"))
            
    print(f"Predictions saved to {prediction_dir}")

# Run test predictions
predict_test_set()

## 10. Visualize Some Test Predictions

In [None]:
def visualize_predictions():
    # Get some sample predictions (first 5)
    prediction_dir = config.prediction_output_path
    test_images = sorted([f for f in os.listdir(config.test_image_path) 
                        if f.endswith('.jpg') or f.endswith('.png')])[:5]
    
    plt.figure(figsize=(15, 10))
    
    for i, img_name in enumerate(test_images):
        # Original image
        img_path = os.path.join(config.test_image_path, img_name)
        image = np.array(Image.open(img_path))
        
        # Prediction mask
        pred_path = os.path.join(prediction_dir, img_name[:-4] + ".png")
        if os.path.exists(pred_path):
            pred = np.array(Image.open(pred_path))
        else:
            pred = np.zeros_like(image[:,:,0])
            
        # Plot original image
        plt.subplot(5, 2, i*2+1)
        plt.imshow(image)
        plt.title(f"Original: {img_name}")
        plt.axis('off')
        
        # Plot prediction mask
        plt.subplot(5, 2, i*2+2)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Prediction: {img_name[:-4]}.png")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(prediction_dir, 'sample_predictions.png'))
    plt.show()

# Visualize some predictions
visualize_predictions()

## 11. Prepare Submission for Kaggle (if needed)

In [None]:
def prepare_kaggle_submission():
    """
    Create a submission file or zip folder for Kaggle, depending on the competition requirements
    """
    prediction_dir = config.prediction_output_path
    submission_path = 'submission_kaggle'
    
    # Create submission directory if it doesn't exist
    os.makedirs(submission_path, exist_ok=True)
    
    # Copy prediction files to submission directory (you may need to adapt this based on competition requirements)
    prediction_files = [f for f in os.listdir(prediction_dir) if f.endswith('.png')]
    
    for file in tqdm(prediction_files, desc="Preparing submission"):
        shutil.copy(
            os.path.join(prediction_dir, file),
            os.path.join(submission_path, file)
        )
    
    # Create a zip file if needed
    shutil.make_archive('submission_kaggle', 'zip', submission_path)
    
    print(f"Submission prepared: {len(prediction_files)} files")
    print(f"Submission files in: {submission_path}/")
    print(f"Submission zip: submission_kaggle.zip")

# Prepare submission
prepare_kaggle_submission()

## 12. Summary and Conclusion

The SAM2-UNet model has been:
1. Trained on the provided training data
2. Validated using the validation data
3. Used to generate predictions on the test data
4. The predictions have been saved to a new folder for Kaggle submission

Key metrics from training:
- Best Validation Dice Score: {best_dice:.4f}
- Best model saved to: {config.save_path}/SAM2-UNet-best.pth
- Test predictions saved to: {config.prediction_output_path}
- Kaggle submission prepared in: submission_kaggle.zip