In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ============================================
# CELL 1: SETUP & INSTALL (Kaggle Version)
# ============================================
import os
import sys

# Install required packages
!pip install roboflow -q
!pip install segmentation-models-pytorch -q
!pip install albumentations -q
!pip install gradio -q
!pip install pycocotools -q
!pip install torchmetrics -q
!pip install pandas tabulate -q

print("‚úÖ Kaggle setup complete!")

In [None]:
# ============================================
# CELL 2: DOWNLOAD DATASET
# ============================================
from roboflow import Roboflow

# Your API key
rf = Roboflow(api_key="FoHdZwbhLlvtF4Xo4zdZ")
project = rf.workspace("studentdatasets").project("microscopy-cell-segmentation")
version = project.version(21)
dataset = version.download("coco-segmentation")

print("‚úÖ Dataset downloaded!")
dataset_path = dataset.location

In [None]:
# ============================================
# CELL 3: IMPORTS & GPU SETUP
# ============================================
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
import matplotlib.pyplot as plt
import json
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import pandas as pd
import torch.nn.functional as F
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
# Create output directory
output_dir = '/kaggle/working/cell_segmentation_v1'
os.makedirs(output_dir, exist_ok=True)
print(f"üìÅ Output directory: {output_dir}")

In [None]:
# ============================================
# CELL 4: IMPROVED DATASET CLASS
# ============================================
class CellSegmentationDataset(Dataset):
    """Improved dataset with better augmentation"""
    def __init__(self, json_path, img_dir, img_size=512, augment=True):
        with open(json_path) as f:
            data = json.load(f)
        
        self.images = data['images']
        self.annotations = data['annotations']
        self.img_dir = img_dir
        self.img_size = img_size
        self.augment = augment
        
        # Create annotation mapping
        self.ann_map = {}
        for ann in self.annotations:
            img_id = ann['image_id']
            if img_id not in self.ann_map:
                self.ann_map[img_id] = []
            self.ann_map[img_id].append(ann)
        
        self.image_paths = [os.path.join(img_dir, img['file_name']) for img in self.images]
        
        # Enhanced augmentations for microscopy
        if augment:
            self.transform = A.Compose([
                A.Resize(img_size, img_size, always_apply=True),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.RandomBrightnessContrast(p=0.3, brightness_limit=0.1, contrast_limit=0.1),
                A.GaussianBlur(p=0.1, blur_limit=(3, 7)),
                A.GaussNoise(p=0.1, var_limit=(10.0, 50.0)),
                A.ElasticTransform(p=0.2, alpha=1, sigma=50, alpha_affine=50),
                A.CoarseDropout(p=0.1, max_holes=8, max_height=32, max_width=32, fill_value=0),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ])
        else:
            self.transform = A.Compose([
                A.Resize(img_size, img_size, always_apply=True),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_info = self.images[idx]
        
        # Create mask
        mask = np.zeros((self.img_size, self.img_size), dtype=np.float32)
        
        if img_info['id'] in self.ann_map:
            for ann in self.ann_map[img_info['id']]:
                for seg in ann['segmentation']:
                    pts = np.array(seg).reshape(-1, 2)
                    if len(pts) > 0:
                        # Preserve aspect ratio
                        pts[:, 0] = pts[:, 0] * self.img_size / img_info['width']
                        pts[:, 1] = pts[:, 1] * self.img_size / img_info['height']
                        pts = pts.astype(np.int32)
                        cv2.fillPoly(mask, [pts], 1)
        
        transformed = self.transform(image=img, mask=mask)
        img_tensor = transformed['image']
        mask_tensor = transformed['mask']
        
        return img_tensor, mask_tensor.float()

# Create datasets
print("üìä Creating datasets...")
train_dataset = CellSegmentationDataset(
    os.path.join(dataset_path, "train", "_annotations.coco.json"),
    os.path.join(dataset_path, "train"),
    augment=True
)

val_dataset = CellSegmentationDataset(
    os.path.join(dataset_path, "valid", "_annotations.coco.json"),
    os.path.join(dataset_path, "valid"),
    augment=False
)

test_dataset = CellSegmentationDataset(
    os.path.join(dataset_path, "test", "_annotations.coco.json"),
    os.path.join(dataset_path, "test"),
    augment=False
)

print(f"‚úÖ Datasets created!")
print(f"Train: {len(train_dataset)} images")
print(f"Validation: {len(val_dataset)} images")
print(f"Test: {len(test_dataset)} images")

In [None]:
# ============================================
# CELL 5: CREATE 2 MODELS (Version 1)
# ============================================
print("üß† CREATING 2 MODELS FOR VERSION 1...")
print("="*50)

# 1. U-Net with EfficientNet-B4
print("1. Creating U-Net EfficientNet-B4...")
model1 = smp.Unet(
    encoder_name="timm-efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None,
    decoder_attention_type="scse",
    decoder_dropout=0.3
).to(device)

# 2. DeepLabV3+ with ResNet50
print("2. Creating DeepLabV3+ ResNet50...")
model2 = smp.DeepLabV3Plus(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None,
    decoder_dropout=0.2
).to(device)

models_v1 = {
    'unet_effb4': model1,
    'deeplabv3_r50': model2
}

for name, model in models_v1.items():
    params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"{name}: {params:.1f}M parameters")
print("‚úÖ 2 Models created for Version 1")
print("="*50)

In [None]:
# ============================================
# CELL 6: ENHANCED TRAINER WITH COMPREHENSIVE METRICS
# ============================================
class EnhancedTrainer:
    def __init__(self, device='cuda'):
        self.device = device
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = smp.losses.DiceLoss(mode='binary')
        self.focal_loss = smp.losses.FocalLoss(mode='binary')
        
        # Comprehensive metrics
        self.iou_metric = BinaryJaccardIndex().to(device)
        self.f1_metric = BinaryF1Score().to(device)
    
    def create_dataloaders(self, batch_size=8):
        """Create dataloaders"""
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        return train_loader, val_loader, test_loader
    
    def calculate_comprehensive_metrics(self, outputs, targets, threshold=0.5):
        """Calculate all metrics for segmentation"""
        with torch.no_grad():
            preds = torch.sigmoid(outputs)
            preds_binary = (preds > threshold).float()
            
            # Basic metrics
            iou = self.iou_metric(preds_binary, targets)
            f1 = self.f1_metric(preds_binary, targets)
            
            # Additional metrics
            intersection = (preds_binary * targets).sum()
            union = preds_binary.sum() + targets.sum()
            dice = (2 * intersection) / (union + 1e-7)
            
            tp = (preds_binary * targets).sum()
            fp = (preds_binary * (1 - targets)).sum()
            fn = ((1 - preds_binary) * targets).sum()
            tn = ((1 - preds_binary) * (1 - targets)).sum()
            
            precision = tp / (tp + fp + 1e-7)
            recall = tp / (tp + fn + 1e-7)
            specificity = tn / (tn + fp + 1e-7)
            accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-7)
            
            return {
                'iou': iou.item(),
                'f1': f1.item(),
                'dice': dice.item(),
                'precision': precision.item(),
                'recall': recall.item(),
                'specificity': specificity.item(),
                'accuracy': accuracy.item(),
                'tp': tp.item(),
                'fp': fp.item(),
                'fn': fn.item(),
                'tn': tn.item()
            }
    
    def combined_loss(self, outputs, targets):
        """Weighted combination of multiple losses"""
        bce = self.bce_loss(outputs, targets)
        dice = self.dice_loss(outputs, targets)
        focal = self.focal_loss(outputs, targets)
        return 0.4*bce + 0.4*dice + 0.2*focal
    
    def train_epoch(self, model, loader, optimizer, scaler=None, epoch=None):
        """Train for one epoch with comprehensive metrics"""
        model.train()
        epoch_metrics = {
            'loss': 0, 'iou': 0, 'f1': 0, 'dice': 0,
            'precision': 0, 'recall': 0, 'accuracy': 0
        }
        
        pbar = tqdm(loader, desc=f'Training Epoch {epoch+1}')
        for images, masks in pbar:
            images, masks = images.to(self.device), masks.to(self.device).unsqueeze(1)
            
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = self.combined_loss(outputs, masks)
                
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = self.combined_loss(outputs, masks)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
            
            # Calculate metrics
            metrics = self.calculate_comprehensive_metrics(outputs, masks)
            
            # Update epoch metrics
            epoch_metrics['loss'] += loss.item()
            for key in metrics:
                if key in epoch_metrics:
                    epoch_metrics[key] += metrics[key]
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'iou': f"{metrics['iou']:.4f}",
                'dice': f"{metrics['dice']:.4f}"
            })
        
        # Average metrics
        for key in epoch_metrics:
            epoch_metrics[key] /= len(loader)
        
        return epoch_metrics
    
    def validate(self, model, loader, split='Validation'):
        """Validate model with comprehensive metrics"""
        model.eval()
        val_metrics = {
            'loss': 0, 'iou': 0, 'f1': 0, 'dice': 0,
            'precision': 0, 'recall': 0, 'accuracy': 0
        }
        
        with torch.no_grad():
            for images, masks in tqdm(loader, desc=split):
                images, masks = images.to(self.device), masks.to(self.device).unsqueeze(1)
                outputs = model(images)
                
                loss = self.combined_loss(outputs, masks)
                metrics = self.calculate_comprehensive_metrics(outputs, masks)
                
                val_metrics['loss'] += loss.item()
                for key in metrics:
                    if key in val_metrics:
                        val_metrics[key] += metrics[key]
        
        # Average metrics
        for key in val_metrics:
            val_metrics[key] /= len(loader)
        
        return val_metrics
    
    def train_model(self, model, train_loader, val_loader, model_name, 
                   epochs=30, lr=1e-4, patience=10):
        """Complete training with comprehensive tracking"""
        print(f"\nüöÄ Training {model_name} for {epochs} epochs...")
        print("="*60)
        
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, verbose=True, min_lr=1e-6
        )
        
        scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
        
        # Comprehensive history
        history = {
            'train': {k: [] for k in ['loss', 'iou', 'f1', 'dice', 'precision', 'recall', 'accuracy']},
            'val': {k: [] for k in ['loss', 'iou', 'f1', 'dice', 'precision', 'recall', 'accuracy']},
            'lr': [],
            'best_epoch': 0
        }
        
        best_iou = 0
        patience_counter = 0
        best_model_state = None
        
        for epoch in range(epochs):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{epochs}")
            print('='*60)
            
            # Training
            train_metrics = self.train_epoch(model, train_loader, optimizer, scaler, epoch)
            for key in train_metrics:
                history['train'][key].append(train_metrics[key])
            
            # Validation
            val_metrics = self.validate(model, val_loader, 'Validation')
            for key in val_metrics:
                history['val'][key].append(val_metrics[key])
            
            # Learning rate tracking
            current_lr = optimizer.param_groups[0]['lr']
            history['lr'].append(current_lr)
            
            # Print epoch results
            print(f"Train - Loss: {train_metrics['loss']:.4f}, IoU: {train_metrics['iou']:.4f}, Dice: {train_metrics['dice']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, IoU: {val_metrics['iou']:.4f}, Dice: {val_metrics['dice']:.4f}")
            print(f"Metrics - Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")
            print(f"Learning Rate: {current_lr:.6f}")
            
            # Update scheduler
            scheduler.step(val_metrics['iou'])
            
            # Early stopping and model saving
            if val_metrics['iou'] > best_iou:
                best_iou = val_metrics['iou']
                patience_counter = 0
                history['best_epoch'] = epoch
                best_model_state = model.state_dict().copy()
                
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_iou': best_iou,
                    'history': history,
                    'val_metrics': val_metrics
                }, os.path.join(output_dir, f'{model_name}_best.pth'))
                print(f"üíæ Saved best model with IoU: {best_iou:.4f}")
            else:
                patience_counter += 1
                print(f"‚è≥ No improvement ({patience_counter}/{patience})")
            
            if patience_counter >= patience:
                print(f"‚èπÔ∏è Early stopping at epoch {epoch+1}")
                break
        
        # Restore best model
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        
        # Save final model and history
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'best_iou': best_iou,
            'final_epoch': epoch
        }, os.path.join(output_dir, f'{model_name}_final.pth'))
        
        # Save metrics to CSV
        metrics_df = pd.DataFrame({
            'epoch': list(range(1, len(history['train']['loss']) + 1)),
            'train_loss': history['train']['loss'],
            'val_loss': history['val']['loss'],
            'train_iou': history['train']['iou'],
            'val_iou': history['val']['iou'],
            'train_dice': history['train']['dice'],
            'val_dice': history['val']['dice'],
            'val_precision': history['val']['precision'],
            'val_recall': history['val']['recall'],
            'val_accuracy': history['val']['accuracy'],
            'learning_rate': history['lr']
        })
        metrics_df.to_csv(os.path.join(output_dir, f'{model_name}_metrics.csv'), index=False)
        
        print(f"\n‚úÖ Training completed for {model_name}!")
        print(f"üìä Best Validation IoU: {best_iou:.4f} at epoch {history['best_epoch'] + 1}")
        print(f"üíæ Models saved to: {output_dir}")
        
        return history, best_iou

