In [None]:
import torch.optim as optim
from tqdm import tqdm
from segment_anything import sam_model_registry

def train_lora_sam(config):
    # Initialize SAM with LoRA
    sam = sam_model_registry[config["model_type"]](checkpoint=config["checkpoint"])
    sam = add_lora_to_sam(sam, rank=config["rank"])
    sam.to(config["device"])
    
    # Only optimize LoRA parameters
    optimizer = optim.Adam(
        [p for p in sam.parameters() if p.requires_grad],
        lr=config["lr"],
        weight_decay=config["weight_decay"]
    )
    
    # Loss function - adjust based on your task
    criterion = nn.BCEWithLogitsLoss()
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(config["epochs"]):
        sam.train()
        train_loss = 0.0
        
        # Training phase
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, masks = images.to(config["device"]), masks.to(config["device"])
            
            # Forward pass
            outputs = sam(images, multimask_output=False)
            loss = criterion(outputs['masks'], masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        sam.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(config["device"]), masks.to(config["device"])
                outputs = sam(images, multimask_output=False)
                loss = criterion(outputs['masks'], masks)
                val_loss += loss.item()
        
        # Print statistics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'model_state_dict': sam.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, config["save_path"])
    
    return sam

# Configuration
config = {
    "model_type": "vit_b",  # or "vit_l", "vit_h" depending on your SAM version
    "checkpoint": "sam_vit_b_01ec64.pth",
    "rank": 8,
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 20,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "save_path": "best_lora_sam.pth"
}

# Start training
trained_sam = train_lora_sam(config)

In [None]:
# Inference

def predict_with_lora_sam(model, image, device="cuda"):
    model.eval()
    with torch.no_grad():
        # Preprocess your image (resize, normalize, etc.)
        image_tensor = preprocess_image(image).to(device)
        
        # Get predictions
        outputs = model(image_tensor, multimask_output=False)
        
        # Post-process masks
        masks = torch.sigmoid(outputs['masks'])
        masks = (masks > 0.5).float()
        
    return masks.cpu().numpy()

# Load saved model
checkpoint = torch.load("best_lora_sam.pth")
sam.load_state_dict(checkpoint['model_state_dict'])

# Example prediction
test_image = load_your_test_image()
predicted_mask = predict_with_lora_sam(sam, test_image)

In [None]:
#Evaluation

def calculate_iou(preds, targets):
    intersection = (preds & targets).float().sum()
    union = (preds | targets).float().sum()
    return (intersection + 1e-6) / (union + 1e-6)

def evaluate(model, dataloader, device):
    model.eval()
    total_iou = 0.0
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images, multimask_output=False)
            preds = (torch.sigmoid(outputs['masks']) > 0.5).float()
            total_iou += calculate_iou(preds, masks)
    return total_iou / len(dataloader)

# Cross Validation

In [None]:
# Setup

from sklearn.model_selection import KFold
import numpy as np
from torch.utils.data import Subset

def prepare_kfold_data(image_paths, mask_paths, n_splits=5):
    """Prepare KFold splits while maintaining image-mask pairs"""
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    folds = []
    
    # Create indices array (assuming image_paths and mask_paths are aligned)
    indices = np.arange(len(image_paths))
    
    for train_idx, val_idx in kf.split(indices):
        folds.append({
            'train': (np.array(image_paths)[train_idx].tolist(), 
                     np.array(mask_paths)[train_idx].tolist()),
            'val': (np.array(image_paths)[val_idx].tolist(),
                   np.array(mask_paths)[val_idx].tolist())
        })
    
    return folds

In [None]:
# Cross Validation Training Loop

from copy import deepcopy
from segment_anything import sam_model_registry

def cross_validate_sam_lora(config, image_paths, mask_paths):
    # Prepare 5 folds
    folds = prepare_kfold_data(image_paths, mask_paths)
    fold_results = []
    best_models = []
    
    for fold_idx, fold in enumerate(folds):
        print(f"\n=== Processing Fold {fold_idx + 1}/5 ===")
        
        # Initialize fresh model for each fold
        sam = sam_model_registry[config["model_type"]](checkpoint=config["checkpoint"])
        sam = add_lora_to_sam(sam, rank=config["rank"])
        sam.to(config["device"])
        
        # Create datasets
        train_dataset = SAMDataset(fold['train'][0], fold['train'][1])
        val_dataset = SAMDataset(fold['val'][0], fold['val'][1])
        
        train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
        
        # Train the model
        model, fold_metrics = train_single_fold(
            sam=sam,
            train_loader=train_loader,
            val_loader=val_loader,
            config=config,
            fold_idx=fold_idx
        )
        
        # Store results
        fold_results.append(fold_metrics)
        best_models.append(deepcopy(model.state_dict()))
        
        # Clean up
        del sam
        torch.cuda.empty_cache()
    
    return fold_results, best_models

