In [3]:
"""
TIL Detection and Segmentation Pipeline for Breast Cancer Histopathology
Author: [Your Name]
Date: [Current Date]
Task: TIGER Challenge Task 1 - Lymphocyte Detection and Segmentation
"""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b4
from PIL import Image
import cv2
import pandas as pd
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ==================== CONFIGURATION ====================
class Config:
    # Data parameters
    data_dir = "./wsi_roi_images"
    img_size = 512
    batch_size = 8
    num_workers = 2

    # Model parameters
    backbone = 'efficientnet-b4'
    num_classes = 3  # Background, Lymphocyte, Other cells

    # Training parameters
    num_epochs = 30
    learning_rate = 1e-4
    weight_decay = 1e-5

    # Paths
    checkpoint_dir = "./checkpoints"
    log_dir = "./logs"

    # Loss weights
    seg_weight = 0.7
    det_weight = 0.3

config = Config()

# ==================== CUSTOM UNET IMPLEMENTATION ====================
class ConvBlock(nn.Module):
    """Convolutional block with batch normalization and ReLU"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UpConv(nn.Module):
    """Upsampling block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels,
                                    kernel_size=2, stride=2)
        self.conv = ConvBlock(out_channels * 2, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Pad x1 to match x2 dimensions if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                   diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class CustomUNet(nn.Module):
    """Custom U-Net implementation"""
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()

        # Encoder (downsampling path)
        self.enc1 = ConvBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc2 = ConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc3 = ConvBlock(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc4 = ConvBlock(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = ConvBlock(512, 1024)

        # Decoder (upsampling path)
        self.up4 = UpConv(1024, 512)
        self.up3 = UpConv(512, 256)
        self.up2 = UpConv(256, 128)
        self.up1 = UpConv(128, 64)

        # Output layer
        self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        # Bottleneck
        b = self.bottleneck(self.pool4(e4))

        # Decoder with skip connections
        d4 = self.up4(b, e4)
        d3 = self.up3(d4, e3)
        d2 = self.up2(d3, e2)
        d1 = self.up1(d2, e1)

        # Output
        return self.out_conv(d1)

# ==================== DATA PREPROCESSING ====================
class StainNormalizer:
    """Reinhard stain normalization for H&E images"""
    def __init__(self, target_image=None):
        self.target_means = None
        self.target_stds = None
        if target_image is not None:
            self.fit(target_image)

    def fit(self, target_image):
        """Fit normalization parameters to target image"""
        target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2LAB)
        self.target_means = np.mean(target_image, axis=(0, 1))
        self.target_stds = np.std(target_image, axis=(0, 1))

    def transform(self, image):
        """Apply stain normalization to image"""
        if self.target_means is None:
            return image

        image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        image_means = np.mean(image_lab, axis=(0, 1))
        image_stds = np.std(image_lab, axis=(0, 1))

        # Normalize each channel
        for i in range(3):
            image_lab[:,:,i] = ((image_lab[:,:,i] - image_means[i]) *
                              (self.target_stds[i] / image_stds[i]) +
                              self.target_means[i])

        # Clip values to valid range
        image_lab = np.clip(image_lab, 0, 255).astype(np.uint8)
        return cv2.cvtColor(image_lab, cv2.COLOR_LAB2RGB)

# ==================== DATASET CLASS ====================
class TILDataset(Dataset):
    """Dataset for TIL detection and segmentation"""
    def __init__(self, image_paths, mask_paths=None, bbox_paths=None,
                 is_train=True, stain_norm=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.bbox_paths = bbox_paths
        self.is_train = is_train
        self.stain_norm = stain_norm

        # Define augmentations
        if is_train:
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ColorJitter(brightness=0.1, contrast=0.1,
                            saturation=0.1, hue=0.05, p=0.5),
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
                A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                A.CoarseDropout(max_holes=8, max_height=32,
                              max_width=32, fill_value=0, p=0.3),
                A.Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])
        else:
            self.transform = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image_path = self.image_paths[idx]
        try:
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Could not read image: {image_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a dummy image if loading fails
            image = np.ones((512, 512, 3), dtype=np.uint8) * 255

        # Apply stain normalization if provided
        if self.stain_norm is not None:
            image = self.stain_norm.transform(image)

        # Initialize mask and bboxes
        mask = None
        bboxes = []

        # Load mask if available
        if self.mask_paths is not None and self.mask_paths[idx] is not None:
            mask_path = self.mask_paths[idx]
            try:
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                if mask is not None:
                    # Resize mask if needed
                    if mask.shape != (512, 512):
                        mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
                    mask = (mask > 0).astype(np.uint8)
            except Exception as e:
                print(f"Error loading mask {mask_path}: {e}")
                mask = np.zeros((512, 512), dtype=np.uint8)

        # Load bounding boxes if available
        if self.bbox_paths is not None and self.bbox_paths[idx] is not None:
            bbox_path = self.bbox_paths[idx]
            bboxes = self._load_bboxes(bbox_path)

        # Apply augmentations
        if self.is_train:
            if mask is not None:
                transformed = self.transform(image=image, mask=mask)
                image = transformed['image']
                mask = transformed['mask']
            else:
                transformed = self.transform(image=image)
                image = transformed['image']
        else:
            transformed = self.transform(image=image)
            image = transformed['image']
            if mask is not None:
                mask = torch.from_numpy(mask).float()

        # Prepare output dictionary
        sample = {
            'image': image,
            'image_path': image_path
        }

        if mask is not None:
            sample['mask'] = mask.long() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask).long()

        if len(bboxes) > 0:
            sample['bboxes'] = torch.tensor(bboxes, dtype=torch.float32)

        return sample

    def _load_bboxes(self, bbox_path):
        """Load bounding boxes from annotation file"""
        bboxes = []
        try:
            if os.path.exists(bbox_path):
                with open(bbox_path, 'r') as f:
                    for line in f:
                        if line.strip():
                            # Assuming format: x_min,y_min,x_max,y_max,class
                            coords = list(map(float, line.strip().split(',')))
                            if len(coords) >= 4:
                                # Normalize coordinates to [0, 1]
                                normalized_coords = [
                                    coords[0] / 512.0,
                                    coords[1] / 512.0,
                                    coords[2] / 512.0,
                                    coords[3] / 512.0
                                ]
                                bboxes.append(normalized_coords[:4])
        except Exception as e:
            print(f"Error loading bboxes {bbox_path}: {e}")
        return bboxes

# ==================== MODEL ARCHITECTURE ====================
class TILDetectionSegmentationModel(nn.Module):
    """Multi-task model for TIL detection and segmentation"""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Segmentation model (Custom U-Net)
        self.segmentation_model = CustomUNet(
            in_channels=3,
            num_classes=config.num_classes
        )

        # Detection head (using EfficientNet features)
        encoder = efficientnet_b4(pretrained=True)
        self.backbone = nn.Sequential(*list(encoder.children())[:-2])

        # Detection layers
        self.detection_head = nn.Sequential(
            nn.Conv2d(1792, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 4)  # 4 coordinates for bounding box
        )

        # Confidence score
        self.confidence_head = nn.Sequential(
            nn.Linear(1792, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Segmentation
        seg_output = self.segmentation_model(x)

        # Detection features
        features = self.backbone(x)

        # Detection
        bbox_pred = self.detection_head(features)

        # Confidence score
        pooled_features = nn.functional.adaptive_avg_pool2d(features, (1, 1))
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        confidence = self.confidence_head(pooled_features)

        return {
            'segmentation': seg_output,
            'bboxes': bbox_pred,
            'confidence': confidence
        }

# ==================== LOSS FUNCTIONS ====================
class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)

        # Convert target to one-hot encoding
        target_one_hot = torch.zeros_like(pred)
        target_one_hot.scatter_(1, target.unsqueeze(1), 1)

        # Calculate Dice for each class (skip background)
        dice = 0
        for class_idx in range(1, pred.shape[1]):
            pred_class = pred[:, class_idx]
            target_class = target_one_hot[:, class_idx]

            intersection = (pred_class * target_class).sum()
            union = pred_class.sum() + target_class.sum()

            dice += (2. * intersection + self.smooth) / (union + self.smooth)

        return 1 - (dice / (pred.shape[1] - 1))

class MultiTaskLoss(nn.Module):
    """Combined loss for segmentation and detection"""
    def __init__(self, seg_weight=0.7, det_weight=0.3):
        super().__init__()
        self.seg_weight = seg_weight
        self.det_weight = det_weight

        # Segmentation losses
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()

        # Detection losses
        self.bbox_loss = nn.SmoothL1Loss()
        self.confidence_loss = nn.BCELoss()

    def forward(self, predictions, targets):
        losses = {}

        # Segmentation loss
        if 'segmentation' in predictions and 'mask' in targets:
            seg_pred = predictions['segmentation']
            seg_target = targets['mask']

            dice_loss = self.dice_loss(seg_pred, seg_target)
            ce_loss = self.ce_loss(seg_pred, seg_target)

            # Weighted combination
            seg_total_loss = 0.5 * dice_loss + 0.5 * ce_loss
            losses['segmentation_loss'] = seg_total_loss
            losses['dice_loss'] = dice_loss
            losses['ce_loss'] = ce_loss
        else:
            seg_total_loss = torch.tensor(0.0).to(predictions['segmentation'].device)

        # Detection loss
        if 'bboxes' in predictions and 'bboxes' in targets:
            bbox_pred = predictions['bboxes']
            bbox_target = targets['bboxes']
            confidence_pred = predictions['confidence']

            # Bounding box regression loss
            bbox_reg_loss = self.bbox_loss(bbox_pred, bbox_target)

            # Confidence loss (use target confidence if available, else assume 1)
            confidence_target = targets.get('confidence', torch.ones_like(confidence_pred))
            confidence_loss = self.confidence_loss(confidence_pred, confidence_target)

            det_total_loss = 0.7 * bbox_reg_loss + 0.3 * confidence_loss
            losses['detection_loss'] = det_total_loss
            losses['bbox_loss'] = bbox_reg_loss
            losses['confidence_loss'] = confidence_loss
        else:
            det_total_loss = torch.tensor(0.0).to(predictions['segmentation'].device)

        # Total weighted loss
        total_loss = self.seg_weight * seg_total_loss + self.det_weight * det_total_loss
        losses['total_loss'] = total_loss

        return losses

# ==================== METRICS CALCULATION ====================
def calculate_metrics(predictions, targets):
    """Calculate evaluation metrics"""
    metrics = {}

    with torch.no_grad():
        # Segmentation metrics
        if 'segmentation' in predictions and 'mask' in targets:
            seg_pred = predictions['segmentation'].argmax(dim=1)
            seg_target = targets['mask']

            # Calculate per-class IoU
            for class_idx in range(1, 3):  # Skip background
                pred_mask = (seg_pred == class_idx).float()
                target_mask = (seg_target == class_idx).float()

                intersection = (pred_mask * target_mask).sum()
                union = pred_mask.sum() + target_mask.sum() - intersection

                if union > 0:
                    iou = intersection / union
                    metrics[f'iou_class_{class_idx}'] = iou.item()

            # Calculate Dice coefficient
            dice = (2 * intersection) / (pred_mask.sum() + target_mask.sum() + 1e-6)
            metrics['dice'] = dice.item()

        # Detection metrics
        if 'bboxes' in predictions and 'bboxes' in targets:
            bbox_pred = predictions['bboxes']
            bbox_target = targets['bboxes']

            if bbox_pred.numel() > 0 and bbox_target.numel() > 0:
                # Calculate L1 distance
                l1_distance = torch.abs(bbox_pred - bbox_target).mean()
                metrics['bbox_l1'] = l1_distance.item()

    return metrics

# ==================== TRAINING PIPELINE ====================
class TILTrainer:
    """Main training pipeline"""
    def __init__(self, config, model, train_loader, val_loader=None):
        self.config = config
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Loss function
        self.criterion = MultiTaskLoss(
            seg_weight=config.seg_weight,
            det_weight=config.det_weight
        )

        # Optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_metrics = []
        self.val_metrics = []
        self.best_val_loss = float('inf')

        # Create checkpoint directory
        os.makedirs(config.checkpoint_dir, exist_ok=True)

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_losses = {
            'total_loss': 0,
            'segmentation_loss': 0,
            'detection_loss': 0,
            'dice_loss': 0,
            'ce_loss': 0
        }

        pbar = tqdm(self.train_loader, desc=f'Train Epoch {epoch}')
        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            images = batch['image'].to(device)

            # Prepare targets
            targets = {}
            if 'mask' in batch:
                targets['mask'] = batch['mask'].to(device)
            if 'bboxes' in batch:
                targets['bboxes'] = batch['bboxes'].to(device)

            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(images)

            # Calculate loss
            loss_dict = self.criterion(outputs, targets)
            loss = loss_dict['total_loss']

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Optimizer step
            self.optimizer.step()

            # Update metrics
            for key in epoch_losses:
                if key in loss_dict:
                    epoch_losses[key] += loss_dict[key].item()

            # Update progress bar
            pbar.set_postfix({
                'loss': loss.item(),
                'seg': loss_dict.get('segmentation_loss', 0).item(),
                'det': loss_dict.get('detection_loss', 0).item()
            })

        # Calculate average losses
        avg_losses = {key: value / len(self.train_loader) for key, value in epoch_losses.items()}
        self.train_losses.append(avg_losses['total_loss'])

        return avg_losses

    def validate(self):
        """Validation step"""
        if self.val_loader is None:
            return None, None

        self.model.eval()
        val_losses = {
            'total_loss': 0,
            'segmentation_loss': 0,
            'detection_loss': 0
        }

        all_metrics = []

        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc='Validation')
            for batch in pbar:
                images = batch['image'].to(device)

                targets = {}
                if 'mask' in batch:
                    targets['mask'] = batch['mask'].to(device)
                if 'bboxes' in batch:
                    targets['bboxes'] = batch['bboxes'].to(device)

                outputs = self.model(images)
                loss_dict = self.criterion(outputs, targets)

                # Calculate metrics
                metrics = calculate_metrics(outputs, targets)
                all_metrics.append(metrics)

                # Update losses
                for key in val_losses:
                    if key in loss_dict:
                        val_losses[key] += loss_dict[key].item()

                pbar.set_postfix({
                    'loss': loss_dict['total_loss'].item()
                })

        # Calculate average losses and metrics
        avg_losses = {key: value / len(self.val_loader) for key, value in val_losses.items()}
        self.val_losses.append(avg_losses['total_loss'])

        # Aggregate metrics
        avg_metrics = {}
        if all_metrics:
            for key in all_metrics[0].keys():
                values = [m[key] for m in all_metrics if key in m]
                if values:
                    avg_metrics[key] = np.mean(values)

        return avg_losses, avg_metrics

    def train(self):
        """Main training loop"""
        print("Starting training...")
        print(f"Training on {len(self.train_loader.dataset)} samples")
        if self.val_loader:
            print(f"Validating on {len(self.val_loader.dataset)} samples")

        for epoch in range(self.config.num_epochs):
            # Train for one epoch
            train_losses = self.train_epoch(epoch)

            # Validate
            if self.val_loader:
                val_losses, val_metrics = self.validate()

                # Update learning rate
                self.scheduler.step(val_losses['total_loss'])

                # Save best model
                if val_losses['total_loss'] < self.best_val_loss:
                    self.best_val_loss = val_losses['total_loss']
                    self.save_checkpoint(f'best_model.pth', epoch, val_losses['total_loss'])

                # Print epoch summary
                print(f"\nEpoch {epoch} Summary:")
                print(f"Train Loss: {train_losses['total_loss']:.4f} | "
                      f"Val Loss: {val_losses['total_loss']:.4f}")
                print(f"Train Seg: {train_losses.get('segmentation_loss', 0):.4f} | "
                      f"Train Det: {train_losses.get('detection_loss', 0):.4f}")

                if val_metrics:
                    print("Validation Metrics:")
                    for key, value in val_metrics.items():
                        print(f"  {key}: {value:.4f}")
            else:
                print(f"\nEpoch {epoch}: Train Loss: {train_losses['total_loss']:.4f}")

            # Save checkpoint every 5 epochs
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pth', epoch)

        # Save final model
        self.save_checkpoint('final_model.pth', self.config.num_epochs - 1)
        print("\nTraining completed!")

        # Plot training curves
        self.plot_training_curves()

    def save_checkpoint(self, filename, epoch=None, val_loss=None):
        """Save model checkpoint"""
        checkpoint_path = os.path.join(self.config.checkpoint_dir, filename)
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'config': self.config.__dict__
        }

        if val_loss is not None:
            checkpoint['val_loss'] = val_loss

        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    def load_checkpoint(self, filename):
        """Load model checkpoint"""
        checkpoint_path = os.path.join(self.config.checkpoint_dir, filename)
        checkpoint = torch.load(checkpoint_path, map_location=device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']

        print(f"Checkpoint loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch']

    def plot_training_curves(self):
        """Plot training and validation losses"""
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()
        plt.grid(True)

        if self.val_losses:
            plt.subplot(1, 2, 2)
            plt.plot(self.val_losses, label='Validation Loss', color='orange')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Validation Loss')
            plt.legend()
            plt.grid(True)

        plt.tight_layout()
        plt.savefig(os.path.join(self.config.checkpoint_dir, 'training_curves.png'))
        plt.show()

# ==================== VISUALIZATION ====================
def visualize_predictions(model, dataloader, num_samples=3):
    """Visualize model predictions"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= num_samples:
                break

            images = batch['image'].to(device)
            outputs = model(images)

            # Get predictions
            seg_pred = outputs['segmentation'][0].argmax(dim=0).cpu().numpy()

            # Original image
            img = images[0].cpu().permute(1, 2, 0).numpy()
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)

            axes[idx, 0].imshow(img)
            axes[idx, 0].set_title('Original Image')
            axes[idx, 0].axis('off')

            # Segmentation prediction
            axes[idx, 1].imshow(seg_pred, cmap='jet', alpha=0.7)
            axes[idx, 1].set_title('Segmentation Prediction')
            axes[idx, 1].axis('off')

            # Overlay
            axes[idx, 2].imshow(img)
            axes[idx, 2].imshow(seg_pred, cmap='jet', alpha=0.5)
            axes[idx, 2].set_title('Overlay')
            axes[idx, 2].axis('off')

            # Ground truth (if available)
            if 'mask' in batch:
                mask = batch['mask'][0].cpu().numpy()
                axes[idx, 3].imshow(mask, cmap='jet', alpha=0.7)
                axes[idx, 3].set_title('Ground Truth')
            else:
                axes[idx, 3].axis('off')

    plt.tight_layout()
    plt.show()

# ==================== DATA LOADING UTILITIES ====================
def load_dataset_paths(data_dir):
    """Load image, mask, and bbox paths from directory"""
    image_files = []
    mask_files = []
    bbox_files = []

    if not os.path.exists(data_dir):
        print(f"Warning: Data directory {data_dir} does not exist.")
        print("Creating dummy dataset for demonstration...")
        return create_dummy_dataset()

    # Walk through directory structure
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                # Skip mask files
                if '_mask' in file.lower() or 'mask_' in file.lower():
                    continue

                image_path = os.path.join(root, file)
                image_files.append(image_path)

                # Look for corresponding mask file
                base_name = os.path.splitext(file)[0]
                mask_path = os.path.join(root, f"{base_name}_mask.png")

                if os.path.exists(mask_path):
                    mask_files.append(mask_path)
                else:
                    # Try other common mask naming conventions
                    mask_path = os.path.join(root, f"mask_{base_name}.png")
                    if os.path.exists(mask_path):
                        mask_files.append(mask_path)
                    else:
                        mask_files.append(None)

                # Look for corresponding bbox file
                bbox_path = os.path.join(root, f"{base_name}_bbox.txt")
                if os.path.exists(bbox_path):
                    bbox_files.append(bbox_path)
                else:
                    bbox_path = os.path.join(root, f"{base_name}.txt")
                    if os.path.exists(bbox_path):
                        bbox_files.append(bbox_path)
                    else:
                        bbox_files.append(None)

    print(f"Found {len(image_files)} images")
    print(f"Found {len([m for m in mask_files if m is not None])} masks")
    print(f"Found {len([b for b in bbox_files if b is not None])} bbox files")

    return image_files, mask_files, bbox_files

def create_dummy_dataset():
    """Create a dummy dataset for demonstration when no real data is available"""
    print("Creating dummy dataset with synthetic data...")

    # Create dummy data directory structure
    dummy_dir = "./dummy_data"
    os.makedirs(dummy_dir, exist_ok=True)

    image_files = []
    mask_files = []
    bbox_files = []

    # Create 50 dummy samples
    for i in range(50):
        # Create random image
        img = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)

        # Create random mask (simulating lymphocytes)
        mask = np.zeros((512, 512), dtype=np.uint8)

        # Add some random "lymphocytes" (small circles)
        num_cells = np.random.randint(5, 20)
        for _ in range(num_cells):
            x = np.random.randint(50, 462)
            y = np.random.randint(50, 462)
            radius = np.random.randint(3, 8)
            cv2.circle(mask, (x, y), radius, 1, -1)

        # Add some "tumor" regions
        num_tumors = np.random.randint(1, 3)
        for _ in range(num_tumors):
            x = np.random.randint(100, 412)
            y = np.random.randint(100, 412)
            width = np.random.randint(30, 80)
            height = np.random.randint(30, 80)
            cv2.rectangle(mask, (x, y), (x+width, y+height), 2, -1)

        # Save files
        img_path = os.path.join(dummy_dir, f"sample_{i:03d}.png")
        mask_path = os.path.join(dummy_dir, f"sample_{i:03d}_mask.png")
        bbox_path = os.path.join(dummy_dir, f"sample_{i:03d}_bbox.txt")

        cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        cv2.imwrite(mask_path, mask * 85)  # Scale for visibility

        # Create dummy bbox file
        with open(bbox_path, 'w') as f:
            for _ in range(np.random.randint(1, 5)):
                x1 = np.random.randint(0, 500)
                y1 = np.random.randint(0, 500)
                x2 = x1 + np.random.randint(10, 50)
                y2 = y1 + np.random.randint(10, 50)
                f.write(f"{x1},{y1},{x2},{y2},0\n")

        image_files.append(img_path)
        mask_files.append(mask_path)
        bbox_files.append(bbox_path)

    print(f"Created dummy dataset with {len(image_files)} samples")
    return image_files, mask_files, bbox_files

# ==================== MAIN EXECUTION ====================
def main():
    """Main execution function"""
    print("=" * 60)
    print("TIL Detection and Segmentation Pipeline")
    print("TIGER Challenge Task")

Using device: cpu