# Initialize trainer
trainer = EnhancedTrainer(device=device)
print("‚úÖ Enhanced trainer created with comprehensive metrics!")

# Create dataloaders
train_loader, val_loader, test_loader = trainer.create_dataloaders(batch_size=8)
print(f"üìä Dataloaders created:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

In [None]:
# ============================================
# CELL 7a: TRAIN MODEL 1 - U-Net EfficientNet-B4 (30 EPOCHS)
# ============================================
print("="*60)
print("1. TRAINING: U-Net EfficientNet-B4 (30 EPOCHS)")
print("="*60)

# Initialize results storage if not exists
if 'all_results_v1' not in globals():
    all_results_v1 = {}

# Train Model 1
history1, best_iou1 = trainer.train_model(
    model=model1,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="unet_effb4",
    epochs=30,  # Increased to 30
    lr=1e-4,
    patience=10
)

all_results_v1['unet_effb4'] = {
    'history': history1,
    'best_iou': best_iou1,
    'model': model1,
    'params': f"{sum(p.numel() for p in model1.parameters()) / 1e6:.1f}M"
}

print(f"\n‚úÖ Model 1 Training Completed!")
print(f"üìä Best Validation IoU: {best_iou1:.4f}")
print(f"üî¢ Parameters: {all_results_v1['unet_effb4']['params']}")

# Save individual model checkpoint
torch.save({
    'model_state_dict': model1.state_dict(),
    'best_iou': best_iou1,
    'history': history1,
    'epochs': 30
}, os.path.join(output_dir, 'model1_complete.pth'))

print(f"üíæ Model 1 saved to: {output_dir}/model1_complete.pth")

In [None]:
# ============================================
# CELL 7b: TRAIN MODEL 2 - DeepLabV3+ ResNet50 (30 EPOCHS)
# ============================================
print("="*60)
print("2. TRAINING: DeepLabV3+ ResNet50 (30 EPOCHS)")
print("="*60)

# Train Model 2
history2, best_iou2 = trainer.train_model(
    model=model2,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name="deeplabv3_r50",
    epochs=30,  # Increased to 30
    lr=1e-4,
    patience=10
)

all_results_v1['deeplabv3_r50'] = {
    'history': history2,
    'best_iou': best_iou2,
    'model': model2,
    'params': f"{sum(p.numel() for p in model2.parameters()) / 1e6:.1f}M"
}

print(f"\n‚úÖ Model 2 Training Completed!")
print(f"üìä Best Validation IoU: {best_iou2:.4f}")
print(f"üî¢ Parameters: {all_results_v1['deeplabv3_r50']['params']}")

# Save individual model checkpoint
torch.save({
    'model_state_dict': model2.state_dict(),
    'best_iou': best_iou2,
    'history': history2,
    'epochs': 30
}, os.path.join(output_dir, 'model2_complete.pth'))

print(f"üíæ Model 2 saved to: {output_dir}/model2_complete.pth")

In [None]:
# ============================================
# CELL 7c: VERSION 1 TRAINING SUMMARY
# ============================================
print("="*60)
print("‚úÖ VERSION 1 TRAINING COMPLETED!")
print("="*60)

print("\nüìä TRAINING SUMMARY:")
print("-" * 40)
for name, data in all_results_v1.items():
    print(f"Model: {name}")
    print(f"  ‚Ä¢ Best Validation IoU: {data['best_iou']:.4f}")
    print(f"  ‚Ä¢ Parameters: {data['params']}")
    
    # Show some training history
    if 'history' in data and 'val' in data['history']:
        val_history = data['history']['val']
        if 'iou' in val_history and len(val_history['iou']) > 0:
            print(f"  ‚Ä¢ Final Epoch IoU: {val_history['iou'][-1]:.4f}")
            print(f"  ‚Ä¢ Best Epoch: {data['history']['best_epoch'] + 1}")
    print("-" * 40)

# Save comprehensive results summary
import json
summary_v1 = {
    'version': 'VERSION_1_MODELS_1_2',
    'models_trained': list(all_results_v1.keys()),
    'training_details': {},
    'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    'epochs': 30,
    'device': str(device)
}

for name, data in all_results_v1.items():
    summary_v1['training_details'][name] = {
        'best_iou': float(data['best_iou']),
        'params': data['params'],
        'final_epoch': len(data['history']['train']['loss']) if 'history' in data else 0
    }

summary_path = os.path.join(output_dir, 'version_1_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary_v1, f, indent=2)

print(f"\nüíæ Version 1 summary saved to: {summary_path}")
print(f"üìÅ Output directory: {output_dir}")

# List all saved files
print("\nüìã Saved files in output directory:")
for file in os.listdir(output_dir):
    if file.endswith('.pth') or file.endswith('.json') or file.endswith('.csv'):
        size = os.path.getsize(os.path.join(output_dir, file)) / 1024
        print(f"  ‚Ä¢ {file} ({size:.1f} KB)")

In [None]:
# ============================================
# CELL 8: COMPREHENSIVE TEST SET EVALUATION
# ============================================
print("="*60)
print("üìä COMPREHENSIVE TEST SET EVALUATION")
print("="*60)

def evaluate_model_test(model, test_loader, model_name):
    """Comprehensive test evaluation with per-metric tracking"""
    model.eval()
    
    # Initialize metrics accumulators
    metrics_sum = {
        'loss': 0, 'iou': 0, 'f1': 0, 'dice': 0,
        'precision': 0, 'recall': 0, 'accuracy': 0,
        'specificity': 0
    }
    
    # Store per-batch metrics for std calculation
    batch_metrics = {k: [] for k in metrics_sum.keys()}
    
    # Store predictions for visualization
    sample_predictions = []
    sample_images = []
    sample_masks = []
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(test_loader, desc=f'Testing {model_name}')):
            images, masks = images.to(device), masks.to(device).unsqueeze(1)
            outputs = model(images)
            
            loss = trainer.combined_loss(outputs, masks)
            metrics = trainer.calculate_comprehensive_metrics(outputs, masks)
            
            # Accumulate sums
            metrics_sum['loss'] += loss.item()
            for key in metrics:
                if key in metrics_sum:
                    metrics_sum[key] += metrics[key]
                    batch_metrics[key].append(metrics[key])
            
            # Store first few samples for visualization
            if batch_idx < 3:  # Store predictions from first 3 batches
                preds = torch.sigmoid(outputs)
                preds_binary = (preds > 0.5).float()
                
                for i in range(min(2, len(images))):  # Store 2 samples per batch
                    sample_images.append(images[i].cpu())
                    sample_masks.append(masks[i].cpu())
                    sample_predictions.append({
                        'prob': preds[i].cpu(),
                        'binary': preds_binary[i].cpu()
                    })
    
    # Calculate means and standard deviations
    num_batches = len(test_loader)
    results_mean = {}
    results_std = {}
    
    for key in metrics_sum:
        results_mean[key] = metrics_sum[key] / num_batches
        if batch_metrics[key]:
            results_std[key] = np.std(batch_metrics[key])
        else:
            results_std[key] = 0
    
    return results_mean, results_std, sample_images, sample_masks, sample_predictions

# Evaluate both models
print("\nüîç Evaluating models on test set...")
test_results_v1 = []
all_test_metrics = {}

for model_name, data in all_results_v1.items():
    print(f"\n{'='*40}")
    print(f"Evaluating {model_name}...")
    print('='*40)
    
    model = data['model']
    mean_metrics, std_metrics, sample_imgs, sample_msks, sample_preds = evaluate_model_test(
        model, test_loader, model_name
    )
    
    # Store results
    test_results_v1.append({
        'Model': model_name,
        'Test_IoU_Mean': f"{mean_metrics['iou']:.4f}",
        'Test_IoU_Std': f"{std_metrics['iou']:.4f}",
        'Test_Dice_Mean': f"{mean_metrics['dice']:.4f}",
        'Test_Dice_Std': f"{std_metrics['dice']:.4f}",
        'Test_F1_Mean': f"{mean_metrics['f1']:.4f}",
        'Test_F1_Std': f"{std_metrics['f1']:.4f}",
        'Precision': f"{mean_metrics['precision']:.4f}",
        'Recall': f"{mean_metrics['recall']:.4f}",
        'Accuracy': f"{mean_metrics['accuracy']:.4f}",
        'Specificity': f"{mean_metrics['specificity']:.4f}",
        'Val_IoU_Best': f"{data['best_iou']:.4f}",
        'Parameters': data['params']
    })
    
    # Store detailed metrics for visualization
    all_test_metrics[model_name] = {
        'mean': mean_metrics,
        'std': std_metrics,
        'samples': {
            'images': sample_imgs,
            'masks': sample_msks,
            'predictions': sample_preds
        }
    }
    
    # Print detailed metrics
    print(f"\nüìä Test Metrics for {model_name}:")
    print(f"  ‚Ä¢ IoU:        {mean_metrics['iou']:.4f} ¬± {std_metrics['iou']:.4f}")
    print(f"  ‚Ä¢ Dice:       {mean_metrics['dice']:.4f} ¬± {std_metrics['dice']:.4f}")
    print(f"  ‚Ä¢ F1-Score:   {mean_metrics['f1']:.4f} ¬± {std_metrics['f1']:.4f}")
    print(f"  ‚Ä¢ Precision:  {mean_metrics['precision']:.4f}")
    print(f"  ‚Ä¢ Recall:     {mean_metrics['recall']:.4f}")
    print(f"  ‚Ä¢ Accuracy:   {mean_metrics['accuracy']:.4f}")
    print(f"  ‚Ä¢ Specificity: {mean_metrics['specificity']:.4f}")
    print(f"  ‚Ä¢ Loss:       {mean_metrics['loss']:.4f}")
    print(f"  ‚Ä¢ Best Val IoU: {data['best_iou']:.4f}")

# Display results as table
print("\n" + "="*80)
print("üèÜ FINAL TEST EVALUATION RESULTS")
print("="*80)

import pandas as pd
from tabulate import tabulate

if test_results_v1:
    df_results = pd.DataFrame(test_results_v1)
    print(tabulate(df_results, headers='keys', tablefmt='pretty', showindex=False))
    
    # Save results to CSV
    results_csv_path = os.path.join(output_dir, 'test_evaluation_results.csv')
    df_results.to_csv(results_csv_path, index=False)
    print(f"\nüíæ Results saved to: {results_csv_path}")
    
    # Sort by IoU for ranking
    df_sorted = df_results.copy()
    df_sorted['IoU_Value'] = df_sorted['Test_IoU_Mean'].apply(lambda x: float(x))
    df_sorted = df_sorted.sort_values('IoU_Value', ascending=False)
    
    print("\n" + "="*80)
    print("üìà RANKING BY TEST IoU (Best to Worst)")
    print("="*80)
    print(tabulate(df_sorted.drop('IoU_Value', axis=1), 
                  headers='keys', tablefmt='pretty', showindex=False))
else:
    print("‚ö†Ô∏è No evaluation results to display")

In [None]:
# ============================================
# CELL 9: COMPREHENSIVE VISUALIZATION
# ============================================
print("="*60)
print("üìä COMPREHENSIVE VISUALIZATION")
print("="*60)

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import gridspec

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Create visualization directory
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)