def train_single_fold(sam, train_loader, val_loader, config, fold_idx):
    optimizer = optim.Adam(
        [p for p in sam.parameters() if p.requires_grad],
        lr=config["lr"],
        weight_decay=config["weight_decay"]
    )
    
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
    best_val_loss = float('inf')
    fold_metrics = {'train_loss': [], 'val_loss': [], 'val_iou': []}
    
    for epoch in range(config["epochs"]):
        # Training phase
        sam.train()
        epoch_train_loss = 0.0
        for images, masks in train_loader:
            images, masks = images.to(config["device"]), masks.to(config["device"])
            
            optimizer.zero_grad()
            outputs = sam(images, multimask_output=False)
            loss = criterion(outputs['masks'], masks)
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
        
        # Validation phase
        val_loss, val_iou = evaluate_fold(sam, val_loader, config["device"])
        
        # Update metrics
        epoch_train_loss /= len(train_loader)
        fold_metrics['train_loss'].append(epoch_train_loss)
        fold_metrics['val_loss'].append(val_loss)
        fold_metrics['val_iou'].append(val_iou)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        print(f"Fold {fold_idx+1} | Epoch {epoch+1}: "
              f"Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val IoU: {val_iou:.4f}")
        
        # Save best model for this fold
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = deepcopy(sam.state_dict())
    
    return sam, fold_metrics

def evaluate_fold(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    total_iou = 0.0
    criterion = nn.BCEWithLogitsLoss()
    
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images, multimask_output=False)
            
            # Calculate loss
            loss = criterion(outputs['masks'], masks)
            total_loss += loss.item()
            
            # Calculate IoU
            preds = (torch.sigmoid(outputs['masks']) > 0.5).float()
            batch_iou = calculate_iou(preds, masks)
            total_iou += batch_iou
    
    return total_loss / len(dataloader), total_iou / len(dataloader)

In [None]:
# Running Cross Validation

# Configuration
config = {
    "model_type": "vit_b",
    "checkpoint": "sam_vit_b_01ec64.pth",
    "rank": 8,
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 10,
    "batch_size": 4,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Example image and mask paths (replace with your actual paths)
image_paths = ["image1.npy", "image2.npy", ...]  
mask_paths = ["mask1.npy", "mask2.npy", ...]

# Run cross-validation
fold_results, best_models = cross_validate_sam_lora(config, image_paths, mask_paths)

In [None]:
# Analyzing Results

import pandas as pd
import matplotlib.pyplot as plt

def analyze_cv_results(fold_results):
    # Convert to DataFrame for analysis
    metrics = []
    for i, fold in enumerate(fold_results):
        best_epoch = np.argmin(fold['val_loss'])
        metrics.append({
            'fold': i+1,
            'best_train_loss': fold['train_loss'][best_epoch],
            'best_val_loss': fold['val_loss'][best_epoch],
            'best_val_iou': fold['val_iou'][best_epoch],
            'final_val_iou': fold['val_iou'][-1]
        })
    
    df = pd.DataFrame(metrics)
    
    # Print summary statistics
    print("\n=== Cross-Validation Results ===")
    print(df)
    print("\nMean Validation IoU:", df['best_val_iou'].mean())
    print("Std Dev Validation IoU:", df['best_val_iou'].std())
    
    # Plot learning curves
    plt.figure(figsize=(12, 4))
    for i, fold in enumerate(fold_results):
        plt.subplot(1, 2, 1)
        plt.plot(fold['train_loss'], label=f'Fold {i+1}')
        plt.subplot(1, 2, 2)
        plt.plot(fold['val_iou'], label=f'Fold {i+1}')
    
    plt.subplot(1, 2, 1)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.title('Validation IoU')
    plt.xlabel('Epoch')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    return df

results_df = analyze_cv_results(fold_results)

In [None]:
# Final Model Selection

# Option 1: Select best fold model
best_fold_idx = np.argmax([res['val_iou'][-1] for res in fold_results])
best_model_state = best_models[best_fold_idx]

# Option 2: Create ensemble of all folds
class EnsembleSAM(nn.Module):
    def __init__(self, model_states, config):
        super().__init__()
        self.models = []
        for state in model_states:
            model = sam_model_registry[config["model_type"]](checkpoint=None)
            model.load_state_dict(state)
            self.models.append(model.to(config["device"]))
        
    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(model(x, multimask_output=False)['masks'])
        return torch.mean(torch.stack(outputs), dim=0)

# Initialize ensemble
ensemble_model = EnsembleSAM(best_models, config)