def visualize_training_history(all_results_dict):
    """Visualize training history for all models"""
    fig = plt.figure(figsize=(18, 12))
    
    # Create subplots
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    metrics_to_plot = ['loss', 'iou', 'dice', 'precision', 'recall', 'accuracy']
    titles = ['Loss', 'IoU', 'Dice Coefficient', 'Precision', 'Recall', 'Accuracy']
    
    for idx, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
        ax = fig.add_subplot(gs[idx // 3, idx % 3])
        
        for model_name, data in all_results_dict.items():
            if 'history' in data and 'train' in data['history']:
                train_vals = data['history']['train'].get(metric, [])
                val_vals = data['history']['val'].get(metric, [])
                
                if train_vals and val_vals:
                    epochs = range(1, len(train_vals) + 1)
                    ax.plot(epochs, train_vals, '--', linewidth=1.5, label=f'{model_name} Train')
                    ax.plot(epochs, val_vals, '-', linewidth=2, label=f'{model_name} Val')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel(title)
        ax.set_title(f'Training vs Validation {title}')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Learning rate plot
    ax_lr = fig.add_subplot(gs[2, :])
    for model_name, data in all_results_dict.items():
        if 'history' in data and 'lr' in data['history']:
            lr_vals = data['history']['lr']
            if lr_vals:
                epochs = range(1, len(lr_vals) + 1)
                ax_lr.plot(epochs, lr_vals, 'o-', linewidth=2, label=model_name)
    
    ax_lr.set_xlabel('Epoch')
    ax_lr.set_ylabel('Learning Rate')
    ax_lr.set_title('Learning Rate Schedule')
    ax_lr.set_yscale('log')
    ax_lr.legend()
    ax_lr.grid(True, alpha=0.3)
    
    plt.suptitle('Model Training History Comparison', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    save_path = os.path.join(vis_dir, 'training_history_comparison.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Training history saved to: {save_path}")

def visualize_test_metrics_comparison(test_metrics_dict):
    """Visualize test metrics comparison"""
    if not test_metrics_dict:
        return
    
    metrics_to_plot = ['iou', 'dice', 'f1', 'precision', 'recall', 'accuracy']
    metric_names = ['IoU', 'Dice', 'F1-Score', 'Precision', 'Recall', 'Accuracy']
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    model_names = list(test_metrics_dict.keys())
    
    for idx, (metric, metric_name) in enumerate(zip(metrics_to_plot, metric_names)):
        ax = axes[idx]
        
        means = []
        stds = []
        for model_name in model_names:
            if 'mean' in test_metrics_dict[model_name]:
                means.append(test_metrics_dict[model_name]['mean'][metric])
                stds.append(test_metrics_dict[model_name]['std'][metric])
        
        if means:
            x_pos = np.arange(len(model_names))
            bars = ax.bar(x_pos, means, yerr=stds, capsize=5, alpha=0.7)
            
            # Add value labels on top of bars
            for bar, mean_val in zip(bars, means):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{mean_val:.3f}', ha='center', va='bottom', fontsize=9)
            
            ax.set_xlabel('Models')
            ax.set_ylabel(metric_name)
            ax.set_title(f'Test {metric_name} Comparison')
            ax.set_xticks(x_pos)
            ax.set_xticklabels([name[:15] for name in model_names], rotation=45, ha='right')
            ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Model Performance Comparison on Test Set', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    save_path = os.path.join(vis_dir, 'test_metrics_comparison.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Test metrics comparison saved to: {save_path}")

def visualize_sample_predictions(test_metrics_dict, num_samples=3):
    """Visualize sample predictions from all models"""
    if not test_metrics_dict:
        return
    
    # Get sample data from first model
    first_model = list(test_metrics_dict.keys())[0]
    if 'samples' not in test_metrics_dict[first_model]:
        return
    
    samples = test_metrics_dict[first_model]['samples']
    num_models = len(test_metrics_dict)
    
    fig, axes = plt.subplots(num_samples, num_models + 2, figsize=(5*(num_models+2), 4*num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for sample_idx in range(min(num_samples, len(samples['images']))):
        img = samples['images'][sample_idx]
        true_mask = samples['masks'][sample_idx]
        
        # Denormalize image
        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        
        true_mask_np = true_mask.numpy().squeeze()
        
        # Original image
        ax = axes[sample_idx, 0] if num_samples > 1 else axes[0]
        ax.imshow(img_np)
        ax.set_title('Original Image', fontsize=10, fontweight='bold')
        ax.axis('off')
        
        # Ground truth
        ax = axes[sample_idx, 1] if num_samples > 1 else axes[1]
        ax.imshow(true_mask_np, cmap='gray')
        ax.set_title('Ground Truth', fontsize=10, fontweight='bold')
        ax.axis('off')
        
        # Each model's prediction
        for model_idx, (model_name, model_data) in enumerate(test_metrics_dict.items()):
            if 'samples' in model_data and len(model_data['samples']['predictions']) > sample_idx:
                pred_data = model_data['samples']['predictions'][sample_idx]
                pred_binary = pred_data['binary'].numpy().squeeze()
                
                # Calculate Dice for this sample
                intersection = (pred_binary * true_mask_np).sum()
                union = pred_binary.sum() + true_mask_np.sum()
                dice = (2 * intersection) / (union + 1e-7) if union > 0 else 0
                
                col_idx = model_idx + 2
                ax = axes[sample_idx, col_idx] if num_samples > 1 else axes[col_idx]
                ax.imshow(pred_binary, cmap='gray')
                ax.set_title(f'{model_name}\nDice: {dice:.3f}', fontsize=9)
                ax.axis('off')
    
    plt.suptitle('Model Predictions Comparison on Sample Images', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    save_path = os.path.join(vis_dir, 'sample_predictions_comparison.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Sample predictions saved to: {save_path}")

def create_confusion_matrix_heatmap(test_metrics_dict):
    """Create confusion matrix visualization"""
    if not test_metrics_dict:
        return
    
    fig, axes = plt.subplots(1, len(test_metrics_dict), figsize=(5*len(test_metrics_dict), 4))
    
    if len(test_metrics_dict) == 1:
        axes = [axes]
    
    for idx, (model_name, model_data) in enumerate(test_metrics_dict.items()):
        if 'mean' in model_data:
            ax = axes[idx]
            
            # Get confusion matrix components
            tp = model_data['mean'].get('tp', 0)
            fp = model_data['mean'].get('fp', 0)
            fn = model_data['mean'].get('fn', 0)
            tn = model_data['mean'].get('tn', 0)
            
            cm = np.array([[tn, fp], [fn, tp]])
            
            # Normalize by row
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
            # Create heatmap
            im = ax.imshow(cm_normalized, interpolation='nearest', cmap='Blues', vmin=0, vmax=1)
            
            # Add text annotations
            for i in range(2):
                for j in range(2):
                    text = ax.text(j, i, f"{cm[i, j]}\n({cm_normalized[i, j]:.2%})",
                                 ha="center", va="center", color="black" if cm_normalized[i, j] < 0.7 else "white")
            
            ax.set_title(f'{model_name}\nConfusion Matrix', fontsize=11)
            ax.set_xlabel('Predicted')
            ax.set_ylabel('Actual')
            ax.set_xticks([0, 1])
            ax.set_yticks([0, 1])
            ax.set_xticklabels(['Negative', 'Positive'])
            ax.set_yticklabels(['Negative', 'Positive'])
    
    plt.suptitle('Confusion Matrices Comparison', fontsize=14, fontweight='bold', y=1.05)
    plt.tight_layout()
    
    save_path = os.path.join(vis_dir, 'confusion_matrices.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Confusion matrices saved to: {save_path}")

# Run all visualizations
print("\nüìà Generating visualizations...")

# 1. Training history
if all_results_v1:
    visualize_training_history(all_results_v1)

# 2. Test metrics comparison
if all_test_metrics:
    visualize_test_metrics_comparison(all_test_metrics)

# 3. Sample predictions
if all_test_metrics:
    visualize_sample_predictions(all_test_metrics, num_samples=3)

# 4. Confusion matrices
if all_test_metrics:
    create_confusion_matrix_heatmap(all_test_metrics)

print("\n" + "="*60)
print("‚úÖ ALL VISUALIZATIONS COMPLETED!")
print("="*60)
print(f"üìÅ Visualizations saved to: {vis_dir}")
print("\nüìã Generated files:")
for file in os.listdir(vis_dir):
    size = os.path.getsize(os.path.join(vis_dir, file)) / 1024
    print(f"  ‚Ä¢ {file} ({size:.1f} KB)")